Image-Text-to-Text
Transformers
Safetensors
English
step3p7
text-generation
vision-language
unsloth - multimodal - moe
conversational
custom_code
Instructions to use unsloth/Step-3.7-Flash with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use unsloth/Step-3.7-Flash with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-text-to-text", model="unsloth/Step-3.7-Flash", trust_remote_code=True) messages = [ { "role": "user", "content": [ {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG"}, {"type": "text", "text": "What animal is on the candy?"} ] }, ] pipe(text=messages)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("unsloth/Step-3.7-Flash", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use unsloth/Step-3.7-Flash with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "unsloth/Step-3.7-Flash" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "unsloth/Step-3.7-Flash", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "Describe this image in one sentence." }, { "type": "image_url", "image_url": { "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" } } ] } ] }'Use Docker
docker model run hf.co/unsloth/Step-3.7-Flash
- SGLang
How to use unsloth/Step-3.7-Flash with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "unsloth/Step-3.7-Flash" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "unsloth/Step-3.7-Flash", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "Describe this image in one sentence." }, { "type": "image_url", "image_url": { "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" } } ] } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "unsloth/Step-3.7-Flash" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "unsloth/Step-3.7-Flash", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "Describe this image in one sentence." }, { "type": "image_url", "image_url": { "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" } } ] } ] }' - Docker Model Runner
How to use unsloth/Step-3.7-Flash with Docker Model Runner:
docker model run hf.co/unsloth/Step-3.7-Flash
Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- README.md +421 -0
- assets/benchmarks.png +3 -0
- chat_template.jinja +89 -0
- config.json +410 -0
- configuration_step3p7.py +207 -0
- model-00001.safetensors +3 -0
- model-00002.safetensors +3 -0
- model-00003.safetensors +3 -0
- model-00004.safetensors +3 -0
- model-00005.safetensors +3 -0
- model-00006.safetensors +3 -0
- model-00007.safetensors +3 -0
- model-00008.safetensors +3 -0
- model-00009.safetensors +3 -0
- model-00010.safetensors +3 -0
- model-00011.safetensors +3 -0
- model-00012.safetensors +3 -0
- model-00013.safetensors +3 -0
- model-00014.safetensors +3 -0
- model-00015.safetensors +3 -0
- model-00016.safetensors +3 -0
- model-00017.safetensors +3 -0
- model-00018.safetensors +3 -0
- model-00019.safetensors +3 -0
- model-00020.safetensors +3 -0
- model-00021.safetensors +3 -0
- model-00022.safetensors +3 -0
- model-00023.safetensors +3 -0
- model-00024.safetensors +3 -0
- model-vit-00001.safetensors +3 -0
- model-vit-00002.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_step3p7.py +1405 -0
- processing_step3.py +475 -0
- processor_config.json +6 -0
- special_tokens_map.json +23 -0
- tokenizer.json +0 -0
- tokenizer_config.json +22 -0
- vision_encoder.py +452 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model:
|
| 3 |
+
- stepfun-ai/Step-3.7-Flash
|
| 4 |
+
license: apache-2.0
|
| 5 |
+
library_name: transformers
|
| 6 |
+
pipeline_tag: image-text-to-text
|
| 7 |
+
language:
|
| 8 |
+
- en
|
| 9 |
+
tags:
|
| 10 |
+
- vision-language
|
| 11 |
+
- unsloth
|
| 12 |
+
- multimodal
|
| 13 |
+
- moe
|
| 14 |
+
---
|
| 15 |
+
<div>
|
| 16 |
+
<p style="margin-top: 0;margin-bottom: 0;">
|
| 17 |
+
<em><a href="https://docs.unsloth.ai/basics/unsloth-dynamic-v2.0-gguf">Unsloth Dynamic 2.0</a> achieves superior accuracy & outperforms other leading quants.</em>
|
| 18 |
+
</p>
|
| 19 |
+
<div style="display: flex; gap: 5px; align-items: center; ">
|
| 20 |
+
<a href="https://github.com/unslothai/unsloth/">
|
| 21 |
+
<img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="133">
|
| 22 |
+
</a>
|
| 23 |
+
<a href="https://discord.gg/unsloth">
|
| 24 |
+
<img src="https://github.com/unslothai/unsloth/raw/main/images/Discord%20button.png" width="173">
|
| 25 |
+
</a>
|
| 26 |
+
<a href="https://docs.unsloth.ai/">
|
| 27 |
+
<img src="https://raw.githubusercontent.com/unslothai/unsloth/refs/heads/main/images/documentation%20green%20button.png" width="143">
|
| 28 |
+
</a>
|
| 29 |
+
</div>
|
| 30 |
+
</div>
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
**[ModelPage]**: https://static.stepfun.com/blog/step-3.7-flash/
|
| 34 |
+
|
| 35 |
+
## 1. Introduction
|
| 36 |
+
|
| 37 |
+
Step 3.7 Flash is a 198B-parameter sparse Mixture-of-Experts (MoE) vision-language model that combines a 196B-parameter language backbone with a 1.8B-parameter vision encoder for native image understanding. Engineered for high-frequency production workloads, it activates approximately 11B parameters per token and delivers a throughput of up to 400 tokens per second. Step 3.7 Flash supports a 256k context window and offers three selectable reasoning levels (low, medium, and high) so developers can easily balance speed, cost, and cognitive depth.
|
| 38 |
+
|
| 39 |
+
We built Step 3.7 Flash for developers who need to scale agentic workflows that combine perception, search, and reasoning. It is designed to handle intensive tasks such as parsing massive financial reports in one pass, running multi-step search loops with cross-source verification, or operating concurrent coding agents in high-throughput pipelines.
|
| 40 |
+
|
| 41 |
+
## 2. Capabilities & Performance
|
| 42 |
+
|
| 43 |
+
### Multimodal Perception and Verification
|
| 44 |
+
|
| 45 |
+
The model delivers top-tier visual intelligence, securing first place on SimpleVQA (Search) with a 79.2 and achieving frontier parity on V* (Python) at 95.3. These metrics reflect strong visual grounding and retrieval-augmented reasoning beyond basic image description. The model accurately processes dense visual interfaces, such as UI wireframes, application GUIs, and data charts, to map them into structured code. When it encounters an incomplete visual asset, it can independently identify missing data and execute lookups to verify context before returning a factually verified conclusion.
|
| 46 |
+
|
| 47 |
+
### Workflow Integrity and Tool Orchestration
|
| 48 |
+
|
| 49 |
+
Execution reliability is critical for autonomous agents. Step 3.7 Flash leads the ClawEval-1.1 benchmark with a score of 67.1, which significantly outperforms the next closest competitor at 59.8. This performance demonstrates high resistance to adversarial traps and strict adherence to system policies during multi-turn orchestration. Backed by scores of 49.5 on Toolathlon and 48.1 on HLE w. Tool, this profile ensures high trajectory integrity. Step 3.7 Flash reliably interacts with external APIs and executes long-horizon workflows without drifting from instructions or violating system constraints.
|
| 50 |
+
|
| 51 |
+
### Code Engineering and Professional Baselines
|
| 52 |
+
|
| 53 |
+
Step 3.7 Flash is built for live engineering tasks and secured a definitive second-place finish on SWE-Bench PRO with a score of 56.3. It can independently trace multi-file repositories, isolate bugs from raw issue reports, and generate functional patches that pass automated unit tests. While evaluations like Terminal-Bench 2.1 (59.5) and GDPVal-AA (45.8) show clear areas for future optimization compared to the absolute peak of the cohort, they establish a dependable baseline for system interactions and structured professional deliverables.
|
| 54 |
+
|
| 55 |
+

|
| 56 |
+
|
| 57 |
+
## 3. Pricing
|
| 58 |
+
|
| 59 |
+
| Token Type | Price |
|
| 60 |
+
|---|---|
|
| 61 |
+
| Input (cache miss) | $0.20 / M tokens |
|
| 62 |
+
| Input (cache hit) | $0.04 / M tokens |
|
| 63 |
+
| Output | $1.15 / M tokens |
|
| 64 |
+
|
| 65 |
+
## 4. Availability, Deployment, and Ecosystem
|
| 66 |
+
- Availability: Step 3.7 Flash is available on the StepFun Open Platform — [platform.stepfun.ai](https://platform.stepfun.ai) (Global) and [platform.stepfun.com](https://platform.stepfun.com) (China), OpenRouter, and NVIDIA NIM. StepFun is also partnering with DeepInfra, Fireworks AI, and Modal to expand availability soon.
|
| 67 |
+
- Deployment: Step 3.7 Flash supports flexible deployment across cloud, data center, and local environments. For large-scale production and enterprise use cases, Step 3.7 Flash can be deployed on modern data center infrastructure. For local and workstation scenarios, it can also run on high-memory devices such as NVIDIA DGX Station, AMD Ryzen AI Max+ 395-based systems, and Mac Studio / Macbook Pro devices with at least 128GB unified memory.
|
| 68 |
+
- Ecosystem: Step 3.7 Flash is supported across popular open-source infrastructure for both inference and model development. For inference and serving, developers can use vLLM, SGLang, Hugging Face Transformers, and llama.cpp. For model development & customization workflows, StepFun model support has landed in the NVIDIA Nemo ecosystem, including AutoModel, Megatron Core and Megatron Bridge. Step 3.7 Flash is also available as an NVIDIA NIM inference microservice for on-prem, cloud, or hybrid deployment.
|
| 69 |
+
|
| 70 |
+
## 5. Examples
|
| 71 |
+
|
| 72 |
+
You can get started with Step 3.7 Flash in minutes using StepFun's API or via other inference providers.
|
| 73 |
+
|
| 74 |
+
> Pick the right `base_url` for your region. StepFun operates two regional platforms with separate API hosts. The `base_url` you pass to the OpenAI client must match the platform where your API key was issued, otherwise requests will be rejected as unauthorized.
|
| 75 |
+
>
|
| 76 |
+
> - **Global**: [platform.stepfun.ai](https://platform.stepfun.ai) — `base_url=https://api.stepfun.ai/v1`
|
| 77 |
+
> - **China**: [platform.stepfun.com](https://platform.stepfun.com) — `base_url=https://api.stepfun.com/v1`
|
| 78 |
+
>
|
| 79 |
+
> To avoid hard-coding the wrong region, the examples below read both the API key and base URL from environment variables. Export them once before running:
|
| 80 |
+
>
|
| 81 |
+
> ```bash
|
| 82 |
+
> export STEP_API_KEY="sk-..."
|
| 83 |
+
> export STEP_BASE_URL="https://api.stepfun.ai/v1" # use https://api.stepfun.com/v1 for the China platform
|
| 84 |
+
> ```
|
| 85 |
+
|
| 86 |
+
### 5.1 Chat Example
|
| 87 |
+
|
| 88 |
+
```python
|
| 89 |
+
import os
|
| 90 |
+
from openai import OpenAI
|
| 91 |
+
|
| 92 |
+
client = OpenAI(
|
| 93 |
+
api_key=os.environ["STEP_API_KEY"],
|
| 94 |
+
base_url=os.environ["STEP_BASE_URL"],
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
completion = client.chat.completions.create(
|
| 98 |
+
model="step-3.7-flash",
|
| 99 |
+
messages=[
|
| 100 |
+
{
|
| 101 |
+
"role": "system",
|
| 102 |
+
"content": "You are an AI assistant provided by StepFun. You are good at Chinese, English, and many other languages, and you can see, think, and act to help users get things done.",
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
"role": "user",
|
| 106 |
+
"content": "Introduce StepFun's artificial intelligence capabilities."
|
| 107 |
+
},
|
| 108 |
+
],
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
print(completion)
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
### 5.2 Text and Image Input Example
|
| 115 |
+
|
| 116 |
+
```python
|
| 117 |
+
import os
|
| 118 |
+
from openai import OpenAI
|
| 119 |
+
|
| 120 |
+
client = OpenAI(
|
| 121 |
+
api_key=os.environ["STEP_API_KEY"],
|
| 122 |
+
base_url=os.environ["STEP_BASE_URL"],
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
completion = client.chat.completions.create(
|
| 126 |
+
model="step-3.7-flash",
|
| 127 |
+
messages=[
|
| 128 |
+
{
|
| 129 |
+
"role": "user",
|
| 130 |
+
"content": [
|
| 131 |
+
{"type": "text", "text": "What is in this picture?"},
|
| 132 |
+
{
|
| 133 |
+
"type": "image_url",
|
| 134 |
+
"image_url": {"url": "https://example.com/photo.jpg"},
|
| 135 |
+
},
|
| 136 |
+
],
|
| 137 |
+
},
|
| 138 |
+
],
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
print(completion)
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
## 6. Local Deployment
|
| 145 |
+
|
| 146 |
+
Step 3.7 Flash is optimized for local inference and supports industry-standard backends including vLLM, SGLang, Hugging Face Transformers and llama.cpp.
|
| 147 |
+
|
| 148 |
+
### 6.1 vLLM
|
| 149 |
+
|
| 150 |
+
We recommend using StepFun's prebuilt vLLM Docker image with Step 3.7 support.
|
| 151 |
+
|
| 152 |
+
1. Install vLLM.
|
| 153 |
+
|
| 154 |
+
```bash
|
| 155 |
+
# via Docker
|
| 156 |
+
docker pull vllm/vllm-openai:stepfun37
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
2. Launch the server.
|
| 160 |
+
|
| 161 |
+
- For FP8 model
|
| 162 |
+
```bash
|
| 163 |
+
vllm serve <MODEL_PATH_OR_HF_ID> \
|
| 164 |
+
--served-model-name step3p7-flash \
|
| 165 |
+
--tensor-parallel-size 8 \
|
| 166 |
+
--enable-expert-parallel \
|
| 167 |
+
--disable-cascade-attn \
|
| 168 |
+
--reasoning-parser step3p5 \
|
| 169 |
+
--enable-auto-tool-choice \
|
| 170 |
+
--tool-call-parser step3p5 \
|
| 171 |
+
--speculative_config '{"method": "mtp", "num_speculative_tokens": 3}' \
|
| 172 |
+
--trust-remote-code
|
| 173 |
+
```
|
| 174 |
+
- For BF16 model
|
| 175 |
+
```bash
|
| 176 |
+
vllm serve <MODEL_PATH_OR_HF_ID> \
|
| 177 |
+
--served-model-name step3p7-flash-bf16 \
|
| 178 |
+
--tensor-parallel-size 8 \
|
| 179 |
+
--enable-expert-parallel \
|
| 180 |
+
--disable-cascade-attn \
|
| 181 |
+
--reasoning-parser step3p5 \
|
| 182 |
+
--enable-auto-tool-choice \
|
| 183 |
+
--tool-call-parser step3p5 \
|
| 184 |
+
--speculative_config '{"method": "mtp", "num_speculative_tokens": 3}' \
|
| 185 |
+
--trust-remote-code
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
- For NVFP4 model
|
| 189 |
+
Compared to standard precisions, running the FP4 quantized version requires modelopt activation and FP8 KV Cache alignment.
|
| 190 |
+
```bash
|
| 191 |
+
python3 -m vllm.entrypoints.openai.api_server \
|
| 192 |
+
--host 0.0.0.0 \
|
| 193 |
+
--port ${PORT} \
|
| 194 |
+
--model stepfun-ai/Step-3.7-Flash-NVFP4 \
|
| 195 |
+
--served-model-name step3p7 \
|
| 196 |
+
--tensor-parallel-size 4 \
|
| 197 |
+
--gpu-memory-utilization 0.9 \
|
| 198 |
+
--enable-expert-parallel \
|
| 199 |
+
--trust-remote-code \
|
| 200 |
+
--quantization modelopt \
|
| 201 |
+
--kv-cache-dtype fp8 \
|
| 202 |
+
--max-model-len 8192 \
|
| 203 |
+
--reasoning-parser step3p5 \
|
| 204 |
+
--enable-auto-tool-choice \
|
| 205 |
+
--tool-call-parser step3p5 \
|
| 206 |
+
--async-scheduling
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
### 6.2 SGLang
|
| 210 |
+
|
| 211 |
+
1. Install SGLang.
|
| 212 |
+
|
| 213 |
+
```bash
|
| 214 |
+
# via Docker
|
| 215 |
+
docker pull lmsysorg/sglang:dev-step-3.7-flash
|
| 216 |
+
|
| 217 |
+
# or from source (pip)
|
| 218 |
+
pip install "sglang[all] @ git+https://github.com/sgl-project/sglang.git"
|
| 219 |
+
```
|
| 220 |
+
|
| 221 |
+
2. Launch the server.
|
| 222 |
+
|
| 223 |
+
> **Note:** For Blackwell GPUs, `--mm-attention-backend fa4` may be used.
|
| 224 |
+
|
| 225 |
+
- For BF16 model
|
| 226 |
+
|
| 227 |
+
```bash
|
| 228 |
+
sglang serve --model-path stepfun-ai/Step-3.7-Flash \
|
| 229 |
+
--tp 8 \
|
| 230 |
+
--reasoning-parser step3p5 \
|
| 231 |
+
--tool-call-parser step3p5 \
|
| 232 |
+
--enable-multimodal \
|
| 233 |
+
--speculative-algorithm EAGLE \
|
| 234 |
+
--speculative-num-steps 3 \
|
| 235 |
+
--speculative-eagle-topk 1 \
|
| 236 |
+
--speculative-num-draft-tokens 4 \
|
| 237 |
+
--enable-multi-layer-eagle \
|
| 238 |
+
--trust-remote-code \
|
| 239 |
+
--host 0.0.0.0 \
|
| 240 |
+
--port 8000
|
| 241 |
+
```
|
| 242 |
+
|
| 243 |
+
- For FP8 model
|
| 244 |
+
|
| 245 |
+
```bash
|
| 246 |
+
sglang serve --model-path stepfun-ai/Step-3.7-Flash-FP8 \
|
| 247 |
+
--tp 8 \
|
| 248 |
+
--ep 4 \
|
| 249 |
+
--reasoning-parser step3p5 \
|
| 250 |
+
--tool-call-parser step3p5 \
|
| 251 |
+
--enable-multimodal \
|
| 252 |
+
--speculative-algorithm EAGLE \
|
| 253 |
+
--speculative-num-steps 3 \
|
| 254 |
+
--speculative-eagle-topk 1 \
|
| 255 |
+
--speculative-num-draft-tokens 4 \
|
| 256 |
+
--enable-multi-layer-eagle \
|
| 257 |
+
--trust-remote-code \
|
| 258 |
+
--host 0.0.0.0 \
|
| 259 |
+
--port 8000
|
| 260 |
+
```
|
| 261 |
+
|
| 262 |
+
- For NVFP4 model
|
| 263 |
+
|
| 264 |
+
```bash
|
| 265 |
+
sglang serve --model-path stepfun-ai/Step-3.7-Flash-NVFP4 \
|
| 266 |
+
--tp 4 --ep 4 \
|
| 267 |
+
--moe-runner-backend flashinfer_trtllm \
|
| 268 |
+
--kv-cache-dtype fp8_e4m3 \
|
| 269 |
+
--quantization modelopt_fp4 \
|
| 270 |
+
--trust-remote-code \
|
| 271 |
+
--reasoning-parser step3p5 \
|
| 272 |
+
--tool-call-parser step3p5 \
|
| 273 |
+
--attention-backend trtllm_mha
|
| 274 |
+
```
|
| 275 |
+
|
| 276 |
+
### 6.3 Transformers (Debug / Verification)
|
| 277 |
+
|
| 278 |
+
Use this snippet for quick functional verification. For high-throughput serving, use vLLM or SGLang.
|
| 279 |
+
|
| 280 |
+
> **Note:** Deployment of this model requires `transformers` 5.0 or later.
|
| 281 |
+
|
| 282 |
+
```python
|
| 283 |
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
| 284 |
+
|
| 285 |
+
MODEL_PATH = "<MODEL_PATH_OR_HF_ID>"
|
| 286 |
+
|
| 287 |
+
# 1. Setup
|
| 288 |
+
processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
| 289 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 290 |
+
MODEL_PATH,
|
| 291 |
+
device_map="auto",
|
| 292 |
+
dtype="auto",
|
| 293 |
+
trust_remote_code=True
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# 2. Prepare Input
|
| 297 |
+
messages = [
|
| 298 |
+
{
|
| 299 |
+
"role": "user",
|
| 300 |
+
"content": [
|
| 301 |
+
{"type": "image", "url": "https://example.com/photo.jpg"},
|
| 302 |
+
{"type": "text", "text": "What is in this picture?"}
|
| 303 |
+
]
|
| 304 |
+
},
|
| 305 |
+
]
|
| 306 |
+
inputs = processor.apply_chat_template(
|
| 307 |
+
messages,
|
| 308 |
+
tokenize=True,
|
| 309 |
+
add_generation_prompt=True,
|
| 310 |
+
return_dict=True,
|
| 311 |
+
return_tensors="pt",
|
| 312 |
+
).to(model.device)
|
| 313 |
+
|
| 314 |
+
# 3. Generate
|
| 315 |
+
generated_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False)
|
| 316 |
+
output_text = processor.decode(generated_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
| 317 |
+
|
| 318 |
+
print(output_text)
|
| 319 |
+
```
|
| 320 |
+
|
| 321 |
+
### 6.4 llama.cpp
|
| 322 |
+
|
| 323 |
+
**System Requirements**
|
| 324 |
+
|
| 325 |
+
GGUF Model Weights:
|
| 326 |
+
|
| 327 |
+
| Component | Quantization | File Size |
|
| 328 |
+
|---|---|---|
|
| 329 |
+
| Language Model | Q4_K_S | 111.5 GB |
|
| 330 |
+
| Language Model | IQ4_XS | 104.99 GB |
|
| 331 |
+
| Language Model | Q3_K_L | 102.5 GB |
|
| 332 |
+
| Multimodal Projector | FP16 | 3.97 GB |
|
| 333 |
+
|
| 334 |
+
- **Runtime Overhead:** ~7 GB
|
| 335 |
+
- **Minimum unified memory / VRAM:** 120 GB (e.g., Mac Studio, NVIDIA DGX Station, AMD Ryzen AI Max+ 395)
|
| 336 |
+
- **Recommended:** 128 GB unified memory
|
| 337 |
+
|
| 338 |
+
**Steps**
|
| 339 |
+
|
| 340 |
+
1. Use llama.cpp:
|
| 341 |
+
|
| 342 |
+
```bash
|
| 343 |
+
git clone https://github.com/stepfun-ai/llama.cpp.git
|
| 344 |
+
cd llama.cpp
|
| 345 |
+
git checkout -b step3.7 origin/step3.7
|
| 346 |
+
```
|
| 347 |
+
|
| 348 |
+
2. Build llama.cpp on Mac:
|
| 349 |
+
|
| 350 |
+
```bash
|
| 351 |
+
cmake -B build-macos -S . \
|
| 352 |
+
-DCMAKE_BUILD_TYPE=Release \
|
| 353 |
+
-DBUILD_SHARED_LIBS=ON \
|
| 354 |
+
-DLLAMA_BUILD_SERVER=ON \
|
| 355 |
+
-DLLAMA_BUILD_TESTS=ON \
|
| 356 |
+
-DGGML_METAL=ON \
|
| 357 |
+
-DGGML_METAL_EMBED_LIBRARY=ON \
|
| 358 |
+
-DGGML_BLAS=ON \
|
| 359 |
+
-DGGML_BLAS_VENDOR=Apple \
|
| 360 |
+
-DGGML_ACCELERATE=ON \
|
| 361 |
+
-DGGML_NATIVE=ON
|
| 362 |
+
cmake --build build-macos -j8
|
| 363 |
+
```
|
| 364 |
+
|
| 365 |
+
3. Build llama.cpp on DGX-Spark:
|
| 366 |
+
|
| 367 |
+
```bash
|
| 368 |
+
cmake -S . -B build-cuda \
|
| 369 |
+
-DCMAKE_BUILD_TYPE=Release \
|
| 370 |
+
-DGGML_CUDA=ON \
|
| 371 |
+
-DGGML_CUDA_GRAPHS=ON \
|
| 372 |
+
-DGGML_CUDA_FORCE_MMQ=ON \
|
| 373 |
+
-DLLAMA_OPENSSL=OFF \
|
| 374 |
+
-DLLAMA_BUILD_COMMON=ON \
|
| 375 |
+
-DLLAMA_BUILD_TOOLS=ON \
|
| 376 |
+
-DLLAMA_BUILD_SERVER=ON \
|
| 377 |
+
-DLLAMA_BUILD_EXAMPLES=OFF \
|
| 378 |
+
-DLLAMA_BUILD_TESTS=OFF
|
| 379 |
+
cmake --build build-cuda -j8
|
| 380 |
+
```
|
| 381 |
+
|
| 382 |
+
4. Build llama.cpp on AMD Windows:
|
| 383 |
+
|
| 384 |
+
```bash
|
| 385 |
+
cmake -S . -B build-vulkan \
|
| 386 |
+
-DCMAKE_BUILD_TYPE=Release \
|
| 387 |
+
-DGGML_VULKAN=ON \
|
| 388 |
+
-DGGML_NATIVE=ON \
|
| 389 |
+
-DLLAMA_BUILD_SERVER=ON \
|
| 390 |
+
-DLLAMA_BUILD_UI=OFF \
|
| 391 |
+
-DLLAMA_BUILD_TOOLS=ON
|
| 392 |
+
cmake --build build-vulkan -j8
|
| 393 |
+
```
|
| 394 |
+
|
| 395 |
+
5. Run with `llama-cli`:
|
| 396 |
+
|
| 397 |
+
```bash
|
| 398 |
+
./llama-cli -m Step3.7_Q4_K_S.gguf -b 2048 -ub 2048 -fa on --temp 1.0 -p "What's your name?"
|
| 399 |
+
```
|
| 400 |
+
|
| 401 |
+
6. Test performance with `llama-batched-bench`:
|
| 402 |
+
|
| 403 |
+
```bash
|
| 404 |
+
./llama-batched-bench -m step3.7_Q4_K_S.gguf -c 32768 -b 2048 -ub 2048 -npp 0,2048,8192,16384,32768 -ntg 128 -npl 1
|
| 405 |
+
```
|
| 406 |
+
|
| 407 |
+
## 7. Using Step 3.7 Flash on Agent Platforms
|
| 408 |
+
|
| 409 |
+
You can use Step 3.7 Flash on Agent platforms such as Hermes Agent, OpenClaw, Kilo Code, and more.
|
| 410 |
+
|
| 411 |
+
## 8. Getting in Touch
|
| 412 |
+
|
| 413 |
+
As we work to shape the future of AGI by expanding broad model capabilities, we want to ensure we are solving the right problems. We invite you to be part of this continuous feedback loop — your insights directly influence our priorities.
|
| 414 |
+
|
| 415 |
+
- **Join the Conversation:** Our [Discord](https://discord.gg/RcMJhNVAQc) community is the primary hub for brainstorming future architectures, proposing capabilities, and getting early access updates 🚀
|
| 416 |
+
- **Report Friction:** Encountering limitations? You can open an issue or start a discussion on GitHub / HuggingFace, or flag it directly in our Discord support channels.
|
| 417 |
+
|
| 418 |
+
## 📄 License
|
| 419 |
+
|
| 420 |
+
This project is open-sourced under the [Apache 2.0 License](https://www.apache.org/licenses/LICENSE-2.0).
|
| 421 |
+
|
assets/benchmarks.png
ADDED
|
Git LFS Details
|
chat_template.jinja
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{% macro render_message_content(message) %}{% if message.content is none %}{{- '' }}{% elif message.content is string %}{{- message.content }}{% elif message.content is mapping %}{{- message.content['value'] if 'value' in message.content else message.content['text'] }}{% elif message.content is iterable %}{% set ns = namespace(needs_text_separator=false) %}{% for item in message.content %}{% if item.type == 'text' %}{% if ns.needs_text_separator %}{{- ' ' }}{% endif %}{{- item['value'] if 'value' in item else item['text'] }}{% set ns.needs_text_separator = true %}{% elif item.type == 'image' %}<im_patch>{% set ns.needs_text_separator = false %}{% endif %}{% endfor %}{% endif %}{% endmacro %}
|
| 2 |
+
{{bos_token}}{%- if tools %}
|
| 3 |
+
{{- '<|im_start|>system\n' }}
|
| 4 |
+
{%- if reasoning_effort is defined %}
|
| 5 |
+
{{- "Reasoning: " + reasoning_effort + '\n\n' }}
|
| 6 |
+
{%- endif %}
|
| 7 |
+
{%- if messages[0].role == 'system' %}
|
| 8 |
+
{{- render_message_content(messages[0]) + '\n\n' }}
|
| 9 |
+
{%- endif %}
|
| 10 |
+
{{- "# Tools\n\nYou have access to the following functions in JSONSchema format:\n\n<tools>" }}
|
| 11 |
+
{%- for tool in tools %}
|
| 12 |
+
{{- "\n" }}
|
| 13 |
+
{{- tool | tojson(ensure_ascii=False) }}
|
| 14 |
+
{%- endfor %}
|
| 15 |
+
{{- "\n</tools>\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...>\n...\n</function> block must be nested within <tool_call>\n...\n</tool_call> XML tags\n- Required parameters MUST be specified\n</IMPORTANT><|im_end|>\n" }}
|
| 16 |
+
{%- else %}
|
| 17 |
+
{%- if messages[0].role == 'system' %}
|
| 18 |
+
{{- '<|im_start|>system\n' }}
|
| 19 |
+
{%- if reasoning_effort is defined %}
|
| 20 |
+
{{- "Reasoning: " + reasoning_effort + '\n\n' }}
|
| 21 |
+
{%- endif %}
|
| 22 |
+
{{- render_message_content(messages[0]) + '<|im_end|>\n' }}
|
| 23 |
+
{%- elif reasoning_effort is defined %}
|
| 24 |
+
{{- '<|im_start|>system\n' + "Reasoning: " + reasoning_effort + '\n\n' + '<|im_end|>\n' }}
|
| 25 |
+
{%- endif %}
|
| 26 |
+
{%- endif %}
|
| 27 |
+
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
| 28 |
+
{%- for message in messages[::-1] %}
|
| 29 |
+
{%- set index = (messages|length - 1) - loop.index0 %}
|
| 30 |
+
{%- if ns.multi_step_tool and message.role == "user" and render_message_content(message) is string and not(render_message_content(message).startswith('<tool_response>') and render_message_content(message).endswith('</tool_response>')) %}
|
| 31 |
+
{%- set ns.multi_step_tool = false %}
|
| 32 |
+
{%- set ns.last_query_index = index %}
|
| 33 |
+
{%- endif %}
|
| 34 |
+
{%- endfor %}
|
| 35 |
+
{%- for message in messages %}
|
| 36 |
+
{%- set content = render_message_content(message) %}
|
| 37 |
+
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
| 38 |
+
{%- set role_name = 'observation' if (message.role == "system" and not loop.first and message.name == 'observation') else message.role %}
|
| 39 |
+
{{- '<|im_start|>' + role_name + '\n' + content + '<|im_end|>' + '\n' }}
|
| 40 |
+
{%- elif message.role == "assistant" %}
|
| 41 |
+
{%- if message.reasoning_content is string %}
|
| 42 |
+
{%- set reasoning_content = message.reasoning_content %}
|
| 43 |
+
{%- else %}
|
| 44 |
+
{%- if '</think>' in content %}
|
| 45 |
+
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
| 46 |
+
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
| 47 |
+
{%- else %}
|
| 48 |
+
{%- set reasoning_content = '' %}
|
| 49 |
+
{%- endif %}
|
| 50 |
+
{%- endif %}
|
| 51 |
+
{%- if loop.index0 > ns.last_query_index %}
|
| 52 |
+
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n' + content }}
|
| 53 |
+
{%- else %}
|
| 54 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 55 |
+
{%- endif %}
|
| 56 |
+
{%- if message.tool_calls %}
|
| 57 |
+
{%- for tool_call in message.tool_calls %}
|
| 58 |
+
{%- if tool_call.function is defined %}
|
| 59 |
+
{%- set tool_call = tool_call.function %}
|
| 60 |
+
{%- endif %}
|
| 61 |
+
{{- '<tool_call>\n<function=' + tool_call.name + '>\n' }}
|
| 62 |
+
{%- if tool_call.arguments is defined %}
|
| 63 |
+
{%- set arguments = tool_call.arguments | fromjson if tool_call.arguments is string else tool_call.arguments %}
|
| 64 |
+
{%- for args_name, args_value in arguments|items %}
|
| 65 |
+
{{- '<parameter=' + args_name + '>\n' }}
|
| 66 |
+
{%- set args_value = args_value | tojson(ensure_ascii=False) | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
|
| 67 |
+
{{- args_value }}
|
| 68 |
+
{{- '\n</parameter>\n' }}
|
| 69 |
+
{%- endfor %}
|
| 70 |
+
{%- endif %}
|
| 71 |
+
{{- '</function>\n</tool_call>' }}
|
| 72 |
+
{%- endfor %}
|
| 73 |
+
{%- endif %}
|
| 74 |
+
{{- '<|im_end|>\n' }}
|
| 75 |
+
{%- elif message.role == "tool" %}
|
| 76 |
+
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
| 77 |
+
{{- '<|im_start|>tool_response\n' }}
|
| 78 |
+
{%- endif %}
|
| 79 |
+
{{- '<tool_response>' }}
|
| 80 |
+
{{- content }}
|
| 81 |
+
{{- '</tool_response>' }}
|
| 82 |
+
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
| 83 |
+
{{- '<|im_end|>\n' }}
|
| 84 |
+
{%- endif %}
|
| 85 |
+
{%- endif %}
|
| 86 |
+
{%- endfor %}
|
| 87 |
+
{%- if add_generation_prompt %}
|
| 88 |
+
{{- '<|im_start|>assistant\n<think>\n' }}
|
| 89 |
+
{%- endif %}
|
config.json
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"Step3p7ForConditionalGeneration"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_step3p7.Step3p7Config",
|
| 7 |
+
"AutoModelForCausalLM": "modeling_step3p7.Step3p7ForConditionalGeneration",
|
| 8 |
+
"AutoProcessor": "processing_step3.Step3VLProcessor"
|
| 9 |
+
},
|
| 10 |
+
"hidden_size": 4096,
|
| 11 |
+
"im_end_token": "<im_end>",
|
| 12 |
+
"im_patch_token": "<im_patch>",
|
| 13 |
+
"im_start_token": "<im_start>",
|
| 14 |
+
"image_token_id": 128001,
|
| 15 |
+
"image_token_len": 169,
|
| 16 |
+
"max_position_embeddings": 262144,
|
| 17 |
+
"model_type": "step3p7",
|
| 18 |
+
"pad_token_id": 2,
|
| 19 |
+
"patch_token_len": 81,
|
| 20 |
+
"projector_bias": false,
|
| 21 |
+
"text_config": {
|
| 22 |
+
"architectures": [
|
| 23 |
+
"Step3p5ForCausalLM"
|
| 24 |
+
],
|
| 25 |
+
"att_impl_type": "GQA",
|
| 26 |
+
"attention_dropout": 0.0,
|
| 27 |
+
"attention_other_setting": {
|
| 28 |
+
"attention_type": "sliding_attention",
|
| 29 |
+
"head_dim": 128,
|
| 30 |
+
"num_attention_groups": 8,
|
| 31 |
+
"num_attention_heads": 96,
|
| 32 |
+
"true_head_dim": 128
|
| 33 |
+
},
|
| 34 |
+
"bos_token_id": 0,
|
| 35 |
+
"torch_dtype": "bfloat16",
|
| 36 |
+
"eos_token_id": [
|
| 37 |
+
1,
|
| 38 |
+
2,
|
| 39 |
+
128007
|
| 40 |
+
],
|
| 41 |
+
"head_dim": 128,
|
| 42 |
+
"hidden_size": 4096,
|
| 43 |
+
"intermediate_size": 11264,
|
| 44 |
+
"layer_types": [
|
| 45 |
+
"full_attention",
|
| 46 |
+
"sliding_attention",
|
| 47 |
+
"sliding_attention",
|
| 48 |
+
"sliding_attention",
|
| 49 |
+
"full_attention",
|
| 50 |
+
"sliding_attention",
|
| 51 |
+
"sliding_attention",
|
| 52 |
+
"sliding_attention",
|
| 53 |
+
"full_attention",
|
| 54 |
+
"sliding_attention",
|
| 55 |
+
"sliding_attention",
|
| 56 |
+
"sliding_attention",
|
| 57 |
+
"full_attention",
|
| 58 |
+
"sliding_attention",
|
| 59 |
+
"sliding_attention",
|
| 60 |
+
"sliding_attention",
|
| 61 |
+
"full_attention",
|
| 62 |
+
"sliding_attention",
|
| 63 |
+
"sliding_attention",
|
| 64 |
+
"sliding_attention",
|
| 65 |
+
"full_attention",
|
| 66 |
+
"sliding_attention",
|
| 67 |
+
"sliding_attention",
|
| 68 |
+
"sliding_attention",
|
| 69 |
+
"full_attention",
|
| 70 |
+
"sliding_attention",
|
| 71 |
+
"sliding_attention",
|
| 72 |
+
"sliding_attention",
|
| 73 |
+
"full_attention",
|
| 74 |
+
"sliding_attention",
|
| 75 |
+
"sliding_attention",
|
| 76 |
+
"sliding_attention",
|
| 77 |
+
"full_attention",
|
| 78 |
+
"sliding_attention",
|
| 79 |
+
"sliding_attention",
|
| 80 |
+
"sliding_attention",
|
| 81 |
+
"full_attention",
|
| 82 |
+
"sliding_attention",
|
| 83 |
+
"sliding_attention",
|
| 84 |
+
"sliding_attention",
|
| 85 |
+
"full_attention",
|
| 86 |
+
"sliding_attention",
|
| 87 |
+
"sliding_attention",
|
| 88 |
+
"sliding_attention",
|
| 89 |
+
"full_attention",
|
| 90 |
+
"sliding_attention",
|
| 91 |
+
"sliding_attention",
|
| 92 |
+
"sliding_attention"
|
| 93 |
+
],
|
| 94 |
+
"max_position_embeddings": 262144,
|
| 95 |
+
"max_seq_len": 262144,
|
| 96 |
+
"model_type": "step3p5",
|
| 97 |
+
"moe_every_n_layer": 1,
|
| 98 |
+
"moe_intermediate_size": 1280,
|
| 99 |
+
"moe_layer_offset": 0,
|
| 100 |
+
"moe_layers_enum": "3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44",
|
| 101 |
+
"moe_num_experts": 288,
|
| 102 |
+
"moe_router_activation": "sigmoid",
|
| 103 |
+
"moe_router_scaling_factor": 3.0,
|
| 104 |
+
"moe_top_k": 8,
|
| 105 |
+
"need_fp32_gate": true,
|
| 106 |
+
"norm_expert_weight": true,
|
| 107 |
+
"num_attention_groups": 8,
|
| 108 |
+
"num_attention_heads": 64,
|
| 109 |
+
"num_hidden_layers": 45,
|
| 110 |
+
"num_nextn_predict_layers": 3,
|
| 111 |
+
"pad_token_id": 1,
|
| 112 |
+
"partial_rotary_factors": [
|
| 113 |
+
0.5,
|
| 114 |
+
1.0,
|
| 115 |
+
1.0,
|
| 116 |
+
1.0,
|
| 117 |
+
0.5,
|
| 118 |
+
1.0,
|
| 119 |
+
1.0,
|
| 120 |
+
1.0,
|
| 121 |
+
0.5,
|
| 122 |
+
1.0,
|
| 123 |
+
1.0,
|
| 124 |
+
1.0,
|
| 125 |
+
0.5,
|
| 126 |
+
1.0,
|
| 127 |
+
1.0,
|
| 128 |
+
1.0,
|
| 129 |
+
0.5,
|
| 130 |
+
1.0,
|
| 131 |
+
1.0,
|
| 132 |
+
1.0,
|
| 133 |
+
0.5,
|
| 134 |
+
1.0,
|
| 135 |
+
1.0,
|
| 136 |
+
1.0,
|
| 137 |
+
0.5,
|
| 138 |
+
1.0,
|
| 139 |
+
1.0,
|
| 140 |
+
1.0,
|
| 141 |
+
0.5,
|
| 142 |
+
1.0,
|
| 143 |
+
1.0,
|
| 144 |
+
1.0,
|
| 145 |
+
0.5,
|
| 146 |
+
1.0,
|
| 147 |
+
1.0,
|
| 148 |
+
1.0,
|
| 149 |
+
0.5,
|
| 150 |
+
1.0,
|
| 151 |
+
1.0,
|
| 152 |
+
1.0,
|
| 153 |
+
0.5,
|
| 154 |
+
1.0,
|
| 155 |
+
1.0,
|
| 156 |
+
1.0,
|
| 157 |
+
0.5,
|
| 158 |
+
1.0,
|
| 159 |
+
1.0,
|
| 160 |
+
1.0
|
| 161 |
+
],
|
| 162 |
+
"rms_norm_eps": 1e-05,
|
| 163 |
+
"rope_parameters": {
|
| 164 |
+
"factor": 2.0,
|
| 165 |
+
"high_freq_factor": 32.0,
|
| 166 |
+
"low_freq_factor": 1.0,
|
| 167 |
+
"original_max_position_embeddings": 131072,
|
| 168 |
+
"rope_theta": [
|
| 169 |
+
5000000.0,
|
| 170 |
+
10000.0,
|
| 171 |
+
10000.0,
|
| 172 |
+
10000.0,
|
| 173 |
+
5000000.0,
|
| 174 |
+
10000.0,
|
| 175 |
+
10000.0,
|
| 176 |
+
10000.0,
|
| 177 |
+
5000000.0,
|
| 178 |
+
10000.0,
|
| 179 |
+
10000.0,
|
| 180 |
+
10000.0,
|
| 181 |
+
5000000.0,
|
| 182 |
+
10000.0,
|
| 183 |
+
10000.0,
|
| 184 |
+
10000.0,
|
| 185 |
+
5000000.0,
|
| 186 |
+
10000.0,
|
| 187 |
+
10000.0,
|
| 188 |
+
10000.0,
|
| 189 |
+
5000000.0,
|
| 190 |
+
10000.0,
|
| 191 |
+
10000.0,
|
| 192 |
+
10000.0,
|
| 193 |
+
5000000.0,
|
| 194 |
+
10000.0,
|
| 195 |
+
10000.0,
|
| 196 |
+
10000.0,
|
| 197 |
+
5000000.0,
|
| 198 |
+
10000.0,
|
| 199 |
+
10000.0,
|
| 200 |
+
10000.0,
|
| 201 |
+
5000000.0,
|
| 202 |
+
10000.0,
|
| 203 |
+
10000.0,
|
| 204 |
+
10000.0,
|
| 205 |
+
5000000.0,
|
| 206 |
+
10000.0,
|
| 207 |
+
10000.0,
|
| 208 |
+
10000.0,
|
| 209 |
+
5000000.0,
|
| 210 |
+
10000.0,
|
| 211 |
+
10000.0,
|
| 212 |
+
10000.0,
|
| 213 |
+
5000000.0,
|
| 214 |
+
10000.0,
|
| 215 |
+
10000.0,
|
| 216 |
+
10000.0
|
| 217 |
+
],
|
| 218 |
+
"rope_type": "llama3"
|
| 219 |
+
},
|
| 220 |
+
"rope_theta": [
|
| 221 |
+
5000000.0,
|
| 222 |
+
10000.0,
|
| 223 |
+
10000.0,
|
| 224 |
+
10000.0,
|
| 225 |
+
5000000.0,
|
| 226 |
+
10000.0,
|
| 227 |
+
10000.0,
|
| 228 |
+
10000.0,
|
| 229 |
+
5000000.0,
|
| 230 |
+
10000.0,
|
| 231 |
+
10000.0,
|
| 232 |
+
10000.0,
|
| 233 |
+
5000000.0,
|
| 234 |
+
10000.0,
|
| 235 |
+
10000.0,
|
| 236 |
+
10000.0,
|
| 237 |
+
5000000.0,
|
| 238 |
+
10000.0,
|
| 239 |
+
10000.0,
|
| 240 |
+
10000.0,
|
| 241 |
+
5000000.0,
|
| 242 |
+
10000.0,
|
| 243 |
+
10000.0,
|
| 244 |
+
10000.0,
|
| 245 |
+
5000000.0,
|
| 246 |
+
10000.0,
|
| 247 |
+
10000.0,
|
| 248 |
+
10000.0,
|
| 249 |
+
5000000.0,
|
| 250 |
+
10000.0,
|
| 251 |
+
10000.0,
|
| 252 |
+
10000.0,
|
| 253 |
+
5000000.0,
|
| 254 |
+
10000.0,
|
| 255 |
+
10000.0,
|
| 256 |
+
10000.0,
|
| 257 |
+
5000000.0,
|
| 258 |
+
10000.0,
|
| 259 |
+
10000.0,
|
| 260 |
+
10000.0,
|
| 261 |
+
5000000.0,
|
| 262 |
+
10000.0,
|
| 263 |
+
10000.0,
|
| 264 |
+
10000.0,
|
| 265 |
+
5000000.0,
|
| 266 |
+
10000.0,
|
| 267 |
+
10000.0,
|
| 268 |
+
10000.0
|
| 269 |
+
],
|
| 270 |
+
"share_expert_dim": 1280,
|
| 271 |
+
"sink": false,
|
| 272 |
+
"sliding_window": 512,
|
| 273 |
+
"swiglu_limits": [
|
| 274 |
+
0.0,
|
| 275 |
+
0.0,
|
| 276 |
+
0.0,
|
| 277 |
+
0.0,
|
| 278 |
+
0.0,
|
| 279 |
+
0.0,
|
| 280 |
+
0.0,
|
| 281 |
+
0.0,
|
| 282 |
+
0.0,
|
| 283 |
+
0.0,
|
| 284 |
+
0.0,
|
| 285 |
+
0.0,
|
| 286 |
+
0.0,
|
| 287 |
+
0.0,
|
| 288 |
+
0.0,
|
| 289 |
+
0.0,
|
| 290 |
+
0.0,
|
| 291 |
+
0.0,
|
| 292 |
+
0.0,
|
| 293 |
+
0.0,
|
| 294 |
+
0.0,
|
| 295 |
+
0.0,
|
| 296 |
+
0.0,
|
| 297 |
+
0.0,
|
| 298 |
+
0.0,
|
| 299 |
+
0.0,
|
| 300 |
+
0.0,
|
| 301 |
+
0.0,
|
| 302 |
+
0.0,
|
| 303 |
+
0.0,
|
| 304 |
+
0.0,
|
| 305 |
+
0.0,
|
| 306 |
+
0.0,
|
| 307 |
+
0.0,
|
| 308 |
+
0.0,
|
| 309 |
+
0.0,
|
| 310 |
+
0.0,
|
| 311 |
+
0.0,
|
| 312 |
+
0.0,
|
| 313 |
+
0.0,
|
| 314 |
+
0.0,
|
| 315 |
+
0.0,
|
| 316 |
+
0.0,
|
| 317 |
+
7,
|
| 318 |
+
7,
|
| 319 |
+
0.0,
|
| 320 |
+
0.0,
|
| 321 |
+
0.0
|
| 322 |
+
],
|
| 323 |
+
"swiglu_limits_shared": [
|
| 324 |
+
0.0,
|
| 325 |
+
0.0,
|
| 326 |
+
0.0,
|
| 327 |
+
0.0,
|
| 328 |
+
0.0,
|
| 329 |
+
0.0,
|
| 330 |
+
0.0,
|
| 331 |
+
0.0,
|
| 332 |
+
0.0,
|
| 333 |
+
0.0,
|
| 334 |
+
0.0,
|
| 335 |
+
0.0,
|
| 336 |
+
0.0,
|
| 337 |
+
0.0,
|
| 338 |
+
0.0,
|
| 339 |
+
0.0,
|
| 340 |
+
0.0,
|
| 341 |
+
0.0,
|
| 342 |
+
0.0,
|
| 343 |
+
0.0,
|
| 344 |
+
0.0,
|
| 345 |
+
0.0,
|
| 346 |
+
0.0,
|
| 347 |
+
0.0,
|
| 348 |
+
0.0,
|
| 349 |
+
0.0,
|
| 350 |
+
0.0,
|
| 351 |
+
0.0,
|
| 352 |
+
0.0,
|
| 353 |
+
0.0,
|
| 354 |
+
0.0,
|
| 355 |
+
0.0,
|
| 356 |
+
0.0,
|
| 357 |
+
0.0,
|
| 358 |
+
0.0,
|
| 359 |
+
0.0,
|
| 360 |
+
0.0,
|
| 361 |
+
0.0,
|
| 362 |
+
0.0,
|
| 363 |
+
0.0,
|
| 364 |
+
0.0,
|
| 365 |
+
0.0,
|
| 366 |
+
0.0,
|
| 367 |
+
16,
|
| 368 |
+
16,
|
| 369 |
+
0.0,
|
| 370 |
+
0.0,
|
| 371 |
+
0.0
|
| 372 |
+
],
|
| 373 |
+
"use_head_wise_attn_gate": true,
|
| 374 |
+
"use_mfa": false,
|
| 375 |
+
"use_moe": true,
|
| 376 |
+
"use_moe_router_bias": true,
|
| 377 |
+
"use_qk_norm": false,
|
| 378 |
+
"use_rope_layers": [],
|
| 379 |
+
"vocab_size": 128896,
|
| 380 |
+
"yarn_only_types": [
|
| 381 |
+
"full_attention"
|
| 382 |
+
]
|
| 383 |
+
},
|
| 384 |
+
"transformers_version": "5.10.0.dev0",
|
| 385 |
+
"understand_projector_stride": 2,
|
| 386 |
+
"unsloth_fixed": true,
|
| 387 |
+
"use_im_start_end": "true",
|
| 388 |
+
"vision_config": {
|
| 389 |
+
"heads": 16,
|
| 390 |
+
"hidden_act": "quick_gelu",
|
| 391 |
+
"image_size": 728,
|
| 392 |
+
"layer_norm_eps": 1e-05,
|
| 393 |
+
"layers": 47,
|
| 394 |
+
"ls_init_value": 0.1,
|
| 395 |
+
"mlp_ratio": 5.833333333333333,
|
| 396 |
+
"model_type": "perception_encoder",
|
| 397 |
+
"num_channels": 3,
|
| 398 |
+
"output_dim": null,
|
| 399 |
+
"patch_size": 14,
|
| 400 |
+
"pool_type": "none",
|
| 401 |
+
"ues_cls_token": false,
|
| 402 |
+
"use_abs_posemb": true,
|
| 403 |
+
"use_cls_token": false,
|
| 404 |
+
"use_ln_post": false,
|
| 405 |
+
"use_ln_pre": true,
|
| 406 |
+
"use_rope2d": true,
|
| 407 |
+
"width": 1536
|
| 408 |
+
},
|
| 409 |
+
"vision_select_layer": -1
|
| 410 |
+
}
|
configuration_step3p7.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Optional, Sequence, Union
|
| 2 |
+
|
| 3 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 4 |
+
|
| 5 |
+
class StepRoboticsVisionEncoderConfig(PretrainedConfig):
|
| 6 |
+
model_type = "perception_encoder"
|
| 7 |
+
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
width=1536,
|
| 11 |
+
layers=47,
|
| 12 |
+
heads=16,
|
| 13 |
+
num_channels=3,
|
| 14 |
+
image_size=728,
|
| 15 |
+
mlp_ratio = 8960/1536,
|
| 16 |
+
patch_size=14,
|
| 17 |
+
hidden_act="quick_gelu",
|
| 18 |
+
layer_norm_eps=1e-5,
|
| 19 |
+
ues_cls_token=False,
|
| 20 |
+
use_cls_token: Optional[bool] = None,
|
| 21 |
+
use_ln_pre=True,
|
| 22 |
+
use_ln_post=False,
|
| 23 |
+
use_abs_posemb=True,
|
| 24 |
+
use_rope2d=True,
|
| 25 |
+
ls_init_value=0.1,
|
| 26 |
+
**kwargs,
|
| 27 |
+
):
|
| 28 |
+
self.width = width
|
| 29 |
+
self.layers = layers
|
| 30 |
+
self.heads = heads
|
| 31 |
+
self.num_channels = num_channels
|
| 32 |
+
self.patch_size = patch_size
|
| 33 |
+
self.image_size = image_size
|
| 34 |
+
self.mlp_ratio = mlp_ratio
|
| 35 |
+
self.layer_norm_eps = layer_norm_eps
|
| 36 |
+
self.hidden_act = hidden_act
|
| 37 |
+
if use_cls_token is None:
|
| 38 |
+
use_cls_token = ues_cls_token
|
| 39 |
+
self.ues_cls_token = use_cls_token
|
| 40 |
+
self.use_cls_token = use_cls_token
|
| 41 |
+
self.use_ln_pre = use_ln_pre
|
| 42 |
+
self.ls_init_value = ls_init_value
|
| 43 |
+
self.use_ln_post = use_ln_post
|
| 44 |
+
self.use_abs_posemb = use_abs_posemb
|
| 45 |
+
self.use_rope2d = use_rope2d
|
| 46 |
+
super().__init__(**kwargs)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class Step3p7TextConfig(PretrainedConfig):
|
| 50 |
+
model_type = "step3p5"
|
| 51 |
+
architectures = ["Step3p5ForCausalLM"]
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
hidden_size: int = 4096,
|
| 56 |
+
intermediate_size: int = 11264,
|
| 57 |
+
num_attention_heads: int = 64,
|
| 58 |
+
num_attention_groups: int = 8,
|
| 59 |
+
num_hidden_layers: int = 45,
|
| 60 |
+
max_seq_len: int = 128000,
|
| 61 |
+
vocab_size: int = 128815,
|
| 62 |
+
rms_norm_eps: float = 1e-5,
|
| 63 |
+
moe_intermediate_size: int = 1280,
|
| 64 |
+
moe_num_experts: int = 288,
|
| 65 |
+
moe_top_k: int = 8,
|
| 66 |
+
rope_theta: float = 10000,
|
| 67 |
+
rope_scaling: Optional[dict[str, Any]] = None,
|
| 68 |
+
max_position_embeddings: int = 128000,
|
| 69 |
+
share_expert_dims: int = 1280,
|
| 70 |
+
share_expert_dim: Optional[int] = None,
|
| 71 |
+
head_dim: int = 128,
|
| 72 |
+
norm_expert_weight: bool = True,
|
| 73 |
+
layer_types: list[str] = None,
|
| 74 |
+
sliding_window: Optional[int] = None,
|
| 75 |
+
pad_token_id: int = 1,
|
| 76 |
+
attention_dropout: float = 0.0,
|
| 77 |
+
use_head_wise_attn_gate: bool = False,
|
| 78 |
+
use_moe_router_bias: bool = False,
|
| 79 |
+
moe_router_activation: str = "softmax",
|
| 80 |
+
moe_router_scaling_factor: float = 1.0,
|
| 81 |
+
need_fp32_gate: bool = False,
|
| 82 |
+
attention_other_setting: Optional[dict[str, Any]] = None,
|
| 83 |
+
swiglu_limits: Optional[list[Optional[float]]] = None,
|
| 84 |
+
swiglu_limits_shared: Optional[list[Optional[float]]] = None,
|
| 85 |
+
use_rope_layers: Optional[list[bool]] = None,
|
| 86 |
+
yarn_only_types: Optional[list[str]] = None,
|
| 87 |
+
moe_layers_enum: tuple[int] = (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
|
| 88 |
+
15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
|
| 89 |
+
25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
|
| 90 |
+
35, 36, 37, 38, 39, 40, 41, 42, 43, 44),
|
| 91 |
+
**kwargs,
|
| 92 |
+
) -> None:
|
| 93 |
+
torch_dtype = kwargs.get("torch_dtype")
|
| 94 |
+
trim_layer_types = _normalize_per_layer_values(layer_types,
|
| 95 |
+
num_hidden_layers)
|
| 96 |
+
if isinstance(rope_scaling, dict):
|
| 97 |
+
rope_scaling = dict(rope_scaling)
|
| 98 |
+
if share_expert_dim is None:
|
| 99 |
+
share_expert_dim = share_expert_dims
|
| 100 |
+
self.hidden_size = hidden_size
|
| 101 |
+
self.intermediate_size = intermediate_size
|
| 102 |
+
self.num_attention_heads = num_attention_heads
|
| 103 |
+
self.num_attention_groups = num_attention_groups
|
| 104 |
+
self.num_hidden_layers = num_hidden_layers
|
| 105 |
+
self.max_seq_len = max_seq_len
|
| 106 |
+
self.vocab_size = vocab_size
|
| 107 |
+
self.rms_norm_eps = rms_norm_eps
|
| 108 |
+
self.moe_intermediate_size = moe_intermediate_size
|
| 109 |
+
self.moe_num_experts = moe_num_experts
|
| 110 |
+
self.moe_top_k = moe_top_k
|
| 111 |
+
self.rope_theta = rope_theta
|
| 112 |
+
self.rope_scaling = rope_scaling
|
| 113 |
+
self.max_position_embeddings = max_position_embeddings
|
| 114 |
+
self.share_expert_dim = share_expert_dim
|
| 115 |
+
self.head_dim = head_dim
|
| 116 |
+
self.norm_expert_weight = norm_expert_weight
|
| 117 |
+
self.moe_layers_enum = moe_layers_enum
|
| 118 |
+
self.layer_types = trim_layer_types
|
| 119 |
+
self.sliding_window = sliding_window
|
| 120 |
+
self.pad_token_id = pad_token_id
|
| 121 |
+
self.attention_dropout = attention_dropout
|
| 122 |
+
self.use_head_wise_attn_gate = use_head_wise_attn_gate
|
| 123 |
+
self.use_moe_router_bias = use_moe_router_bias
|
| 124 |
+
self.moe_router_activation = moe_router_activation
|
| 125 |
+
self.moe_router_scaling_factor = moe_router_scaling_factor
|
| 126 |
+
self.need_fp32_gate = need_fp32_gate
|
| 127 |
+
self.attention_other_setting = attention_other_setting
|
| 128 |
+
self.swiglu_limits = swiglu_limits
|
| 129 |
+
self.swiglu_limits_shared = swiglu_limits_shared
|
| 130 |
+
self.use_rope_layers = use_rope_layers
|
| 131 |
+
self.yarn_only_types = yarn_only_types
|
| 132 |
+
super().__init__(**kwargs)
|
| 133 |
+
if torch_dtype is not None:
|
| 134 |
+
self.torch_dtype = torch_dtype
|
| 135 |
+
self.layer_types = layer_types
|
| 136 |
+
|
| 137 |
+
def to_dict(self):
|
| 138 |
+
output = super().to_dict()
|
| 139 |
+
torch_dtype = getattr(self, "torch_dtype", None)
|
| 140 |
+
if torch_dtype is not None:
|
| 141 |
+
output["torch_dtype"] = torch_dtype
|
| 142 |
+
return output
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _normalize_per_layer_values(
|
| 146 |
+
values: Optional[Sequence[Any]],
|
| 147 |
+
num_hidden_layers: int,
|
| 148 |
+
) -> Optional[list[Any]]:
|
| 149 |
+
if values is None:
|
| 150 |
+
return None
|
| 151 |
+
normalized = list(values)
|
| 152 |
+
if not normalized:
|
| 153 |
+
return normalized
|
| 154 |
+
if len(normalized) < num_hidden_layers:
|
| 155 |
+
normalized.extend([normalized[-1]] *
|
| 156 |
+
(num_hidden_layers - len(normalized)))
|
| 157 |
+
# Some checkpoints keep MTP/spec layer entries after the decoder layers.
|
| 158 |
+
# This config only builds num_hidden_layers decoder layers, and HF strict
|
| 159 |
+
# validation requires per-layer fields to match that decoder count.
|
| 160 |
+
return normalized[:num_hidden_layers]
|
| 161 |
+
|
| 162 |
+
class Step3p7Config(PretrainedConfig):
|
| 163 |
+
# This loader is a compatibility shim for original Step VL checkpoints
|
| 164 |
+
# whose top-level config model_type is `step3p7`.
|
| 165 |
+
model_type = "step3p7"
|
| 166 |
+
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
vision_config: Optional[Union[dict, StepRoboticsVisionEncoderConfig]] = None,
|
| 170 |
+
text_config: Optional[Union[dict, Step3p7TextConfig]] = None,
|
| 171 |
+
understand_projector_stride: int = 2,
|
| 172 |
+
projector_bias: bool = False,
|
| 173 |
+
image_token_id: int = 151679,
|
| 174 |
+
**kwargs,
|
| 175 |
+
) -> None:
|
| 176 |
+
shared_rope_scaling = kwargs.get("rope_scaling")
|
| 177 |
+
if isinstance(shared_rope_scaling, dict):
|
| 178 |
+
shared_rope_scaling = dict(shared_rope_scaling)
|
| 179 |
+
|
| 180 |
+
if vision_config is None:
|
| 181 |
+
vision_config = StepRoboticsVisionEncoderConfig()
|
| 182 |
+
elif isinstance(vision_config, dict):
|
| 183 |
+
vision_config = StepRoboticsVisionEncoderConfig(**vision_config)
|
| 184 |
+
self.vision_config = vision_config
|
| 185 |
+
|
| 186 |
+
if text_config is None:
|
| 187 |
+
text_config = Step3p7TextConfig(rope_scaling=shared_rope_scaling)
|
| 188 |
+
elif isinstance(text_config, dict):
|
| 189 |
+
text_config = dict(text_config)
|
| 190 |
+
if shared_rope_scaling is not None and "rope_scaling" not in text_config:
|
| 191 |
+
text_config["rope_scaling"] = shared_rope_scaling
|
| 192 |
+
text_config = Step3p7TextConfig(**text_config)
|
| 193 |
+
elif shared_rope_scaling is not None and text_config.rope_scaling is None:
|
| 194 |
+
text_config.rope_scaling = dict(shared_rope_scaling)
|
| 195 |
+
self.text_config = text_config
|
| 196 |
+
|
| 197 |
+
rope_scaling = kwargs.get("rope_scaling")
|
| 198 |
+
if isinstance(rope_scaling, dict):
|
| 199 |
+
kwargs["rope_scaling"] = dict(rope_scaling)
|
| 200 |
+
|
| 201 |
+
self.understand_projector_stride = understand_projector_stride
|
| 202 |
+
self.projector_bias = projector_bias
|
| 203 |
+
self.hidden_size = text_config.hidden_size
|
| 204 |
+
self.max_position_embeddings = text_config.max_position_embeddings
|
| 205 |
+
self.image_token_id = image_token_id
|
| 206 |
+
# Help Auto classes find the correct implementation when saving/loading.
|
| 207 |
+
super().__init__(**kwargs)
|
model-00001.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5a2d47133d0ffa22f50a24ad4974c559c1b31f26f5baca24fc4f4dfe198b46c6
|
| 3 |
+
size 924094096
|
model-00002.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:67c13067deed696b62763643b7d531fd2cfde4c6e81cfcaba5460551e510d0af
|
| 3 |
+
size 9808156008
|
model-00003.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6f3567584681f4d2792e4d949c9440198f792a5afd93220d3770b509728b6ef1
|
| 3 |
+
size 18557475928
|
model-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d035fb813758ed63f1d537bbf41f6cbb2c5c8eb05f187de18a448c7766a64960
|
| 3 |
+
size 18624846944
|
model-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f9a2c0daa3a49fc88e53e0b6419f2e4db7e412f40760488d49ca0f834fe83725
|
| 3 |
+
size 18557475928
|
model-00006.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7fee76c5fb28547ad0d4094a0bae7755a292dd439cc23b054210a24c965b093f
|
| 3 |
+
size 18624846976
|
model-00007.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ccad5d228ec280d95419fbbcf2590f2cdfc4c932a7249a7669dc7f509dc7fe66
|
| 3 |
+
size 18557475968
|
model-00008.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4d537acabde8deace533c23df8e43268f1423b41e7b6e27c79232955283f4e44
|
| 3 |
+
size 18624846976
|
model-00009.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:48be665fd9bce6e2fdac06d03a1a9916794fce4231b03009e6a4cfca1055a2c9
|
| 3 |
+
size 18557475968
|
model-00010.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dd61c7f6d62725005a07fe778dc572b9642972054424b2a12d1494e7ca241d91
|
| 3 |
+
size 18624846976
|
model-00011.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:51c5fe0dce035dd7fc01333fe3ba0fff46e65412ad7a71c09fa8e2992b8d26a7
|
| 3 |
+
size 18557475968
|
model-00012.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0f3e890ede3949af958a72da0beb99db6834853ee22978eb7782a600d013abac
|
| 3 |
+
size 18624846976
|
model-00013.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:98802ed9091498df2ef7a73b2697f5ac275a64892d984b9045a0a99f7b459c78
|
| 3 |
+
size 18557475968
|
model-00014.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:459e5814b710f888b6763385fb179d52f746f59e702dd165f0c5d5cc73417b03
|
| 3 |
+
size 18624846976
|
model-00015.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:13a51f345afa384b930387d40ac79ed6614f02129d61a9714e213f726970f47c
|
| 3 |
+
size 18557475968
|
model-00016.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3475a9dcaff31af71b6183371f8e355bdedea5f4dbb1ade6e84dcfe28ddc9517
|
| 3 |
+
size 18624846976
|
model-00017.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:92917af53ef59cd99d43d49de2ffcbec3d21db7ebc59107a66aa2438da2eca14
|
| 3 |
+
size 18557475968
|
model-00018.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aba73fb3d39556bba83fe864f7a7b60e8b2085204b074101500531e69525ee4f
|
| 3 |
+
size 18624846976
|
model-00019.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:617c98c96871403936caa0dcea602e7650cb947493555c142dc80e6c991adad8
|
| 3 |
+
size 18557475968
|
model-00020.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1ccea8f04adaeeb446b8def20c6042c96f6da4eb68da6bf2a76bacf65350e4e9
|
| 3 |
+
size 18624846976
|
model-00021.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:af8c9ca65f1830163f6d5741569b4dd4c62468a1c21556e7b760e303bc3b7818
|
| 3 |
+
size 18557475968
|
model-00022.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0cc5137141b5e2522fd3e69a4c828a0dbb602569ab8a0afcce5151b06800339f
|
| 3 |
+
size 18624846976
|
model-00023.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:05c2c2a08df421f617794e137429246a6ea60dd908fc691263242a12325dae7f
|
| 3 |
+
size 9245052456
|
model-00024.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7688adfc7748c12fdc8504187c57fe6ec6005798a02defc0d3372f921b1400a1
|
| 3 |
+
size 6968188464
|
model-vit-00001.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:22aa3f3679feffb57c2fb0bc885db0f5613db3536efef5d4b0984e8d769f6017
|
| 3 |
+
size 1613990904
|
model-vit-00002.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1f63ca4700a4184459d3ddb3a86c54a62914d359cedfddcfc14739ae782be082
|
| 3 |
+
size 2348122376
|
model.safetensors.index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_step3p7.py
ADDED
|
@@ -0,0 +1,1405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
import copy
|
| 16 |
+
import inspect
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Callable, Literal, Optional, Tuple, TypedDict, Union
|
| 19 |
+
|
| 20 |
+
from PIL import Image
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
from transformers.activations import ACT2FN
|
| 26 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 27 |
+
from transformers.generation import GenerationMixin
|
| 28 |
+
from transformers.masking_utils import (
|
| 29 |
+
create_causal_mask,
|
| 30 |
+
create_sliding_window_causal_mask,
|
| 31 |
+
)
|
| 32 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 33 |
+
from transformers.modeling_layers import GradientCheckpointingLayer
|
| 34 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
| 35 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 36 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 37 |
+
from transformers.processing_utils import Unpack
|
| 38 |
+
from transformers.utils import TransformersKwargs, can_return_tuple, logging
|
| 39 |
+
from .configuration_step3p7 import Step3p7Config, Step3p7TextConfig
|
| 40 |
+
from .vision_encoder import StepRoboticsVisionEncoder
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
logger = logging.get_logger(__name__)
|
| 44 |
+
_MASK_INPUT_EMBEDS_ARG = (
|
| 45 |
+
"inputs_embeds"
|
| 46 |
+
if "inputs_embeds" in inspect.signature(create_causal_mask).parameters
|
| 47 |
+
else "input_embeds"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
__all__ = [
|
| 51 |
+
"Step3p7Model",
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class StepVLImagePixelInputs(TypedDict):
|
| 56 |
+
type: Literal["pixel_values"]
|
| 57 |
+
pixel_values: torch.Tensor
|
| 58 |
+
patch_pixel_values: Optional[torch.Tensor]
|
| 59 |
+
num_patches: list[int]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class StepVLImageEmbeddingInputs(TypedDict):
|
| 63 |
+
type: Literal["image_embeds"]
|
| 64 |
+
image_embeds: torch.Tensor
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
StepVLImageInputs = Union[StepVLImagePixelInputs, StepVLImageEmbeddingInputs]
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _flatten_embeddings(embeddings) -> torch.Tensor:
|
| 71 |
+
"""
|
| 72 |
+
Recursively flattens and concatenates NestedTensors on all but the last
|
| 73 |
+
dimension.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
if isinstance(embeddings, torch.Tensor):
|
| 77 |
+
# Flatten all but the last dimension.
|
| 78 |
+
return embeddings.flatten(0, -2)
|
| 79 |
+
|
| 80 |
+
return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
|
| 81 |
+
|
| 82 |
+
def _embedding_count_expression(embeddings) -> str:
|
| 83 |
+
"""
|
| 84 |
+
Constructs a debugging representation of the number of embeddings in the
|
| 85 |
+
NestedTensors.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
if isinstance(embeddings, torch.Tensor):
|
| 89 |
+
return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
|
| 90 |
+
|
| 91 |
+
return " + ".join(_embedding_count_expression(inner) for inner in embeddings)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _merge_multimodal_embeddings(
|
| 95 |
+
inputs_embeds: torch.Tensor,
|
| 96 |
+
is_multimodal: torch.Tensor,
|
| 97 |
+
multimodal_embeddings,
|
| 98 |
+
) -> torch.Tensor:
|
| 99 |
+
"""
|
| 100 |
+
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
|
| 101 |
+
positions in ``inputs_embeds`` corresponding to placeholder tokens in
|
| 102 |
+
``input_ids``.
|
| 103 |
+
Note:
|
| 104 |
+
This updates ``inputs_embeds`` in place.
|
| 105 |
+
"""
|
| 106 |
+
num_expected_tokens = is_multimodal.sum().item()
|
| 107 |
+
assert isinstance(num_expected_tokens, int)
|
| 108 |
+
|
| 109 |
+
flattened = _flatten_embeddings(multimodal_embeddings)
|
| 110 |
+
if flattened.shape[0] != num_expected_tokens:
|
| 111 |
+
expr = _embedding_count_expression(multimodal_embeddings)
|
| 112 |
+
raise ValueError(
|
| 113 |
+
f"Attempted to assign {expr} = {flattened.shape[0]} "
|
| 114 |
+
f"multimodal tokens to {num_expected_tokens} placeholders"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
is_multimodal = is_multimodal.to(inputs_embeds.device)
|
| 118 |
+
flattened = flattened.to(inputs_embeds.device)
|
| 119 |
+
inputs_embeds[is_multimodal] = flattened
|
| 120 |
+
return inputs_embeds
|
| 121 |
+
|
| 122 |
+
def merge_multimodal_embeddings(
|
| 123 |
+
input_ids: torch.Tensor,
|
| 124 |
+
inputs_embeds: torch.Tensor,
|
| 125 |
+
multimodal_embeddings,
|
| 126 |
+
placeholder_token_id: Union[int, list[int]],
|
| 127 |
+
) -> torch.Tensor:
|
| 128 |
+
"""
|
| 129 |
+
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
|
| 130 |
+
positions in ``inputs_embeds`` corresponding to placeholder tokens in
|
| 131 |
+
``input_ids``.
|
| 132 |
+
|
| 133 |
+
``placeholder_token_id`` can be a list of token ids (e.g, token ids
|
| 134 |
+
of img_start, img_break, and img_end tokens) when needed: This means
|
| 135 |
+
the order of these tokens in the ``input_ids`` MUST MATCH the order of
|
| 136 |
+
their embeddings in ``multimodal_embeddings`` since we need to
|
| 137 |
+
slice-merge instead of individually scattering.
|
| 138 |
+
For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
|
| 139 |
+
- T is text token
|
| 140 |
+
- S is image start token
|
| 141 |
+
- I is image embedding token
|
| 142 |
+
- B is image break token
|
| 143 |
+
- E is image end token.
|
| 144 |
+
|
| 145 |
+
Then the image embeddings (that correspond to I's) from vision encoder
|
| 146 |
+
must be padded with embeddings of S, B, and E in the same order of
|
| 147 |
+
input_ids for a correct embedding merge.
|
| 148 |
+
Note:
|
| 149 |
+
This updates ``inputs_embeds`` in place.
|
| 150 |
+
"""
|
| 151 |
+
if isinstance(placeholder_token_id, list):
|
| 152 |
+
placeholder_token_id = torch.tensor(
|
| 153 |
+
placeholder_token_id, device=input_ids.device
|
| 154 |
+
)
|
| 155 |
+
return _merge_multimodal_embeddings(
|
| 156 |
+
inputs_embeds,
|
| 157 |
+
torch.isin(input_ids, placeholder_token_id),
|
| 158 |
+
multimodal_embeddings,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
return _merge_multimodal_embeddings(
|
| 162 |
+
inputs_embeds,
|
| 163 |
+
(input_ids == placeholder_token_id),
|
| 164 |
+
multimodal_embeddings,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class Step3p7PreTrainedModel(PreTrainedModel):
|
| 169 |
+
# Link this model family to its configuration class so PreTrainedModel.from_pretrained
|
| 170 |
+
# can load the config instead of failing with a NoneType error.
|
| 171 |
+
config_class = Step3p7Config
|
| 172 |
+
supports_gradient_checkpointing = True
|
| 173 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 174 |
+
_keys_to_ignore_on_load_unexpected = [
|
| 175 |
+
r"model\.layers\.45\.*",
|
| 176 |
+
r"model\.layers\.46\.*",
|
| 177 |
+
r"model\.layers\.47\.*",
|
| 178 |
+
]
|
| 179 |
+
_supports_flash_attn = False
|
| 180 |
+
_supports_sdpa = True
|
| 181 |
+
_supports_flex_attn = True
|
| 182 |
+
_supports_static_cache = True
|
| 183 |
+
_supports_attention_backend = True
|
| 184 |
+
|
| 185 |
+
@classmethod
|
| 186 |
+
def from_pretrained(
|
| 187 |
+
cls, pretrained_model_name_or_path, *model_args, **kwargs
|
| 188 |
+
):
|
| 189 |
+
key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None)
|
| 190 |
+
if key_mapping is not None and kwargs.get("key_mapping") is None:
|
| 191 |
+
# Transformers only applies checkpoint renaming when key_mapping is
|
| 192 |
+
# passed explicitly; inheriting the class attribute alone is not enough.
|
| 193 |
+
kwargs["key_mapping"] = copy.deepcopy(key_mapping)
|
| 194 |
+
return super().from_pretrained(
|
| 195 |
+
pretrained_model_name_or_path, *model_args, **kwargs
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class Step3p7RotaryEmbedding(nn.Module):
|
| 200 |
+
def __init__(self, config: Step3p7TextConfig, device=None, layer_idx=None):
|
| 201 |
+
super().__init__()
|
| 202 |
+
self.layer_idx = layer_idx
|
| 203 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 204 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 205 |
+
|
| 206 |
+
rope_theta = config.rope_theta
|
| 207 |
+
if isinstance(rope_theta, list):
|
| 208 |
+
rope_theta = rope_theta[0 if layer_idx is None else layer_idx]
|
| 209 |
+
|
| 210 |
+
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
|
| 211 |
+
partial_rotary_factors = getattr(config, "partial_rotary_factors", None)
|
| 212 |
+
if partial_rotary_factors is not None:
|
| 213 |
+
partial_rotary_factor = partial_rotary_factors[
|
| 214 |
+
0 if layer_idx is None else layer_idx
|
| 215 |
+
]
|
| 216 |
+
|
| 217 |
+
self.rope_theta = rope_theta
|
| 218 |
+
self.partial_rotary_factor = partial_rotary_factor
|
| 219 |
+
|
| 220 |
+
self.config = copy.copy(config)
|
| 221 |
+
self.config.rope_theta = rope_theta
|
| 222 |
+
self.config.partial_rotary_factor = partial_rotary_factor
|
| 223 |
+
|
| 224 |
+
if config.rope_parameters is not None:
|
| 225 |
+
self.config.rope_parameters = copy.deepcopy(config.rope_parameters)
|
| 226 |
+
self.config.rope_parameters["rope_theta"] = rope_theta
|
| 227 |
+
self.config.rope_parameters["partial_rotary_factor"] = (
|
| 228 |
+
partial_rotary_factor
|
| 229 |
+
)
|
| 230 |
+
self.rope_type = self.config.rope_parameters.get(
|
| 231 |
+
"rope_type", self.config.rope_parameters.get("type")
|
| 232 |
+
)
|
| 233 |
+
else:
|
| 234 |
+
self.rope_type = "default"
|
| 235 |
+
|
| 236 |
+
self.rope_init_fn = self.compute_default_rope_parameters
|
| 237 |
+
if self.rope_type != "default":
|
| 238 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 239 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(
|
| 240 |
+
self.config, device
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 244 |
+
self.original_inv_freq = self.inv_freq
|
| 245 |
+
|
| 246 |
+
@torch.no_grad()
|
| 247 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 248 |
+
def forward(self, x, position_ids):
|
| 249 |
+
inv_freq_expanded = (
|
| 250 |
+
self.inv_freq[None, :, None]
|
| 251 |
+
.float()
|
| 252 |
+
.expand(position_ids.shape[0], -1, 1)
|
| 253 |
+
.to(x.device)
|
| 254 |
+
)
|
| 255 |
+
position_ids_expanded = position_ids[:, None, :].float().to(x.device)
|
| 256 |
+
|
| 257 |
+
device_type = (
|
| 258 |
+
x.device.type
|
| 259 |
+
if isinstance(x.device.type, str) and x.device.type != "mps"
|
| 260 |
+
else "cpu"
|
| 261 |
+
)
|
| 262 |
+
with torch.autocast(
|
| 263 |
+
device_type=device_type, enabled=False
|
| 264 |
+
): # Force float32
|
| 265 |
+
freqs = (
|
| 266 |
+
inv_freq_expanded.float() @ position_ids_expanded.float()
|
| 267 |
+
).transpose(1, 2)
|
| 268 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 269 |
+
cos = emb.cos() * self.attention_scaling
|
| 270 |
+
sin = emb.sin() * self.attention_scaling
|
| 271 |
+
|
| 272 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 273 |
+
|
| 274 |
+
@staticmethod
|
| 275 |
+
def compute_default_rope_parameters(
|
| 276 |
+
config: Step3p7TextConfig | None = None,
|
| 277 |
+
device: Optional["torch.device"] = None,
|
| 278 |
+
) -> tuple["torch.Tensor", float]:
|
| 279 |
+
"""
|
| 280 |
+
Computes the inverse frequencies according to the original RoPE implementation
|
| 281 |
+
Args:
|
| 282 |
+
config ([`~transformers.PreTrainedConfig`]):
|
| 283 |
+
The model configuration.
|
| 284 |
+
device (`torch.device`):
|
| 285 |
+
The device to use for initialization of the inverse frequencies.
|
| 286 |
+
seq_len (`int`, *optional*):
|
| 287 |
+
The current sequence length. Unused for this type of RoPE.
|
| 288 |
+
Returns:
|
| 289 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 290 |
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
| 291 |
+
"""
|
| 292 |
+
base = config.rope_theta
|
| 293 |
+
partial_rotary_factor = getattr(
|
| 294 |
+
config, "partial_rotary_factor", 1.0
|
| 295 |
+
)
|
| 296 |
+
head_dim = (
|
| 297 |
+
getattr(config, "head_dim", None)
|
| 298 |
+
or config.hidden_size // config.num_attention_heads
|
| 299 |
+
)
|
| 300 |
+
dim = int(head_dim * partial_rotary_factor)
|
| 301 |
+
|
| 302 |
+
attention_factor = 1.0 # Unused in this type of RoPE
|
| 303 |
+
|
| 304 |
+
# Compute the inverse frequencies
|
| 305 |
+
inv_freq = 1.0 / (
|
| 306 |
+
base
|
| 307 |
+
** (
|
| 308 |
+
torch.arange(0, dim, 2, dtype=torch.int64).to(
|
| 309 |
+
device=device, dtype=torch.float
|
| 310 |
+
)
|
| 311 |
+
/ dim
|
| 312 |
+
)
|
| 313 |
+
)
|
| 314 |
+
return inv_freq, attention_factor
|
| 315 |
+
|
| 316 |
+
def rotate_half(x):
|
| 317 |
+
"""Rotates half the hidden dims of the input."""
|
| 318 |
+
x1 = x[..., :x.shape[-1] // 2]
|
| 319 |
+
x2 = x[..., x.shape[-1] // 2:]
|
| 320 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 324 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 325 |
+
|
| 326 |
+
Args:
|
| 327 |
+
q (`torch.Tensor`): The query tensor.
|
| 328 |
+
k (`torch.Tensor`): The key tensor.
|
| 329 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 330 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 331 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 332 |
+
Deprecated and unused.
|
| 333 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 334 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 335 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 336 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 337 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 338 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 339 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 340 |
+
Returns:
|
| 341 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 342 |
+
"""
|
| 343 |
+
rotary_dim = cos.shape[-1]
|
| 344 |
+
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
|
| 345 |
+
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
|
| 346 |
+
|
| 347 |
+
# Apply rotary embeddings on the first half or full tensor
|
| 348 |
+
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
|
| 349 |
+
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
|
| 350 |
+
|
| 351 |
+
# Concatenate back to full shape
|
| 352 |
+
q_embed = torch.cat([q_embed, q_pass], dim=-1)
|
| 353 |
+
k_embed = torch.cat([k_embed, k_pass], dim=-1)
|
| 354 |
+
return q_embed, k_embed
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 358 |
+
"""
|
| 359 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 360 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 361 |
+
"""
|
| 362 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 363 |
+
if n_rep == 1:
|
| 364 |
+
return hidden_states
|
| 365 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(
|
| 366 |
+
batch, num_key_value_heads, n_rep, slen, head_dim
|
| 367 |
+
)
|
| 368 |
+
return hidden_states.reshape(
|
| 369 |
+
batch, num_key_value_heads * n_rep, slen, head_dim
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
# Adapted from transformers.models.llama.modeling_llama.eager_attention_forward.
|
| 374 |
+
# Llama4 does not cast attention weights to fp32 here.
|
| 375 |
+
def eager_attention_forward(
|
| 376 |
+
module: nn.Module,
|
| 377 |
+
query: torch.Tensor,
|
| 378 |
+
key: torch.Tensor,
|
| 379 |
+
value: torch.Tensor,
|
| 380 |
+
attention_mask: Optional[torch.Tensor],
|
| 381 |
+
scaling: float,
|
| 382 |
+
dropout: float = 0.0,
|
| 383 |
+
**kwargs,
|
| 384 |
+
):
|
| 385 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 386 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 387 |
+
# breakpoint()
|
| 388 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 389 |
+
if attention_mask is not None:
|
| 390 |
+
causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
|
| 391 |
+
attn_weights = attn_weights + causal_mask
|
| 392 |
+
|
| 393 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 394 |
+
attn_weights = nn.functional.dropout(
|
| 395 |
+
attn_weights, p=dropout, training=module.training
|
| 396 |
+
)
|
| 397 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 398 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 399 |
+
|
| 400 |
+
return attn_output, attn_weights
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
@dataclass
|
| 404 |
+
class Step3p7CausalLMOutputWithPast(ModelOutput):
|
| 405 |
+
r"""
|
| 406 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 407 |
+
Language modeling loss (for next-token prediction).
|
| 408 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 409 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 410 |
+
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 411 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
| 412 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
| 413 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
| 414 |
+
`past_key_values` input) to speed up sequential decoding.
|
| 415 |
+
"""
|
| 416 |
+
|
| 417 |
+
loss: Optional[torch.FloatTensor] = None
|
| 418 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 419 |
+
logits: torch.FloatTensor = None
|
| 420 |
+
past_key_values: Optional[list[torch.FloatTensor]] = None
|
| 421 |
+
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 422 |
+
attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
class Step3p7MLP(nn.Module):
|
| 426 |
+
def __init__(self, config, intermediate_size=None, swiglu_limit=None):
|
| 427 |
+
super().__init__()
|
| 428 |
+
self.config = config
|
| 429 |
+
self.hidden_size = config.hidden_size
|
| 430 |
+
self.intermediate_size = (
|
| 431 |
+
intermediate_size
|
| 432 |
+
if intermediate_size is not None
|
| 433 |
+
else config.intermediate_size
|
| 434 |
+
)
|
| 435 |
+
self.gate_proj = nn.Linear(self.hidden_size,
|
| 436 |
+
self.intermediate_size,
|
| 437 |
+
bias=False)
|
| 438 |
+
self.up_proj = nn.Linear(self.hidden_size,
|
| 439 |
+
self.intermediate_size,
|
| 440 |
+
bias=False)
|
| 441 |
+
self.down_proj = nn.Linear(self.intermediate_size,
|
| 442 |
+
self.hidden_size,
|
| 443 |
+
bias=False)
|
| 444 |
+
self.act_fn = ACT2FN["silu"]
|
| 445 |
+
self.limit = swiglu_limit
|
| 446 |
+
|
| 447 |
+
def forward(self, x):
|
| 448 |
+
up = self.up_proj(x)
|
| 449 |
+
gate = self.act_fn(self.gate_proj(x))
|
| 450 |
+
if self.limit is not None:
|
| 451 |
+
gate = gate.clamp(min=None, max=self.limit)
|
| 452 |
+
up = up.clamp(min=-self.limit, max=self.limit)
|
| 453 |
+
|
| 454 |
+
return self.down_proj(gate * up)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def sigmoid_routing_function(gating_output: torch.Tensor, topk: int,
|
| 458 |
+
renormalize: bool):
|
| 459 |
+
gating_output = gating_output.float()
|
| 460 |
+
gate_prob = torch.sigmoid(gating_output)
|
| 461 |
+
gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
|
| 462 |
+
topk_prob, indices = torch.topk(gate_prob, k=topk, dim=1)
|
| 463 |
+
expert_topk_weight = topk_prob
|
| 464 |
+
if renormalize:
|
| 465 |
+
expert_topk_weight = expert_topk_weight / torch.sum(
|
| 466 |
+
expert_topk_weight, dim=-1, keepdim=True)
|
| 467 |
+
return expert_topk_weight, indices
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def softmax_routing_function(gating_output: torch.Tensor, top_k: int,
|
| 471 |
+
renormalize: bool):
|
| 472 |
+
gating_output = gating_output.float()
|
| 473 |
+
gate_prob = torch.softmax(gating_output, dim=-1)
|
| 474 |
+
gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
|
| 475 |
+
topk_prob, indices = torch.topk(gate_prob, k=top_k, dim=1)
|
| 476 |
+
expert_topk_weight = topk_prob
|
| 477 |
+
if renormalize:
|
| 478 |
+
expert_topk_weight = expert_topk_weight / torch.sum(
|
| 479 |
+
expert_topk_weight, dim=-1, keepdim=True)
|
| 480 |
+
return expert_topk_weight, indices.to(torch.int32)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
class MoELinear(nn.Module):
|
| 484 |
+
|
| 485 |
+
def __init__(self, num_experts, in_features, out_features):
|
| 486 |
+
super().__init__()
|
| 487 |
+
self.num_experts = num_experts
|
| 488 |
+
self.in_features = in_features
|
| 489 |
+
self.out_features = out_features
|
| 490 |
+
self.weight = nn.Parameter(
|
| 491 |
+
torch.empty(num_experts, out_features, in_features))
|
| 492 |
+
|
| 493 |
+
def forward(self, x, expert_id):
|
| 494 |
+
x = F.linear(x.float(), self.weight[expert_id].float())
|
| 495 |
+
return x
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
class Step3p7MoEMLP(nn.Module):
|
| 499 |
+
|
| 500 |
+
def __init__(self, config, swiglu_limit=None):
|
| 501 |
+
super().__init__()
|
| 502 |
+
self.num_experts = config.moe_num_experts
|
| 503 |
+
self.top_k = config.moe_top_k
|
| 504 |
+
self.hidden_size = config.hidden_size
|
| 505 |
+
self.moe_intermediate_size = config.moe_intermediate_size
|
| 506 |
+
|
| 507 |
+
self.use_moe_router_bias = config.use_moe_router_bias
|
| 508 |
+
if self.use_moe_router_bias:
|
| 509 |
+
self.router_bias = nn.Parameter(torch.zeros(config.moe_num_experts,
|
| 510 |
+
dtype=torch.float32),
|
| 511 |
+
requires_grad=False)
|
| 512 |
+
self.custom_routing_function = self.router_bias_func
|
| 513 |
+
elif config.moe_router_activation == "sigmoid":
|
| 514 |
+
self.custom_routing_function = sigmoid_routing_function
|
| 515 |
+
else:
|
| 516 |
+
self.custom_routing_function = None
|
| 517 |
+
self.need_fp32_gate = config.need_fp32_gate
|
| 518 |
+
self.routed_scaling_factor = getattr(config,
|
| 519 |
+
"moe_router_scaling_factor", 1.0)
|
| 520 |
+
|
| 521 |
+
# gating
|
| 522 |
+
self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
|
| 523 |
+
|
| 524 |
+
self.act_fn = ACT2FN["silu"]
|
| 525 |
+
self.limit = swiglu_limit
|
| 526 |
+
|
| 527 |
+
self.up_proj = MoELinear(self.num_experts, self.hidden_size,
|
| 528 |
+
self.moe_intermediate_size)
|
| 529 |
+
self.gate_proj = MoELinear(self.num_experts, self.hidden_size,
|
| 530 |
+
self.moe_intermediate_size)
|
| 531 |
+
self.down_proj = MoELinear(self.num_experts,
|
| 532 |
+
self.moe_intermediate_size,
|
| 533 |
+
self.hidden_size)
|
| 534 |
+
|
| 535 |
+
def router_bias_func(self, gating_output: torch.Tensor, topk: int,
|
| 536 |
+
renormalize: bool):
|
| 537 |
+
gate_prob = torch.sigmoid(gating_output.float())
|
| 538 |
+
gate_prob_with_bias = gate_prob + self.router_bias.unsqueeze(0)
|
| 539 |
+
_, indices = torch.topk(gate_prob_with_bias, k=topk, dim=1)
|
| 540 |
+
topk_prob = torch.gather(gate_prob, 1, indices)
|
| 541 |
+
expert_topk_weight = topk_prob
|
| 542 |
+
if renormalize:
|
| 543 |
+
expert_topk_weight = expert_topk_weight / (
|
| 544 |
+
torch.sum(expert_topk_weight, dim=-1, keepdim=True) + 1e-20)
|
| 545 |
+
return expert_topk_weight, indices
|
| 546 |
+
|
| 547 |
+
def get_expert_output(self, inputs: torch.Tensor, expert_id):
|
| 548 |
+
#if self.limit is None:
|
| 549 |
+
up = self.up_proj(inputs, expert_id)
|
| 550 |
+
gate = self.act_fn(self.gate_proj(inputs, expert_id))
|
| 551 |
+
if self.limit is not None:
|
| 552 |
+
gate = gate.clamp(min=None, max=self.limit)
|
| 553 |
+
up = up.clamp(min=-self.limit, max=self.limit)
|
| 554 |
+
|
| 555 |
+
return self.down_proj(gate * up, expert_id)
|
| 556 |
+
|
| 557 |
+
def forward(self, hidden_states):
|
| 558 |
+
""" """
|
| 559 |
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 560 |
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 561 |
+
if self.need_fp32_gate:
|
| 562 |
+
router_logits = torch.matmul(
|
| 563 |
+
hidden_states.to(torch.float32),
|
| 564 |
+
self.gate.weight.t().to(torch.float32),
|
| 565 |
+
)
|
| 566 |
+
else:
|
| 567 |
+
# router_logits: (batch * sequence_length, n_experts)
|
| 568 |
+
router_logits = self.gate(hidden_states)
|
| 569 |
+
|
| 570 |
+
if self.custom_routing_function:
|
| 571 |
+
routing_weights, selected_experts = self.custom_routing_function(
|
| 572 |
+
router_logits, self.top_k, renormalize=True)
|
| 573 |
+
else:
|
| 574 |
+
routing_weights = F.softmax(router_logits,
|
| 575 |
+
dim=1,
|
| 576 |
+
dtype=torch.float)
|
| 577 |
+
routing_weights, selected_experts = torch.topk(routing_weights,
|
| 578 |
+
self.top_k,
|
| 579 |
+
dim=-1)
|
| 580 |
+
|
| 581 |
+
routing_weights = routing_weights * self.routed_scaling_factor
|
| 582 |
+
|
| 583 |
+
final_hidden_states = torch.zeros(
|
| 584 |
+
(batch_size * sequence_length, hidden_dim),
|
| 585 |
+
dtype=hidden_states.dtype,
|
| 586 |
+
device=hidden_states.device)
|
| 587 |
+
|
| 588 |
+
# One hot encode the selected experts to create an expert mask
|
| 589 |
+
# this will be used to easily index which expert is going to be sollicitated
|
| 590 |
+
expert_mask = torch.nn.functional.one_hot(
|
| 591 |
+
selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
| 592 |
+
|
| 593 |
+
# Loop over all available experts in the model and perform the computation on each expert
|
| 594 |
+
for expert_idx in range(self.num_experts):
|
| 595 |
+
idx, top_x = torch.where(expert_mask[expert_idx])
|
| 596 |
+
|
| 597 |
+
# Index the correct hidden states and compute the expert hidden state for
|
| 598 |
+
# the current expert. We need to make sure to multiply the output hidden
|
| 599 |
+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
| 600 |
+
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
|
| 601 |
+
current_hidden_states = (
|
| 602 |
+
self.get_expert_output(current_state, expert_idx) *
|
| 603 |
+
routing_weights[top_x, idx, None])
|
| 604 |
+
|
| 605 |
+
# However `index_add_` only support torch tensors for indexing so we'll use
|
| 606 |
+
# the `top_x` tensor here.
|
| 607 |
+
final_hidden_states.index_add_(
|
| 608 |
+
0, top_x, current_hidden_states.to(hidden_states.dtype))
|
| 609 |
+
final_hidden_states = final_hidden_states.reshape(
|
| 610 |
+
batch_size, sequence_length, hidden_dim)
|
| 611 |
+
return final_hidden_states
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
class Step3p7RMSNorm(nn.Module):
|
| 615 |
+
|
| 616 |
+
def __init__(
|
| 617 |
+
self,
|
| 618 |
+
hidden_size: int,
|
| 619 |
+
eps: float = 1e-5,
|
| 620 |
+
) -> None:
|
| 621 |
+
super().__init__()
|
| 622 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 623 |
+
self.variance_epsilon = eps
|
| 624 |
+
|
| 625 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 626 |
+
dtype = x.dtype
|
| 627 |
+
x = x.float()
|
| 628 |
+
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
| 629 |
+
normed = x * torch.rsqrt(variance + self.variance_epsilon)
|
| 630 |
+
normed = normed * (self.weight.float() + 1)
|
| 631 |
+
return normed.to(dtype)
|
| 632 |
+
class Step3p7Attention(nn.Module):
|
| 633 |
+
|
| 634 |
+
def __init__(self, config: Step3p7TextConfig, layer_idx):
|
| 635 |
+
super().__init__()
|
| 636 |
+
self.config = config
|
| 637 |
+
self.layer_idx = layer_idx
|
| 638 |
+
self.num_attention_heads = config.num_attention_heads
|
| 639 |
+
self.num_key_value_heads = config.num_attention_groups
|
| 640 |
+
|
| 641 |
+
layer_types = getattr(config, "layer_types", [])
|
| 642 |
+
if layer_types:
|
| 643 |
+
enable_sliding_window = layer_types[
|
| 644 |
+
self.layer_idx] == "sliding_attention"
|
| 645 |
+
else:
|
| 646 |
+
enable_sliding_window = self.layer_idx % 2 == 0
|
| 647 |
+
|
| 648 |
+
yarn_only_types = getattr(config, "yarn_only_types", None)
|
| 649 |
+
if yarn_only_types and layer_types[
|
| 650 |
+
self.layer_idx] not in yarn_only_types:
|
| 651 |
+
config.rope_parameters = None
|
| 652 |
+
else:
|
| 653 |
+
config.rope_parameters = getattr(config, "rope_scaling", None)
|
| 654 |
+
|
| 655 |
+
self.sliding_window = config.sliding_window
|
| 656 |
+
if enable_sliding_window:
|
| 657 |
+
self.num_attention_heads = config.attention_other_setting[
|
| 658 |
+
"num_attention_heads"]
|
| 659 |
+
self.num_key_value_heads = config.attention_other_setting[
|
| 660 |
+
"num_attention_groups"]
|
| 661 |
+
|
| 662 |
+
if self.sliding_window is not None and enable_sliding_window:
|
| 663 |
+
self.sliding_window = (self.sliding_window)
|
| 664 |
+
else:
|
| 665 |
+
self.sliding_window = None
|
| 666 |
+
self.head_dim = getattr(config, "head_dim",
|
| 667 |
+
config.hidden_size // self.num_attention_heads)
|
| 668 |
+
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
|
| 669 |
+
|
| 670 |
+
self.rotary_emb = Step3p7RotaryEmbedding(config, layer_idx=layer_idx)
|
| 671 |
+
|
| 672 |
+
self.q_size = self.num_attention_heads * self.head_dim
|
| 673 |
+
self.kv_size = self.num_key_value_heads * self.head_dim
|
| 674 |
+
self.scaling = self.head_dim**-0.5
|
| 675 |
+
|
| 676 |
+
self.q_proj = nn.Linear(config.hidden_size, self.q_size, bias=False)
|
| 677 |
+
self.k_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
|
| 678 |
+
self.v_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
|
| 679 |
+
self.o_proj = nn.Linear(self.q_size, config.hidden_size, bias=False)
|
| 680 |
+
self.attention_dropout = getattr(config, "attention_dropout", 0.0)
|
| 681 |
+
self.q_norm = Step3p7RMSNorm(self.head_dim,
|
| 682 |
+
eps=config.rms_norm_eps)
|
| 683 |
+
self.k_norm = Step3p7RMSNorm(self.head_dim,
|
| 684 |
+
eps=config.rms_norm_eps)
|
| 685 |
+
|
| 686 |
+
self.use_head_wise_attn_gate = config.use_head_wise_attn_gate
|
| 687 |
+
if self.use_head_wise_attn_gate:
|
| 688 |
+
self.g_proj = nn.Linear(config.hidden_size,
|
| 689 |
+
self.num_attention_heads,
|
| 690 |
+
bias=False)
|
| 691 |
+
|
| 692 |
+
self.use_rope = True
|
| 693 |
+
use_rope_layers = getattr(config, "use_rope_layers", None)
|
| 694 |
+
if use_rope_layers:
|
| 695 |
+
self.use_rope = use_rope_layers[self.layer_idx]
|
| 696 |
+
|
| 697 |
+
def forward(
|
| 698 |
+
self,
|
| 699 |
+
hidden_states: torch.Tensor,
|
| 700 |
+
attention_mask: Optional[torch.Tensor],
|
| 701 |
+
past_key_value: Optional[Cache] = None,
|
| 702 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 703 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 704 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 705 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
|
| 706 |
+
Optional[Tuple[torch.Tensor]]]:
|
| 707 |
+
input_shape = hidden_states.shape[:-1]
|
| 708 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 709 |
+
|
| 710 |
+
query_states = self.q_norm(
|
| 711 |
+
self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
| 712 |
+
key_states = self.k_norm(
|
| 713 |
+
self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
| 714 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(
|
| 715 |
+
1, 2)
|
| 716 |
+
if self.use_head_wise_attn_gate:
|
| 717 |
+
gate_states = self.g_proj(hidden_states)
|
| 718 |
+
cos, sin = self.rotary_emb(hidden_states, position_ids)
|
| 719 |
+
|
| 720 |
+
# cos, sin = position_embeddings
|
| 721 |
+
query_states, key_states = apply_rotary_pos_emb(
|
| 722 |
+
query_states, key_states, cos, sin)
|
| 723 |
+
|
| 724 |
+
# query_states, key_states = apply_rotary_pos_emb(query_norm_states, key_norm_states, cos, sin)
|
| 725 |
+
if past_key_value is not None:
|
| 726 |
+
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
| 727 |
+
cache_kwargs = {
|
| 728 |
+
"sin": sin,
|
| 729 |
+
"cos": cos,
|
| 730 |
+
"cache_position": cache_position
|
| 731 |
+
}
|
| 732 |
+
key_states, value_states = past_key_value.update(
|
| 733 |
+
key_states, value_states, self.layer_idx, cache_kwargs)
|
| 734 |
+
|
| 735 |
+
attention_interface: Callable = eager_attention_forward
|
| 736 |
+
# TODO: considering FP8;
|
| 737 |
+
# RuntimeError: Expected attn_mask dtype to be bool or float or to match query dtype,
|
| 738 |
+
# but got attn_mask.dtype: long int and query.dtype: c10::BFloat16 instead.
|
| 739 |
+
if self.config._attn_implementation != "eager":
|
| 740 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[
|
| 741 |
+
self.config._attn_implementation]
|
| 742 |
+
|
| 743 |
+
attn_output, attn_weights = attention_interface(
|
| 744 |
+
self,
|
| 745 |
+
query_states,
|
| 746 |
+
key_states,
|
| 747 |
+
value_states,
|
| 748 |
+
attention_mask,
|
| 749 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 750 |
+
scaling=self.scaling,
|
| 751 |
+
sliding_window=self.sliding_window, # main diff with Llama
|
| 752 |
+
**kwargs,
|
| 753 |
+
)
|
| 754 |
+
attn_output = attn_output.reshape(*input_shape, -1)
|
| 755 |
+
if self.use_head_wise_attn_gate:
|
| 756 |
+
output = attn_output.view(
|
| 757 |
+
*attn_output.shape[:-1], self.num_attention_heads,
|
| 758 |
+
self.head_dim) * gate_states.unsqueeze(-1).sigmoid()
|
| 759 |
+
attn_output = output.view(*attn_output.shape)
|
| 760 |
+
attn_output = self.o_proj(attn_output)
|
| 761 |
+
|
| 762 |
+
return attn_output, attn_weights
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
class Step3p7DecoderLayer(GradientCheckpointingLayer):
|
| 766 |
+
|
| 767 |
+
def __init__(self, config, layer_idx):
|
| 768 |
+
super().__init__()
|
| 769 |
+
self.hidden_size = config.hidden_size
|
| 770 |
+
self.layer_idx = layer_idx
|
| 771 |
+
self.self_attn = Step3p7Attention(config, layer_idx)
|
| 772 |
+
layer_types = getattr(config, "layer_types", None) or []
|
| 773 |
+
if layer_types:
|
| 774 |
+
self.attention_type = layer_types[layer_idx]
|
| 775 |
+
else:
|
| 776 |
+
self.attention_type = (
|
| 777 |
+
"sliding_attention" if layer_idx % 2 == 0 else "full_attention"
|
| 778 |
+
)
|
| 779 |
+
|
| 780 |
+
moe_layers_enum = getattr(config, "moe_layers_enum", None)
|
| 781 |
+
if moe_layers_enum is not None:
|
| 782 |
+
if isinstance(moe_layers_enum, str):
|
| 783 |
+
moe_layers_idx = [
|
| 784 |
+
int(i) for i in moe_layers_enum.split(',') if i.strip()
|
| 785 |
+
]
|
| 786 |
+
else:
|
| 787 |
+
moe_layers_idx = [int(i) for i in moe_layers_enum]
|
| 788 |
+
else:
|
| 789 |
+
moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
|
| 790 |
+
self.is_moe_layer = layer_idx in moe_layers_idx
|
| 791 |
+
self.use_moe = False
|
| 792 |
+
|
| 793 |
+
if config.swiglu_limits_shared and config.swiglu_limits_shared[
|
| 794 |
+
layer_idx] is not None and config.swiglu_limits_shared[
|
| 795 |
+
layer_idx] != 0:
|
| 796 |
+
swiglu_limit_shared = config.swiglu_limits_shared[layer_idx]
|
| 797 |
+
else:
|
| 798 |
+
swiglu_limit_shared = None
|
| 799 |
+
if config.swiglu_limits and config.swiglu_limits[
|
| 800 |
+
layer_idx] is not None and config.swiglu_limits[layer_idx] != 0:
|
| 801 |
+
swiglu_limit = config.swiglu_limits[layer_idx]
|
| 802 |
+
else:
|
| 803 |
+
swiglu_limit = None
|
| 804 |
+
if self.is_moe_layer:
|
| 805 |
+
self.moe = Step3p7MoEMLP(config, swiglu_limit=swiglu_limit) #
|
| 806 |
+
self.share_expert = Step3p7MLP(
|
| 807 |
+
config,
|
| 808 |
+
intermediate_size=config.share_expert_dim,
|
| 809 |
+
swiglu_limit=swiglu_limit_shared)
|
| 810 |
+
self.use_moe = True
|
| 811 |
+
else:
|
| 812 |
+
self.mlp = Step3p7MLP(config,
|
| 813 |
+
intermediate_size=config.intermediate_size,
|
| 814 |
+
swiglu_limit=swiglu_limit_shared)
|
| 815 |
+
|
| 816 |
+
self.input_layernorm = Step3p7RMSNorm(
|
| 817 |
+
config.hidden_size,
|
| 818 |
+
eps=config.rms_norm_eps)
|
| 819 |
+
self.post_attention_layernorm = Step3p7RMSNorm(
|
| 820 |
+
config.hidden_size,
|
| 821 |
+
eps=config.rms_norm_eps)
|
| 822 |
+
|
| 823 |
+
def forward(
|
| 824 |
+
self,
|
| 825 |
+
hidden_states: torch.Tensor,
|
| 826 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 827 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 828 |
+
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
| 829 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 830 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 831 |
+
) -> torch.FloatTensor:
|
| 832 |
+
residual = hidden_states
|
| 833 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 834 |
+
hidden_states, _ = self.self_attn(
|
| 835 |
+
hidden_states=hidden_states,
|
| 836 |
+
attention_mask=attention_mask,
|
| 837 |
+
position_ids=position_ids,
|
| 838 |
+
past_key_value=past_key_value,
|
| 839 |
+
cache_position=cache_position,
|
| 840 |
+
**kwargs,
|
| 841 |
+
)
|
| 842 |
+
hidden_states = residual + hidden_states
|
| 843 |
+
|
| 844 |
+
# Fully Connected
|
| 845 |
+
residual = hidden_states
|
| 846 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 847 |
+
if self.use_moe:
|
| 848 |
+
share_output = self.share_expert(hidden_states)
|
| 849 |
+
moe_output = self.moe(hidden_states)
|
| 850 |
+
ffn_output = moe_output + share_output
|
| 851 |
+
else:
|
| 852 |
+
ffn_output = self.mlp(hidden_states)
|
| 853 |
+
if isinstance(ffn_output, tuple):
|
| 854 |
+
hidden_states, _ = ffn_output
|
| 855 |
+
else:
|
| 856 |
+
hidden_states = ffn_output
|
| 857 |
+
|
| 858 |
+
hidden_states = residual + hidden_states
|
| 859 |
+
return hidden_states
|
| 860 |
+
|
| 861 |
+
|
| 862 |
+
class Step3p7TextPreTrainedModel(PreTrainedModel):
|
| 863 |
+
# Link this model family to its configuration class so PreTrainedModel.from_pretrained
|
| 864 |
+
# can load the config instead of failing with a NoneType error.
|
| 865 |
+
config_class = Step3p7TextConfig
|
| 866 |
+
supports_gradient_checkpointing = True
|
| 867 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 868 |
+
_keys_to_ignore_on_load_unexpected = [
|
| 869 |
+
r"model\.layers\.45\.*",
|
| 870 |
+
r"model\.layers\.46\.*",
|
| 871 |
+
r"model\.layers\.47\.*",
|
| 872 |
+
]
|
| 873 |
+
_supports_flash_attn = False
|
| 874 |
+
_supports_sdpa = True
|
| 875 |
+
_supports_flex_attn = True
|
| 876 |
+
_supports_static_cache = True
|
| 877 |
+
_supports_attention_backend = True
|
| 878 |
+
|
| 879 |
+
|
| 880 |
+
class Step3p7TextModel(Step3p7TextPreTrainedModel, GenerationMixin):
|
| 881 |
+
_no_split_modules = ["Step3p7DecoderLayer"]
|
| 882 |
+
base_model_prefix = "model"
|
| 883 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 884 |
+
config: Step3p7TextConfig
|
| 885 |
+
|
| 886 |
+
def __init__(self, config: Step3p7TextConfig):
|
| 887 |
+
super().__init__(config)
|
| 888 |
+
self.padding_idx = config.pad_token_id
|
| 889 |
+
self.vocab_size = config.vocab_size
|
| 890 |
+
|
| 891 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
|
| 892 |
+
self.padding_idx)
|
| 893 |
+
self.layers = nn.ModuleList([
|
| 894 |
+
Step3p7DecoderLayer(config, layer_idx)
|
| 895 |
+
for layer_idx in range(config.num_hidden_layers)
|
| 896 |
+
])
|
| 897 |
+
self.norm = Step3p7RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 898 |
+
self.gradient_checkpointing = False
|
| 899 |
+
layer_types = self.config.layer_types or []
|
| 900 |
+
self.has_sliding_layers = (not layer_types or
|
| 901 |
+
"sliding_attention" in layer_types)
|
| 902 |
+
|
| 903 |
+
# Initialize weights and apply final processing
|
| 904 |
+
self.post_init()
|
| 905 |
+
|
| 906 |
+
|
| 907 |
+
def get_input_embeddings(self, input_ids):
|
| 908 |
+
return self.embed_tokens(input_ids)
|
| 909 |
+
|
| 910 |
+
@can_return_tuple
|
| 911 |
+
def forward(
|
| 912 |
+
self,
|
| 913 |
+
input_ids: torch.LongTensor = None,
|
| 914 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 915 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 916 |
+
past_key_values: Optional[Cache] = None,
|
| 917 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 918 |
+
use_cache: Optional[bool] = None,
|
| 919 |
+
output_attentions: Optional[bool] = None,
|
| 920 |
+
output_hidden_states: Optional[bool] = None,
|
| 921 |
+
return_dict: Optional[bool] = None,
|
| 922 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 923 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 924 |
+
) -> Union[tuple, BaseModelOutputWithPast]:
|
| 925 |
+
output_attentions = (
|
| 926 |
+
output_attentions
|
| 927 |
+
if output_attentions is not None
|
| 928 |
+
else self.config.output_attentions
|
| 929 |
+
)
|
| 930 |
+
output_hidden_states = (
|
| 931 |
+
output_hidden_states
|
| 932 |
+
if output_hidden_states is not None
|
| 933 |
+
else self.config.output_hidden_states
|
| 934 |
+
)
|
| 935 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 936 |
+
return_dict = (
|
| 937 |
+
return_dict
|
| 938 |
+
if return_dict is not None
|
| 939 |
+
else getattr(self.config, "return_dict", True)
|
| 940 |
+
)
|
| 941 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 942 |
+
raise ValueError(
|
| 943 |
+
"You must specify exactly one of input_ids or inputs_embeds")
|
| 944 |
+
|
| 945 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
| 946 |
+
logger.warning_once(
|
| 947 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
| 948 |
+
)
|
| 949 |
+
use_cache = False
|
| 950 |
+
|
| 951 |
+
if inputs_embeds is None:
|
| 952 |
+
inputs_embeds = self.embed_tokens(
|
| 953 |
+
input_ids.to(self.embed_tokens.weight.device))
|
| 954 |
+
|
| 955 |
+
if use_cache and past_key_values is None:
|
| 956 |
+
past_key_values = DynamicCache()
|
| 957 |
+
|
| 958 |
+
if cache_position is None:
|
| 959 |
+
past_seen_tokens = past_key_values.get_seq_length(
|
| 960 |
+
) if past_key_values is not None else 0
|
| 961 |
+
cache_position = torch.arange(past_seen_tokens,
|
| 962 |
+
past_seen_tokens +
|
| 963 |
+
inputs_embeds.shape[1],
|
| 964 |
+
device=inputs_embeds.device)
|
| 965 |
+
|
| 966 |
+
if position_ids is None:
|
| 967 |
+
position_ids = cache_position.unsqueeze(0)
|
| 968 |
+
|
| 969 |
+
hidden_states = inputs_embeds
|
| 970 |
+
|
| 971 |
+
# It may already have been prepared by e.g. `generate`
|
| 972 |
+
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
| 973 |
+
# Prepare mask arguments
|
| 974 |
+
mask_kwargs = {
|
| 975 |
+
"config": self.config,
|
| 976 |
+
"attention_mask": attention_mask,
|
| 977 |
+
"past_key_values": past_key_values,
|
| 978 |
+
"position_ids": position_ids,
|
| 979 |
+
}
|
| 980 |
+
mask_kwargs[_MASK_INPUT_EMBEDS_ARG] = inputs_embeds
|
| 981 |
+
# Create the masks
|
| 982 |
+
causal_mask_mapping = {
|
| 983 |
+
"full_attention": create_causal_mask(**mask_kwargs),
|
| 984 |
+
}
|
| 985 |
+
|
| 986 |
+
# The sliding window alternating layers are not always activated depending on the config
|
| 987 |
+
if self.has_sliding_layers:
|
| 988 |
+
causal_mask_mapping[
|
| 989 |
+
"sliding_attention"] = create_sliding_window_causal_mask(
|
| 990 |
+
**mask_kwargs)
|
| 991 |
+
|
| 992 |
+
# # create position embeddings to be shared across the decoder layers
|
| 993 |
+
# decoder layers
|
| 994 |
+
all_hidden_states = () if output_hidden_states else None
|
| 995 |
+
all_self_attns = () if output_attentions else None
|
| 996 |
+
for decoder_layer in self.layers[:self.config.num_hidden_layers]:
|
| 997 |
+
if output_hidden_states:
|
| 998 |
+
all_hidden_states += (hidden_states, )
|
| 999 |
+
|
| 1000 |
+
layer_outputs = decoder_layer(
|
| 1001 |
+
hidden_states,
|
| 1002 |
+
attention_mask=causal_mask_mapping[
|
| 1003 |
+
decoder_layer.attention_type],
|
| 1004 |
+
position_ids=position_ids,
|
| 1005 |
+
past_key_value=past_key_values,
|
| 1006 |
+
output_attentions=output_attentions,
|
| 1007 |
+
use_cache=use_cache,
|
| 1008 |
+
cache_position=cache_position,
|
| 1009 |
+
**kwargs,
|
| 1010 |
+
)
|
| 1011 |
+
|
| 1012 |
+
hidden_states = layer_outputs
|
| 1013 |
+
|
| 1014 |
+
hidden_states = self.norm(hidden_states)
|
| 1015 |
+
|
| 1016 |
+
return BaseModelOutputWithPast(
|
| 1017 |
+
last_hidden_state=hidden_states,
|
| 1018 |
+
past_key_values=past_key_values if use_cache else None,
|
| 1019 |
+
hidden_states=all_hidden_states,
|
| 1020 |
+
attentions=all_self_attns,
|
| 1021 |
+
)
|
| 1022 |
+
|
| 1023 |
+
|
| 1024 |
+
class Step3p7Model(Step3p7PreTrainedModel, GenerationMixin):
|
| 1025 |
+
config: Step3p7Config
|
| 1026 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 1027 |
+
base_model_prefix = ""
|
| 1028 |
+
|
| 1029 |
+
def __init__(self, config: Step3p7Config):
|
| 1030 |
+
super().__init__(config)
|
| 1031 |
+
self.vision_model = StepRoboticsVisionEncoder(config.vision_config)
|
| 1032 |
+
self.language_model = Step3p7TextModel(config.text_config)
|
| 1033 |
+
self.vocab_size = config.text_config.vocab_size
|
| 1034 |
+
self.vit_large_projector = nn.Linear(
|
| 1035 |
+
config.vision_config.width * 4,
|
| 1036 |
+
config.text_config.hidden_size,
|
| 1037 |
+
bias=config.projector_bias)
|
| 1038 |
+
self.image_placeholder_token_id = config.image_token_id
|
| 1039 |
+
|
| 1040 |
+
# Initialize weights and apply final processing
|
| 1041 |
+
self.post_init()
|
| 1042 |
+
|
| 1043 |
+
def get_input_embeddings(
|
| 1044 |
+
self,
|
| 1045 |
+
input_ids: torch.Tensor,
|
| 1046 |
+
multimodal_embeddings = None,
|
| 1047 |
+
) -> torch.Tensor:
|
| 1048 |
+
# breakpoint()
|
| 1049 |
+
input_ids = input_ids.squeeze(0)
|
| 1050 |
+
if multimodal_embeddings is None:
|
| 1051 |
+
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
| 1052 |
+
else:
|
| 1053 |
+
is_text = input_ids != self.config.image_token_id
|
| 1054 |
+
text_ids = input_ids[is_text]
|
| 1055 |
+
text_embeds = self.language_model.get_input_embeddings(text_ids)
|
| 1056 |
+
|
| 1057 |
+
inputs_embeds = torch.empty(input_ids.shape[0],
|
| 1058 |
+
text_embeds.shape[-1],
|
| 1059 |
+
dtype=text_embeds.dtype,
|
| 1060 |
+
device=text_embeds.device)
|
| 1061 |
+
inputs_embeds[is_text] = text_embeds
|
| 1062 |
+
inputs_embeds = merge_multimodal_embeddings(
|
| 1063 |
+
input_ids, inputs_embeds, multimodal_embeddings,
|
| 1064 |
+
self.config.image_token_id)
|
| 1065 |
+
inputs_embeds = inputs_embeds.unsqueeze(0)
|
| 1066 |
+
return inputs_embeds
|
| 1067 |
+
|
| 1068 |
+
|
| 1069 |
+
def set_input_embeddings(self, value):
|
| 1070 |
+
return self.language_model.set_input_embeddings(value)
|
| 1071 |
+
|
| 1072 |
+
def set_decoder(self, decoder):
|
| 1073 |
+
self.language_model = decoder
|
| 1074 |
+
|
| 1075 |
+
def get_decoder(self):
|
| 1076 |
+
return self.language_model
|
| 1077 |
+
|
| 1078 |
+
def _parse_and_validate_image_input(
|
| 1079 |
+
self, **kwargs: object) -> Optional[StepVLImageInputs]:
|
| 1080 |
+
pixel_values = kwargs.pop("pixel_values", None)
|
| 1081 |
+
patch_pixel_values = kwargs.pop("patch_pixel_values", None)
|
| 1082 |
+
num_patches = kwargs.pop("num_patches", None)
|
| 1083 |
+
image_embeds = kwargs.pop("image_embeds", None)
|
| 1084 |
+
|
| 1085 |
+
if pixel_values is None and image_embeds is None:
|
| 1086 |
+
return None
|
| 1087 |
+
|
| 1088 |
+
if pixel_values is not None:
|
| 1089 |
+
# pixel_values = flatten_bn(pixel_values, concat=True)
|
| 1090 |
+
if pixel_values.dim() >= 3:
|
| 1091 |
+
pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:])
|
| 1092 |
+
if patch_pixel_values is not None:
|
| 1093 |
+
# patch_pixel_values = flatten_bn(patch_pixel_values,
|
| 1094 |
+
# concat=True)
|
| 1095 |
+
patch_pixel_values = patch_pixel_values.view(
|
| 1096 |
+
-1, *patch_pixel_values.shape[-3:])
|
| 1097 |
+
# Handle empty patch_pixel_values by setting to None
|
| 1098 |
+
if patch_pixel_values.shape[0] == 0:
|
| 1099 |
+
patch_pixel_values = None
|
| 1100 |
+
|
| 1101 |
+
return StepVLImagePixelInputs(
|
| 1102 |
+
type="pixel_values",
|
| 1103 |
+
pixel_values=pixel_values.to(self.dtype).to(self.device),
|
| 1104 |
+
patch_pixel_values=patch_pixel_values.to(self.dtype).to(
|
| 1105 |
+
self.device) if patch_pixel_values is not None else None,
|
| 1106 |
+
num_patches=num_patches,
|
| 1107 |
+
)
|
| 1108 |
+
|
| 1109 |
+
if image_embeds is not None:
|
| 1110 |
+
if image_embeds.dim() == 2 or image_embeds.dim() >= 3:
|
| 1111 |
+
image_embeds = image_embeds.view(-1, image_embeds.shape[-1])
|
| 1112 |
+
else:
|
| 1113 |
+
raise ValueError(
|
| 1114 |
+
f"Unexpected shape for image_embeds: {image_embeds.shape}")
|
| 1115 |
+
|
| 1116 |
+
return StepVLImageEmbeddingInputs(
|
| 1117 |
+
type="image_embeds",
|
| 1118 |
+
image_embeds=image_embeds.to(self.dtype).to(self.device),
|
| 1119 |
+
)
|
| 1120 |
+
return None
|
| 1121 |
+
|
| 1122 |
+
def _process_image_features(self,
|
| 1123 |
+
image_features: torch.Tensor) -> torch.Tensor:
|
| 1124 |
+
B, P = image_features.shape[:2]
|
| 1125 |
+
HW = int(P ** 0.5)
|
| 1126 |
+
image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW)
|
| 1127 |
+
image_features = self.vision_model.vit_downsampler1(image_features)
|
| 1128 |
+
image_features = self.vision_model.vit_downsampler2(image_features)
|
| 1129 |
+
|
| 1130 |
+
B, C, HW, HW = image_features.shape
|
| 1131 |
+
image_features = image_features.view(B, -1, HW * HW).permute(0, 2, 1)
|
| 1132 |
+
image_features = self.vit_large_projector(image_features)
|
| 1133 |
+
return image_features
|
| 1134 |
+
|
| 1135 |
+
def _get_vision_model_output(self,
|
| 1136 |
+
input_tensor: torch.Tensor) -> torch.Tensor:
|
| 1137 |
+
return self.vision_model(input_tensor)
|
| 1138 |
+
|
| 1139 |
+
def _process_image_input(
|
| 1140 |
+
self, image_input: StepVLImageInputs) -> tuple[torch.Tensor, ...]:
|
| 1141 |
+
|
| 1142 |
+
if image_input["type"] == "image_embeds":
|
| 1143 |
+
image_features = image_input["image_embeds"]
|
| 1144 |
+
else:
|
| 1145 |
+
image_features = self._get_vision_model_output(
|
| 1146 |
+
image_input["pixel_values"])
|
| 1147 |
+
patch_image_features = self._get_vision_model_output(
|
| 1148 |
+
image_input["patch_pixel_values"]
|
| 1149 |
+
) if image_input["patch_pixel_values"] is not None else None
|
| 1150 |
+
num_patches = image_input["num_patches"]
|
| 1151 |
+
|
| 1152 |
+
image_features = self._process_image_features(image_features)
|
| 1153 |
+
patch_image_features = self._process_image_features(
|
| 1154 |
+
patch_image_features) if patch_image_features is not None else None
|
| 1155 |
+
|
| 1156 |
+
merged_image_features = []
|
| 1157 |
+
cur_patch_idx = 0
|
| 1158 |
+
for i, num_patch in enumerate(num_patches):
|
| 1159 |
+
cur_feature = []
|
| 1160 |
+
if num_patch > 0:
|
| 1161 |
+
patch_slice = patch_image_features[
|
| 1162 |
+
cur_patch_idx:cur_patch_idx + num_patch]
|
| 1163 |
+
cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1]))
|
| 1164 |
+
cur_feature.append(image_features[i].view(
|
| 1165 |
+
-1, image_features.shape[-1]))
|
| 1166 |
+
cur_patch_idx += num_patch
|
| 1167 |
+
merged_image_features.append(
|
| 1168 |
+
torch.cat(cur_feature) if len(cur_feature) >
|
| 1169 |
+
1 else cur_feature[0])
|
| 1170 |
+
|
| 1171 |
+
return merged_image_features
|
| 1172 |
+
|
| 1173 |
+
def get_multimodal_embeddings(self, **kwargs):
|
| 1174 |
+
# breakpoint()
|
| 1175 |
+
image_input = self._parse_and_validate_image_input(**kwargs)
|
| 1176 |
+
if image_input is None:
|
| 1177 |
+
return None
|
| 1178 |
+
vision_embeddings = self._process_image_input(image_input)
|
| 1179 |
+
return vision_embeddings
|
| 1180 |
+
|
| 1181 |
+
@can_return_tuple
|
| 1182 |
+
def forward(
|
| 1183 |
+
self,
|
| 1184 |
+
input_ids: torch.LongTensor = None,
|
| 1185 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1186 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1187 |
+
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
| 1188 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1189 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1190 |
+
use_cache: Optional[bool] = None,
|
| 1191 |
+
output_attentions: Optional[bool] = None,
|
| 1192 |
+
output_hidden_states: Optional[bool] = None,
|
| 1193 |
+
return_dict: Optional[bool] = None,
|
| 1194 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 1195 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 1196 |
+
images: Optional[list[Image.Image]] = None,
|
| 1197 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 1198 |
+
) -> Union[tuple, Step3p7CausalLMOutputWithPast]:
|
| 1199 |
+
r"""
|
| 1200 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1201 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 1202 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 1203 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 1204 |
+
Example:
|
| 1205 |
+
```python
|
| 1206 |
+
>>> from transformers import AutoTokenizer, Llama4ForCausalLM
|
| 1207 |
+
>>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
|
| 1208 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
|
| 1209 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 1210 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 1211 |
+
>>> # Generate
|
| 1212 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 1213 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 1214 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 1215 |
+
```"""
|
| 1216 |
+
output_attentions = (
|
| 1217 |
+
output_attentions
|
| 1218 |
+
if output_attentions is not None
|
| 1219 |
+
else self.config.output_attentions
|
| 1220 |
+
)
|
| 1221 |
+
output_hidden_states = (
|
| 1222 |
+
output_hidden_states
|
| 1223 |
+
if output_hidden_states is not None
|
| 1224 |
+
else self.config.output_hidden_states
|
| 1225 |
+
)
|
| 1226 |
+
return_dict = (
|
| 1227 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 1228 |
+
)
|
| 1229 |
+
|
| 1230 |
+
if inputs_embeds is None:
|
| 1231 |
+
input_ids = input_ids
|
| 1232 |
+
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
| 1233 |
+
inputs_embeds = self.get_input_embeddings(input_ids,
|
| 1234 |
+
vision_embeddings)
|
| 1235 |
+
input_ids = None
|
| 1236 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 1237 |
+
outputs = self.language_model(
|
| 1238 |
+
input_ids=None,
|
| 1239 |
+
position_ids=position_ids,
|
| 1240 |
+
attention_mask=attention_mask,
|
| 1241 |
+
past_key_values=past_key_values,
|
| 1242 |
+
inputs_embeds=inputs_embeds,
|
| 1243 |
+
use_cache=use_cache,
|
| 1244 |
+
output_attentions=output_attentions,
|
| 1245 |
+
output_hidden_states=output_hidden_states,
|
| 1246 |
+
return_dict=True,
|
| 1247 |
+
cache_position=cache_position,
|
| 1248 |
+
**kwargs,
|
| 1249 |
+
)
|
| 1250 |
+
|
| 1251 |
+
output = Step3p7CausalLMOutputWithPast(
|
| 1252 |
+
last_hidden_state=outputs.last_hidden_state,
|
| 1253 |
+
past_key_values=outputs.past_key_values,
|
| 1254 |
+
attentions=outputs.attentions,
|
| 1255 |
+
)
|
| 1256 |
+
return output if return_dict else output.to_tuple()
|
| 1257 |
+
|
| 1258 |
+
|
| 1259 |
+
class Step3p7ForConditionalGeneration(Step3p7PreTrainedModel, GenerationMixin):
|
| 1260 |
+
_checkpoint_conversion_mapping = {
|
| 1261 |
+
"^vision_model": "model.vision_model",
|
| 1262 |
+
r"^model(?!\.(language_model|vision_model))": "model.language_model",
|
| 1263 |
+
"^vit_large_projector": "model.vit_large_projector",
|
| 1264 |
+
}
|
| 1265 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 1266 |
+
config: Step3p7Config
|
| 1267 |
+
|
| 1268 |
+
def __init__(self, config: Step3p7Config):
|
| 1269 |
+
super().__init__(config)
|
| 1270 |
+
self.model = Step3p7Model(config)
|
| 1271 |
+
self.lm_head = nn.Linear(config.hidden_size,
|
| 1272 |
+
config.text_config.vocab_size,
|
| 1273 |
+
bias=False)
|
| 1274 |
+
|
| 1275 |
+
self.post_init()
|
| 1276 |
+
|
| 1277 |
+
def get_input_embeddings(self):
|
| 1278 |
+
return self.model.get_input_embeddings()
|
| 1279 |
+
|
| 1280 |
+
def set_input_embeddings(self, value):
|
| 1281 |
+
self.model.set_input_embeddings(value)
|
| 1282 |
+
|
| 1283 |
+
def get_output_embeddings(self):
|
| 1284 |
+
return self.model.get_output_embeddings()
|
| 1285 |
+
|
| 1286 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1287 |
+
self.model.set_output_embeddings(new_embeddings)
|
| 1288 |
+
|
| 1289 |
+
def set_decoder(self, decoder):
|
| 1290 |
+
self.model.set_decoder(decoder)
|
| 1291 |
+
|
| 1292 |
+
def get_decoder(self):
|
| 1293 |
+
return self.model.get_decoder()
|
| 1294 |
+
|
| 1295 |
+
@property
|
| 1296 |
+
def language_model(self):
|
| 1297 |
+
return self.model.language_model
|
| 1298 |
+
|
| 1299 |
+
@property
|
| 1300 |
+
def visual(self):
|
| 1301 |
+
return self.model.vision_model
|
| 1302 |
+
|
| 1303 |
+
def forward(
|
| 1304 |
+
self,
|
| 1305 |
+
input_ids: torch.LongTensor = None,
|
| 1306 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 1307 |
+
num_patches=None,
|
| 1308 |
+
patch_pixel_values=None,
|
| 1309 |
+
patch_newline_mask=None,
|
| 1310 |
+
image_embeds: Optional[torch.FloatTensor] = None,
|
| 1311 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1312 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1313 |
+
past_key_values: Optional[Cache] = None,
|
| 1314 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1315 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1316 |
+
use_cache: Optional[bool] = None,
|
| 1317 |
+
output_attentions: Optional[bool] = None,
|
| 1318 |
+
output_hidden_states: Optional[bool] = None,
|
| 1319 |
+
return_dict: Optional[bool] = None,
|
| 1320 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 1321 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 1322 |
+
) -> Union[tuple, Step3p7CausalLMOutputWithPast]:
|
| 1323 |
+
output_attentions = (
|
| 1324 |
+
output_attentions
|
| 1325 |
+
if output_attentions is not None
|
| 1326 |
+
else self.config.output_attentions
|
| 1327 |
+
)
|
| 1328 |
+
output_hidden_states = (
|
| 1329 |
+
output_hidden_states
|
| 1330 |
+
if output_hidden_states is not None
|
| 1331 |
+
else self.config.output_hidden_states
|
| 1332 |
+
)
|
| 1333 |
+
|
| 1334 |
+
outputs = self.model(
|
| 1335 |
+
input_ids=input_ids,
|
| 1336 |
+
num_patches=num_patches,
|
| 1337 |
+
patch_pixel_values=patch_pixel_values,
|
| 1338 |
+
patch_newline_mask=patch_newline_mask,
|
| 1339 |
+
position_ids=position_ids,
|
| 1340 |
+
attention_mask=attention_mask,
|
| 1341 |
+
past_key_values=past_key_values,
|
| 1342 |
+
inputs_embeds=inputs_embeds,
|
| 1343 |
+
use_cache=use_cache,
|
| 1344 |
+
output_attentions=output_attentions,
|
| 1345 |
+
output_hidden_states=output_hidden_states,
|
| 1346 |
+
return_dict=return_dict,
|
| 1347 |
+
cache_position=cache_position,
|
| 1348 |
+
**kwargs,
|
| 1349 |
+
)
|
| 1350 |
+
|
| 1351 |
+
hidden_states = outputs.last_hidden_state
|
| 1352 |
+
logits = self.lm_head(hidden_states)
|
| 1353 |
+
|
| 1354 |
+
los = None
|
| 1355 |
+
if labels is not None:
|
| 1356 |
+
loss = self.loss_function(
|
| 1357 |
+
logits=logits, labels=labels, vocab_size=self.config.vocab_size
|
| 1358 |
+
)
|
| 1359 |
+
|
| 1360 |
+
return Step3p7CausalLMOutputWithPast(
|
| 1361 |
+
logits=logits,
|
| 1362 |
+
)
|
| 1363 |
+
|
| 1364 |
+
|
| 1365 |
+
def prepare_inputs_for_generation(
|
| 1366 |
+
self,
|
| 1367 |
+
input_ids,
|
| 1368 |
+
past_key_values=None,
|
| 1369 |
+
inputs_embeds=None,
|
| 1370 |
+
pixel_values=None,
|
| 1371 |
+
patch_pixel_values=None,
|
| 1372 |
+
num_patches=None,
|
| 1373 |
+
image_embeds=None,
|
| 1374 |
+
attention_mask=None,
|
| 1375 |
+
cache_position=None,
|
| 1376 |
+
logits_to_keep=None,
|
| 1377 |
+
**kwargs,
|
| 1378 |
+
):
|
| 1379 |
+
model_inputs = super().prepare_inputs_for_generation(
|
| 1380 |
+
input_ids,
|
| 1381 |
+
past_key_values=past_key_values,
|
| 1382 |
+
inputs_embeds=inputs_embeds,
|
| 1383 |
+
attention_mask=attention_mask,
|
| 1384 |
+
cache_position=cache_position,
|
| 1385 |
+
logits_to_keep=logits_to_keep,
|
| 1386 |
+
**kwargs,
|
| 1387 |
+
)
|
| 1388 |
+
|
| 1389 |
+
generation_cache_position = model_inputs.get("cache_position", cache_position)
|
| 1390 |
+
is_prefill = past_key_values is None
|
| 1391 |
+
if generation_cache_position is not None and generation_cache_position.numel() > 0:
|
| 1392 |
+
is_prefill = generation_cache_position[0].item() == 0
|
| 1393 |
+
|
| 1394 |
+
if is_prefill:
|
| 1395 |
+
# During cached decoding, input ids no longer contain image tokens,
|
| 1396 |
+
# so pixel values should only be passed at the first step.
|
| 1397 |
+
model_inputs["pixel_values"] = pixel_values
|
| 1398 |
+
|
| 1399 |
+
return model_inputs
|
| 1400 |
+
|
| 1401 |
+
def _fix_state_dict_key_on_load(self, key: str) -> tuple[str, bool]:
|
| 1402 |
+
if key.startswith("language_model."):
|
| 1403 |
+
return key[len("language_model.") :], True
|
| 1404 |
+
|
| 1405 |
+
return key, False
|
processing_step3.py
ADDED
|
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import BaseImageProcessor, ImageProcessingMixin
|
| 2 |
+
from transformers.processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
|
| 3 |
+
import math
|
| 4 |
+
from typing import Iterable, Optional, Tuple, List, TypedDict, Literal, Union, overload
|
| 5 |
+
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torchvision
|
| 10 |
+
from torch import nn
|
| 11 |
+
from torch.nn import functional as F, LayerNorm
|
| 12 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 13 |
+
from transformers.activations import ACT2FN
|
| 14 |
+
from torchvision import transforms
|
| 15 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 16 |
+
from transformers.feature_extraction_utils import BatchFeature, TensorType
|
| 17 |
+
from transformers.image_utils import ImageInput
|
| 18 |
+
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
| 19 |
+
from transformers.tokenization_utils_tokenizers import TokenizersBackend
|
| 20 |
+
from math import ceil
|
| 21 |
+
from itertools import product
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
MAX_IMAGE_SIZE: int = 3024
|
| 26 |
+
|
| 27 |
+
class Step3VLImagePixelInputs(TypedDict):
|
| 28 |
+
type: Literal["pixel_values"]
|
| 29 |
+
pixel_values: torch.Tensor
|
| 30 |
+
patch_pixel_values: Optional[torch.Tensor]
|
| 31 |
+
num_patches: list[int]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Step3VLImageEmbeddingInputs(TypedDict):
|
| 35 |
+
type: Literal["image_embeds"]
|
| 36 |
+
image_embeds: torch.Tensor
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class GPUToTensor(torch.nn.Module):
|
| 43 |
+
|
| 44 |
+
def forward(self, raw_image: Union[np.ndarray,
|
| 45 |
+
Image.Image]) -> torch.Tensor:
|
| 46 |
+
if isinstance(raw_image, Image.Image):
|
| 47 |
+
return transforms.ToTensor()(raw_image)
|
| 48 |
+
if raw_image.ndim == 2:
|
| 49 |
+
raw_image = raw_image[:, :, None].repeat(3, -1)
|
| 50 |
+
if torch.cuda.is_available():
|
| 51 |
+
device = torch.device("cuda")
|
| 52 |
+
else:
|
| 53 |
+
device = torch.device("cpu")
|
| 54 |
+
image_tensor = torch.from_numpy(raw_image).to(device)
|
| 55 |
+
image_tensor = torch.permute(image_tensor, (2, 0, 1)).contiguous()
|
| 56 |
+
if image_tensor.dtype == torch.uint8:
|
| 57 |
+
image_tensor = image_tensor.to(torch.float32).div(255)
|
| 58 |
+
return image_tensor
|
| 59 |
+
|
| 60 |
+
class Step3VisionProcessor(BaseImageProcessor):
|
| 61 |
+
|
| 62 |
+
def __init__(self, size, interpolation_mode="bicubic", patch_size=None):
|
| 63 |
+
mean = [0.48145466, 0.4578275, 0.40821073]
|
| 64 |
+
std = [0.26862954, 0.26130258, 0.27577711]
|
| 65 |
+
patch_size = patch_size if patch_size is not None else size
|
| 66 |
+
|
| 67 |
+
self.transform = transforms.Compose([
|
| 68 |
+
GPUToTensor(),
|
| 69 |
+
transforms.Normalize(mean, std),
|
| 70 |
+
transforms.Resize(
|
| 71 |
+
(size, size),
|
| 72 |
+
interpolation=InterpolationMode.BICUBIC if interpolation_mode
|
| 73 |
+
== "bicubic" else InterpolationMode.BILINEAR,
|
| 74 |
+
antialias=True),
|
| 75 |
+
])
|
| 76 |
+
|
| 77 |
+
self.patch_transform = transforms.Compose([
|
| 78 |
+
GPUToTensor(),
|
| 79 |
+
transforms.Normalize(mean, std),
|
| 80 |
+
transforms.Resize(
|
| 81 |
+
(patch_size, patch_size),
|
| 82 |
+
interpolation=InterpolationMode.BICUBIC if interpolation_mode
|
| 83 |
+
== "bicubic" else InterpolationMode.BILINEAR,
|
| 84 |
+
antialias=True),
|
| 85 |
+
]) if patch_size is not None else None
|
| 86 |
+
|
| 87 |
+
def __call__(self, image, is_patch=False):
|
| 88 |
+
if is_patch:
|
| 89 |
+
return {"pixel_values": self.patch_transform(image).unsqueeze(0)}
|
| 90 |
+
else:
|
| 91 |
+
return {"pixel_values": self.transform(image).unsqueeze(0)}
|
| 92 |
+
|
| 93 |
+
class ImagePatcher:
|
| 94 |
+
def determine_window_size(self, long: int, short: int) -> int:
|
| 95 |
+
if long <= 728:
|
| 96 |
+
return short if long / short > 1.5 else 0
|
| 97 |
+
return min(short, 504) if long / short > 4 else 504
|
| 98 |
+
def slide_window(
|
| 99 |
+
self,
|
| 100 |
+
width: int,
|
| 101 |
+
height: int,
|
| 102 |
+
sizes: list[tuple[int, int]],
|
| 103 |
+
steps: list[tuple[int, int]],
|
| 104 |
+
img_rate_thr: float = 0.6,
|
| 105 |
+
) -> tuple[list[tuple[int, int, int, int]], tuple[int, int]]:
|
| 106 |
+
assert 1 >= img_rate_thr >= 0, "The `in_rate_thr` should lie in 0~1"
|
| 107 |
+
windows = []
|
| 108 |
+
# Sliding windows.
|
| 109 |
+
for size, step in zip(sizes, steps):
|
| 110 |
+
size_w, size_h = size
|
| 111 |
+
step_w, step_h = step
|
| 112 |
+
|
| 113 |
+
x_num = 1 if width <= size_w else ceil((width - size_w) / step_w +
|
| 114 |
+
1)
|
| 115 |
+
x_start = [step_w * i for i in range(x_num)]
|
| 116 |
+
if len(x_start) > 1 and x_start[-1] + size_w > width:
|
| 117 |
+
x_start[-1] = width - size_w
|
| 118 |
+
|
| 119 |
+
y_num = 1 if height <= size_h else ceil((height - size_h) /
|
| 120 |
+
step_h + 1)
|
| 121 |
+
y_start = [step_h * i for i in range(y_num)]
|
| 122 |
+
if len(y_start) > 1 and y_start[-1] + size_h > height:
|
| 123 |
+
y_start[-1] = height - size_h
|
| 124 |
+
|
| 125 |
+
start = np.array(list(product(y_start, x_start)), dtype=int)
|
| 126 |
+
start[:, [0, 1]] = start[:, [1, 0]]
|
| 127 |
+
windows.append(np.concatenate([start, start + size], axis=1))
|
| 128 |
+
windows = np.concatenate(windows, axis=0)
|
| 129 |
+
|
| 130 |
+
return [(int(box[0]), int(box[1]), int(box[2] - box[0]),
|
| 131 |
+
int(box[3] - box[1])) for box in windows], (x_num, y_num)
|
| 132 |
+
|
| 133 |
+
def square_pad(self, img: Image.Image) -> Image.Image:
|
| 134 |
+
w, h = img.size
|
| 135 |
+
if w == h:
|
| 136 |
+
return img
|
| 137 |
+
size = max(w, h)
|
| 138 |
+
padded = Image.new(img.mode, (size, size), 0)
|
| 139 |
+
padded.paste(img, (0, 0))
|
| 140 |
+
return padded
|
| 141 |
+
|
| 142 |
+
def get_image_size_for_padding(self, img_width: int,
|
| 143 |
+
img_height: int) -> tuple[int, int]:
|
| 144 |
+
ratio = img_width / img_height
|
| 145 |
+
if min(img_height, img_width) < 32 and (ratio > 4 or ratio < 1 / 4):
|
| 146 |
+
new_size = max(img_height, img_width)
|
| 147 |
+
return new_size, new_size
|
| 148 |
+
return img_width, img_height
|
| 149 |
+
|
| 150 |
+
def get_image_size_for_preprocess(self, img_width: int,
|
| 151 |
+
img_height: int) -> tuple[int, int]:
|
| 152 |
+
|
| 153 |
+
if max(img_height, img_width) > MAX_IMAGE_SIZE:
|
| 154 |
+
scale_factor = MAX_IMAGE_SIZE / max(img_height, img_width)
|
| 155 |
+
img_width = int(img_width * scale_factor)
|
| 156 |
+
img_height = int(img_height * scale_factor)
|
| 157 |
+
return img_width, img_height
|
| 158 |
+
|
| 159 |
+
def get_image_size_for_crop(self, img_width: int, img_height: int,
|
| 160 |
+
window_size: int):
|
| 161 |
+
w_ratio = img_width / window_size
|
| 162 |
+
h_ratio = img_height / window_size
|
| 163 |
+
|
| 164 |
+
if w_ratio < 1:
|
| 165 |
+
width_new = img_width
|
| 166 |
+
else:
|
| 167 |
+
decimal_w = w_ratio - img_width // window_size
|
| 168 |
+
w_ratio = int(w_ratio) + 1 if decimal_w > 0.2 else int(w_ratio)
|
| 169 |
+
width_new = window_size * w_ratio
|
| 170 |
+
if h_ratio < 1:
|
| 171 |
+
height_new = img_height
|
| 172 |
+
else:
|
| 173 |
+
decimal_h = h_ratio - img_height // window_size
|
| 174 |
+
h_ratio = int(h_ratio) + 1 if decimal_h > 0.2 else int(h_ratio)
|
| 175 |
+
height_new = window_size * h_ratio
|
| 176 |
+
return int(width_new), int(height_new)
|
| 177 |
+
|
| 178 |
+
def patch_crop(self, img: Image.Image, i: int, j: int, th: int, tw: int):
|
| 179 |
+
target = img.crop((j, i, j + tw, i + th))
|
| 180 |
+
return target
|
| 181 |
+
|
| 182 |
+
def get_num_patches(self, img_width: int,
|
| 183 |
+
img_height: int) -> tuple[int, int]:
|
| 184 |
+
img_width, img_height = self.get_image_size_for_padding(
|
| 185 |
+
img_width, img_height)
|
| 186 |
+
img_width, img_height = self.get_image_size_for_preprocess(
|
| 187 |
+
img_width, img_height)
|
| 188 |
+
window_size = self.determine_window_size(max(img_height, img_width),
|
| 189 |
+
min(img_height, img_width))
|
| 190 |
+
if window_size == 0:
|
| 191 |
+
return 0, 0
|
| 192 |
+
else:
|
| 193 |
+
img_width, img_height = self.get_image_size_for_crop(
|
| 194 |
+
img_width, img_height, window_size)
|
| 195 |
+
center_list, (x_num, y_num) = self.slide_window(
|
| 196 |
+
img_width, img_height, [(window_size, window_size)],
|
| 197 |
+
[(window_size, window_size)])
|
| 198 |
+
full_rows = (len(center_list) - 1) // x_num + 1
|
| 199 |
+
if len(center_list) > 0 and len(center_list) % x_num == 0:
|
| 200 |
+
full_rows -= 1
|
| 201 |
+
return len(center_list), full_rows
|
| 202 |
+
|
| 203 |
+
def __call__(
|
| 204 |
+
self, img: Image.Image
|
| 205 |
+
) -> tuple[Image.Image, list[Image.Image], list[bool] | None]:
|
| 206 |
+
img_width, img_height = img.size
|
| 207 |
+
new_img_width, new_img_height = self.get_image_size_for_padding(
|
| 208 |
+
img_width, img_height)
|
| 209 |
+
if new_img_width != img_width or new_img_height != img_height:
|
| 210 |
+
img = self.square_pad(img)
|
| 211 |
+
img_width, img_height = img.size
|
| 212 |
+
|
| 213 |
+
new_img_width, new_img_height = self.get_image_size_for_preprocess(
|
| 214 |
+
img_width, img_height)
|
| 215 |
+
img = img.resize((new_img_width, new_img_height),
|
| 216 |
+
Image.Resampling.BILINEAR)
|
| 217 |
+
window_size = self.determine_window_size(
|
| 218 |
+
max(new_img_height, new_img_width),
|
| 219 |
+
min(new_img_height, new_img_width))
|
| 220 |
+
# return img, [], None
|
| 221 |
+
if window_size == 0:
|
| 222 |
+
return img, [], None
|
| 223 |
+
else:
|
| 224 |
+
new_img_width, new_img_height = self.get_image_size_for_crop(
|
| 225 |
+
new_img_width, new_img_height, window_size)
|
| 226 |
+
if (new_img_width, new_img_height) != (img_width, img_height):
|
| 227 |
+
img_for_crop = img.resize((new_img_width, new_img_height),
|
| 228 |
+
Image.Resampling.BILINEAR)
|
| 229 |
+
else:
|
| 230 |
+
img_for_crop = img
|
| 231 |
+
|
| 232 |
+
patches = []
|
| 233 |
+
newlines = []
|
| 234 |
+
center_list, (x_num, y_num) = self.slide_window(
|
| 235 |
+
new_img_width, new_img_height, [(window_size, window_size)],
|
| 236 |
+
[(window_size, window_size)])
|
| 237 |
+
for patch_id, center_lf_point in enumerate(center_list):
|
| 238 |
+
x, y, patch_w, patch_h = center_lf_point
|
| 239 |
+
big_patch = self.patch_crop(img_for_crop, y, x, patch_h,
|
| 240 |
+
patch_w)
|
| 241 |
+
patches.append(big_patch)
|
| 242 |
+
if (patch_id + 1) % x_num == 0:
|
| 243 |
+
newlines.append(patch_id)
|
| 244 |
+
|
| 245 |
+
if newlines and newlines[-1] == len(patches) - 1:
|
| 246 |
+
newlines.pop()
|
| 247 |
+
|
| 248 |
+
return img, patches, [i in newlines for i in range(len(patches))] if len(patches) > 0 else None
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class Step3VLProcessor(ProcessorMixin):
|
| 254 |
+
# Align ProcessorMixin with our custom components.
|
| 255 |
+
# We only have an image processor (not a feature extractor) plus a tokenizer.
|
| 256 |
+
attributes = ["tokenizer"]
|
| 257 |
+
tokenizer_class = "AutoTokenizer"
|
| 258 |
+
|
| 259 |
+
@classmethod
|
| 260 |
+
def _load_tokenizer_from_pretrained(
|
| 261 |
+
cls, sub_processor_type, pretrained_model_name_or_path, subfolder="", **kwargs
|
| 262 |
+
):
|
| 263 |
+
return TokenizersBackend.from_pretrained(
|
| 264 |
+
pretrained_model_name_or_path,
|
| 265 |
+
subfolder=subfolder,
|
| 266 |
+
**kwargs,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
def __init__(
|
| 270 |
+
self,
|
| 271 |
+
tokenizer=None,
|
| 272 |
+
chat_template=None,
|
| 273 |
+
**kwargs
|
| 274 |
+
) -> None:
|
| 275 |
+
self.image_size = 728
|
| 276 |
+
self.patch_size = 504
|
| 277 |
+
|
| 278 |
+
self.image_preprocessor = Step3VisionProcessor(self.image_size,
|
| 279 |
+
"bilinear",
|
| 280 |
+
self.patch_size)
|
| 281 |
+
|
| 282 |
+
self.num_image_feature_size = 169
|
| 283 |
+
self.num_patch_feature_size = 81
|
| 284 |
+
self.image_token = "<im_patch>"
|
| 285 |
+
self.image_feature_placeholder = (self.image_token *
|
| 286 |
+
self.num_image_feature_size)
|
| 287 |
+
self.patch_feature_placeholder = (self.image_token *
|
| 288 |
+
self.num_patch_feature_size)
|
| 289 |
+
super().__init__(tokenizer=tokenizer, chat_template=chat_template, **kwargs)
|
| 290 |
+
self.patcher = ImagePatcher()
|
| 291 |
+
|
| 292 |
+
@property
|
| 293 |
+
def image_token_id(self) -> int:
|
| 294 |
+
return self.tokenizer.get_vocab()[self.image_token]
|
| 295 |
+
|
| 296 |
+
def get_num_image_tokens(self, img_width: int, img_height: int) -> int:
|
| 297 |
+
num_patches, num_newlines = self.patcher.get_num_patches(
|
| 298 |
+
img_width, img_height)
|
| 299 |
+
|
| 300 |
+
return num_patches * (
|
| 301 |
+
self.num_patch_feature_size +
|
| 302 |
+
2) + self.num_image_feature_size + 2 + num_newlines
|
| 303 |
+
|
| 304 |
+
def _split_images(self,
|
| 305 |
+
images: list[Image.Image]) -> list[ImageWithPatches]:
|
| 306 |
+
result = []
|
| 307 |
+
for img in images:
|
| 308 |
+
result.append(self.patcher(img))
|
| 309 |
+
return result
|
| 310 |
+
|
| 311 |
+
def _convert_images_to_pixel_values(
|
| 312 |
+
self,
|
| 313 |
+
images: list[Image.Image],
|
| 314 |
+
is_patch: bool = False,
|
| 315 |
+
) -> list[torch.Tensor]:
|
| 316 |
+
return [
|
| 317 |
+
self.image_preprocessor(img, is_patch=is_patch)["pixel_values"]
|
| 318 |
+
for img in images
|
| 319 |
+
]
|
| 320 |
+
|
| 321 |
+
def _get_patch_repl(
|
| 322 |
+
self,
|
| 323 |
+
num_patches: int,
|
| 324 |
+
patch_newline_mask: list[bool] | None,
|
| 325 |
+
) -> tuple[str, list[int]]:
|
| 326 |
+
text = ""
|
| 327 |
+
token_ids = []
|
| 328 |
+
for i in range(num_patches):
|
| 329 |
+
assert len(patch_newline_mask) == num_patches
|
| 330 |
+
text += f"<patch_start>{self.patch_feature_placeholder}<patch_end>"
|
| 331 |
+
token_ids.extend(
|
| 332 |
+
[self.tokenizer.convert_tokens_to_ids("<patch_start>")] +
|
| 333 |
+
[self.image_token_id] * self.num_patch_feature_size +
|
| 334 |
+
[self.tokenizer.convert_tokens_to_ids("<patch_end>")])
|
| 335 |
+
if patch_newline_mask and patch_newline_mask[i]:
|
| 336 |
+
text += "<patch_newline>"
|
| 337 |
+
token_ids.append(
|
| 338 |
+
self.tokenizer.convert_tokens_to_ids("<patch_newline>"))
|
| 339 |
+
return text, token_ids
|
| 340 |
+
|
| 341 |
+
def _get_image_repl(
|
| 342 |
+
self,
|
| 343 |
+
num_images: int,
|
| 344 |
+
) -> tuple[str, list[int]]:
|
| 345 |
+
text = f"<im_start>{self.image_feature_placeholder}<im_end>"
|
| 346 |
+
token_ids = [
|
| 347 |
+
self.tokenizer.convert_tokens_to_ids("<im_start>")
|
| 348 |
+
] + [self.image_token_id] * self.num_image_feature_size + [
|
| 349 |
+
self.tokenizer.convert_tokens_to_ids("<im_end>")
|
| 350 |
+
]
|
| 351 |
+
return text * num_images, token_ids * num_images
|
| 352 |
+
|
| 353 |
+
def _get_image_repl_features(
|
| 354 |
+
self,
|
| 355 |
+
num_images: int,
|
| 356 |
+
num_patches: int,
|
| 357 |
+
patch_new_line_idx: Optional[list[bool]],
|
| 358 |
+
) -> tuple[str, list[int]]:
|
| 359 |
+
if num_patches > 0:
|
| 360 |
+
patch_repl, patch_repl_ids = self._get_patch_repl(
|
| 361 |
+
num_patches, patch_new_line_idx)
|
| 362 |
+
else:
|
| 363 |
+
patch_repl = ""
|
| 364 |
+
patch_repl_ids = []
|
| 365 |
+
image_repl, image_repl_ids = self._get_image_repl(num_images)
|
| 366 |
+
return patch_repl + image_repl, patch_repl_ids + image_repl_ids
|
| 367 |
+
|
| 368 |
+
def replace_placeholder(self, text: str, placeholder: str,
|
| 369 |
+
repls: list[str]) -> str:
|
| 370 |
+
parts = text.split(placeholder)
|
| 371 |
+
|
| 372 |
+
if len(parts) - 1 != len(repls):
|
| 373 |
+
raise ValueError(
|
| 374 |
+
"The number of placeholders does not match the number of replacements." # noqa: E501
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
result = [parts[0]]
|
| 378 |
+
for i, repl in enumerate(repls):
|
| 379 |
+
result.append(repl)
|
| 380 |
+
result.append(parts[i + 1])
|
| 381 |
+
|
| 382 |
+
return "".join(result)
|
| 383 |
+
|
| 384 |
+
def __call__(
|
| 385 |
+
self,
|
| 386 |
+
text: Optional[Union[str, list[str]]] = None,
|
| 387 |
+
images: ImageInput | None = None,
|
| 388 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 389 |
+
**kwargs,
|
| 390 |
+
) -> BatchFeature:
|
| 391 |
+
|
| 392 |
+
if images is not None:
|
| 393 |
+
images = self.image_preprocessor.fetch_images(images)
|
| 394 |
+
if text is None:
|
| 395 |
+
text = []
|
| 396 |
+
if not isinstance(text, list):
|
| 397 |
+
text = [text]
|
| 398 |
+
if images is None:
|
| 399 |
+
images = []
|
| 400 |
+
elif not isinstance(images, list):
|
| 401 |
+
images = [images]
|
| 402 |
+
elif isinstance(images[0], list):
|
| 403 |
+
images = images[0]
|
| 404 |
+
|
| 405 |
+
if len(images) == 0:
|
| 406 |
+
image_inputs = {}
|
| 407 |
+
text_inputs = self.tokenizer(text)
|
| 408 |
+
else:
|
| 409 |
+
splitted_images_data = self._split_images(images)
|
| 410 |
+
pixel_values_lst = []
|
| 411 |
+
patch_pixel_values_lst = []
|
| 412 |
+
patch_newline_mask_lst = []
|
| 413 |
+
image_repl_str_lst = []
|
| 414 |
+
image_repl_ids_lst = []
|
| 415 |
+
num_patches = []
|
| 416 |
+
for raw_img, img_patches, patch_newline_mask in splitted_images_data: # noqa: E501
|
| 417 |
+
pixel_values_lst.extend(
|
| 418 |
+
self._convert_images_to_pixel_values([raw_img]))
|
| 419 |
+
|
| 420 |
+
if len(img_patches) > 0:
|
| 421 |
+
patch_pixel_values_lst.extend(
|
| 422 |
+
self._convert_images_to_pixel_values(img_patches,
|
| 423 |
+
is_patch=True))
|
| 424 |
+
num_patches.append(len(img_patches))
|
| 425 |
+
|
| 426 |
+
image_repl_str, image_repl_ids = self._get_image_repl_features(
|
| 427 |
+
1, len(img_patches), patch_newline_mask)
|
| 428 |
+
image_repl_str_lst.append(image_repl_str)
|
| 429 |
+
image_repl_ids_lst.extend(image_repl_ids)
|
| 430 |
+
|
| 431 |
+
if patch_newline_mask is not None:
|
| 432 |
+
patch_newline_mask_lst.extend(patch_newline_mask)
|
| 433 |
+
|
| 434 |
+
image_inputs = {
|
| 435 |
+
"pixel_values": torch.cat(pixel_values_lst),
|
| 436 |
+
"num_patches": num_patches,
|
| 437 |
+
}
|
| 438 |
+
if patch_pixel_values_lst:
|
| 439 |
+
image_inputs["patch_pixel_values"] = torch.cat(
|
| 440 |
+
patch_pixel_values_lst)
|
| 441 |
+
if patch_newline_mask_lst:
|
| 442 |
+
image_inputs["patch_newline_mask"] = torch.tensor(
|
| 443 |
+
patch_newline_mask_lst, dtype=torch.bool)
|
| 444 |
+
|
| 445 |
+
text = [
|
| 446 |
+
self.replace_placeholder(t, self.image_token,
|
| 447 |
+
image_repl_str_lst) for t in text
|
| 448 |
+
]
|
| 449 |
+
text_inputs = self.tokenizer(text)
|
| 450 |
+
|
| 451 |
+
return BatchFeature(
|
| 452 |
+
{
|
| 453 |
+
**text_inputs,
|
| 454 |
+
**image_inputs,
|
| 455 |
+
},
|
| 456 |
+
tensor_type=return_tensors,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
|
| 460 |
+
def batch_decode(self, *args, **kwargs):
|
| 461 |
+
"""
|
| 462 |
+
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
| 463 |
+
refer to the docstring of this method for more information.
|
| 464 |
+
"""
|
| 465 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 466 |
+
|
| 467 |
+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
|
| 468 |
+
def decode(self, *args, **kwargs):
|
| 469 |
+
"""
|
| 470 |
+
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
| 471 |
+
the docstring of this method for more information.
|
| 472 |
+
"""
|
| 473 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 474 |
+
|
| 475 |
+
__all__ = ["Step3VLProcessor"]
|
processor_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_map": {
|
| 3 |
+
"AutoProcessor": "processing_step3.Step3VLProcessor"
|
| 4 |
+
},
|
| 5 |
+
"processor_class": "Step3VLProcessor"
|
| 6 |
+
}
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<|begin▁of▁sentence|>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "<|im_end|>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": {
|
| 17 |
+
"content": "<|end▁of▁sentence|>",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
}
|
| 23 |
+
}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": null,
|
| 3 |
+
"auto_map": {
|
| 4 |
+
"AutoProcessor": "processing_step3.Step3VLProcessor"
|
| 5 |
+
},
|
| 6 |
+
"backend": "tokenizers",
|
| 7 |
+
"bos_token": "<|begin▁of▁sentence|>",
|
| 8 |
+
"clean_up_tokenization_spaces": false,
|
| 9 |
+
"eos_token": "<|im_end|>",
|
| 10 |
+
"is_local": true,
|
| 11 |
+
"legacy": true,
|
| 12 |
+
"local_files_only": false,
|
| 13 |
+
"model_max_length": 262144,
|
| 14 |
+
"pad_token": "<|▁pad▁|>",
|
| 15 |
+
"padding_side": "left",
|
| 16 |
+
"processor_class": "Step3VLProcessor",
|
| 17 |
+
"sp_model_kwargs": {},
|
| 18 |
+
"tokenizer_class": "TokenizersBackend",
|
| 19 |
+
"unk_token": null,
|
| 20 |
+
"use_default_system_prompt": false,
|
| 21 |
+
"chat_template": "{% macro render_message_content(message) %}{% if message.content is none %}{{- '' }}{% elif message.content is string %}{{- message.content }}{% elif message.content is mapping %}{{- message.content['value'] if 'value' in message.content else message.content['text'] }}{% elif message.content is iterable %}{% set ns = namespace(needs_text_separator=false) %}{% for item in message.content %}{% if item.type == 'text' %}{% if ns.needs_text_separator %}{{- ' ' }}{% endif %}{{- item['value'] if 'value' in item else item['text'] }}{% set ns.needs_text_separator = true %}{% elif item.type == 'image' %}<im_patch>{% set ns.needs_text_separator = false %}{% endif %}{% endfor %}{% endif %}{% endmacro %}\n{{bos_token}}{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if reasoning_effort is defined %}\n {{- \"Reasoning: \" + reasoning_effort + '\\n\\n' }}\n {%- endif %}\n {%- if messages[0].role == 'system' %}\n {{- render_message_content(messages[0]) + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou have access to the following functions in JSONSchema format:\\n\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson(ensure_ascii=False) }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nIf you choose to call a function ONLY reply in the following format with NO suffix:\\n\\n<tool_call>\\n<function=example_function_name>\\n<parameter=example_parameter_1>\\nvalue_1\\n</parameter>\\n<parameter=example_parameter_2>\\nThis is the value for the second parameter\\nthat can span\\nmultiple lines\\n</parameter>\\n</function>\\n</tool_call>\\n\\n<IMPORTANT>\\nReminder:\\n- Function calls MUST follow the specified format: an inner <function=...>\\n...\\n</function> block must be nested within <tool_call>\\n...\\n</tool_call> XML tags\\n- Required parameters MUST be specified\\n</IMPORTANT><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' }}\n {%- if reasoning_effort is defined %}\n {{- \"Reasoning: \" + reasoning_effort + '\\n\\n' }}\n {%- endif %}\n {{- render_message_content(messages[0]) + '<|im_end|>\\n' }}\n {%- elif reasoning_effort is defined %}\n {{- '<|im_start|>system\\n' + \"Reasoning: \" + reasoning_effort + '\\n\\n' + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and render_message_content(message) is string and not(render_message_content(message).startswith('<tool_response>') and render_message_content(message).endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- set content = render_message_content(message) %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {%- set role_name = 'observation' if (message.role == \"system\" and not loop.first and message.name == 'observation') else message.role %}\n {{- '<|im_start|>' + role_name + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in content %}\n {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n {%- else %}\n {%- set reasoning_content = '' %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content + '\\n</think>\\n' + content }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n<function=' + tool_call.name + '>\\n' }}\n {%- if tool_call.arguments is defined %}\n {%- set arguments = tool_call.arguments | fromjson if tool_call.arguments is string else tool_call.arguments %}\n {%- for args_name, args_value in arguments|items %}\n {{- '<parameter=' + args_name + '>\\n' }}\n {%- set args_value = args_value | tojson(ensure_ascii=False) | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}\n {{- args_value }}\n {{- '\\n</parameter>\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '</function>\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>tool_response\\n' }}\n {%- endif %}\n {{- '<tool_response>' }}\n {{- content }}\n {{- '</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n<think>\\n' }}\n{%- endif %}\n"
|
| 22 |
+
}
|
vision_encoder.py
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Literal, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from transformers.activations import ACT2FN
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
from .configuration_step3p7 import StepRoboticsVisionEncoderConfig
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 14 |
+
"""Rotate last dimension halves (used by RoPE)."""
|
| 15 |
+
x = x.reshape(*x.shape[:-1], -1, 2)
|
| 16 |
+
x1, x2 = x.unbind(dim=-1)
|
| 17 |
+
x = torch.stack((-x2, x1), dim=-1)
|
| 18 |
+
return x.reshape(*x.shape[:-2], -1)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def apply_rotary_emb(freqs: torch.Tensor,
|
| 22 |
+
t: torch.Tensor,
|
| 23 |
+
start_index: int = 0,
|
| 24 |
+
scale: float = 1.0,
|
| 25 |
+
seq_dim: int = -2) -> torch.Tensor:
|
| 26 |
+
"""Apply 2D rotary embeddings to queries / keys."""
|
| 27 |
+
dtype = t.dtype
|
| 28 |
+
|
| 29 |
+
if t.ndim == 3:
|
| 30 |
+
seq_len = t.shape[seq_dim]
|
| 31 |
+
freqs = freqs[-seq_len:]
|
| 32 |
+
|
| 33 |
+
rot_dim = freqs.shape[-1]
|
| 34 |
+
end_index = start_index + rot_dim
|
| 35 |
+
assert rot_dim <= t.shape[-1], (
|
| 36 |
+
f"feature dimension {t.shape[-1]} is too small for rot_dim {rot_dim}")
|
| 37 |
+
|
| 38 |
+
t_left, t, t_right = (
|
| 39 |
+
t[..., :start_index],
|
| 40 |
+
t[..., start_index:end_index],
|
| 41 |
+
t[..., end_index:],
|
| 42 |
+
)
|
| 43 |
+
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
| 44 |
+
out = torch.cat((t_left, t, t_right), dim=-1)
|
| 45 |
+
return out.type(dtype)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class EncoderRope2D(nn.Module):
|
| 49 |
+
"""Cacheable 2D rotary positional embedding."""
|
| 50 |
+
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
dim: int,
|
| 54 |
+
max_grid_height: int,
|
| 55 |
+
max_grid_width: int,
|
| 56 |
+
use_cls_token: bool = False,
|
| 57 |
+
theta: Union[int, float] = 10000,
|
| 58 |
+
max_freq: int = 10,
|
| 59 |
+
num_freqs: int = 1,
|
| 60 |
+
theta_rescale_factor: float = 1.0,
|
| 61 |
+
):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.dim = dim
|
| 64 |
+
self.max_grid_height = max_grid_height
|
| 65 |
+
self.max_grid_width = max_grid_width
|
| 66 |
+
self.use_cls_token = use_cls_token
|
| 67 |
+
self.theta = theta * theta_rescale_factor**(dim / (dim - 2))
|
| 68 |
+
self.max_freq = max_freq
|
| 69 |
+
self.num_freqs = num_freqs
|
| 70 |
+
cache = self._compute_2d_freqs()
|
| 71 |
+
self.register_buffer("freqs_cache", cache, persistent=False)
|
| 72 |
+
|
| 73 |
+
def _compute_inv_freq(self, base: Union[int, float],
|
| 74 |
+
dim: int) -> torch.Tensor:
|
| 75 |
+
|
| 76 |
+
freqs = 1.0 / (base**(
|
| 77 |
+
torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
| 78 |
+
return freqs
|
| 79 |
+
|
| 80 |
+
def _compute_freqs(self, t: torch.Tensor, inv_freq: torch.Tensor):
|
| 81 |
+
freqs = torch.einsum("..., f -> ... f", t.type(inv_freq.dtype),
|
| 82 |
+
inv_freq)
|
| 83 |
+
freqs = freqs.repeat_interleave(2, dim=-1)
|
| 84 |
+
return freqs
|
| 85 |
+
|
| 86 |
+
def _compute_2d_freqs(self) -> torch.Tensor:
|
| 87 |
+
grid_h_range = torch.arange(self.max_grid_height, dtype=torch.float)
|
| 88 |
+
grid_w_range = torch.arange(self.max_grid_width, dtype=torch.float)
|
| 89 |
+
if self.use_cls_token:
|
| 90 |
+
grid_h_range += 1
|
| 91 |
+
grid_w_range += 1
|
| 92 |
+
inv_freq = self._compute_inv_freq(self.theta, self.dim // 2)
|
| 93 |
+
freqs_h = self._compute_freqs(grid_h_range, inv_freq)[:, None].expand(
|
| 94 |
+
self.max_grid_height, self.max_grid_width, -1)
|
| 95 |
+
freqs_w = self._compute_freqs(grid_w_range, inv_freq)[None, :].expand(
|
| 96 |
+
self.max_grid_height, self.max_grid_width, -1)
|
| 97 |
+
freqs = torch.cat([freqs_w, freqs_h], dim=-1).reshape(
|
| 98 |
+
self.max_grid_height * self.max_grid_width, -1)
|
| 99 |
+
if self.use_cls_token:
|
| 100 |
+
freqs = torch.cat([torch.zeros(1, freqs.shape[-1]), freqs], dim=0)
|
| 101 |
+
freqs = freqs[None, None, ...]
|
| 102 |
+
return freqs
|
| 103 |
+
|
| 104 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor,
|
| 105 |
+
grid_hw: tuple[int, int]):
|
| 106 |
+
# If grid matches cached shape we reuse directly to avoid recomputation.
|
| 107 |
+
if grid_hw[0] != self.max_grid_height or grid_hw[1] != self.max_grid_width:
|
| 108 |
+
rows = torch.arange(grid_hw[0], device=q.device).view(-1, 1)
|
| 109 |
+
cols = torch.arange(grid_hw[1], device=q.device).view(1, -1)
|
| 110 |
+
positions = (rows * self.max_grid_width + cols).reshape(-1).to(
|
| 111 |
+
torch.long)
|
| 112 |
+
if self.use_cls_token:
|
| 113 |
+
positions = torch.cat(
|
| 114 |
+
[torch.zeros(1, device=q.device), positions + 1], dim=0)
|
| 115 |
+
freqs = self.freqs_cache.index_select(2, positions)
|
| 116 |
+
else:
|
| 117 |
+
freqs = self.freqs_cache
|
| 118 |
+
q = apply_rotary_emb(freqs, q)
|
| 119 |
+
k = apply_rotary_emb(freqs, k)
|
| 120 |
+
return q, k
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class EncoderLayerScale(nn.Module):
|
| 124 |
+
"""Per-channel residual scaling used when ls_init_value is set."""
|
| 125 |
+
|
| 126 |
+
def __init__(self, dim: int, init_values: float):
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.gamma = nn.Parameter(torch.full((dim,), init_values))
|
| 129 |
+
|
| 130 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # (B, L, D)
|
| 131 |
+
return hidden_states * self.gamma
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class EncoderMLP(nn.Module):
|
| 135 |
+
"""Feed-forward network used inside each transformer block."""
|
| 136 |
+
|
| 137 |
+
def __init__(self, hidden_size: int, intermediate_size: int,
|
| 138 |
+
hidden_act: str):
|
| 139 |
+
super().__init__()
|
| 140 |
+
self.c_fc = nn.Linear(hidden_size, intermediate_size, bias=True)
|
| 141 |
+
self.act_fn = ACT2FN[hidden_act]
|
| 142 |
+
self.c_proj = nn.Linear(intermediate_size, hidden_size, bias=True)
|
| 143 |
+
|
| 144 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 145 |
+
|
| 146 |
+
hidden_states = self.c_proj(self.act_fn(self.c_fc(hidden_states)))
|
| 147 |
+
return hidden_states
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class EncoderVisionAttention(nn.Module):
|
| 151 |
+
"""Multi-head self attention with optional 2D RoPE."""
|
| 152 |
+
|
| 153 |
+
def __init__(
|
| 154 |
+
self,
|
| 155 |
+
hidden_size: int,
|
| 156 |
+
num_heads: int,
|
| 157 |
+
max_grid_height: int,
|
| 158 |
+
max_grid_width: int,
|
| 159 |
+
use_cls_token: bool = False,
|
| 160 |
+
use_rope2d: bool = True,
|
| 161 |
+
rope_theta: Union[int, float] = 10000,
|
| 162 |
+
rope_max_freq: int = 10,
|
| 163 |
+
rope_num_freqs: int = 1,
|
| 164 |
+
rope_theta_rescale_factor: float = 1.0,
|
| 165 |
+
rope_freqs_for: Literal["lang", "pixel", "constant"] = "lang",
|
| 166 |
+
):
|
| 167 |
+
super().__init__()
|
| 168 |
+
if hidden_size % num_heads != 0:
|
| 169 |
+
raise ValueError(
|
| 170 |
+
f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})."
|
| 171 |
+
)
|
| 172 |
+
self.num_heads = num_heads
|
| 173 |
+
self.head_dim = hidden_size // num_heads
|
| 174 |
+
self.scale = self.head_dim**-0.5
|
| 175 |
+
self.in_proj_weight = nn.Parameter(torch.zeros(hidden_size * 3, hidden_size))
|
| 176 |
+
self.in_proj_bias = nn.Parameter(torch.zeros(hidden_size * 3))
|
| 177 |
+
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True)
|
| 178 |
+
|
| 179 |
+
self.rope = None
|
| 180 |
+
if use_rope2d:
|
| 181 |
+
self.rope = EncoderRope2D(
|
| 182 |
+
dim=self.head_dim,
|
| 183 |
+
max_grid_height=max_grid_height,
|
| 184 |
+
max_grid_width=max_grid_width,
|
| 185 |
+
use_cls_token=use_cls_token,
|
| 186 |
+
theta=rope_theta,
|
| 187 |
+
max_freq=rope_max_freq,
|
| 188 |
+
num_freqs=rope_num_freqs,
|
| 189 |
+
theta_rescale_factor=rope_theta_rescale_factor,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
def forward(self, hidden_states: torch.Tensor, grid_hw: tuple[int, int]) -> torch.Tensor:
|
| 193 |
+
bsz, seq_len, _ = hidden_states.shape
|
| 194 |
+
qkv = F.linear(
|
| 195 |
+
hidden_states,
|
| 196 |
+
self.in_proj_weight,
|
| 197 |
+
self.in_proj_bias,
|
| 198 |
+
)
|
| 199 |
+
q, k, v = qkv.chunk(3, dim=-1)
|
| 200 |
+
|
| 201 |
+
q = q.view(bsz, seq_len, self.num_heads,
|
| 202 |
+
self.head_dim).transpose(1, 2)
|
| 203 |
+
k = k.view(bsz, seq_len, self.num_heads,
|
| 204 |
+
self.head_dim).transpose(1, 2)
|
| 205 |
+
if self.rope is not None:
|
| 206 |
+
q, k = self.rope(q, k, grid_hw=grid_hw)
|
| 207 |
+
v = v.view(bsz, seq_len, self.num_heads,
|
| 208 |
+
self.head_dim).transpose(1, 2)
|
| 209 |
+
|
| 210 |
+
attn_output = F.scaled_dot_product_attention(
|
| 211 |
+
q, k, v, is_causal=False, scale=self.scale)
|
| 212 |
+
attn_output = attn_output.transpose(1, 2).reshape(
|
| 213 |
+
bsz, seq_len, self.num_heads * self.head_dim)
|
| 214 |
+
return self.out_proj(attn_output)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class EncoderVisionBlock(nn.Module):
|
| 218 |
+
"""A single Vision Transformer block (self-attention + MLP)."""
|
| 219 |
+
|
| 220 |
+
def __init__(
|
| 221 |
+
self,
|
| 222 |
+
hidden_size: int,
|
| 223 |
+
num_heads: int,
|
| 224 |
+
mlp_ratio: float,
|
| 225 |
+
hidden_act: str,
|
| 226 |
+
layer_norm_eps: float,
|
| 227 |
+
ls_init_value: Optional[float] = None,
|
| 228 |
+
max_grid_height: Optional[int] = None,
|
| 229 |
+
max_grid_width: Optional[int] = None,
|
| 230 |
+
use_cls_token: bool = False,
|
| 231 |
+
use_rope2d: bool = True,
|
| 232 |
+
rope_kwargs: Optional[dict] = None,
|
| 233 |
+
):
|
| 234 |
+
super().__init__()
|
| 235 |
+
rope_kwargs = rope_kwargs or {}
|
| 236 |
+
self.attn = EncoderVisionAttention(
|
| 237 |
+
hidden_size,
|
| 238 |
+
num_heads,
|
| 239 |
+
max_grid_height=max_grid_height,
|
| 240 |
+
max_grid_width=max_grid_width,
|
| 241 |
+
use_cls_token=use_cls_token,
|
| 242 |
+
use_rope2d=use_rope2d,
|
| 243 |
+
**rope_kwargs,
|
| 244 |
+
)
|
| 245 |
+
self.ln_1 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
| 246 |
+
self.ln_2 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
| 247 |
+
|
| 248 |
+
intermediate = int(hidden_size * mlp_ratio)
|
| 249 |
+
self.mlp = EncoderMLP(hidden_size, intermediate, hidden_act)
|
| 250 |
+
|
| 251 |
+
self.ls_1 = EncoderLayerScale(hidden_size, ls_init_value)
|
| 252 |
+
self.ls_2 = EncoderLayerScale(hidden_size, ls_init_value)
|
| 253 |
+
|
| 254 |
+
def forward(self, hidden_states: torch.Tensor,
|
| 255 |
+
grid_hw: tuple[int, int]) -> torch.Tensor:
|
| 256 |
+
# breakpoint()
|
| 257 |
+
residual = hidden_states
|
| 258 |
+
hidden_states = self.ln_1(hidden_states)
|
| 259 |
+
hidden_states = self.attn(hidden_states, grid_hw=grid_hw)
|
| 260 |
+
hidden_states = residual + self.ls_1(hidden_states)
|
| 261 |
+
|
| 262 |
+
residual = hidden_states
|
| 263 |
+
hidden_states = self.ln_2(hidden_states)
|
| 264 |
+
hidden_states = self.mlp(hidden_states)
|
| 265 |
+
hidden_states = residual + self.ls_2(hidden_states)
|
| 266 |
+
return hidden_states
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class EncoderVisionTransformer(nn.Module):
|
| 270 |
+
"""Stack of encoder blocks parameterised by Step35VisionEncoderConfig."""
|
| 271 |
+
|
| 272 |
+
def __init__(
|
| 273 |
+
self,
|
| 274 |
+
embed_dim: int,
|
| 275 |
+
depth: int,
|
| 276 |
+
num_heads: int,
|
| 277 |
+
mlp_ratio: float,
|
| 278 |
+
hidden_act: str,
|
| 279 |
+
layer_norm_eps: float,
|
| 280 |
+
ls_init_value: Optional[float] = None,
|
| 281 |
+
max_grid_height: Optional[int] = None,
|
| 282 |
+
max_grid_width: Optional[int] = None,
|
| 283 |
+
use_cls_token: bool = False,
|
| 284 |
+
use_rope2d: bool = True,
|
| 285 |
+
rope_kwargs: Optional[dict] = None,
|
| 286 |
+
):
|
| 287 |
+
super().__init__()
|
| 288 |
+
self.layers = depth
|
| 289 |
+
rope_kwargs = rope_kwargs or {}
|
| 290 |
+
self.resblocks = nn.ModuleList([
|
| 291 |
+
EncoderVisionBlock(embed_dim, num_heads, mlp_ratio, hidden_act,
|
| 292 |
+
layer_norm_eps,
|
| 293 |
+
max_grid_height=max_grid_height,
|
| 294 |
+
max_grid_width=max_grid_width,
|
| 295 |
+
use_cls_token=use_cls_token,
|
| 296 |
+
use_rope2d=use_rope2d,
|
| 297 |
+
ls_init_value=ls_init_value,
|
| 298 |
+
rope_kwargs=rope_kwargs)
|
| 299 |
+
for _ in range(depth)
|
| 300 |
+
])
|
| 301 |
+
|
| 302 |
+
def forward(self,
|
| 303 |
+
hidden_states: torch.Tensor,
|
| 304 |
+
grid_hw: tuple[int, int]) -> torch.Tensor:
|
| 305 |
+
for block in self.resblocks:
|
| 306 |
+
hidden_states = block(hidden_states, grid_hw=grid_hw)
|
| 307 |
+
return hidden_states
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class StepRoboticsVisionEncoder(nn.Module):
|
| 311 |
+
"""
|
| 312 |
+
Vision encoder built from StepRoboticsVisionEncoderConfig.
|
| 313 |
+
|
| 314 |
+
The encoder performs patch embedding followed by a stack of transformer
|
| 315 |
+
blocks. Only the config fields defined in StepRoboticsVisionEncoderConfig (and
|
| 316 |
+
StepRoboticVLConfig.vision_config) are expected.
|
| 317 |
+
"""
|
| 318 |
+
|
| 319 |
+
def __init__(self, config: StepRoboticsVisionEncoderConfig):
|
| 320 |
+
super().__init__()
|
| 321 |
+
self.config = config
|
| 322 |
+
|
| 323 |
+
# Align commonly used attributes so downstream code (e.g. StepRoboticVL)
|
| 324 |
+
# can access them without extra renaming.
|
| 325 |
+
self.hidden_size = config.width
|
| 326 |
+
self.num_heads = config.heads
|
| 327 |
+
self.num_hidden_layers = config.layers
|
| 328 |
+
self.patch_size = config.patch_size
|
| 329 |
+
self.image_size = config.image_size
|
| 330 |
+
self.use_cls_token = getattr(config, "use_cls_token", False)
|
| 331 |
+
self.use_rope2d = getattr(config, "use_rope2d", True)
|
| 332 |
+
self.use_abs_posemb = getattr(config, "use_abs_posemb", True)
|
| 333 |
+
self.layer_norm_eps = config.layer_norm_eps
|
| 334 |
+
self.mlp_ratio = getattr(config, "mlp_ratio", 8960 / 1536)
|
| 335 |
+
self.ls_init_value = getattr(config, "ls_init_value", None)
|
| 336 |
+
self.hidden_act = config.hidden_act
|
| 337 |
+
self.use_ln_pre = getattr(config, "use_ln_pre", False)
|
| 338 |
+
self.use_ln_post = getattr(config, "use_ln_post", True)
|
| 339 |
+
|
| 340 |
+
# Patch embedding.
|
| 341 |
+
self.conv1 = nn.Conv2d(in_channels=config.num_channels,
|
| 342 |
+
out_channels=self.hidden_size,
|
| 343 |
+
kernel_size=self.patch_size,
|
| 344 |
+
stride=self.patch_size,
|
| 345 |
+
bias=False)
|
| 346 |
+
|
| 347 |
+
self.ln_pre = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) if self.use_ln_pre else nn.Identity()
|
| 348 |
+
self.ln_post = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) if self.use_ln_post else nn.Identity()
|
| 349 |
+
|
| 350 |
+
grid_size = self.image_size // self.patch_size
|
| 351 |
+
self.base_grid = (grid_size, grid_size)
|
| 352 |
+
|
| 353 |
+
if self.use_cls_token:
|
| 354 |
+
self.class_embedding = nn.Parameter(
|
| 355 |
+
torch.randn(self.hidden_size) * (self.hidden_size**-0.5))
|
| 356 |
+
else:
|
| 357 |
+
self.class_embedding = None
|
| 358 |
+
|
| 359 |
+
if self.use_abs_posemb:
|
| 360 |
+
self.posemb_grid_size = self.image_size // self.patch_size
|
| 361 |
+
self.positional_embedding = nn.Parameter(
|
| 362 |
+
(self.hidden_size**-0.5) * torch.randn(
|
| 363 |
+
int(self.use_cls_token) + self.posemb_grid_size**2,
|
| 364 |
+
self.hidden_size,
|
| 365 |
+
))
|
| 366 |
+
|
| 367 |
+
self.transformer = EncoderVisionTransformer(
|
| 368 |
+
embed_dim=self.hidden_size,
|
| 369 |
+
depth=self.num_hidden_layers,
|
| 370 |
+
num_heads=self.num_heads,
|
| 371 |
+
mlp_ratio=self.mlp_ratio,
|
| 372 |
+
hidden_act=self.hidden_act,
|
| 373 |
+
layer_norm_eps=self.layer_norm_eps,
|
| 374 |
+
ls_init_value=self.ls_init_value,
|
| 375 |
+
max_grid_height=self.base_grid[0],
|
| 376 |
+
max_grid_width=self.base_grid[1],
|
| 377 |
+
use_cls_token=self.use_cls_token,
|
| 378 |
+
use_rope2d=self.use_rope2d,
|
| 379 |
+
rope_kwargs={
|
| 380 |
+
"rope_theta": getattr(config, "rope_theta", 10000),
|
| 381 |
+
"rope_max_freq": getattr(config, "rope_max_freq", 10),
|
| 382 |
+
"rope_num_freqs": getattr(config, "rope_num_freqs", 1),
|
| 383 |
+
"rope_theta_rescale_factor":
|
| 384 |
+
getattr(config, "rope_theta_rescale_factor", 1.0),
|
| 385 |
+
"rope_freqs_for": getattr(config, "rope_freqs_for", "lang"),
|
| 386 |
+
},
|
| 387 |
+
)
|
| 388 |
+
self.vit_downsampler1 = nn.Conv2d(self.hidden_size,
|
| 389 |
+
self.hidden_size * 2,
|
| 390 |
+
kernel_size=3,
|
| 391 |
+
stride=2,
|
| 392 |
+
padding=1)
|
| 393 |
+
self.vit_downsampler2 = nn.Conv2d(self.hidden_size * 2,
|
| 394 |
+
self.hidden_size * 4,
|
| 395 |
+
kernel_size=3,
|
| 396 |
+
stride=2,
|
| 397 |
+
padding=1)
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def sample_abs_posemb(self, grid_h: int, grid_w: int):
|
| 401 |
+
if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w:
|
| 402 |
+
return self.positional_embedding[None, ...]
|
| 403 |
+
|
| 404 |
+
pos_embed = self.positional_embedding
|
| 405 |
+
if self.use_cls_token:
|
| 406 |
+
cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]
|
| 407 |
+
|
| 408 |
+
pos_embed = (pos_embed.reshape(1, self.posemb_grid_size,
|
| 409 |
+
self.posemb_grid_size,
|
| 410 |
+
-1).permute(0, 3, 1, 2).contiguous())
|
| 411 |
+
pos_embed = F.interpolate(pos_embed,
|
| 412 |
+
size=(grid_h, grid_w),
|
| 413 |
+
mode="bilinear",
|
| 414 |
+
align_corners=False)
|
| 415 |
+
pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.hidden_size)
|
| 416 |
+
|
| 417 |
+
if self.use_cls_token:
|
| 418 |
+
pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0)
|
| 419 |
+
|
| 420 |
+
return pos_embed[None, ...]
|
| 421 |
+
|
| 422 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 423 |
+
"""
|
| 424 |
+
Args:
|
| 425 |
+
pixel_values: Image tensor of shape (B, C, H, W).
|
| 426 |
+
layer_idx: Negative indices stop after a given block (e.g., -1 uses all blocks).
|
| 427 |
+
strip_cls_token: If True and cls token is used, remove it from output.
|
| 428 |
+
"""
|
| 429 |
+
bsz, _, height, width = pixel_values.shape
|
| 430 |
+
grid_h, grid_w = height // self.patch_size, width // self.patch_size
|
| 431 |
+
|
| 432 |
+
hidden_state = self.conv1(pixel_values) # (B, D, Gh, Gw)
|
| 433 |
+
hidden_state = hidden_state.flatten(2).transpose(1, 2) # (B, Gh*Gw, D)
|
| 434 |
+
|
| 435 |
+
if self.use_cls_token:
|
| 436 |
+
cls_token = self.class_embedding.view(1, 1,
|
| 437 |
+
-1).expand(bsz, -1, -1)
|
| 438 |
+
hidden_state = torch.cat([cls_token, hidden_state], dim=1)
|
| 439 |
+
|
| 440 |
+
if self.use_abs_posemb:
|
| 441 |
+
pos_emb = self.sample_abs_posemb(grid_h, grid_w)
|
| 442 |
+
hidden_state = hidden_state + pos_emb
|
| 443 |
+
hidden_state = self.ln_pre(hidden_state)
|
| 444 |
+
hidden_state = self.transformer(hidden_state, grid_hw=(grid_h, grid_w))
|
| 445 |
+
|
| 446 |
+
if self.use_ln_post:
|
| 447 |
+
hidden_state = self.ln_post(hidden_state)
|
| 448 |
+
|
| 449 |
+
if self.use_cls_token:
|
| 450 |
+
hidden_state = hidden_state[:, 1:, :]
|
| 451 |
+
|
| 452 |
+
return hidden_state
|