Robotics
Transformers
Safetensors
molmoact2
image-text-to-text
so100
so101
custom_code
8-bit precision
Instructions to use OpenRAL/rskill-molmoact2-so101-nf4 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use OpenRAL/rskill-molmoact2-so101-nf4 with Transformers:
# Load model directly from transformers import AutoModelForImageTextToText model = AutoModelForImageTextToText.from_pretrained("OpenRAL/rskill-molmoact2-so101-nf4", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
Commit ·
e01c114
0
Parent(s):
Duplicate from AdrianLlopart/rskill-molmoact2-so101-nf4
Browse files- .gitattributes +38 -0
- README.md +192 -0
- config.json +153 -0
- configuration_molmoact2.py +543 -0
- generation_config.json +6 -0
- image_processing_molmoact2.py +546 -0
- inference.py +768 -0
- model.safetensors +3 -0
- modeling_molmoact2.py +0 -0
- norm_stats.json +202 -0
- processing_molmoact2.py +418 -0
- processor_config.json +85 -0
- quantization_metadata.json +14 -0
- tokenizer.json +3 -0
- tokenizer_config.json +34 -0
- video_processing_molmoact2.py +969 -0
.gitattributes
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/sample_realsense_top_rgb.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/sample_realsense_side_rgb.png filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
tags:
|
| 4 |
+
- molmoact2
|
| 5 |
+
- robotics
|
| 6 |
+
- image-text-to-text
|
| 7 |
+
- so100
|
| 8 |
+
- so101
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
<img src="assets/MolmoAct2.svg" alt="MolmoAct Logo" height="50">
|
| 12 |
+
|
| 13 |
+
# **MolmoAct2-SO100_101**
|
| 14 |
+
|
| 15 |
+
MolmoAct2 is an open vision-language-action model for robot control. It builds on Molmo2-ER and attaches a flow-matching continuous action expert that conditions on the VLM key-value cache through a per-layer connection.
|
| 16 |
+
|
| 17 |
+
This checkpoint is fine-tuned on the SO-100/101 mixture with absolute joint-pose control and annotated language instructions. It is intended for both further fine-tuning and SO-100/101 policy inference.
|
| 18 |
+
|
| 19 |
+
## Quick Links
|
| 20 |
+
|
| 21 |
+
- 📂 Models: [Models](https://huggingface.co/collections/allenai/molmoact2-models), [Finetuned Models](https://huggingface.co/collections/allenai/molmoact2-finetuned-models)
|
| 22 |
+
- 📂 Datasets: [MolmoAct2-BimanualYAM Dataset](https://huggingface.co/collections/allenai/molmoact2-datasets), [MolmoAct2 Datasets](https://huggingface.co/collections/allenai/molmoact2-datasets), [Molmo2-ER Datasets](https://huggingface.co/collections/allenai/molmo2-er-datasets)
|
| 23 |
+
- 📄 Paper: [arXiv:2605.02881](https://arxiv.org/abs/2605.02881)
|
| 24 |
+
- 💻 Code: [allenai/molmoact2](https://github.com/allenai/molmoact2)
|
| 25 |
+
- 🎥 Blog Post: [MolmoAct2](https://allenai.org/blog/molmoact2)
|
| 26 |
+
|
| 27 |
+
## Intended Use
|
| 28 |
+
|
| 29 |
+
Use this checkpoint for SO-100/101 inference or for further fine-tuning. Dataset normalization metadata is stored in `norm_stats.json`. pass `norm_tag="so100_so101_molmoact2"` at inference time.
|
| 30 |
+
|
| 31 |
+
Continuous action prediction is the intended and recommended inference mode. Discrete action prediction is exposed for parity and debugging, but we use continuous actions by default.
|
| 32 |
+
|
| 33 |
+
## Install
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
pip install torch transformers pillow numpy huggingface_hub
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
## Sample Input
|
| 40 |
+
|
| 41 |
+
This sample comes from `Beegbrain/pick_lemon_and_drop_in_bowl`, episode 0, frame 0. Camera order for this checkpoint does not matter. random camera order is acceptable.
|
| 42 |
+
|
| 43 |
+
| Realsense Top RGB | Realsense Side RGB |
|
| 44 |
+
| --- | --- |
|
| 45 |
+
|  |  |
|
| 46 |
+
|
| 47 |
+
```python
|
| 48 |
+
from huggingface_hub import hf_hub_download
|
| 49 |
+
from PIL import Image
|
| 50 |
+
import numpy as np
|
| 51 |
+
|
| 52 |
+
repo_id = "allenai/MolmoAct2-SO100_101"
|
| 53 |
+
|
| 54 |
+
top_rgb = Image.open(
|
| 55 |
+
hf_hub_download(repo_id, "assets/sample_realsense_top_rgb.png")
|
| 56 |
+
).convert("RGB")
|
| 57 |
+
side_rgb = Image.open(
|
| 58 |
+
hf_hub_download(repo_id, "assets/sample_realsense_side_rgb.png")
|
| 59 |
+
).convert("RGB")
|
| 60 |
+
|
| 61 |
+
task = "Move the arm towards the lemon, grasp it, lift it up, and drop it into the red bowl."
|
| 62 |
+
robot_state = np.array(
|
| 63 |
+
[
|
| 64 |
+
-0.52734375,
|
| 65 |
+
189.140625,
|
| 66 |
+
181.40625,
|
| 67 |
+
60.64453125,
|
| 68 |
+
-3.603515625,
|
| 69 |
+
1.0971786975860596,
|
| 70 |
+
],
|
| 71 |
+
dtype=np.float32,
|
| 72 |
+
)
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
## Continuous Actions
|
| 76 |
+
|
| 77 |
+
```python
|
| 78 |
+
import numpy as np
|
| 79 |
+
import torch
|
| 80 |
+
from huggingface_hub import hf_hub_download
|
| 81 |
+
from PIL import Image
|
| 82 |
+
from transformers import AutoModelForImageTextToText, AutoProcessor
|
| 83 |
+
|
| 84 |
+
repo_id = "allenai/MolmoAct2-SO100_101"
|
| 85 |
+
|
| 86 |
+
top_rgb = Image.open(
|
| 87 |
+
hf_hub_download(repo_id, "assets/sample_realsense_top_rgb.png")
|
| 88 |
+
).convert("RGB")
|
| 89 |
+
side_rgb = Image.open(
|
| 90 |
+
hf_hub_download(repo_id, "assets/sample_realsense_side_rgb.png")
|
| 91 |
+
).convert("RGB")
|
| 92 |
+
task = "Move the arm towards the lemon, grasp it, lift it up, and drop it into the red bowl."
|
| 93 |
+
robot_state = np.array(
|
| 94 |
+
[
|
| 95 |
+
-0.52734375,
|
| 96 |
+
189.140625,
|
| 97 |
+
181.40625,
|
| 98 |
+
60.64453125,
|
| 99 |
+
-3.603515625,
|
| 100 |
+
1.0971786975860596,
|
| 101 |
+
],
|
| 102 |
+
dtype=np.float32,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
|
| 106 |
+
model = AutoModelForImageTextToText.from_pretrained(
|
| 107 |
+
repo_id,
|
| 108 |
+
trust_remote_code=True,
|
| 109 |
+
dtype=torch.float32,
|
| 110 |
+
).to("cuda").eval()
|
| 111 |
+
|
| 112 |
+
out = model.predict_action(
|
| 113 |
+
processor=processor,
|
| 114 |
+
images=[top_rgb, side_rgb],
|
| 115 |
+
task=task,
|
| 116 |
+
state=robot_state,
|
| 117 |
+
norm_tag="so100_so101_molmoact2",
|
| 118 |
+
inference_action_mode="continuous",
|
| 119 |
+
enable_depth_reasoning=False,
|
| 120 |
+
num_steps=10,
|
| 121 |
+
normalize_language=True,
|
| 122 |
+
enable_cuda_graph=True,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
actions = out.actions
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
MolmoAct2 was trained with mixed precision. For our reported experiments, we ran inference in `float32`. This path uses the most GPU memory: roughly 26GB with CUDA graph enabled, or around 24GB without CUDA graph.
|
| 129 |
+
|
| 130 |
+
If you have a GPU with less memory, you can run inference with `bfloat16` instead:
|
| 131 |
+
|
| 132 |
+
```python
|
| 133 |
+
model = AutoModelForImageTextToText.from_pretrained(
|
| 134 |
+
repo_id,
|
| 135 |
+
trust_remote_code=True,
|
| 136 |
+
dtype=torch.bfloat16,
|
| 137 |
+
).to("cuda").eval()
|
| 138 |
+
|
| 139 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
| 140 |
+
out = model.predict_action(...)
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
Using `bfloat16` is much more memory efficient and can run under 16GB of GPU memory in our tests. It usually does not hurt performance much.
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
Images may be PIL images or RGB arrays. Camera order does not need to be fixed for this checkpoint. random camera order is acceptable. `state` is the raw robot state, and actions are returned in robot scale.
|
| 147 |
+
|
| 148 |
+
`normalize_language=True` is the default. It lowercases the task string and removes trailing sentence punctuation to match training preprocessing. Set it to `False` if you need to preserve the task text exactly.
|
| 149 |
+
|
| 150 |
+
`enable_cuda_graph=True` is the default. The first few calls can be slow because the model warms up and captures CUDA graphs. run several random warm-up calls before measuring deployment latency. `num_steps` controls the continuous flow solver and defaults to the checkpoint config value, 10.
|
| 151 |
+
|
| 152 |
+
Depth reasoning is disabled for this checkpoint. Calling `enable_depth_reasoning=True` will raise an error.
|
| 153 |
+
|
| 154 |
+
## Discrete Actions
|
| 155 |
+
|
| 156 |
+
Discrete action inference requires a caller-provided action tokenizer. It is not saved in this repository. Discrete mode decodes action tokens directly. the continuous action expert is not used.
|
| 157 |
+
|
| 158 |
+
```python
|
| 159 |
+
action_tokenizer = AutoProcessor.from_pretrained(
|
| 160 |
+
"allenai/MolmoAct2-FAST-Tokenizer",
|
| 161 |
+
trust_remote_code=True,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
out = model.predict_action(
|
| 165 |
+
processor=processor,
|
| 166 |
+
images=[top_rgb, side_rgb],
|
| 167 |
+
task=task,
|
| 168 |
+
state=robot_state,
|
| 169 |
+
norm_tag="so100_so101_molmoact2",
|
| 170 |
+
inference_action_mode="discrete",
|
| 171 |
+
action_tokenizer=action_tokenizer,
|
| 172 |
+
enable_depth_reasoning=False,
|
| 173 |
+
)
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
## Model and Hardware Safety
|
| 177 |
+
|
| 178 |
+
MolmoAct2 generate robot actions from visual observations and language instructions, but their behavior may vary across embodiments, environments, and hardware configurations. Users should carefully validate model outputs before deployment, especially when operating physical robots or other actuated systems. Where possible, actions should be monitored through interpretable intermediate outputs (adaptive depth map), simulation rollouts, action limits, or other safety checks before execution on hardware. The model’s action space should be bounded by the training data, robot controller limits, and task-specific safety constraints, including limits on speed, workspace, torque, and contact force. Users should follow the hardware manufacturer’s safety guidelines, use appropriate emergency-stop mechanisms, and operate the system only in a safely configured environment with human supervision.
|
| 179 |
+
|
| 180 |
+
## Citation
|
| 181 |
+
|
| 182 |
+
```bibtex
|
| 183 |
+
@misc{fang2026molmoact2actionreasoningmodels,
|
| 184 |
+
title={MolmoAct2: Action Reasoning Models for Real-world Deployment},
|
| 185 |
+
author={Haoquan Fang and Jiafei Duan and Donovan Clay and Sam Wang and Shuo Liu and Weikai Huang and Xiang Fan and Wei-Chuan Tsai and Shirui Chen and Yi Ru Wang and Shanli Xing and Jaemin Cho and Jae Sung Park and Ainaz Eftekhar and Peter Sushko and Karen Farley and Angad Wadhwa and Cole Harrison and Winson Han and Ying-Chun Lee and Eli VanderBilt and Rose Hendrix and Suveen Ellawela and Lucas Ngoo and Joyce Chai and Zhongzheng Ren and Ali Farhadi and Dieter Fox and Ranjay Krishna},
|
| 186 |
+
year={2026},
|
| 187 |
+
eprint={2605.02881},
|
| 188 |
+
archivePrefix={arXiv},
|
| 189 |
+
primaryClass={cs.RO},
|
| 190 |
+
url={https://arxiv.org/abs/2605.02881},
|
| 191 |
+
}
|
| 192 |
+
```
|
config.json
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"action_end_token_id": 151933,
|
| 3 |
+
"action_expert_config": {
|
| 4 |
+
"attn_dropout": 0.0,
|
| 5 |
+
"causal_attn": false,
|
| 6 |
+
"context_layer_norm": true,
|
| 7 |
+
"dropout": 0.0,
|
| 8 |
+
"ffn_multiple_of": 256,
|
| 9 |
+
"hidden_size": 768,
|
| 10 |
+
"mlp_ratio": 4.0,
|
| 11 |
+
"model_type": "molmoact2_action_expert",
|
| 12 |
+
"num_heads": 8,
|
| 13 |
+
"num_layers": 36,
|
| 14 |
+
"qk_norm": true,
|
| 15 |
+
"qk_norm_eps": 1e-06,
|
| 16 |
+
"rope": true,
|
| 17 |
+
"timestep_embed_dim": 256
|
| 18 |
+
},
|
| 19 |
+
"action_expert_depth_gate": false,
|
| 20 |
+
"action_expert_depth_gate_init_bias": -4.0,
|
| 21 |
+
"action_expert_depth_gate_per_layer": false,
|
| 22 |
+
"action_mode": "both",
|
| 23 |
+
"max_action_horizon": 30,
|
| 24 |
+
"action_output_token_id": 151931,
|
| 25 |
+
"action_start_token_id": 151932,
|
| 26 |
+
"action_token_start_id": 151934,
|
| 27 |
+
"adapter_config": {
|
| 28 |
+
"attention_dropout": 0.0,
|
| 29 |
+
"attn_implementation": "sdpa",
|
| 30 |
+
"float32_attention": true,
|
| 31 |
+
"head_dim": 72,
|
| 32 |
+
"hidden_act": "silu",
|
| 33 |
+
"hidden_size": 1152,
|
| 34 |
+
"image_feature_dropout": 0.0,
|
| 35 |
+
"initializer_range": 0.02,
|
| 36 |
+
"intermediate_size": 9728,
|
| 37 |
+
"model_type": "molmoact2",
|
| 38 |
+
"num_attention_heads": 16,
|
| 39 |
+
"num_key_value_heads": 16,
|
| 40 |
+
"pooling_attention_mask": true,
|
| 41 |
+
"residual_dropout": 0.0,
|
| 42 |
+
"text_hidden_size": 2560,
|
| 43 |
+
"vit_layers": [
|
| 44 |
+
-3,
|
| 45 |
+
-9
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
"add_action_expert": true,
|
| 49 |
+
"add_control_tokens": true,
|
| 50 |
+
"add_setup_tokens": true,
|
| 51 |
+
"architectures": [
|
| 52 |
+
"MolmoAct2ForConditionalGeneration"
|
| 53 |
+
],
|
| 54 |
+
"auto_map": {
|
| 55 |
+
"AutoConfig": "configuration_molmoact2.MolmoAct2Config",
|
| 56 |
+
"AutoModelForImageTextToText": "modeling_molmoact2.MolmoAct2ForConditionalGeneration"
|
| 57 |
+
},
|
| 58 |
+
"depth_end_token_id": null,
|
| 59 |
+
"depth_mode": 2,
|
| 60 |
+
"depth_output_token_id": null,
|
| 61 |
+
"depth_start_token_id": null,
|
| 62 |
+
"depth_token_start_id": null,
|
| 63 |
+
"dtype": "float32",
|
| 64 |
+
"enable_depth_reasoning": false,
|
| 65 |
+
"flow_matching_beta_alpha": 1.0,
|
| 66 |
+
"flow_matching_beta_beta": 1.5,
|
| 67 |
+
"flow_matching_cutoff": 1.0,
|
| 68 |
+
"flow_matching_num_steps": 10,
|
| 69 |
+
"flow_matching_time_offset": 0.001,
|
| 70 |
+
"flow_matching_time_scale": 0.999,
|
| 71 |
+
"frame_end_token_id": 154632,
|
| 72 |
+
"frame_start_token_id": 154631,
|
| 73 |
+
"image_col_id": 154627,
|
| 74 |
+
"image_end_token_id": 154625,
|
| 75 |
+
"image_high_res_id": 154626,
|
| 76 |
+
"image_low_res_id": 154630,
|
| 77 |
+
"image_patch_id": 154626,
|
| 78 |
+
"image_start_token_id": 154624,
|
| 79 |
+
"initializer_range": 0.02,
|
| 80 |
+
"low_res_image_start_token_id": 154628,
|
| 81 |
+
"mask_action_dim_padding": true,
|
| 82 |
+
"max_action_dim": 32,
|
| 83 |
+
"model_type": "molmoact2",
|
| 84 |
+
"n_obs_steps": 1,
|
| 85 |
+
"norm_stats_filename": "norm_stats.json",
|
| 86 |
+
"num_action_tokens": 2048,
|
| 87 |
+
"num_depth_codes": 100,
|
| 88 |
+
"num_depth_tokens": 0,
|
| 89 |
+
"num_state_tokens": 256,
|
| 90 |
+
"state_end_token_id": 151674,
|
| 91 |
+
"state_format": "discrete",
|
| 92 |
+
"state_start_token_id": 151673,
|
| 93 |
+
"state_token_start_id": 151675,
|
| 94 |
+
"text_config": {
|
| 95 |
+
"additional_vocab_size": 128,
|
| 96 |
+
"attention_dropout": 0.0,
|
| 97 |
+
"attn_implementation": "sdpa",
|
| 98 |
+
"embedding_dropout": 0.0,
|
| 99 |
+
"head_dim": 128,
|
| 100 |
+
"hidden_act": "silu",
|
| 101 |
+
"hidden_size": 2560,
|
| 102 |
+
"initializer_range": 0.02,
|
| 103 |
+
"intermediate_size": 9728,
|
| 104 |
+
"layer_norm_eps": 1e-06,
|
| 105 |
+
"max_position_embeddings": 16384,
|
| 106 |
+
"model_type": "molmoact2_text",
|
| 107 |
+
"norm_after": false,
|
| 108 |
+
"num_attention_heads": 32,
|
| 109 |
+
"num_hidden_layers": 36,
|
| 110 |
+
"num_key_value_heads": 8,
|
| 111 |
+
"qk_norm_type": "qwen3",
|
| 112 |
+
"qkv_bias": false,
|
| 113 |
+
"residual_dropout": 0.0,
|
| 114 |
+
"rope_parameters": {
|
| 115 |
+
"rope_theta": 5000000.0,
|
| 116 |
+
"rope_type": "default"
|
| 117 |
+
},
|
| 118 |
+
"rope_scaling_layers": null,
|
| 119 |
+
"rope_theta": 5000000.0,
|
| 120 |
+
"tie_word_embeddings": false,
|
| 121 |
+
"use_cache": true,
|
| 122 |
+
"use_qk_norm": true,
|
| 123 |
+
"vocab_size": 154624
|
| 124 |
+
},
|
| 125 |
+
"tie_word_embeddings": false,
|
| 126 |
+
"transformers_version": "5.3.0",
|
| 127 |
+
"use_frame_special_tokens": true,
|
| 128 |
+
"vit_config": {
|
| 129 |
+
"attention_dropout": 0.0,
|
| 130 |
+
"attn_implementation": "sdpa",
|
| 131 |
+
"float32_attention": true,
|
| 132 |
+
"head_dim": 72,
|
| 133 |
+
"hidden_act": "gelu_pytorch_tanh",
|
| 134 |
+
"hidden_size": 1152,
|
| 135 |
+
"image_default_input_size": [
|
| 136 |
+
378,
|
| 137 |
+
378
|
| 138 |
+
],
|
| 139 |
+
"image_num_pos": 729,
|
| 140 |
+
"image_patch_size": 14,
|
| 141 |
+
"initializer_range": 0.02,
|
| 142 |
+
"intermediate_size": 4304,
|
| 143 |
+
"layer_norm_eps": 1e-06,
|
| 144 |
+
"model_type": "molmoact2",
|
| 145 |
+
"num_attention_heads": 16,
|
| 146 |
+
"num_hidden_layers": 27,
|
| 147 |
+
"num_key_value_heads": 16,
|
| 148 |
+
"residual_dropout": 0.0
|
| 149 |
+
},
|
| 150 |
+
"bos_token_id": 151645,
|
| 151 |
+
"eos_token_id": 151645,
|
| 152 |
+
"pad_token_id": 151643
|
| 153 |
+
}
|
configuration_molmoact2.py
ADDED
|
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MolmoAct2 configuration
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Optional, Any
|
| 6 |
+
|
| 7 |
+
from transformers import PretrainedConfig
|
| 8 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
| 9 |
+
from transformers.utils import logging
|
| 10 |
+
|
| 11 |
+
logger = logging.get_logger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class MolmoAct2VitConfig(PretrainedConfig):
|
| 15 |
+
r"""
|
| 16 |
+
This is the configuration class to store the configuration of a [`MolmoAct2VisionTransformer`].
|
| 17 |
+
It is used to instantiate a `MolmoAct2VisionTransformer` according to the specified arguments,
|
| 18 |
+
defining the model architecture.
|
| 19 |
+
|
| 20 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 21 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 22 |
+
|
| 23 |
+
Example:
|
| 24 |
+
```python
|
| 25 |
+
>>> from transformers import MolmoAct2VitConfig, MolmoAct2VisionTransformer
|
| 26 |
+
|
| 27 |
+
>>> # Initializing a MolmoAct2VitConfig
|
| 28 |
+
>>> configuration = MolmoAct2VitConfig()
|
| 29 |
+
|
| 30 |
+
>>> # Initializing a MolmoAct2VisionTransformer (with random weights)
|
| 31 |
+
>>> model = MolmoAct2VisionTransformer(configuration)
|
| 32 |
+
|
| 33 |
+
>>> # Accessing the model configuration
|
| 34 |
+
>>> configuration = model.config
|
| 35 |
+
```"""
|
| 36 |
+
|
| 37 |
+
model_type = "molmoact2"
|
| 38 |
+
base_config_key = "vit_config"
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
hidden_size: int = 1152,
|
| 43 |
+
intermediate_size: int = 4304,
|
| 44 |
+
num_hidden_layers: int = 27,
|
| 45 |
+
num_attention_heads: int = 16,
|
| 46 |
+
num_key_value_heads: int = 16,
|
| 47 |
+
head_dim: int = 72,
|
| 48 |
+
hidden_act: str = "gelu_pytorch_tanh",
|
| 49 |
+
layer_norm_eps: float = 1e-6,
|
| 50 |
+
image_default_input_size: tuple[int, int] = (378, 378),
|
| 51 |
+
image_patch_size: int = 14,
|
| 52 |
+
image_num_pos: int = 577,
|
| 53 |
+
attention_dropout: float = 0.0,
|
| 54 |
+
residual_dropout: float = 0.0,
|
| 55 |
+
initializer_range: float = 0.02,
|
| 56 |
+
float32_attention: bool = True,
|
| 57 |
+
attn_implementation: str = "eager",
|
| 58 |
+
**kwargs,
|
| 59 |
+
):
|
| 60 |
+
self.attn_implementation = attn_implementation
|
| 61 |
+
super().__init__(
|
| 62 |
+
attn_implementation=attn_implementation,
|
| 63 |
+
**kwargs
|
| 64 |
+
)
|
| 65 |
+
self.hidden_size = hidden_size
|
| 66 |
+
self.intermediate_size = intermediate_size
|
| 67 |
+
self.num_hidden_layers = num_hidden_layers
|
| 68 |
+
self.num_attention_heads = num_attention_heads
|
| 69 |
+
self.num_key_value_heads = num_key_value_heads
|
| 70 |
+
self.head_dim = head_dim
|
| 71 |
+
self.hidden_act = hidden_act
|
| 72 |
+
self.layer_norm_eps = layer_norm_eps
|
| 73 |
+
self.image_default_input_size = image_default_input_size
|
| 74 |
+
self.image_patch_size = image_patch_size
|
| 75 |
+
self.image_num_pos = image_num_pos
|
| 76 |
+
self.attention_dropout = attention_dropout
|
| 77 |
+
self.residual_dropout = residual_dropout
|
| 78 |
+
self.initializer_range = initializer_range
|
| 79 |
+
self.float32_attention = float32_attention
|
| 80 |
+
|
| 81 |
+
@property
|
| 82 |
+
def image_num_patch(self):
|
| 83 |
+
h, w = self.image_default_input_size
|
| 84 |
+
return h // self.image_patch_size, w // self.image_patch_size
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class MolmoAct2AdapterConfig(PretrainedConfig):
|
| 88 |
+
r"""
|
| 89 |
+
This is the configuration class to store the configuration of MolmoAct2Adapter. With MolmoAct2VitConfig,
|
| 90 |
+
It is used to instantiate an MolmoAct2VisionBackbone according to the specified arguments,
|
| 91 |
+
defining the model architecture.
|
| 92 |
+
|
| 93 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 94 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 95 |
+
|
| 96 |
+
Example:
|
| 97 |
+
|
| 98 |
+
```python
|
| 99 |
+
>>> from transformers import MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2VisionBackbone
|
| 100 |
+
|
| 101 |
+
>>> # Initializing a MolmoAct2VitConfig and a MolmoAct2AdapterConfig
|
| 102 |
+
>>> vit_config = MolmoAct2VitConfig()
|
| 103 |
+
>>> adapter_config = MolmoPoolingConfig()
|
| 104 |
+
|
| 105 |
+
>>> # Initializing a MolmoAct2VisionBackbone (with random weights)
|
| 106 |
+
>>> model = MolmoAct2VisionBackbone(vit_config, adapter_config)
|
| 107 |
+
|
| 108 |
+
>>> # Accessing the model configuration
|
| 109 |
+
>>> vit_configuration = model.vit_config
|
| 110 |
+
>>> adapter_configuration = model.adapter_config
|
| 111 |
+
```"""
|
| 112 |
+
|
| 113 |
+
model_type = "molmoact2"
|
| 114 |
+
base_config_key = "adapter_config"
|
| 115 |
+
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
vit_layers: tuple = (-3, -9),
|
| 119 |
+
pooling_attention_mask: bool = False,
|
| 120 |
+
hidden_size: int = 1152,
|
| 121 |
+
num_attention_heads: int = 16,
|
| 122 |
+
num_key_value_heads: int = 16,
|
| 123 |
+
head_dim: int = 72,
|
| 124 |
+
float32_attention: bool = True,
|
| 125 |
+
attention_dropout: float = 0.0,
|
| 126 |
+
residual_dropout: float = 0.0,
|
| 127 |
+
hidden_act: str = "silu",
|
| 128 |
+
intermediate_size: int = 18944,
|
| 129 |
+
text_hidden_size: int = 3584,
|
| 130 |
+
image_feature_dropout: float = 0.0,
|
| 131 |
+
initializer_range: float = 0.02,
|
| 132 |
+
attn_implementation: str = "eager",
|
| 133 |
+
**kwargs,
|
| 134 |
+
):
|
| 135 |
+
self.attn_implementation = attn_implementation
|
| 136 |
+
super().__init__(
|
| 137 |
+
attn_implementation=attn_implementation,
|
| 138 |
+
**kwargs
|
| 139 |
+
)
|
| 140 |
+
self.vit_layers = vit_layers
|
| 141 |
+
self.pooling_attention_mask = pooling_attention_mask
|
| 142 |
+
self.hidden_size = hidden_size
|
| 143 |
+
self.num_attention_heads = num_attention_heads
|
| 144 |
+
self.num_key_value_heads = num_key_value_heads
|
| 145 |
+
self.head_dim = head_dim
|
| 146 |
+
self.float32_attention = float32_attention
|
| 147 |
+
self.attention_dropout = attention_dropout
|
| 148 |
+
self.residual_dropout = residual_dropout
|
| 149 |
+
self.hidden_act = hidden_act
|
| 150 |
+
self.intermediate_size = intermediate_size
|
| 151 |
+
self.text_hidden_size = text_hidden_size
|
| 152 |
+
self.image_feature_dropout = image_feature_dropout
|
| 153 |
+
self.initializer_range = initializer_range
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class MolmoAct2TextConfig(PretrainedConfig):
|
| 157 |
+
r"""
|
| 158 |
+
This is the configuration class to store the configuration of a [`MolmoAct2TextModel`]. It is used to instantiate a
|
| 159 |
+
`MolmoAct2TextModel` according to the specified arguments, defining the model architecture.
|
| 160 |
+
|
| 161 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 162 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 163 |
+
|
| 164 |
+
Example:
|
| 165 |
+
```python
|
| 166 |
+
>>> from transformers import MolmoAct2TextConfig, MolmoAct2TextModel
|
| 167 |
+
|
| 168 |
+
>>> # Initializing a MolmoAct2TextConfig
|
| 169 |
+
>>> configuration = MolmoAct2TextConfig()
|
| 170 |
+
|
| 171 |
+
>>> # Initializing a MolmoAct2TextModel (with random weights)
|
| 172 |
+
>>> model = MolmoAct2TextModel(configuration)
|
| 173 |
+
|
| 174 |
+
>>> # Accessing the model configuration
|
| 175 |
+
>>> configuration = model.config
|
| 176 |
+
```"""
|
| 177 |
+
|
| 178 |
+
model_type = "molmoact2_text"
|
| 179 |
+
base_config_key = "text_config"
|
| 180 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 181 |
+
base_model_tp_plan = {
|
| 182 |
+
"blocks.*.self_attn.att_proj": "colwise",
|
| 183 |
+
"blocks.*.self_attn.attn_out": "rowwise",
|
| 184 |
+
"blocks.*.mlp.ff_proj": "colwise",
|
| 185 |
+
"blocks.*.mlp.ff_out": "rowwise",
|
| 186 |
+
}
|
| 187 |
+
base_model_pp_plan = {
|
| 188 |
+
"wte": (["input_ids"], ["inputs_embeds"]),
|
| 189 |
+
"blocks": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 190 |
+
"ln_f": (["hidden_states"], ["hidden_states"]),
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
def __init__(
|
| 194 |
+
self,
|
| 195 |
+
hidden_size: int = 3584,
|
| 196 |
+
num_attention_heads: int = 28,
|
| 197 |
+
num_key_value_heads: Optional[int] = 4,
|
| 198 |
+
head_dim: int = 128,
|
| 199 |
+
vocab_size: int = 152064,
|
| 200 |
+
additional_vocab_size: int = 128,
|
| 201 |
+
qkv_bias: bool = True,
|
| 202 |
+
num_hidden_layers: int = 48,
|
| 203 |
+
intermediate_size: int = 18944,
|
| 204 |
+
hidden_act: str = "silu",
|
| 205 |
+
embedding_dropout: float=0.0,
|
| 206 |
+
attention_dropout: float=0.0,
|
| 207 |
+
residual_dropout: float = 0.0,
|
| 208 |
+
max_position_embeddings: int = 4096,
|
| 209 |
+
rope_theta: float = 1000000.0,
|
| 210 |
+
rope_scaling: dict[str, Any] = None,
|
| 211 |
+
rope_scaling_layers: Optional[list[int]] = None,
|
| 212 |
+
use_qk_norm: bool = False,
|
| 213 |
+
qk_norm_type: str = "olmo",
|
| 214 |
+
layer_norm_eps: int = 1e-6,
|
| 215 |
+
norm_after: bool = False,
|
| 216 |
+
initializer_range: float = 0.02,
|
| 217 |
+
use_cache=True,
|
| 218 |
+
tie_word_embeddings=False,
|
| 219 |
+
attn_implementation: str = "eager",
|
| 220 |
+
**kwargs,
|
| 221 |
+
):
|
| 222 |
+
self.attn_implementation = attn_implementation
|
| 223 |
+
super().__init__(
|
| 224 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 225 |
+
attn_implementation=attn_implementation,
|
| 226 |
+
**kwargs
|
| 227 |
+
)
|
| 228 |
+
self.hidden_size = hidden_size
|
| 229 |
+
self.num_attention_heads = num_attention_heads
|
| 230 |
+
if num_key_value_heads is None:
|
| 231 |
+
num_key_value_heads = num_attention_heads
|
| 232 |
+
self.num_key_value_heads = num_key_value_heads
|
| 233 |
+
self.head_dim = head_dim
|
| 234 |
+
self.vocab_size = vocab_size
|
| 235 |
+
self.additional_vocab_size = additional_vocab_size
|
| 236 |
+
self.qkv_bias = qkv_bias
|
| 237 |
+
self.num_hidden_layers = num_hidden_layers
|
| 238 |
+
self.intermediate_size = intermediate_size
|
| 239 |
+
self.hidden_act = hidden_act
|
| 240 |
+
self.embedding_dropout = embedding_dropout
|
| 241 |
+
self.attention_dropout = attention_dropout
|
| 242 |
+
self.residual_dropout = residual_dropout
|
| 243 |
+
self.max_position_embeddings = max_position_embeddings
|
| 244 |
+
self.rope_theta = rope_theta
|
| 245 |
+
self.rope_scaling = rope_scaling
|
| 246 |
+
self.rope_scaling_layers = rope_scaling_layers
|
| 247 |
+
self.use_qk_norm = use_qk_norm
|
| 248 |
+
self.qk_norm_type = qk_norm_type
|
| 249 |
+
self.layer_norm_eps = layer_norm_eps
|
| 250 |
+
self.norm_after = norm_after
|
| 251 |
+
self.initializer_range = initializer_range
|
| 252 |
+
self.use_cache = use_cache
|
| 253 |
+
|
| 254 |
+
# Validate the correctness of rotary position embeddings parameters
|
| 255 |
+
rope_config_validation(self)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class MolmoAct2ActionExpertConfig(PretrainedConfig):
|
| 259 |
+
r"""Configuration for the MolmoAct2 modern action expert."""
|
| 260 |
+
|
| 261 |
+
model_type = "molmoact2_action_expert"
|
| 262 |
+
base_config_key = "action_expert_config"
|
| 263 |
+
|
| 264 |
+
def __init__(
|
| 265 |
+
self,
|
| 266 |
+
max_action_horizon: int = 32,
|
| 267 |
+
max_action_dim: int = 32,
|
| 268 |
+
hidden_size: int = 1024,
|
| 269 |
+
num_layers: int = 32,
|
| 270 |
+
num_heads: int = 16,
|
| 271 |
+
mlp_ratio: float = 8.0 / 3.0,
|
| 272 |
+
ffn_multiple_of: int = 256,
|
| 273 |
+
timestep_embed_dim: int = 256,
|
| 274 |
+
dropout: float = 0.0,
|
| 275 |
+
attn_dropout: float = 0.0,
|
| 276 |
+
context_layer_norm: bool = True,
|
| 277 |
+
qk_norm: bool = True,
|
| 278 |
+
qk_norm_eps: float = 1e-6,
|
| 279 |
+
rope: bool = True,
|
| 280 |
+
causal_attn: bool = False,
|
| 281 |
+
**kwargs,
|
| 282 |
+
):
|
| 283 |
+
super().__init__(**kwargs)
|
| 284 |
+
self.max_action_horizon = max_action_horizon
|
| 285 |
+
self.max_action_dim = max_action_dim
|
| 286 |
+
self.hidden_size = hidden_size
|
| 287 |
+
self.num_layers = num_layers
|
| 288 |
+
self.num_heads = num_heads
|
| 289 |
+
self.mlp_ratio = mlp_ratio
|
| 290 |
+
self.ffn_multiple_of = ffn_multiple_of
|
| 291 |
+
self.timestep_embed_dim = timestep_embed_dim
|
| 292 |
+
self.dropout = dropout
|
| 293 |
+
self.attn_dropout = attn_dropout
|
| 294 |
+
self.context_layer_norm = context_layer_norm
|
| 295 |
+
self.qk_norm = qk_norm
|
| 296 |
+
self.qk_norm_eps = qk_norm_eps
|
| 297 |
+
self.rope = rope
|
| 298 |
+
self.causal_attn = causal_attn
|
| 299 |
+
|
| 300 |
+
def to_dict(self):
|
| 301 |
+
output = super().to_dict()
|
| 302 |
+
# These are derived from the parent MolmoAct2Config for HF exports. Keeping
|
| 303 |
+
# them out of the public nested config avoids duplicated sources of truth.
|
| 304 |
+
output.pop("max_action_horizon", None)
|
| 305 |
+
output.pop("max_action_dim", None)
|
| 306 |
+
return output
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class MolmoAct2Config(PretrainedConfig):
|
| 310 |
+
r"""
|
| 311 |
+
This is the configuration class to store the configuration of a [`MolmoAct2ForConditionalGeneration`].
|
| 312 |
+
It is used to instantiate an MolmoAct2 model according to the specified arguments, defining the model architecture.
|
| 313 |
+
|
| 314 |
+
Example:
|
| 315 |
+
|
| 316 |
+
```python
|
| 317 |
+
>>> from transformers import MolmoAct2Config, MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2TextConfig
|
| 318 |
+
|
| 319 |
+
>>> # Initializing a MolmoAct2VitConfig
|
| 320 |
+
>>> vit_config = MolmoAct2VitConfig()
|
| 321 |
+
|
| 322 |
+
>>> # Initializing a MolmoAct2AdapterConfig
|
| 323 |
+
>>> adapter_config = MolmoAct2AdapterConfig()
|
| 324 |
+
|
| 325 |
+
>>> # Initializing a MolmoAct2TextConfig
|
| 326 |
+
>>> text_config = MolmoAct2TextConfig()
|
| 327 |
+
|
| 328 |
+
>>> # Initializing a MolmoAct2Config
|
| 329 |
+
>>> configuration = MolmoAct2Config(
|
| 330 |
+
>>> vit_config=vit_config,
|
| 331 |
+
>>> adapter_config=adapter_config,
|
| 332 |
+
>>> text_config=text_config,
|
| 333 |
+
>>> image_start_token_id=151936,
|
| 334 |
+
>>> image_end_token_id=151937,
|
| 335 |
+
>>> image_patch_id=151938,
|
| 336 |
+
>>> image_col_id=151939,
|
| 337 |
+
>>> low_res_image_start_token_id=151940,
|
| 338 |
+
>>> image_low_res_id=151942,
|
| 339 |
+
>>> frame_start_token_id=151943,
|
| 340 |
+
>>> frame_end_token_id=151944,
|
| 341 |
+
>>> )
|
| 342 |
+
|
| 343 |
+
>>> # Initializing a model
|
| 344 |
+
>>> model = MolmoAct2ForConditionalGeneration(configuration)
|
| 345 |
+
|
| 346 |
+
>>> # Accessing the model configuration
|
| 347 |
+
>>> configuration = model.config
|
| 348 |
+
```"""
|
| 349 |
+
|
| 350 |
+
model_type = "molmoact2"
|
| 351 |
+
sub_configs = {
|
| 352 |
+
"text_config": MolmoAct2TextConfig,
|
| 353 |
+
"vit_config": MolmoAct2VitConfig,
|
| 354 |
+
"adapter_config": MolmoAct2AdapterConfig,
|
| 355 |
+
"action_expert_config": MolmoAct2ActionExpertConfig,
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
def __init__(
|
| 359 |
+
self,
|
| 360 |
+
vit_config: MolmoAct2VitConfig = None,
|
| 361 |
+
adapter_config: MolmoAct2AdapterConfig = None,
|
| 362 |
+
text_config: MolmoAct2TextConfig = None,
|
| 363 |
+
action_expert_config: MolmoAct2ActionExpertConfig = None,
|
| 364 |
+
image_start_token_id: int = None,
|
| 365 |
+
low_res_image_start_token_id: int = None,
|
| 366 |
+
image_end_token_id: int = None,
|
| 367 |
+
image_low_res_id: int = None,
|
| 368 |
+
image_patch_id: int = None,
|
| 369 |
+
image_col_id: int = None,
|
| 370 |
+
frame_start_token_id: int = None,
|
| 371 |
+
frame_end_token_id: int = None,
|
| 372 |
+
use_frame_special_tokens: bool = True,
|
| 373 |
+
initializer_range: float = 0.02,
|
| 374 |
+
add_action_expert: bool = True,
|
| 375 |
+
max_action_dim: int = 32,
|
| 376 |
+
max_action_horizon: int = 30,
|
| 377 |
+
n_obs_steps: int = 30,
|
| 378 |
+
action_mode: str = "both",
|
| 379 |
+
state_format: str = "discrete",
|
| 380 |
+
flow_matching_num_steps: int = 10,
|
| 381 |
+
flow_matching_cutoff: float = 1.0,
|
| 382 |
+
flow_matching_time_offset: float = 0.001,
|
| 383 |
+
flow_matching_time_scale: float = 0.999,
|
| 384 |
+
flow_matching_beta_alpha: float = 1.0,
|
| 385 |
+
flow_matching_beta_beta: float = 1.5,
|
| 386 |
+
mask_action_dim_padding: bool = True,
|
| 387 |
+
enable_depth_reasoning: bool = False,
|
| 388 |
+
depth_mode: int = 2,
|
| 389 |
+
num_depth_codes: int = 100,
|
| 390 |
+
action_expert_depth_gate: bool = False,
|
| 391 |
+
action_expert_depth_gate_per_layer: bool = False,
|
| 392 |
+
action_expert_depth_gate_init_bias: float = -4.0,
|
| 393 |
+
action_output_token_id: int = None,
|
| 394 |
+
action_start_token_id: int = None,
|
| 395 |
+
action_end_token_id: int = None,
|
| 396 |
+
action_token_start_id: int = None,
|
| 397 |
+
num_action_tokens: int = 0,
|
| 398 |
+
depth_output_token_id: int = None,
|
| 399 |
+
depth_start_token_id: int = None,
|
| 400 |
+
depth_end_token_id: int = None,
|
| 401 |
+
depth_token_start_id: int = None,
|
| 402 |
+
num_depth_tokens: int = 0,
|
| 403 |
+
state_start_token_id: int = None,
|
| 404 |
+
state_end_token_id: int = None,
|
| 405 |
+
state_token_start_id: int = None,
|
| 406 |
+
num_state_tokens: int = 0,
|
| 407 |
+
add_setup_tokens: bool = True,
|
| 408 |
+
add_control_tokens: bool = True,
|
| 409 |
+
norm_stats_filename: str = "norm_stats.json",
|
| 410 |
+
**kwargs,
|
| 411 |
+
):
|
| 412 |
+
super().__init__(**kwargs)
|
| 413 |
+
if vit_config is None:
|
| 414 |
+
self.vit_config = MolmoAct2VitConfig()
|
| 415 |
+
elif isinstance(vit_config, dict):
|
| 416 |
+
self.vit_config = MolmoAct2VitConfig(**vit_config)
|
| 417 |
+
else:
|
| 418 |
+
self.vit_config = vit_config
|
| 419 |
+
if adapter_config is None:
|
| 420 |
+
self.adapter_config = MolmoAct2AdapterConfig()
|
| 421 |
+
elif isinstance(adapter_config, dict):
|
| 422 |
+
self.adapter_config = MolmoAct2AdapterConfig(**adapter_config)
|
| 423 |
+
else:
|
| 424 |
+
self.adapter_config = adapter_config
|
| 425 |
+
if text_config is None:
|
| 426 |
+
self.text_config = MolmoAct2TextConfig()
|
| 427 |
+
elif isinstance(text_config, dict):
|
| 428 |
+
self.text_config = MolmoAct2TextConfig(**text_config)
|
| 429 |
+
else:
|
| 430 |
+
self.text_config = text_config
|
| 431 |
+
self.add_action_expert = bool(add_action_expert)
|
| 432 |
+
if not self.add_action_expert:
|
| 433 |
+
self.action_expert_config = None
|
| 434 |
+
elif action_expert_config is None:
|
| 435 |
+
self.action_expert_config = MolmoAct2ActionExpertConfig(
|
| 436 |
+
max_action_horizon=max_action_horizon,
|
| 437 |
+
max_action_dim=max_action_dim,
|
| 438 |
+
num_layers=self.text_config.num_hidden_layers,
|
| 439 |
+
)
|
| 440 |
+
elif isinstance(action_expert_config, dict):
|
| 441 |
+
self.action_expert_config = MolmoAct2ActionExpertConfig(**action_expert_config)
|
| 442 |
+
else:
|
| 443 |
+
self.action_expert_config = action_expert_config
|
| 444 |
+
if self.add_action_expert:
|
| 445 |
+
self.action_expert_config.max_action_dim = int(max_action_dim)
|
| 446 |
+
self.action_expert_config.max_action_horizon = int(max_action_horizon)
|
| 447 |
+
self._validate_release_action_config(
|
| 448 |
+
state_format=state_format,
|
| 449 |
+
)
|
| 450 |
+
self.image_start_token_id = image_start_token_id
|
| 451 |
+
self.low_res_image_start_token_id = low_res_image_start_token_id
|
| 452 |
+
self.image_end_token_id = image_end_token_id
|
| 453 |
+
self.image_low_res_id = image_low_res_id
|
| 454 |
+
self.image_high_res_id = image_patch_id
|
| 455 |
+
self.image_patch_id = image_patch_id
|
| 456 |
+
self.image_col_id = image_col_id
|
| 457 |
+
self.frame_start_token_id = frame_start_token_id
|
| 458 |
+
self.frame_end_token_id = frame_end_token_id
|
| 459 |
+
self.use_frame_special_tokens = use_frame_special_tokens
|
| 460 |
+
self.initializer_range = initializer_range
|
| 461 |
+
self.max_action_dim = max_action_dim
|
| 462 |
+
self.max_action_horizon = max_action_horizon
|
| 463 |
+
self.n_obs_steps = n_obs_steps
|
| 464 |
+
self.action_mode = action_mode
|
| 465 |
+
self.state_format = state_format
|
| 466 |
+
self.flow_matching_num_steps = flow_matching_num_steps
|
| 467 |
+
self.flow_matching_cutoff = flow_matching_cutoff
|
| 468 |
+
self.flow_matching_time_offset = flow_matching_time_offset
|
| 469 |
+
self.flow_matching_time_scale = flow_matching_time_scale
|
| 470 |
+
self.flow_matching_beta_alpha = flow_matching_beta_alpha
|
| 471 |
+
self.flow_matching_beta_beta = flow_matching_beta_beta
|
| 472 |
+
self.mask_action_dim_padding = mask_action_dim_padding
|
| 473 |
+
self.enable_depth_reasoning = enable_depth_reasoning
|
| 474 |
+
self.depth_mode = depth_mode
|
| 475 |
+
self.num_depth_codes = num_depth_codes
|
| 476 |
+
self.action_expert_depth_gate = action_expert_depth_gate
|
| 477 |
+
self.action_expert_depth_gate_per_layer = action_expert_depth_gate_per_layer
|
| 478 |
+
self.action_expert_depth_gate_init_bias = action_expert_depth_gate_init_bias
|
| 479 |
+
self.action_output_token_id = action_output_token_id
|
| 480 |
+
self.action_start_token_id = action_start_token_id
|
| 481 |
+
self.action_end_token_id = action_end_token_id
|
| 482 |
+
self.action_token_start_id = action_token_start_id
|
| 483 |
+
self.num_action_tokens = num_action_tokens
|
| 484 |
+
self.depth_output_token_id = depth_output_token_id
|
| 485 |
+
self.depth_start_token_id = depth_start_token_id
|
| 486 |
+
self.depth_end_token_id = depth_end_token_id
|
| 487 |
+
self.depth_token_start_id = depth_token_start_id
|
| 488 |
+
self.num_depth_tokens = num_depth_tokens
|
| 489 |
+
self.state_start_token_id = state_start_token_id
|
| 490 |
+
self.state_end_token_id = state_end_token_id
|
| 491 |
+
self.state_token_start_id = state_token_start_id
|
| 492 |
+
self.num_state_tokens = num_state_tokens
|
| 493 |
+
self.add_setup_tokens = add_setup_tokens
|
| 494 |
+
self.add_control_tokens = add_control_tokens
|
| 495 |
+
self.norm_stats_filename = norm_stats_filename
|
| 496 |
+
|
| 497 |
+
@staticmethod
|
| 498 |
+
def _validate_release_action_config(
|
| 499 |
+
*,
|
| 500 |
+
state_format: str,
|
| 501 |
+
) -> None:
|
| 502 |
+
if state_format != "discrete":
|
| 503 |
+
raise ValueError("MolmoAct2 HF export supports only state_format='discrete'.")
|
| 504 |
+
|
| 505 |
+
@property
|
| 506 |
+
def image_num_patch(self):
|
| 507 |
+
assert self.vit_config is not None
|
| 508 |
+
return self.vit_config.image_num_patch
|
| 509 |
+
|
| 510 |
+
@property
|
| 511 |
+
def num_attention_heads(self):
|
| 512 |
+
return self.text_config.num_attention_heads
|
| 513 |
+
|
| 514 |
+
@property
|
| 515 |
+
def num_key_value_heads(self):
|
| 516 |
+
return self.text_config.num_key_value_heads
|
| 517 |
+
|
| 518 |
+
@property
|
| 519 |
+
def head_dim(self):
|
| 520 |
+
return self.text_config.head_dim
|
| 521 |
+
|
| 522 |
+
@property
|
| 523 |
+
def num_hidden_layers(self):
|
| 524 |
+
return self.text_config.num_hidden_layers
|
| 525 |
+
|
| 526 |
+
@property
|
| 527 |
+
def hidden_size(self):
|
| 528 |
+
return self.text_config.hidden_size
|
| 529 |
+
|
| 530 |
+
@property
|
| 531 |
+
def vocab_size(self):
|
| 532 |
+
return self.text_config.vocab_size
|
| 533 |
+
|
| 534 |
+
@property
|
| 535 |
+
def max_position_embeddings(self):
|
| 536 |
+
return self.text_config.max_position_embeddings
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
MolmoAct2VitConfig.register_for_auto_class()
|
| 540 |
+
MolmoAct2AdapterConfig.register_for_auto_class()
|
| 541 |
+
MolmoAct2TextConfig.register_for_auto_class()
|
| 542 |
+
MolmoAct2ActionExpertConfig.register_for_auto_class()
|
| 543 |
+
MolmoAct2Config.register_for_auto_class()
|
generation_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 151645,
|
| 3 |
+
"eos_token_id": 151645,
|
| 4 |
+
"pad_token_id": 151643,
|
| 5 |
+
"transformers_version": "5.3.0"
|
| 6 |
+
}
|
image_processing_molmoact2.py
ADDED
|
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Image processor class for MolmoAct2"""
|
| 2 |
+
from typing import Optional, Union
|
| 3 |
+
import numpy as np
|
| 4 |
+
import einops
|
| 5 |
+
import torch
|
| 6 |
+
import torchvision.transforms
|
| 7 |
+
|
| 8 |
+
from transformers.image_utils import (
|
| 9 |
+
IMAGENET_STANDARD_MEAN,
|
| 10 |
+
IMAGENET_STANDARD_STD,
|
| 11 |
+
ImageInput,
|
| 12 |
+
PILImageResampling,
|
| 13 |
+
make_flat_list_of_images,
|
| 14 |
+
valid_images,
|
| 15 |
+
to_numpy_array,
|
| 16 |
+
)
|
| 17 |
+
from transformers.image_transforms import convert_to_rgb
|
| 18 |
+
from transformers.processing_utils import ImagesKwargs
|
| 19 |
+
from transformers.image_processing_utils import BaseImageProcessor, get_size_dict
|
| 20 |
+
from transformers.utils import logging
|
| 21 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 22 |
+
from transformers.utils import TensorType, logging
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def normalize_image(
|
| 29 |
+
image: np.ndarray,
|
| 30 |
+
image_mean: list[float],
|
| 31 |
+
image_std: list[float],
|
| 32 |
+
) -> np.ndarray:
|
| 33 |
+
if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
|
| 34 |
+
return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
|
| 35 |
+
image -= np.array(image_mean, dtype=np.float32)[None, None, :]
|
| 36 |
+
image /= np.array(image_std, dtype=np.float32)[None, None, :]
|
| 37 |
+
return image
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def resize_image(
|
| 41 |
+
image: np.ndarray,
|
| 42 |
+
desired_output_size: list[int],
|
| 43 |
+
resample: PILImageResampling,
|
| 44 |
+
) -> np.ndarray:
|
| 45 |
+
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
|
| 46 |
+
dtype = image.dtype
|
| 47 |
+
if torch.is_floating_point(image):
|
| 48 |
+
in_min = 0.0
|
| 49 |
+
in_max = 1.0
|
| 50 |
+
resized = torchvision.transforms.Resize(
|
| 51 |
+
desired_output_size,
|
| 52 |
+
resample,
|
| 53 |
+
antialias=False,
|
| 54 |
+
)(image)
|
| 55 |
+
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
|
| 56 |
+
else:
|
| 57 |
+
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(image.dtype)
|
| 58 |
+
in_min = 0.0
|
| 59 |
+
in_max = 255.0
|
| 60 |
+
resized = torchvision.transforms.Resize(
|
| 61 |
+
desired_output_size,
|
| 62 |
+
resample,
|
| 63 |
+
antialias=False,
|
| 64 |
+
)(image)
|
| 65 |
+
resized = torch.clip(resized, 0, 255).to(dtype)
|
| 66 |
+
|
| 67 |
+
resized = resized.to(torch.float32)
|
| 68 |
+
resized = (resized - in_min) / (in_max - in_min)
|
| 69 |
+
|
| 70 |
+
resized = torch.permute(resized, [1, 2, 0]).numpy()
|
| 71 |
+
|
| 72 |
+
return resized
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def select_tiling(h, w, patch_size, max_num_crops):
|
| 76 |
+
"""Divide in image of size [w, h] in up to max_num_patches of size patch_size"""
|
| 77 |
+
original_size = np.stack([h, w]) # [1, 2]
|
| 78 |
+
original_res = h * w
|
| 79 |
+
tilings = []
|
| 80 |
+
for i in range(1, max_num_crops + 1):
|
| 81 |
+
for j in range(1, max_num_crops + 1):
|
| 82 |
+
if i*j <= max_num_crops:
|
| 83 |
+
tilings.append((i, j))
|
| 84 |
+
# sort so argmin and argmax favour smaller tilings in the event of a tie
|
| 85 |
+
tilings.sort(key=lambda x: (x[0]*x[1], x[0]))
|
| 86 |
+
candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
|
| 87 |
+
candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
|
| 88 |
+
|
| 89 |
+
# How much we would need to scale the image to fit exactly in each tiling
|
| 90 |
+
original_size = np.stack([h, w], dtype=np.float32) # [1, 2]
|
| 91 |
+
|
| 92 |
+
# The original size can be zero in rare cases if the image is smaller than the margin
|
| 93 |
+
# In those cases letting the scale become infinite means the tiling is based on the
|
| 94 |
+
# other side, or falls back to the smallest tiling
|
| 95 |
+
with np.errstate(divide='ignore'):
|
| 96 |
+
required_scale_d = candidate_resolutions.astype(np.float32) / original_size,
|
| 97 |
+
required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
|
| 98 |
+
if np.all(required_scale < 1):
|
| 99 |
+
# We are forced to downscale, so try to minimize the amount of downscaling
|
| 100 |
+
ix = np.argmax(required_scale)
|
| 101 |
+
else:
|
| 102 |
+
# Pick the resolution that required the least upscaling so that it most closely fits the image
|
| 103 |
+
required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
|
| 104 |
+
ix = np.argmin(required_scale)
|
| 105 |
+
return candidate_tilings[ix]
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def build_resized_image(
|
| 109 |
+
image: np.ndarray,
|
| 110 |
+
base_image_input_size: list[int],
|
| 111 |
+
resample: PILImageResampling,
|
| 112 |
+
image_mean: list[float],
|
| 113 |
+
image_std: list[float],
|
| 114 |
+
image_patch_size: int,
|
| 115 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 116 |
+
resized = resize_image(
|
| 117 |
+
image, base_image_input_size, resample,
|
| 118 |
+
)
|
| 119 |
+
resized = normalize_image(resized, image_mean, image_std)
|
| 120 |
+
if len(resized.shape) == 3:
|
| 121 |
+
resized = np.expand_dims(resized, 0)
|
| 122 |
+
crop_patch_w = base_image_input_size[1] // image_patch_size
|
| 123 |
+
crop_patch_h = base_image_input_size[0] // image_patch_size
|
| 124 |
+
resize_idx = np.arange(crop_patch_w*crop_patch_h).reshape([crop_patch_h, crop_patch_w])
|
| 125 |
+
return resized, resize_idx
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def build_overlapping_crops(
|
| 129 |
+
image: np.ndarray,
|
| 130 |
+
max_crops: int,
|
| 131 |
+
overlap_margins: list[int],
|
| 132 |
+
base_image_input_size: list[int],
|
| 133 |
+
resample: PILImageResampling,
|
| 134 |
+
image_mean: list[float],
|
| 135 |
+
image_std: list[float],
|
| 136 |
+
image_patch_size: int,
|
| 137 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 138 |
+
"""Decompose an image into a set of overlapping crops
|
| 139 |
+
|
| 140 |
+
:return crop_arr: [n_crops, h, w, 3] The crops
|
| 141 |
+
:return patch_idx: [overlap_patch_h, overlap_patch_w] For each patch in the resized image
|
| 142 |
+
the crops were extracted from, what patch in `crop_arr` it corresponds to
|
| 143 |
+
"""
|
| 144 |
+
original_image_h, original_image_w = image.shape[:2]
|
| 145 |
+
crop_size = base_image_input_size[0]
|
| 146 |
+
assert base_image_input_size[0] == base_image_input_size[1]
|
| 147 |
+
|
| 148 |
+
left_margin, right_margin = overlap_margins
|
| 149 |
+
total_margin_pixels = image_patch_size * (right_margin + left_margin) # pixels removed per dim
|
| 150 |
+
crop_patches = base_image_input_size[0] // image_patch_size # patches per crop dim
|
| 151 |
+
crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
|
| 152 |
+
crop_window_size = crop_window_patches * image_patch_size
|
| 153 |
+
crop_patch_w = base_image_input_size[1] // image_patch_size
|
| 154 |
+
crop_patch_h = base_image_input_size[0] // image_patch_size
|
| 155 |
+
original_image_h, original_image_w = image.shape[:2]
|
| 156 |
+
crop_size = base_image_input_size[0]
|
| 157 |
+
|
| 158 |
+
# Decide how to tile the image, to account for the overlap margins we compute the tiling
|
| 159 |
+
# as if we had an image without the margins and were using a crop size without the margins
|
| 160 |
+
tiling = select_tiling(
|
| 161 |
+
original_image_h - total_margin_pixels,
|
| 162 |
+
original_image_w - total_margin_pixels,
|
| 163 |
+
crop_window_size,
|
| 164 |
+
max_crops,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
src = resize_image(
|
| 168 |
+
image,
|
| 169 |
+
[tiling[0]*crop_window_size+total_margin_pixels, tiling[1]*crop_window_size+total_margin_pixels],
|
| 170 |
+
resample,
|
| 171 |
+
)
|
| 172 |
+
src = normalize_image(src, image_mean, image_std)
|
| 173 |
+
|
| 174 |
+
# Now we have to split the image into crops, and track what patches came from
|
| 175 |
+
# where in `patch_idx_arr`
|
| 176 |
+
n_crops = tiling[0] * tiling[1]
|
| 177 |
+
crop_arr = np.zeros([n_crops, crop_size, crop_size, 3], dtype=src.dtype)
|
| 178 |
+
patch_idx_arr = np.zeros([n_crops, crop_patch_h, crop_patch_w], dtype=np.int32)
|
| 179 |
+
on_crop = 0
|
| 180 |
+
for i in range(tiling[0]):
|
| 181 |
+
# Slide over `src` by `crop_window_size` steps, but extract crops of size `crops_size`
|
| 182 |
+
# which results in overlapping crop windows
|
| 183 |
+
y0 = i*crop_window_size
|
| 184 |
+
for j in range(tiling[1]):
|
| 185 |
+
x0 = j*crop_window_size
|
| 186 |
+
crop_arr[on_crop] = src[y0:y0+crop_size, x0:x0+crop_size]
|
| 187 |
+
patch_idx = np.arange(crop_patch_w*crop_patch_h).reshape(crop_patch_h, crop_patch_w)
|
| 188 |
+
patch_idx += on_crop * crop_patch_h * crop_patch_w
|
| 189 |
+
|
| 190 |
+
# Mask out idx that are in the overlap region
|
| 191 |
+
if i != 0:
|
| 192 |
+
patch_idx[:left_margin, :] = -1
|
| 193 |
+
if j != 0:
|
| 194 |
+
patch_idx[:, :left_margin] = -1
|
| 195 |
+
if i != tiling[0]-1:
|
| 196 |
+
patch_idx[-right_margin:, :] = -1
|
| 197 |
+
if j != tiling[1]-1:
|
| 198 |
+
patch_idx[:, -right_margin:] = -1
|
| 199 |
+
patch_idx_arr[on_crop] = patch_idx
|
| 200 |
+
on_crop += 1
|
| 201 |
+
|
| 202 |
+
# `patch_idx_arr` is ordered crop-by-crop, here we transpose `patch_idx_arr`
|
| 203 |
+
# so it is ordered left-to-right order
|
| 204 |
+
patch_idx_arr = np.reshape(
|
| 205 |
+
patch_idx_arr,
|
| 206 |
+
[tiling[0], tiling[1], crop_patch_h, crop_patch_w]
|
| 207 |
+
)
|
| 208 |
+
patch_idx_arr = np.transpose(patch_idx_arr, [0, 2, 1, 3])
|
| 209 |
+
patch_idx_arr = np.reshape(patch_idx_arr, [-1])
|
| 210 |
+
|
| 211 |
+
# Now get the parts not in the overlap region, so it should map each patch in `src`
|
| 212 |
+
# to the correct patch it should come from in `crop_arr`
|
| 213 |
+
patch_idx_arr = patch_idx_arr[patch_idx_arr >= 0].reshape(
|
| 214 |
+
src.shape[0]//image_patch_size,
|
| 215 |
+
src.shape[1]//image_patch_size,
|
| 216 |
+
)
|
| 217 |
+
return crop_arr, patch_idx_arr
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
|
| 221 |
+
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
|
| 222 |
+
if len(array.shape) == 3:
|
| 223 |
+
n_crops, h, w = array.shape
|
| 224 |
+
h_patches = h//patch_size
|
| 225 |
+
w_patches = w//patch_size
|
| 226 |
+
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
|
| 227 |
+
array = np.transpose(array, [0, 1, 3, 2, 4])
|
| 228 |
+
array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size])
|
| 229 |
+
return array
|
| 230 |
+
else:
|
| 231 |
+
n_crops, h, w, c = array.shape
|
| 232 |
+
h_patches = h//patch_size
|
| 233 |
+
w_patches = w//patch_size
|
| 234 |
+
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
|
| 235 |
+
array = np.transpose(array, [0, 1, 3, 2, 4, 5])
|
| 236 |
+
array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size*c])
|
| 237 |
+
return array
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def arange_for_pooling(
|
| 241 |
+
idx_arr: np.ndarray,
|
| 242 |
+
pool_h: int,
|
| 243 |
+
pool_w: int,
|
| 244 |
+
) -> np.ndarray:
|
| 245 |
+
h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
|
| 246 |
+
w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
|
| 247 |
+
idx_arr = np.pad(idx_arr, [[h_pad//2, (h_pad+1)//2], [w_pad//2, (w_pad+1)//2]],
|
| 248 |
+
mode='constant',constant_values=-1)
|
| 249 |
+
return einops.rearrange(
|
| 250 |
+
idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def image_to_patches_and_grids(
|
| 254 |
+
image: np.ndarray,
|
| 255 |
+
max_crops: int,
|
| 256 |
+
overlap_margins: list[int],
|
| 257 |
+
base_image_input_size: list[int],
|
| 258 |
+
resample: PILImageResampling,
|
| 259 |
+
image_mean: list[float],
|
| 260 |
+
image_std: list[float],
|
| 261 |
+
image_patch_size: int,
|
| 262 |
+
image_pooling_w: int,
|
| 263 |
+
image_pooling_h: int,
|
| 264 |
+
crop_mode: str = "overlap-and-resize-c2",
|
| 265 |
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 266 |
+
"""
|
| 267 |
+
:return image_grids, the shape of each (low-res, high-res) image after pooling
|
| 268 |
+
:return crops, the image crops to processes with the ViT
|
| 269 |
+
:return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
|
| 270 |
+
patches in `crops` to pool for that token, masked with -1
|
| 271 |
+
"""
|
| 272 |
+
if isinstance(base_image_input_size, int):
|
| 273 |
+
base_image_input_size = (base_image_input_size, base_image_input_size)
|
| 274 |
+
|
| 275 |
+
base_image_input_d = image_patch_size
|
| 276 |
+
pooling_w = image_pooling_w
|
| 277 |
+
pooling_h = image_pooling_h
|
| 278 |
+
crop_patch_w = base_image_input_size[1] // base_image_input_d
|
| 279 |
+
crop_patch_h = base_image_input_size[0] // base_image_input_d
|
| 280 |
+
|
| 281 |
+
if crop_mode == "resize":
|
| 282 |
+
resized, resize_idx = build_resized_image(
|
| 283 |
+
image,
|
| 284 |
+
base_image_input_size,
|
| 285 |
+
resample,
|
| 286 |
+
image_mean,
|
| 287 |
+
image_std,
|
| 288 |
+
image_patch_size,
|
| 289 |
+
)
|
| 290 |
+
resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
|
| 291 |
+
resized_h, resized_w = resize_idx.shape[:2]
|
| 292 |
+
resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w])
|
| 293 |
+
image_grid = [np.array([resized_h, resized_w, 0, 0])]
|
| 294 |
+
return (
|
| 295 |
+
np.stack(image_grid, 0),
|
| 296 |
+
batch_pixels_to_patches(resized, image_patch_size),
|
| 297 |
+
resize_idx,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
if crop_mode not in {"overlap-and-resize-c2", "overlap-and-resize"}:
|
| 301 |
+
raise ValueError(f"Unsupported MolmoAct2 image crop_mode {crop_mode!r}.")
|
| 302 |
+
|
| 303 |
+
crop_arr, patch_idx_arr = build_overlapping_crops(
|
| 304 |
+
image,
|
| 305 |
+
max_crops,
|
| 306 |
+
overlap_margins,
|
| 307 |
+
base_image_input_size,
|
| 308 |
+
resample,
|
| 309 |
+
image_mean,
|
| 310 |
+
image_std,
|
| 311 |
+
image_patch_size,
|
| 312 |
+
)
|
| 313 |
+
pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w)
|
| 314 |
+
h, w = pooling_idx.shape[:2]
|
| 315 |
+
pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w])
|
| 316 |
+
|
| 317 |
+
# Finally do the same for the global image
|
| 318 |
+
resized, resize_idx = build_resized_image(
|
| 319 |
+
image,
|
| 320 |
+
base_image_input_size,
|
| 321 |
+
resample,
|
| 322 |
+
image_mean,
|
| 323 |
+
image_std,
|
| 324 |
+
image_patch_size,
|
| 325 |
+
)
|
| 326 |
+
crop_arr = np.concatenate([resized, crop_arr], 0)
|
| 327 |
+
|
| 328 |
+
resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
|
| 329 |
+
resized_h, resized_w = resize_idx.shape[:2]
|
| 330 |
+
resize_idx = resize_idx.reshape([-1, pooling_h*pooling_w])
|
| 331 |
+
|
| 332 |
+
# Global image goes first, so the order of patches in previous crops gets increased
|
| 333 |
+
pooling_idx = np.where(
|
| 334 |
+
pooling_idx >= 0,
|
| 335 |
+
pooling_idx + crop_patch_h*crop_patch_w,
|
| 336 |
+
-1
|
| 337 |
+
)
|
| 338 |
+
pooling_idx = np.concatenate([resize_idx, pooling_idx])
|
| 339 |
+
image_grid = [np.array([resized_h, resized_w, h, w])]
|
| 340 |
+
|
| 341 |
+
return (
|
| 342 |
+
np.stack(image_grid, 0),
|
| 343 |
+
batch_pixels_to_patches(crop_arr, image_patch_size),
|
| 344 |
+
pooling_idx
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
class MolmoAct2ImagesKwargs(ImagesKwargs, total=False):
|
| 349 |
+
max_crops: Optional[int]
|
| 350 |
+
overlap_margins: Optional[list[int]]
|
| 351 |
+
crop_mode: Optional[str]
|
| 352 |
+
patch_size: Optional[int]
|
| 353 |
+
pooling_size: Optional[list[int]]
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class MolmoAct2ImageProcessor(BaseImageProcessor):
|
| 357 |
+
r"""
|
| 358 |
+
Constructs a MolmoAct2 image processor that preprocesses images for the model.
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
size (`dict[str, int]` *optional*, defaults to `{"height": 378, "width": 378}`):
|
| 362 |
+
Size of the image after resizing.
|
| 363 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
| 364 |
+
Resampling filter to use when resizing the image.
|
| 365 |
+
image_mean (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
| 366 |
+
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
| 367 |
+
image_std (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
| 368 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
| 369 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
| 370 |
+
Whether to convert the image to RGB.
|
| 371 |
+
max_crops (`int`, *optional*, defaults to `8`):
|
| 372 |
+
Maximum number of crops to use per image.
|
| 373 |
+
overlap_margins (`list[int]`, *optional*, defaults to `[4, 4]`):
|
| 374 |
+
Overlap margins to use.
|
| 375 |
+
patch_size (`int`, *optional*, defaults to 14):
|
| 376 |
+
The spatial patch size of the vision encoder.
|
| 377 |
+
pooling_size (`list[int]`, *optional*, defaults to `[2, 2]`):
|
| 378 |
+
The pooling size of the vision adapter.
|
| 379 |
+
"""
|
| 380 |
+
|
| 381 |
+
model_input_names = ["pixel_values", "image_token_pooling", "image_grids", "image_num_crops"]
|
| 382 |
+
|
| 383 |
+
def __init__(
|
| 384 |
+
self,
|
| 385 |
+
size: Optional[dict[str, int]] = None,
|
| 386 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 387 |
+
image_mean: Optional[Union[float, list[float]]] = None,
|
| 388 |
+
image_std: Optional[Union[float, list[float]]] = None,
|
| 389 |
+
do_convert_rgb: bool = True,
|
| 390 |
+
max_crops: int = 8,
|
| 391 |
+
overlap_margins: list[int] = [4, 4],
|
| 392 |
+
crop_mode: str = "overlap-and-resize-c2",
|
| 393 |
+
patch_size: int = 14,
|
| 394 |
+
pooling_size: list[int] = [2, 2],
|
| 395 |
+
**kwargs,
|
| 396 |
+
) -> None:
|
| 397 |
+
super().__init__(**kwargs)
|
| 398 |
+
size = size if size is not None else {"height": 378, "width": 378}
|
| 399 |
+
size = get_size_dict(size, default_to_square=True)
|
| 400 |
+
self.size = size
|
| 401 |
+
|
| 402 |
+
self.resample = resample
|
| 403 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
| 404 |
+
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
| 405 |
+
self.do_convert_rgb = do_convert_rgb
|
| 406 |
+
|
| 407 |
+
self.max_crops = max_crops
|
| 408 |
+
self.overlap_margins = overlap_margins
|
| 409 |
+
self.crop_mode = crop_mode
|
| 410 |
+
self.patch_size = patch_size
|
| 411 |
+
self.pooling_size = pooling_size
|
| 412 |
+
|
| 413 |
+
def preprocess(
|
| 414 |
+
self,
|
| 415 |
+
images: ImageInput,
|
| 416 |
+
size: Optional[dict[str, int]] = None,
|
| 417 |
+
resample: Optional[PILImageResampling] = None,
|
| 418 |
+
image_mean: Optional[Union[float, list[float]]] = None,
|
| 419 |
+
image_std: Optional[Union[float, list[float]]] = None,
|
| 420 |
+
do_convert_rgb: Optional[bool] = None,
|
| 421 |
+
max_crops: Optional[int] = None,
|
| 422 |
+
overlap_margins: Optional[list[int]] = None,
|
| 423 |
+
crop_mode: Optional[str] = None,
|
| 424 |
+
patch_size: Optional[int] = None,
|
| 425 |
+
pooling_size: Optional[list[int]] = None,
|
| 426 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 427 |
+
**kwargs,
|
| 428 |
+
) -> BatchFeature:
|
| 429 |
+
"""
|
| 430 |
+
Args:
|
| 431 |
+
images (`ImageInput`):
|
| 432 |
+
Image to preprocess.
|
| 433 |
+
size (`dict[str, int]`, *optional*, defaults to `self.size`):
|
| 434 |
+
Size of the image after resizing.
|
| 435 |
+
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
| 436 |
+
Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
|
| 437 |
+
has an effect if `do_resize` is set to `True`.
|
| 438 |
+
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
|
| 439 |
+
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
| 440 |
+
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
|
| 441 |
+
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
| 442 |
+
`True`.
|
| 443 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 444 |
+
Whether to convert the image to RGB.
|
| 445 |
+
max_crops (`int`, *optional*, defaults to `self.max_crops`):
|
| 446 |
+
Maximum number of crops to use per image.
|
| 447 |
+
overlap_margins (`list[int]`, *optional*, defaults to `self.overlap_margins`):
|
| 448 |
+
Overlap margins to use.
|
| 449 |
+
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
| 450 |
+
The spatial patch size of the vision encoder.
|
| 451 |
+
pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
|
| 452 |
+
The pooling size of the vision adapter.
|
| 453 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 454 |
+
The type of tensors to return. Can be one of:
|
| 455 |
+
- Unset: Return a list of `np.ndarray`.
|
| 456 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 457 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 458 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 459 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 460 |
+
|
| 461 |
+
Returns:
|
| 462 |
+
A `BatchFeature` containing the following keys:
|
| 463 |
+
- `pixel_values`: The preprocessed images.
|
| 464 |
+
- `image_token_pooling`: The indices of the patches in `crops` to pool for each token in `image_tokens`.
|
| 465 |
+
- `image_grids`: The image grids.
|
| 466 |
+
- `image_num_crops`: The number of crops for each image.
|
| 467 |
+
"""
|
| 468 |
+
if size is not None:
|
| 469 |
+
if "height" not in size or "width" not in size:
|
| 470 |
+
raise ValueError("size must contain 'height' and 'width' keys.")
|
| 471 |
+
else:
|
| 472 |
+
size = {**self.size}
|
| 473 |
+
|
| 474 |
+
base_image_input_size = [size["height"], size["width"]]
|
| 475 |
+
|
| 476 |
+
resample = resample or self.resample
|
| 477 |
+
image_mean = image_mean or self.image_mean
|
| 478 |
+
image_std = image_std or self.image_std
|
| 479 |
+
do_convert_rgb = do_convert_rgb or self.do_convert_rgb
|
| 480 |
+
|
| 481 |
+
max_crops = max_crops or self.max_crops
|
| 482 |
+
overlap_margins = overlap_margins or self.overlap_margins
|
| 483 |
+
crop_mode = crop_mode or self.crop_mode
|
| 484 |
+
patch_size = patch_size or self.patch_size
|
| 485 |
+
pooling_size = pooling_size or self.pooling_size
|
| 486 |
+
|
| 487 |
+
image_pooling_h, image_pooling_w = pooling_size
|
| 488 |
+
|
| 489 |
+
if images is not None:
|
| 490 |
+
images = self.fetch_images(images)
|
| 491 |
+
images = make_flat_list_of_images(images)
|
| 492 |
+
|
| 493 |
+
if images is not None and not valid_images(images):
|
| 494 |
+
raise ValueError(
|
| 495 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 496 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
if do_convert_rgb:
|
| 500 |
+
images = [convert_to_rgb(image) for image in images]
|
| 501 |
+
|
| 502 |
+
# All transformations expect numpy arrays.
|
| 503 |
+
images = [to_numpy_array(image) for image in images]
|
| 504 |
+
|
| 505 |
+
data = {}
|
| 506 |
+
if images is not None:
|
| 507 |
+
batch_grids = []
|
| 508 |
+
batch_crops = []
|
| 509 |
+
batch_pooled_patches_idx = []
|
| 510 |
+
batch_num_crops = []
|
| 511 |
+
|
| 512 |
+
for image in images:
|
| 513 |
+
image_grid, crops, pooled_idx = image_to_patches_and_grids(
|
| 514 |
+
image,
|
| 515 |
+
max_crops,
|
| 516 |
+
overlap_margins,
|
| 517 |
+
base_image_input_size,
|
| 518 |
+
resample,
|
| 519 |
+
image_mean,
|
| 520 |
+
image_std,
|
| 521 |
+
patch_size,
|
| 522 |
+
image_pooling_w,
|
| 523 |
+
image_pooling_h,
|
| 524 |
+
crop_mode,
|
| 525 |
+
)
|
| 526 |
+
batch_grids.append(image_grid)
|
| 527 |
+
batch_crops.append(crops)
|
| 528 |
+
batch_pooled_patches_idx.append(pooled_idx)
|
| 529 |
+
batch_num_crops.append(crops.shape[0])
|
| 530 |
+
|
| 531 |
+
pixel_values = np.concatenate(batch_crops, 0)
|
| 532 |
+
image_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
|
| 533 |
+
image_grids = np.concatenate(batch_grids, 0)
|
| 534 |
+
image_num_crops = np.array(batch_num_crops)
|
| 535 |
+
|
| 536 |
+
data.update(
|
| 537 |
+
pixel_values=pixel_values,
|
| 538 |
+
image_token_pooling=image_token_pooling,
|
| 539 |
+
image_grids=image_grids,
|
| 540 |
+
image_num_crops=image_num_crops,
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
return BatchFeature(data, tensor_type=return_tensors)
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
MolmoAct2ImageProcessor.register_for_auto_class()
|
inference.py
ADDED
|
@@ -0,0 +1,768 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Inference utilities for MolmoAct2"""
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Iterable, Optional, Sequence, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
from transformers.cache_utils import Cache
|
| 9 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class _ActionFlowInputs:
|
| 14 |
+
trajectory: torch.Tensor
|
| 15 |
+
context: Any
|
| 16 |
+
modulations: Sequence[Any]
|
| 17 |
+
action_dim_is_pad: Optional[torch.Tensor]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class _ActionFlowCudaGraph:
|
| 22 |
+
key: Tuple[Any, ...]
|
| 23 |
+
graph: torch.cuda.CUDAGraph
|
| 24 |
+
static_inputs: _ActionFlowInputs
|
| 25 |
+
output: torch.Tensor
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class _DepthDecodeCudaGraphLayerStage:
|
| 30 |
+
residual: torch.Tensor
|
| 31 |
+
query: torch.Tensor
|
| 32 |
+
key: torch.Tensor
|
| 33 |
+
value: torch.Tensor
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class _DepthDecodeCudaGraphPostStage:
|
| 38 |
+
graph: torch.cuda.CUDAGraph
|
| 39 |
+
attn_context: torch.Tensor
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class _DepthDecodeCudaGraph:
|
| 44 |
+
cache_key: Tuple[Any, ...]
|
| 45 |
+
pre_graph: torch.cuda.CUDAGraph
|
| 46 |
+
token_ids: torch.Tensor
|
| 47 |
+
cos: torch.Tensor
|
| 48 |
+
sin: torch.Tensor
|
| 49 |
+
positions: torch.Tensor
|
| 50 |
+
stages: Sequence[_DepthDecodeCudaGraphLayerStage]
|
| 51 |
+
post_graphs: Sequence[_DepthDecodeCudaGraphPostStage]
|
| 52 |
+
output: torch.Tensor
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dataclass
|
| 56 |
+
class _DepthDecodeCudaGraphSpec:
|
| 57 |
+
eligible: bool
|
| 58 |
+
cache_key_prefix: Tuple[Any, ...]
|
| 59 |
+
num_hidden_layers: int
|
| 60 |
+
head_dim: int
|
| 61 |
+
num_attention_heads: int
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _cache_seq_len_int(past_key_values: Optional[Cache]) -> int:
|
| 65 |
+
if past_key_values is None:
|
| 66 |
+
return 0
|
| 67 |
+
seq_len = past_key_values.get_seq_length()
|
| 68 |
+
if torch.is_tensor(seq_len):
|
| 69 |
+
return int(seq_len.item())
|
| 70 |
+
return int(seq_len)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _cache_max_len_int(past_key_values: Optional[Cache]) -> int:
|
| 74 |
+
if past_key_values is None:
|
| 75 |
+
return -1
|
| 76 |
+
max_len = past_key_values.get_max_cache_shape()
|
| 77 |
+
if torch.is_tensor(max_len):
|
| 78 |
+
return int(max_len.item())
|
| 79 |
+
return int(max_len)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _iter_cache_key_values(
|
| 83 |
+
past_key_values: Cache,
|
| 84 |
+
) -> Iterable[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]:
|
| 85 |
+
layers = getattr(past_key_values, "layers", None)
|
| 86 |
+
if layers is not None:
|
| 87 |
+
for layer in layers:
|
| 88 |
+
yield getattr(layer, "keys", None), getattr(layer, "values", None)
|
| 89 |
+
return
|
| 90 |
+
for layer in past_key_values:
|
| 91 |
+
yield layer[0], layer[1]
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class _DepthDecodeStaticLayerCache:
|
| 95 |
+
is_compileable = False
|
| 96 |
+
is_sliding = False
|
| 97 |
+
|
| 98 |
+
def __init__(self, max_cache_len: int) -> None:
|
| 99 |
+
self.max_cache_len = int(max_cache_len)
|
| 100 |
+
self.cumulative_length = 0
|
| 101 |
+
self.keys: Optional[torch.Tensor] = None
|
| 102 |
+
self.values: Optional[torch.Tensor] = None
|
| 103 |
+
|
| 104 |
+
def _allocate(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
|
| 105 |
+
bsz, n_heads = key_states.shape[:2]
|
| 106 |
+
self.keys = torch.empty(
|
| 107 |
+
(bsz, n_heads, self.max_cache_len, key_states.shape[-1]),
|
| 108 |
+
dtype=key_states.dtype,
|
| 109 |
+
device=key_states.device,
|
| 110 |
+
)
|
| 111 |
+
self.values = torch.empty(
|
| 112 |
+
(bsz, n_heads, self.max_cache_len, value_states.shape[-1]),
|
| 113 |
+
dtype=value_states.dtype,
|
| 114 |
+
device=value_states.device,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def update(
|
| 118 |
+
self,
|
| 119 |
+
key_states: torch.Tensor,
|
| 120 |
+
value_states: torch.Tensor,
|
| 121 |
+
*args,
|
| 122 |
+
**kwargs,
|
| 123 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 124 |
+
if self.keys is None:
|
| 125 |
+
self._allocate(key_states, value_states)
|
| 126 |
+
start = self.cumulative_length
|
| 127 |
+
end = start + key_states.shape[-2]
|
| 128 |
+
if end > self.max_cache_len:
|
| 129 |
+
raise RuntimeError(
|
| 130 |
+
f"KV cache length {end} exceeds max_cache_len={self.max_cache_len}."
|
| 131 |
+
)
|
| 132 |
+
self.keys[:, :, start:end, :].copy_(key_states)
|
| 133 |
+
self.values[:, :, start:end, :].copy_(value_states)
|
| 134 |
+
self.cumulative_length = end
|
| 135 |
+
return self.keys[:, :, :end, :], self.values[:, :, :end, :]
|
| 136 |
+
|
| 137 |
+
def get_seq_length(self) -> int:
|
| 138 |
+
return self.cumulative_length
|
| 139 |
+
|
| 140 |
+
def get_max_cache_shape(self) -> int:
|
| 141 |
+
return -1
|
| 142 |
+
|
| 143 |
+
def reset(self) -> None:
|
| 144 |
+
self.cumulative_length = 0
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class _DepthDecodeStaticCache(Cache):
|
| 148 |
+
def __init__(self, config: PretrainedConfig, max_cache_len: int) -> None:
|
| 149 |
+
text_config = config.get_text_config(decoder=True)
|
| 150 |
+
super().__init__(
|
| 151 |
+
layers=[
|
| 152 |
+
_DepthDecodeStaticLayerCache(max_cache_len=max_cache_len)
|
| 153 |
+
for _ in range(text_config.num_hidden_layers)
|
| 154 |
+
]
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
def get_seq_length(self, layer_idx: int = 0) -> int:
|
| 158 |
+
return self.layers[layer_idx].get_seq_length()
|
| 159 |
+
|
| 160 |
+
def get_max_cache_shape(self, layer_idx: int = 0) -> int:
|
| 161 |
+
return self.layers[layer_idx].get_max_cache_shape()
|
| 162 |
+
|
| 163 |
+
def reset(self) -> None:
|
| 164 |
+
for layer in self.layers:
|
| 165 |
+
layer.reset()
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class ActionCudaGraphManager:
|
| 169 |
+
def __init__(self, model: Any) -> None:
|
| 170 |
+
self.model = model
|
| 171 |
+
self.enabled = True
|
| 172 |
+
self.action_flow_graph: Optional[_ActionFlowCudaGraph] = None
|
| 173 |
+
|
| 174 |
+
def set_enabled(self, enabled: bool) -> None:
|
| 175 |
+
self.enabled = bool(enabled)
|
| 176 |
+
|
| 177 |
+
def can_use_action_flow(self, inputs: _ActionFlowInputs) -> bool:
|
| 178 |
+
action_model = self.model
|
| 179 |
+
if not self.enabled:
|
| 180 |
+
return False
|
| 181 |
+
if action_model.training or action_model._require_action_expert().training:
|
| 182 |
+
return False
|
| 183 |
+
if inputs.trajectory.device.type != "cuda":
|
| 184 |
+
return False
|
| 185 |
+
|
| 186 |
+
def all_on_cuda():
|
| 187 |
+
yield inputs.trajectory
|
| 188 |
+
for k, v in inputs.context.kv_contexts:
|
| 189 |
+
yield k
|
| 190 |
+
yield v
|
| 191 |
+
for t in (
|
| 192 |
+
inputs.context.cross_mask,
|
| 193 |
+
inputs.context.self_mask,
|
| 194 |
+
inputs.context.valid_action,
|
| 195 |
+
inputs.action_dim_is_pad,
|
| 196 |
+
):
|
| 197 |
+
if t is not None:
|
| 198 |
+
yield t
|
| 199 |
+
if inputs.context.rope_cache is not None:
|
| 200 |
+
yield from inputs.context.rope_cache
|
| 201 |
+
for step in inputs.modulations:
|
| 202 |
+
yield step.conditioning
|
| 203 |
+
for block_modulation in step.block_modulations:
|
| 204 |
+
yield from block_modulation
|
| 205 |
+
yield from step.final_modulation
|
| 206 |
+
|
| 207 |
+
return all(t.device.type == "cuda" for t in all_on_cuda())
|
| 208 |
+
|
| 209 |
+
def run_action_flow(
|
| 210 |
+
self,
|
| 211 |
+
inputs: _ActionFlowInputs,
|
| 212 |
+
steps: int,
|
| 213 |
+
run_loop,
|
| 214 |
+
) -> torch.Tensor:
|
| 215 |
+
key = _cuda_graph_key(inputs, steps)
|
| 216 |
+
cache = self.action_flow_graph
|
| 217 |
+
if cache is None or cache.key != key:
|
| 218 |
+
static_inputs = _clone_static_inputs(inputs)
|
| 219 |
+
graph, output = _capture_cuda_graph(
|
| 220 |
+
lambda: run_loop(static_inputs, steps),
|
| 221 |
+
inputs.trajectory.device,
|
| 222 |
+
after_warmup=lambda: static_inputs.trajectory.copy_(inputs.trajectory),
|
| 223 |
+
)
|
| 224 |
+
cache = _ActionFlowCudaGraph(
|
| 225 |
+
key=key,
|
| 226 |
+
graph=graph,
|
| 227 |
+
static_inputs=static_inputs,
|
| 228 |
+
output=output,
|
| 229 |
+
)
|
| 230 |
+
self.action_flow_graph = cache
|
| 231 |
+
else:
|
| 232 |
+
_copy_inputs_(cache.static_inputs, inputs)
|
| 233 |
+
|
| 234 |
+
cache.graph.replay()
|
| 235 |
+
return cache.output.clone()
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class DepthDecodeCudaGraphManager:
|
| 239 |
+
def __init__(self, model: Any) -> None:
|
| 240 |
+
self.model = model
|
| 241 |
+
self.backbone = model.model
|
| 242 |
+
self.enabled = True
|
| 243 |
+
self.graph: Optional[_DepthDecodeCudaGraph] = None
|
| 244 |
+
self.graph_spec: Optional[_DepthDecodeCudaGraphSpec] = None
|
| 245 |
+
|
| 246 |
+
def set_enabled(self, enabled: bool) -> None:
|
| 247 |
+
self.enabled = bool(enabled)
|
| 248 |
+
|
| 249 |
+
def make_static_cache(self, max_cache_len: int) -> _DepthDecodeStaticCache:
|
| 250 |
+
return _DepthDecodeStaticCache(
|
| 251 |
+
config=self.model.config.text_config,
|
| 252 |
+
max_cache_len=max_cache_len,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
def _depth_decode_spec(self) -> _DepthDecodeCudaGraphSpec:
|
| 256 |
+
static = self.graph_spec
|
| 257 |
+
if static is None:
|
| 258 |
+
cfg = self.backbone.transformer.config
|
| 259 |
+
rotary_emb = getattr(self.backbone.transformer, "rotary_emb", None)
|
| 260 |
+
static = _DepthDecodeCudaGraphSpec(
|
| 261 |
+
eligible=(
|
| 262 |
+
not cfg.norm_after
|
| 263 |
+
and cfg.rope_scaling_layers is None
|
| 264 |
+
and getattr(rotary_emb, "rope_type", None) == "default"
|
| 265 |
+
and cfg._attn_implementation == "sdpa"
|
| 266 |
+
),
|
| 267 |
+
cache_key_prefix=(
|
| 268 |
+
cfg.hidden_size,
|
| 269 |
+
cfg.num_attention_heads,
|
| 270 |
+
cfg.num_key_value_heads,
|
| 271 |
+
cfg.head_dim,
|
| 272 |
+
cfg.num_hidden_layers,
|
| 273 |
+
cfg.use_qk_norm,
|
| 274 |
+
cfg.qk_norm_type,
|
| 275 |
+
cfg._attn_implementation,
|
| 276 |
+
),
|
| 277 |
+
num_hidden_layers=cfg.num_hidden_layers,
|
| 278 |
+
head_dim=cfg.head_dim,
|
| 279 |
+
num_attention_heads=cfg.num_attention_heads,
|
| 280 |
+
)
|
| 281 |
+
self.graph_spec = static
|
| 282 |
+
return static
|
| 283 |
+
|
| 284 |
+
def can_use(
|
| 285 |
+
self,
|
| 286 |
+
next_input_ids: torch.Tensor,
|
| 287 |
+
*,
|
| 288 |
+
past_key_values: Cache,
|
| 289 |
+
attention_bias: torch.Tensor,
|
| 290 |
+
) -> bool:
|
| 291 |
+
if (
|
| 292 |
+
not self.enabled
|
| 293 |
+
or self.model.training
|
| 294 |
+
or self.backbone.transformer.training
|
| 295 |
+
):
|
| 296 |
+
return False
|
| 297 |
+
if next_input_ids.device.type != "cuda":
|
| 298 |
+
return False
|
| 299 |
+
if (
|
| 300 |
+
next_input_ids.ndim != 2
|
| 301 |
+
or next_input_ids.shape[0] != 1
|
| 302 |
+
or next_input_ids.shape[1] != 1
|
| 303 |
+
):
|
| 304 |
+
return False
|
| 305 |
+
if not isinstance(past_key_values, _DepthDecodeStaticCache):
|
| 306 |
+
return False
|
| 307 |
+
if (
|
| 308 |
+
not torch.is_tensor(attention_bias)
|
| 309 |
+
or attention_bias.device != next_input_ids.device
|
| 310 |
+
):
|
| 311 |
+
return False
|
| 312 |
+
return self._depth_decode_spec().eligible
|
| 313 |
+
|
| 314 |
+
def _depth_decode_key(
|
| 315 |
+
self,
|
| 316 |
+
next_input_ids: torch.Tensor,
|
| 317 |
+
attention_bias: torch.Tensor,
|
| 318 |
+
) -> Tuple[Any, ...]:
|
| 319 |
+
device = next_input_ids.device
|
| 320 |
+
return (
|
| 321 |
+
self._depth_decode_spec().cache_key_prefix,
|
| 322 |
+
device.type,
|
| 323 |
+
device.index,
|
| 324 |
+
self.model.lm_head.weight.dtype,
|
| 325 |
+
attention_bias.shape[-1],
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
def _select_depth_decode_rope(
|
| 329 |
+
self, cos: torch.Tensor, sin: torch.Tensor, *, past_length: int
|
| 330 |
+
) -> None:
|
| 331 |
+
emb = self.backbone.transformer.rotary_emb
|
| 332 |
+
cos.copy_(emb._pos_cos_cache[0, :, past_length : past_length + 1, :])
|
| 333 |
+
sin.copy_(emb._pos_sin_cache[0, :, past_length : past_length + 1, :])
|
| 334 |
+
|
| 335 |
+
def _depth_decode_pre_layer(
|
| 336 |
+
self,
|
| 337 |
+
layer_idx: int,
|
| 338 |
+
hidden_states: torch.Tensor,
|
| 339 |
+
cos: torch.Tensor,
|
| 340 |
+
sin: torch.Tensor,
|
| 341 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 342 |
+
block = self.backbone.transformer.blocks[layer_idx]
|
| 343 |
+
attention = block.self_attn
|
| 344 |
+
residual = hidden_states
|
| 345 |
+
hidden_states = block.attn_norm(hidden_states)
|
| 346 |
+
|
| 347 |
+
input_shape = hidden_states.shape[:-1]
|
| 348 |
+
hidden_shape = (*input_shape, -1, attention.head_dim)
|
| 349 |
+
qkv = attention.att_proj(hidden_states)
|
| 350 |
+
query_states, key_states, value_states = qkv.split(attention.fused_dims, dim=-1)
|
| 351 |
+
value_states = value_states.view(hidden_shape)
|
| 352 |
+
|
| 353 |
+
apply_qk_norm = attention.q_norm is not None and attention.k_norm is not None
|
| 354 |
+
norm_after_view = apply_qk_norm and attention.qk_norm_type == "qwen3"
|
| 355 |
+
|
| 356 |
+
if apply_qk_norm and not norm_after_view:
|
| 357 |
+
query_states = attention.q_norm(query_states)
|
| 358 |
+
key_states = attention.k_norm(key_states)
|
| 359 |
+
|
| 360 |
+
query_states = query_states.view(hidden_shape)
|
| 361 |
+
key_states = key_states.view(hidden_shape)
|
| 362 |
+
|
| 363 |
+
if norm_after_view:
|
| 364 |
+
query_states = attention.q_norm(query_states)
|
| 365 |
+
key_states = attention.k_norm(key_states)
|
| 366 |
+
|
| 367 |
+
query_states = query_states.transpose(1, 2)
|
| 368 |
+
key_states = key_states.transpose(1, 2)
|
| 369 |
+
value_states = value_states.transpose(1, 2)
|
| 370 |
+
query_states, key_states = _apply_rotary_pos_emb(
|
| 371 |
+
query_states, key_states, cos, sin
|
| 372 |
+
)
|
| 373 |
+
return residual, query_states, key_states, value_states
|
| 374 |
+
|
| 375 |
+
def _depth_decode_pre0(
|
| 376 |
+
self,
|
| 377 |
+
token_ids: torch.Tensor,
|
| 378 |
+
cos: torch.Tensor,
|
| 379 |
+
sin: torch.Tensor,
|
| 380 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 381 |
+
inputs_embeds = self.model._embed_base_tokens(token_ids)
|
| 382 |
+
return self._depth_decode_pre_layer(0, inputs_embeds, cos, sin)
|
| 383 |
+
|
| 384 |
+
def _depth_decode_post_layer(
|
| 385 |
+
self,
|
| 386 |
+
layer_idx: int,
|
| 387 |
+
residual: torch.Tensor,
|
| 388 |
+
attn_context: torch.Tensor,
|
| 389 |
+
) -> torch.Tensor:
|
| 390 |
+
block = self.backbone.transformer.blocks[layer_idx]
|
| 391 |
+
attention = block.self_attn
|
| 392 |
+
input_shape = residual.shape[:-1]
|
| 393 |
+
attn_output = attn_context.reshape(*input_shape, -1).contiguous()
|
| 394 |
+
attn_output = attention.attn_out(attn_output)
|
| 395 |
+
hidden_states = residual + block.dropout(attn_output)
|
| 396 |
+
|
| 397 |
+
residual = hidden_states
|
| 398 |
+
hidden_states = block.ff_norm(hidden_states)
|
| 399 |
+
hidden_states = block.mlp(hidden_states)
|
| 400 |
+
hidden_states = residual + block.dropout(hidden_states)
|
| 401 |
+
return hidden_states
|
| 402 |
+
|
| 403 |
+
def _depth_decode_post_and_pre_next(
|
| 404 |
+
self,
|
| 405 |
+
layer_idx: int,
|
| 406 |
+
residual: torch.Tensor,
|
| 407 |
+
attn_context: torch.Tensor,
|
| 408 |
+
cos: torch.Tensor,
|
| 409 |
+
sin: torch.Tensor,
|
| 410 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 411 |
+
hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
|
| 412 |
+
return self._depth_decode_pre_layer(layer_idx + 1, hidden_states, cos, sin)
|
| 413 |
+
|
| 414 |
+
def _depth_decode_last_post(
|
| 415 |
+
self,
|
| 416 |
+
layer_idx: int,
|
| 417 |
+
residual: torch.Tensor,
|
| 418 |
+
attn_context: torch.Tensor,
|
| 419 |
+
) -> torch.Tensor:
|
| 420 |
+
hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
|
| 421 |
+
return self.backbone.transformer.ln_f(hidden_states)
|
| 422 |
+
|
| 423 |
+
def _build_depth_decode_graph(
|
| 424 |
+
self,
|
| 425 |
+
next_input_ids: torch.Tensor,
|
| 426 |
+
*,
|
| 427 |
+
past_length: int,
|
| 428 |
+
attention_bias: torch.Tensor,
|
| 429 |
+
) -> _DepthDecodeCudaGraph:
|
| 430 |
+
text_config = self.backbone.transformer.config
|
| 431 |
+
device = next_input_ids.device
|
| 432 |
+
dtype = self.model.lm_head.weight.dtype
|
| 433 |
+
static = self._depth_decode_spec()
|
| 434 |
+
num_layers = static.num_hidden_layers
|
| 435 |
+
head_dim = static.head_dim
|
| 436 |
+
max_cache_len = int(attention_bias.shape[-1])
|
| 437 |
+
max_rope_len = max(int(text_config.max_position_embeddings or 0), max_cache_len)
|
| 438 |
+
self.backbone.transformer.prepare_rope_cache(
|
| 439 |
+
device=device, max_seq_len=max_rope_len
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
token_ids = torch.empty((1, 1), device=device, dtype=torch.long)
|
| 443 |
+
cos = torch.empty((1, 1, head_dim), device=device, dtype=dtype)
|
| 444 |
+
sin = torch.empty_like(cos)
|
| 445 |
+
positions = torch.arange(max_cache_len, device=device, dtype=torch.long)
|
| 446 |
+
context_shape = (1, 1, static.num_attention_heads, head_dim)
|
| 447 |
+
|
| 448 |
+
token_ids.copy_(next_input_ids)
|
| 449 |
+
self._select_depth_decode_rope(cos, sin, past_length=past_length)
|
| 450 |
+
|
| 451 |
+
pre_graph, pre_output = _capture_cuda_graph(
|
| 452 |
+
lambda: self._depth_decode_pre0(token_ids, cos, sin),
|
| 453 |
+
device,
|
| 454 |
+
)
|
| 455 |
+
stages = [_DepthDecodeCudaGraphLayerStage(*pre_output)]
|
| 456 |
+
post_graphs = []
|
| 457 |
+
for layer_idx in range(num_layers - 1):
|
| 458 |
+
stage = stages[-1]
|
| 459 |
+
attn_context = torch.empty(context_shape, device=device, dtype=dtype)
|
| 460 |
+
graph, output = _capture_cuda_graph(
|
| 461 |
+
lambda layer_idx=layer_idx, stage=stage, attn_context=attn_context: (
|
| 462 |
+
self._depth_decode_post_and_pre_next(
|
| 463 |
+
layer_idx,
|
| 464 |
+
stage.residual,
|
| 465 |
+
attn_context,
|
| 466 |
+
cos,
|
| 467 |
+
sin,
|
| 468 |
+
)
|
| 469 |
+
),
|
| 470 |
+
device,
|
| 471 |
+
)
|
| 472 |
+
post_graphs.append(
|
| 473 |
+
_DepthDecodeCudaGraphPostStage(graph=graph, attn_context=attn_context)
|
| 474 |
+
)
|
| 475 |
+
stages.append(_DepthDecodeCudaGraphLayerStage(*output))
|
| 476 |
+
|
| 477 |
+
last_stage = stages[-1]
|
| 478 |
+
last_attn_context = torch.empty(context_shape, device=device, dtype=dtype)
|
| 479 |
+
last_graph, last_output = _capture_cuda_graph(
|
| 480 |
+
lambda: self._depth_decode_last_post(
|
| 481 |
+
num_layers - 1,
|
| 482 |
+
last_stage.residual,
|
| 483 |
+
last_attn_context,
|
| 484 |
+
),
|
| 485 |
+
device,
|
| 486 |
+
)
|
| 487 |
+
post_graphs.append(
|
| 488 |
+
_DepthDecodeCudaGraphPostStage(
|
| 489 |
+
graph=last_graph, attn_context=last_attn_context
|
| 490 |
+
)
|
| 491 |
+
)
|
| 492 |
+
return _DepthDecodeCudaGraph(
|
| 493 |
+
cache_key=self._depth_decode_key(next_input_ids, attention_bias),
|
| 494 |
+
pre_graph=pre_graph,
|
| 495 |
+
token_ids=token_ids,
|
| 496 |
+
cos=cos,
|
| 497 |
+
sin=sin,
|
| 498 |
+
positions=positions,
|
| 499 |
+
stages=tuple(stages),
|
| 500 |
+
post_graphs=tuple(post_graphs),
|
| 501 |
+
output=last_output,
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
def _get_depth_decode_graph(
|
| 505 |
+
self,
|
| 506 |
+
next_input_ids: torch.Tensor,
|
| 507 |
+
*,
|
| 508 |
+
past_length: int,
|
| 509 |
+
attention_bias: torch.Tensor,
|
| 510 |
+
) -> _DepthDecodeCudaGraph:
|
| 511 |
+
key = self._depth_decode_key(next_input_ids, attention_bias)
|
| 512 |
+
decode_graph = self.graph
|
| 513 |
+
if decode_graph is None or decode_graph.cache_key != key:
|
| 514 |
+
decode_graph = self._build_depth_decode_graph(
|
| 515 |
+
next_input_ids,
|
| 516 |
+
past_length=past_length,
|
| 517 |
+
attention_bias=attention_bias,
|
| 518 |
+
)
|
| 519 |
+
self.graph = decode_graph
|
| 520 |
+
else:
|
| 521 |
+
decode_graph.token_ids.copy_(next_input_ids)
|
| 522 |
+
self._select_depth_decode_rope(
|
| 523 |
+
decode_graph.cos, decode_graph.sin, past_length=past_length
|
| 524 |
+
)
|
| 525 |
+
return decode_graph
|
| 526 |
+
|
| 527 |
+
def _run_depth_decode_attention_core(
|
| 528 |
+
self,
|
| 529 |
+
layer_idx: int,
|
| 530 |
+
stage: _DepthDecodeCudaGraphLayerStage,
|
| 531 |
+
*,
|
| 532 |
+
past_key_values: Cache,
|
| 533 |
+
attention_bias: torch.Tensor,
|
| 534 |
+
cache_position: torch.Tensor,
|
| 535 |
+
cos: torch.Tensor,
|
| 536 |
+
sin: torch.Tensor,
|
| 537 |
+
) -> torch.Tensor:
|
| 538 |
+
attention = self.backbone.transformer.blocks[layer_idx].self_attn
|
| 539 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 540 |
+
key_states, value_states = past_key_values.update(
|
| 541 |
+
stage.key,
|
| 542 |
+
stage.value,
|
| 543 |
+
layer_idx,
|
| 544 |
+
cache_kwargs,
|
| 545 |
+
)
|
| 546 |
+
key_states = _repeat_kv(key_states, attention.num_key_value_groups)
|
| 547 |
+
value_states = _repeat_kv(value_states, attention.num_key_value_groups)
|
| 548 |
+
attn_output = F.scaled_dot_product_attention(
|
| 549 |
+
stage.query,
|
| 550 |
+
key_states,
|
| 551 |
+
value_states,
|
| 552 |
+
attn_mask=attention_bias,
|
| 553 |
+
dropout_p=0.0,
|
| 554 |
+
is_causal=False,
|
| 555 |
+
)
|
| 556 |
+
return attn_output.transpose(1, 2)
|
| 557 |
+
|
| 558 |
+
def run(
|
| 559 |
+
self,
|
| 560 |
+
next_input_ids: torch.Tensor,
|
| 561 |
+
*,
|
| 562 |
+
past_key_values: Cache,
|
| 563 |
+
attention_bias: torch.Tensor,
|
| 564 |
+
past_length: int,
|
| 565 |
+
) -> Tuple[torch.Tensor, Cache]:
|
| 566 |
+
end = past_length + 1
|
| 567 |
+
decode_graph = self._get_depth_decode_graph(
|
| 568 |
+
next_input_ids,
|
| 569 |
+
past_length=past_length,
|
| 570 |
+
attention_bias=attention_bias,
|
| 571 |
+
)
|
| 572 |
+
cache_position = decode_graph.positions[past_length:end]
|
| 573 |
+
attention_bias_q = attention_bias[:, :, past_length:end, :end]
|
| 574 |
+
|
| 575 |
+
decode_graph.pre_graph.replay()
|
| 576 |
+
|
| 577 |
+
for layer_idx, post_graph in enumerate(decode_graph.post_graphs):
|
| 578 |
+
attn_context = self._run_depth_decode_attention_core(
|
| 579 |
+
layer_idx,
|
| 580 |
+
decode_graph.stages[layer_idx],
|
| 581 |
+
past_key_values=past_key_values,
|
| 582 |
+
attention_bias=attention_bias_q,
|
| 583 |
+
cache_position=cache_position,
|
| 584 |
+
cos=decode_graph.cos,
|
| 585 |
+
sin=decode_graph.sin,
|
| 586 |
+
)
|
| 587 |
+
post_graph.attn_context.copy_(attn_context)
|
| 588 |
+
post_graph.graph.replay()
|
| 589 |
+
|
| 590 |
+
return decode_graph.output, past_key_values
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
def _cuda_graph_tensor_signature(
|
| 594 |
+
tensor: Optional[torch.Tensor],
|
| 595 |
+
) -> Optional[Tuple[Any, ...]]:
|
| 596 |
+
if tensor is None:
|
| 597 |
+
return None
|
| 598 |
+
return (
|
| 599 |
+
tuple(tensor.shape),
|
| 600 |
+
tuple(tensor.stride()),
|
| 601 |
+
str(tensor.dtype),
|
| 602 |
+
str(tensor.device),
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
def _cuda_graph_context_signature(context: Any) -> Tuple[Any, ...]:
|
| 607 |
+
sig = _cuda_graph_tensor_signature
|
| 608 |
+
return (
|
| 609 |
+
tuple((sig(k), sig(v)) for k, v in context.kv_contexts),
|
| 610 |
+
sig(context.cross_mask),
|
| 611 |
+
sig(context.self_mask),
|
| 612 |
+
sig(context.valid_action),
|
| 613 |
+
None
|
| 614 |
+
if context.rope_cache is None
|
| 615 |
+
else tuple(sig(t) for t in context.rope_cache),
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> Tuple[Any, ...]:
|
| 620 |
+
sig = _cuda_graph_tensor_signature
|
| 621 |
+
return tuple(
|
| 622 |
+
(
|
| 623 |
+
sig(step.conditioning),
|
| 624 |
+
tuple(
|
| 625 |
+
tuple(sig(t) for t in block_modulation)
|
| 626 |
+
for block_modulation in step.block_modulations
|
| 627 |
+
),
|
| 628 |
+
tuple(sig(t) for t in step.final_modulation),
|
| 629 |
+
)
|
| 630 |
+
for step in modulations
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
def _cuda_graph_key(inputs: _ActionFlowInputs, steps: int) -> Tuple[Any, ...]:
|
| 635 |
+
sig = _cuda_graph_tensor_signature
|
| 636 |
+
return (
|
| 637 |
+
sig(inputs.trajectory),
|
| 638 |
+
_cuda_graph_context_signature(inputs.context),
|
| 639 |
+
_cuda_graph_modulation_signature(inputs.modulations),
|
| 640 |
+
sig(inputs.action_dim_is_pad),
|
| 641 |
+
int(steps),
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
def _clone_static_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
| 646 |
+
if tensor is None:
|
| 647 |
+
return None
|
| 648 |
+
static = torch.empty_strided(
|
| 649 |
+
tuple(tensor.shape),
|
| 650 |
+
tuple(tensor.stride()),
|
| 651 |
+
device=tensor.device,
|
| 652 |
+
dtype=tensor.dtype,
|
| 653 |
+
)
|
| 654 |
+
static.copy_(tensor)
|
| 655 |
+
return static
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
def _clone_static_context(context: Any) -> Any:
|
| 659 |
+
rope_cache = None
|
| 660 |
+
if context.rope_cache is not None:
|
| 661 |
+
rope_cache = tuple(_clone_static_tensor(t) for t in context.rope_cache)
|
| 662 |
+
return context.__class__(
|
| 663 |
+
kv_contexts=tuple(
|
| 664 |
+
(_clone_static_tensor(k), _clone_static_tensor(v))
|
| 665 |
+
for k, v in context.kv_contexts
|
| 666 |
+
),
|
| 667 |
+
cross_mask=_clone_static_tensor(context.cross_mask),
|
| 668 |
+
self_mask=_clone_static_tensor(context.self_mask),
|
| 669 |
+
valid_action=_clone_static_tensor(context.valid_action),
|
| 670 |
+
rope_cache=rope_cache,
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
def _clone_static_modulations(modulations: Sequence[Any]) -> Sequence[Any]:
|
| 675 |
+
return tuple(
|
| 676 |
+
step.__class__(
|
| 677 |
+
conditioning=_clone_static_tensor(step.conditioning),
|
| 678 |
+
block_modulations=tuple(
|
| 679 |
+
tuple(_clone_static_tensor(t) for t in block_modulation)
|
| 680 |
+
for block_modulation in step.block_modulations
|
| 681 |
+
),
|
| 682 |
+
final_modulation=tuple(
|
| 683 |
+
_clone_static_tensor(t) for t in step.final_modulation
|
| 684 |
+
),
|
| 685 |
+
)
|
| 686 |
+
for step in modulations
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
def _clone_static_inputs(inputs: _ActionFlowInputs) -> _ActionFlowInputs:
|
| 691 |
+
return _ActionFlowInputs(
|
| 692 |
+
trajectory=_clone_static_tensor(inputs.trajectory),
|
| 693 |
+
context=_clone_static_context(inputs.context),
|
| 694 |
+
modulations=_clone_static_modulations(inputs.modulations),
|
| 695 |
+
action_dim_is_pad=_clone_static_tensor(inputs.action_dim_is_pad),
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
def _copy_context_(dst: Any, src: Any) -> None:
|
| 700 |
+
for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts):
|
| 701 |
+
dst_k.copy_(src_k)
|
| 702 |
+
dst_v.copy_(src_v)
|
| 703 |
+
if src.cross_mask is not None:
|
| 704 |
+
dst.cross_mask.copy_(src.cross_mask)
|
| 705 |
+
if src.self_mask is not None:
|
| 706 |
+
dst.self_mask.copy_(src.self_mask)
|
| 707 |
+
if src.valid_action is not None:
|
| 708 |
+
dst.valid_action.copy_(src.valid_action)
|
| 709 |
+
if src.rope_cache is not None:
|
| 710 |
+
for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache):
|
| 711 |
+
dst_tensor.copy_(src_tensor)
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
def _copy_inputs_(dst: _ActionFlowInputs, src: _ActionFlowInputs) -> None:
|
| 715 |
+
dst.trajectory.copy_(src.trajectory)
|
| 716 |
+
_copy_context_(dst.context, src.context)
|
| 717 |
+
if src.action_dim_is_pad is not None:
|
| 718 |
+
dst.action_dim_is_pad.copy_(src.action_dim_is_pad)
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 722 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 723 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 724 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
def _apply_rotary_pos_emb(
|
| 728 |
+
q: torch.Tensor,
|
| 729 |
+
k: torch.Tensor,
|
| 730 |
+
cos: torch.Tensor,
|
| 731 |
+
sin: torch.Tensor,
|
| 732 |
+
unsqueeze_dim: int = 1,
|
| 733 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 734 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 735 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 736 |
+
q_embed = (q * cos) + (_rotate_half(q) * sin)
|
| 737 |
+
k_embed = (k * cos) + (_rotate_half(k) * sin)
|
| 738 |
+
return q_embed, k_embed
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 742 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 743 |
+
if n_rep == 1:
|
| 744 |
+
return hidden_states
|
| 745 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(
|
| 746 |
+
batch, num_key_value_heads, n_rep, slen, head_dim
|
| 747 |
+
)
|
| 748 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 749 |
+
|
| 750 |
+
|
| 751 |
+
def _capture_cuda_graph(
|
| 752 |
+
fn,
|
| 753 |
+
device: torch.device,
|
| 754 |
+
*,
|
| 755 |
+
after_warmup=None,
|
| 756 |
+
) -> Tuple[torch.cuda.CUDAGraph, Any]:
|
| 757 |
+
warmup_stream = torch.cuda.Stream(device=device)
|
| 758 |
+
warmup_stream.wait_stream(torch.cuda.current_stream(device))
|
| 759 |
+
with torch.cuda.stream(warmup_stream):
|
| 760 |
+
fn()
|
| 761 |
+
torch.cuda.current_stream(device).wait_stream(warmup_stream)
|
| 762 |
+
if after_warmup is not None:
|
| 763 |
+
after_warmup()
|
| 764 |
+
|
| 765 |
+
graph = torch.cuda.CUDAGraph()
|
| 766 |
+
with torch.cuda.graph(graph):
|
| 767 |
+
output = fn()
|
| 768 |
+
return graph, output
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a62d97555bec82618aa1e7c3f143a41067ba7e8bed074d825b9b1e88391d516a
|
| 3 |
+
size 4183452369
|
modeling_molmoact2.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
norm_stats.json
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"format": "molmoact2_norm_stats.v1",
|
| 3 |
+
"norm_mode": "q01_q99",
|
| 4 |
+
"metadata_by_tag": {
|
| 5 |
+
"so100_so101_molmoact2": {
|
| 6 |
+
"action_key": "action",
|
| 7 |
+
"state_key": "observation.state",
|
| 8 |
+
"camera_keys": [],
|
| 9 |
+
"normalize_gripper": true,
|
| 10 |
+
"action_horizon": 30,
|
| 11 |
+
"n_action_steps": 30,
|
| 12 |
+
"setup_type": "single so100/so101 robotic arm in molmoact2",
|
| 13 |
+
"control_mode": "absolute joint pose",
|
| 14 |
+
"action_stats": {
|
| 15 |
+
"min": [
|
| 16 |
+
-122.607421875,
|
| 17 |
+
-270.0,
|
| 18 |
+
-269.208984375,
|
| 19 |
+
-125.771484375,
|
| 20 |
+
-269.912109375,
|
| 21 |
+
-31.57327651977539
|
| 22 |
+
],
|
| 23 |
+
"max": [
|
| 24 |
+
179.208984375,
|
| 25 |
+
219.638671875,
|
| 26 |
+
195.380859375,
|
| 27 |
+
178.9453125,
|
| 28 |
+
269.82421875,
|
| 29 |
+
119.40789031982422
|
| 30 |
+
],
|
| 31 |
+
"mean": [
|
| 32 |
+
3.343996486826433,
|
| 33 |
+
125.7905980370996,
|
| 34 |
+
120.20220128113388,
|
| 35 |
+
55.88144220174933,
|
| 36 |
+
-11.543010633027725,
|
| 37 |
+
11.25886240824774
|
| 38 |
+
],
|
| 39 |
+
"std": [
|
| 40 |
+
28.909870406169997,
|
| 41 |
+
52.25069634659296,
|
| 42 |
+
47.94432906599221,
|
| 43 |
+
36.01019142727721,
|
| 44 |
+
69.35504013212369,
|
| 45 |
+
17.116239869449775
|
| 46 |
+
],
|
| 47 |
+
"count": [
|
| 48 |
+
19619650.0
|
| 49 |
+
],
|
| 50 |
+
"q01": [
|
| 51 |
+
-42.1300246338976,
|
| 52 |
+
45.18258358164995,
|
| 53 |
+
35.40059182962813,
|
| 54 |
+
4.929781836327758,
|
| 55 |
+
-65.57568617645342,
|
| 56 |
+
-0.3016556932619033
|
| 57 |
+
],
|
| 58 |
+
"q10": [
|
| 59 |
+
-25.040070398997557,
|
| 60 |
+
68.27827215165794,
|
| 61 |
+
65.76540485606242,
|
| 62 |
+
26.58811186925123,
|
| 63 |
+
-39.81868441470048,
|
| 64 |
+
0.26123181871944706
|
| 65 |
+
],
|
| 66 |
+
"q50": [
|
| 67 |
+
3.0828094324713105,
|
| 68 |
+
124.5495736487354,
|
| 69 |
+
122.75175717637279,
|
| 70 |
+
57.77960070056314,
|
| 71 |
+
-11.094802886190045,
|
| 72 |
+
4.866634607477139
|
| 73 |
+
],
|
| 74 |
+
"q90": [
|
| 75 |
+
31.591544866079253,
|
| 76 |
+
181.76986724267596,
|
| 77 |
+
168.5741215400282,
|
| 78 |
+
82.4353358815596,
|
| 79 |
+
16.05609349144359,
|
| 80 |
+
32.12324970648343
|
| 81 |
+
],
|
| 82 |
+
"q99": [
|
| 83 |
+
48.55349563198916,
|
| 84 |
+
186.10646680077767,
|
| 85 |
+
173.6076722013997,
|
| 86 |
+
93.41056417929472,
|
| 87 |
+
43.53107398260694,
|
| 88 |
+
44.74649336930881
|
| 89 |
+
],
|
| 90 |
+
"names": [
|
| 91 |
+
"shoulder_pan",
|
| 92 |
+
"shoulder_lift",
|
| 93 |
+
"elbow_flex",
|
| 94 |
+
"wrist_flex",
|
| 95 |
+
"wrist_roll",
|
| 96 |
+
"gripper"
|
| 97 |
+
],
|
| 98 |
+
"mask": [
|
| 99 |
+
true,
|
| 100 |
+
true,
|
| 101 |
+
true,
|
| 102 |
+
true,
|
| 103 |
+
true,
|
| 104 |
+
true
|
| 105 |
+
]
|
| 106 |
+
},
|
| 107 |
+
"state_stats": {
|
| 108 |
+
"min": [
|
| 109 |
+
-115.048828125,
|
| 110 |
+
-270.0,
|
| 111 |
+
-235.8984375,
|
| 112 |
+
-113.818359375,
|
| 113 |
+
-268.9453125,
|
| 114 |
+
-8.521058082580566
|
| 115 |
+
],
|
| 116 |
+
"max": [
|
| 117 |
+
178.505859375,
|
| 118 |
+
218.49609375,
|
| 119 |
+
192.041015625,
|
| 120 |
+
207.861328125,
|
| 121 |
+
250.048828125,
|
| 122 |
+
118.2519302368164
|
| 123 |
+
],
|
| 124 |
+
"mean": [
|
| 125 |
+
3.3225097946752244,
|
| 126 |
+
124.40594064960378,
|
| 127 |
+
121.59550610749059,
|
| 128 |
+
55.903039878016074,
|
| 129 |
+
-11.41740021122887,
|
| 130 |
+
13.358497334686597
|
| 131 |
+
],
|
| 132 |
+
"std": [
|
| 133 |
+
28.79265204113751,
|
| 134 |
+
52.702867303079756,
|
| 135 |
+
47.00596021941705,
|
| 136 |
+
35.53803566355756,
|
| 137 |
+
69.12836626047817,
|
| 138 |
+
16.333280282904557
|
| 139 |
+
],
|
| 140 |
+
"count": [
|
| 141 |
+
19619650.0
|
| 142 |
+
],
|
| 143 |
+
"q01": [
|
| 144 |
+
-41.90962240941357,
|
| 145 |
+
43.66791235922949,
|
| 146 |
+
38.38770483255723,
|
| 147 |
+
5.711740446834044,
|
| 148 |
+
-63.44539045209019,
|
| 149 |
+
0.9435577790191543
|
| 150 |
+
],
|
| 151 |
+
"q10": [
|
| 152 |
+
-24.949315993050774,
|
| 153 |
+
66.30007546431412,
|
| 154 |
+
68.16816985859437,
|
| 155 |
+
27.120731646136054,
|
| 156 |
+
-39.50255020332888,
|
| 157 |
+
1.6190225837869365
|
| 158 |
+
],
|
| 159 |
+
"q50": [
|
| 160 |
+
3.066375725640164,
|
| 161 |
+
123.16482094240277,
|
| 162 |
+
124.39930058290133,
|
| 163 |
+
57.88605464633133,
|
| 164 |
+
-11.037436711677765,
|
| 165 |
+
9.241478261568748
|
| 166 |
+
],
|
| 167 |
+
"q90": [
|
| 168 |
+
31.472920732960127,
|
| 169 |
+
180.87158401301218,
|
| 170 |
+
168.5699720215359,
|
| 171 |
+
81.64709150074712,
|
| 172 |
+
15.887605114617852,
|
| 173 |
+
31.887861734718296
|
| 174 |
+
],
|
| 175 |
+
"q99": [
|
| 176 |
+
48.29435703371732,
|
| 177 |
+
185.2611055842669,
|
| 178 |
+
173.13578487933165,
|
| 179 |
+
91.78122415137209,
|
| 180 |
+
42.94491979114059,
|
| 181 |
+
44.13755601580974
|
| 182 |
+
],
|
| 183 |
+
"names": [
|
| 184 |
+
"shoulder_pan",
|
| 185 |
+
"shoulder_lift",
|
| 186 |
+
"elbow_flex",
|
| 187 |
+
"wrist_flex",
|
| 188 |
+
"wrist_roll",
|
| 189 |
+
"gripper"
|
| 190 |
+
],
|
| 191 |
+
"mask": [
|
| 192 |
+
true,
|
| 193 |
+
true,
|
| 194 |
+
true,
|
| 195 |
+
true,
|
| 196 |
+
true,
|
| 197 |
+
true
|
| 198 |
+
]
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
}
|
| 202 |
+
}
|
processing_molmoact2.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Processor class for MolmoAct2.
|
| 3 |
+
"""
|
| 4 |
+
from typing import Optional, Union
|
| 5 |
+
import dataclasses
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from transformers.image_utils import ImageInput
|
| 10 |
+
from transformers.video_utils import VideoInput
|
| 11 |
+
from transformers.processing_utils import (
|
| 12 |
+
Unpack,
|
| 13 |
+
ProcessingKwargs,
|
| 14 |
+
ProcessorMixin,
|
| 15 |
+
)
|
| 16 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 17 |
+
from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
|
| 18 |
+
from transformers.utils import logging
|
| 19 |
+
|
| 20 |
+
from transformers import AutoTokenizer
|
| 21 |
+
from .image_processing_molmoact2 import MolmoAct2ImagesKwargs, MolmoAct2ImageProcessor
|
| 22 |
+
from .video_processing_molmoact2 import MolmoAct2VideoProcessorKwargs, MolmoAct2VideoProcessor
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Special tokens, these should be present in any tokenizer we use since the preprocessor uses them
|
| 29 |
+
IMAGE_PATCH_TOKEN = f"<im_patch>" # Where to insert high-res tokens
|
| 30 |
+
IMAGE_LOW_RES_TOKEN = f"<im_low>" # Where to insert low-res tokens
|
| 31 |
+
IM_START_TOKEN = f"<im_start>"
|
| 32 |
+
LOW_RES_IMAGE_START_TOKEN = f"<low_res_im_start>"
|
| 33 |
+
FRAME_START_TOKEN = f"<frame_start>"
|
| 34 |
+
IM_END_TOKEN = f"<im_end>"
|
| 35 |
+
FRAME_END_TOKEN= f"<frame_end>"
|
| 36 |
+
IM_COL_TOKEN = f"<im_col>"
|
| 37 |
+
IMAGE_PROMPT = "<|image|>"
|
| 38 |
+
VIDEO_PROMPT = "<|video|>"
|
| 39 |
+
|
| 40 |
+
IMAGE_TOKENS = [
|
| 41 |
+
IMAGE_PATCH_TOKEN,
|
| 42 |
+
IM_COL_TOKEN,
|
| 43 |
+
IM_START_TOKEN,
|
| 44 |
+
LOW_RES_IMAGE_START_TOKEN,
|
| 45 |
+
FRAME_START_TOKEN,
|
| 46 |
+
IM_END_TOKEN,
|
| 47 |
+
FRAME_END_TOKEN,
|
| 48 |
+
IMAGE_LOW_RES_TOKEN,
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class MolmoAct2ProcessorKwargs(ProcessingKwargs, total=False):
|
| 53 |
+
"""MolmoAct2 processor kwargs"""
|
| 54 |
+
images_kwargs: MolmoAct2ImagesKwargs
|
| 55 |
+
videos_kwargs: MolmoAct2VideoProcessorKwargs
|
| 56 |
+
_defaults = {
|
| 57 |
+
"text_kwargs": {
|
| 58 |
+
"padding": False,
|
| 59 |
+
"return_mm_token_type_ids": True,
|
| 60 |
+
},
|
| 61 |
+
"videos_kwargs": {"return_metadata": True},
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class MolmoAct2Processor(ProcessorMixin):
|
| 66 |
+
attributes = ["image_processor", "video_processor", "tokenizer"]
|
| 67 |
+
optional_attributes = [
|
| 68 |
+
"chat_template",
|
| 69 |
+
"time_mode",
|
| 70 |
+
"image_use_col_tokens",
|
| 71 |
+
"use_single_crop_col_tokens",
|
| 72 |
+
"use_single_crop_start_token",
|
| 73 |
+
"video_use_col_tokens",
|
| 74 |
+
"use_frame_special_tokens",
|
| 75 |
+
]
|
| 76 |
+
image_processor_class = "AutoImageProcessor"
|
| 77 |
+
video_processor_class = "AutoVideoProcessor"
|
| 78 |
+
tokenizer_class = "AutoTokenizer"
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
image_processor: MolmoAct2ImageProcessor = None,
|
| 83 |
+
video_processor: MolmoAct2VideoProcessor = None,
|
| 84 |
+
tokenizer: AutoTokenizer = None,
|
| 85 |
+
chat_template: Optional[str] = None,
|
| 86 |
+
image_use_col_tokens: Optional[bool] = True,
|
| 87 |
+
use_single_crop_col_tokens: Optional[bool] = None,
|
| 88 |
+
use_single_crop_start_token: Optional[bool] = True,
|
| 89 |
+
video_use_col_tokens: Optional[bool] = False,
|
| 90 |
+
use_frame_special_tokens: Optional[bool] = True,
|
| 91 |
+
**kwargs
|
| 92 |
+
) -> None:
|
| 93 |
+
super().__init__(
|
| 94 |
+
image_processor,
|
| 95 |
+
video_processor,
|
| 96 |
+
tokenizer,
|
| 97 |
+
chat_template=chat_template,
|
| 98 |
+
)
|
| 99 |
+
self.image_use_col_tokens = image_use_col_tokens
|
| 100 |
+
self.use_single_crop_col_tokens = use_single_crop_col_tokens
|
| 101 |
+
self.use_single_crop_start_token = use_single_crop_start_token
|
| 102 |
+
self.video_use_col_tokens = video_use_col_tokens
|
| 103 |
+
self.use_frame_special_tokens = use_frame_special_tokens
|
| 104 |
+
|
| 105 |
+
self.image_placeholder_token = IMAGE_PROMPT
|
| 106 |
+
self.video_placeholder_token = VIDEO_PROMPT
|
| 107 |
+
self.image_token_ids = [
|
| 108 |
+
tokenizer.convert_tokens_to_ids(token)
|
| 109 |
+
for token in IMAGE_TOKENS
|
| 110 |
+
]
|
| 111 |
+
|
| 112 |
+
def get_image_tokens(self, image_grid: np.ndarray):
|
| 113 |
+
resized_h, resized_w, height, width = image_grid
|
| 114 |
+
if int(height) == 0 or int(width) == 0:
|
| 115 |
+
per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
|
| 116 |
+
use_single_crop_col_tokens = (
|
| 117 |
+
self.image_use_col_tokens
|
| 118 |
+
if self.use_single_crop_col_tokens is None
|
| 119 |
+
else self.use_single_crop_col_tokens
|
| 120 |
+
)
|
| 121 |
+
if use_single_crop_col_tokens:
|
| 122 |
+
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
| 123 |
+
joint = [
|
| 124 |
+
[IM_START_TOKEN],
|
| 125 |
+
np.tile(per_row, [resized_h]),
|
| 126 |
+
[IM_END_TOKEN],
|
| 127 |
+
]
|
| 128 |
+
return np.concatenate(joint)
|
| 129 |
+
per_row = np.full(width, IMAGE_PATCH_TOKEN)
|
| 130 |
+
if self.image_use_col_tokens:
|
| 131 |
+
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
| 132 |
+
joint = [
|
| 133 |
+
[IM_START_TOKEN],
|
| 134 |
+
np.tile(per_row, [height]),
|
| 135 |
+
[IM_END_TOKEN],
|
| 136 |
+
]
|
| 137 |
+
per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
|
| 138 |
+
use_single_crop_col_tokens = (
|
| 139 |
+
self.image_use_col_tokens
|
| 140 |
+
if self.use_single_crop_col_tokens is None
|
| 141 |
+
else self.use_single_crop_col_tokens
|
| 142 |
+
)
|
| 143 |
+
image_start_token = (
|
| 144 |
+
LOW_RES_IMAGE_START_TOKEN
|
| 145 |
+
if self.use_single_crop_start_token
|
| 146 |
+
else IM_START_TOKEN
|
| 147 |
+
)
|
| 148 |
+
if use_single_crop_col_tokens:
|
| 149 |
+
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
| 150 |
+
joint = [
|
| 151 |
+
[image_start_token],
|
| 152 |
+
np.tile(per_row, [resized_h]),
|
| 153 |
+
[IM_END_TOKEN],
|
| 154 |
+
] + joint
|
| 155 |
+
|
| 156 |
+
return np.concatenate(joint)
|
| 157 |
+
|
| 158 |
+
def get_video_string(
|
| 159 |
+
self,
|
| 160 |
+
video_grid: np.ndarray,
|
| 161 |
+
timestamps: np.ndarray,
|
| 162 |
+
):
|
| 163 |
+
if self.use_frame_special_tokens:
|
| 164 |
+
start_token_id = FRAME_START_TOKEN
|
| 165 |
+
end_token_id = FRAME_END_TOKEN
|
| 166 |
+
else:
|
| 167 |
+
start_token_id = IM_START_TOKEN
|
| 168 |
+
end_token_id = IM_END_TOKEN
|
| 169 |
+
|
| 170 |
+
num_frames, h, w = video_grid
|
| 171 |
+
video_string: str = ""
|
| 172 |
+
for frame_idx, frame_time in enumerate(timestamps):
|
| 173 |
+
# `per-frame-compact` time mode
|
| 174 |
+
prev_space = " " if frame_idx > 0 else ""
|
| 175 |
+
frame_prefix = prev_space + f"{frame_time:.1f} " # explicit whitespace before/after image tokens
|
| 176 |
+
|
| 177 |
+
video_string += frame_prefix
|
| 178 |
+
per_row = np.full(w, IMAGE_PATCH_TOKEN)
|
| 179 |
+
if self.video_use_col_tokens:
|
| 180 |
+
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
| 181 |
+
extra_tokens = np.tile(per_row, [h])
|
| 182 |
+
video_tokens = [
|
| 183 |
+
[start_token_id],
|
| 184 |
+
extra_tokens,
|
| 185 |
+
[end_token_id],
|
| 186 |
+
]
|
| 187 |
+
video_string += "".join(np.concatenate(video_tokens, 0))
|
| 188 |
+
|
| 189 |
+
return video_string
|
| 190 |
+
|
| 191 |
+
def insert_bos(
|
| 192 |
+
self,
|
| 193 |
+
input_ids: np.ndarray,
|
| 194 |
+
attention_mask: np.ndarray,
|
| 195 |
+
bos_token_id: int,
|
| 196 |
+
pad_token_id: int,
|
| 197 |
+
):
|
| 198 |
+
"""
|
| 199 |
+
Args:
|
| 200 |
+
input_ids: [B, S] array with left padding
|
| 201 |
+
attention_mask: [B, S] array (0 for pad, 1 for valid)
|
| 202 |
+
bos_token_id: int
|
| 203 |
+
pad_token_id: int
|
| 204 |
+
Returns:
|
| 205 |
+
input_ids_out: [B, S] or [B, S+1] array with bos inserted if needed
|
| 206 |
+
attention_mask_out: same shape as input_ids_out
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
need_to_expand = len(input_ids.shape) == 1
|
| 210 |
+
if need_to_expand:
|
| 211 |
+
input_ids = input_ids[None, :]
|
| 212 |
+
attention_mask = attention_mask[None, :]
|
| 213 |
+
|
| 214 |
+
B, S = input_ids.shape
|
| 215 |
+
|
| 216 |
+
# Handle zero-length sequence
|
| 217 |
+
if S == 0:
|
| 218 |
+
new_input_ids = np.full((B, 1), bos_token_id, dtype=input_ids.dtype)
|
| 219 |
+
new_attention_mask = np.ones((B, 1), dtype=attention_mask.dtype)
|
| 220 |
+
if need_to_expand:
|
| 221 |
+
new_input_ids = new_input_ids[0]
|
| 222 |
+
new_attention_mask = new_attention_mask[0]
|
| 223 |
+
return new_input_ids, new_attention_mask
|
| 224 |
+
|
| 225 |
+
first_valid_index = (attention_mask == 1).argmax(axis=-1) # [B]
|
| 226 |
+
bos_already_present = np.all(input_ids[np.arange(B), first_valid_index] == bos_token_id)
|
| 227 |
+
|
| 228 |
+
if bos_already_present:
|
| 229 |
+
if need_to_expand:
|
| 230 |
+
input_ids = input_ids[0]
|
| 231 |
+
attention_mask = attention_mask[0]
|
| 232 |
+
return input_ids, attention_mask
|
| 233 |
+
else:
|
| 234 |
+
new_input_ids = np.full((B, S+1), pad_token_id, dtype=input_ids.dtype)
|
| 235 |
+
new_attention_mask = np.zeros((B, S+1), dtype=attention_mask.dtype)
|
| 236 |
+
|
| 237 |
+
src_idx = np.tile(np.arange(S), (B, 1)) # [B, S]
|
| 238 |
+
valid_mask = src_idx >= first_valid_index[:, None] # [B, S]
|
| 239 |
+
tgt_idx = src_idx + 1 # shit right
|
| 240 |
+
batch_idx = np.tile(np.arange(B)[:, None], (1, S)) # [B, S]
|
| 241 |
+
|
| 242 |
+
# flatten valid_positions
|
| 243 |
+
flat_vals = input_ids[valid_mask]
|
| 244 |
+
flat_batch = batch_idx[valid_mask]
|
| 245 |
+
flat_tgt = tgt_idx[valid_mask]
|
| 246 |
+
|
| 247 |
+
new_input_ids[flat_batch, flat_tgt] = flat_vals
|
| 248 |
+
new_attention_mask[flat_batch, flat_tgt] = 1
|
| 249 |
+
|
| 250 |
+
insert_pos = first_valid_index
|
| 251 |
+
new_input_ids[np.arange(B), insert_pos] = bos_token_id
|
| 252 |
+
new_attention_mask[np.arange(B), insert_pos] = 1
|
| 253 |
+
|
| 254 |
+
if need_to_expand:
|
| 255 |
+
new_input_ids = new_input_ids[0]
|
| 256 |
+
new_attention_mask = new_attention_mask[0]
|
| 257 |
+
|
| 258 |
+
return new_input_ids, new_attention_mask
|
| 259 |
+
|
| 260 |
+
def __call__(
|
| 261 |
+
self,
|
| 262 |
+
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
|
| 263 |
+
images: ImageInput = None,
|
| 264 |
+
videos: VideoInput = None,
|
| 265 |
+
**kwargs: Unpack[MolmoAct2ProcessorKwargs],
|
| 266 |
+
) -> BatchFeature:
|
| 267 |
+
"""
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
text (`str`, `list[str]`, `list[list[str]]`):
|
| 271 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
| 272 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
| 273 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 274 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
|
| 275 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
| 276 |
+
tensor. Both channels-first and channels-last formats are supported.
|
| 277 |
+
videos (`dict[str, Any]` or `list[dict[str, Any]]`):
|
| 278 |
+
The video or batch of videos to be prepared. Each video can be a dictionary with the following keys:
|
| 279 |
+
- `"frames"`: `np.ndarray` of shape (T, H, W, 3)
|
| 280 |
+
- `"timestamps"`: `np.ndarray` of shape (T,)
|
| 281 |
+
- `"sampled_fps"`: `float` (optional)
|
| 282 |
+
- `"sampling_augmentation"`: `str` (optional)
|
| 283 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
| 284 |
+
If set, will return tensors of a particular framework. Acceptable values are:
|
| 285 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
| 286 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
| 287 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
| 288 |
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
`BatchFeature`: A [`BatchFeature`] with the following fields:
|
| 292 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
| 293 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
| 294 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`).
|
| 295 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
| 296 |
+
- **image_token_pooling** -- Indices of the patches in `image_grids` to pool for each token in `image_tokens`.
|
| 297 |
+
Returned when `images` is not `None`.
|
| 298 |
+
- **image_grids** -- Grids of images. Returned when `images` is not `None`.
|
| 299 |
+
- **image_num_crops** -- Number of crops for each image. Returned when `images` is not `None`.
|
| 300 |
+
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
|
| 301 |
+
- **video_token_pooling** -- Indices of the patches in `video_grids` to pool for each token in `video_tokens`.
|
| 302 |
+
Returned when `videos` is not `None`.
|
| 303 |
+
- **video_grids** -- Grids of videos. Returned when `videos` is not `None`.
|
| 304 |
+
"""
|
| 305 |
+
|
| 306 |
+
output_kwargs = self._merge_kwargs(
|
| 307 |
+
MolmoAct2ProcessorKwargs,
|
| 308 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 309 |
+
**kwargs,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
if images is not None:
|
| 313 |
+
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
| 314 |
+
image_grids = image_inputs["image_grids"]
|
| 315 |
+
else:
|
| 316 |
+
image_inputs = {}
|
| 317 |
+
image_grids = None
|
| 318 |
+
|
| 319 |
+
if videos is not None:
|
| 320 |
+
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
|
| 321 |
+
video_grids = videos_inputs["video_grids"]
|
| 322 |
+
# If user has not requested video metadata, pop it
|
| 323 |
+
if "return_metadata" not in kwargs:
|
| 324 |
+
video_metadata = videos_inputs.pop("video_metadata")
|
| 325 |
+
else:
|
| 326 |
+
video_metadata = videos_inputs["video_metadata"]
|
| 327 |
+
else:
|
| 328 |
+
videos_inputs = {}
|
| 329 |
+
video_grids = None
|
| 330 |
+
|
| 331 |
+
if not isinstance(text, list):
|
| 332 |
+
text = [text]
|
| 333 |
+
|
| 334 |
+
text = text.copy() # below lines change text in-place
|
| 335 |
+
|
| 336 |
+
if image_grids is not None:
|
| 337 |
+
index = 0
|
| 338 |
+
for i in range(len(text)):
|
| 339 |
+
num_images = text[i].count(self.image_placeholder_token)
|
| 340 |
+
image_grids_i = image_grids[index:index+num_images]
|
| 341 |
+
for image_grid in image_grids_i:
|
| 342 |
+
image_tokens = self.get_image_tokens(image_grid)
|
| 343 |
+
image_string = "".join(image_tokens)
|
| 344 |
+
text[i] = text[i].replace(self.image_placeholder_token, image_string, 1)
|
| 345 |
+
index += num_images
|
| 346 |
+
|
| 347 |
+
if video_grids is not None:
|
| 348 |
+
index = 0
|
| 349 |
+
for i in range(len(text)):
|
| 350 |
+
num_videos = text[i].count(self.video_placeholder_token)
|
| 351 |
+
assert num_videos in {0, 1}, "At most one video is supported for now"
|
| 352 |
+
video_grids_i = video_grids[index:index+num_videos]
|
| 353 |
+
metadata_i = video_metadata[index:index+num_videos]
|
| 354 |
+
for video_grid, metadata in zip(video_grids_i, metadata_i):
|
| 355 |
+
video_string = self.get_video_string(
|
| 356 |
+
video_grid,
|
| 357 |
+
metadata.timestamps,
|
| 358 |
+
)
|
| 359 |
+
text[i] = text[i].replace(self.video_placeholder_token, video_string, 1)
|
| 360 |
+
index += num_videos
|
| 361 |
+
|
| 362 |
+
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
| 363 |
+
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
| 364 |
+
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
| 365 |
+
|
| 366 |
+
input_ids = text_inputs["input_ids"]
|
| 367 |
+
attention_mask = text_inputs["attention_mask"]
|
| 368 |
+
|
| 369 |
+
input_ids = np.array(input_ids)
|
| 370 |
+
attention_mask = np.array(attention_mask)
|
| 371 |
+
|
| 372 |
+
bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
|
| 373 |
+
input_ids, attention_mask = self.insert_bos(
|
| 374 |
+
input_ids, attention_mask, bos, self.tokenizer.pad_token_id
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
if return_mm_token_type_ids:
|
| 378 |
+
image_tokens = np.array(self.image_token_ids).astype(input_ids.dtype)
|
| 379 |
+
token_type_ids = np.any(input_ids[:, :, None] == image_tokens[None, None, :], axis=-1)
|
| 380 |
+
text_inputs["token_type_ids"] = token_type_ids.tolist()
|
| 381 |
+
|
| 382 |
+
text_inputs["input_ids"] = input_ids.tolist()
|
| 383 |
+
text_inputs["attention_mask"] = attention_mask.tolist()
|
| 384 |
+
|
| 385 |
+
return BatchFeature(
|
| 386 |
+
data={**text_inputs, **image_inputs, **videos_inputs},
|
| 387 |
+
tensor_type=return_tensors,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
def post_process_image_text_to_text(
|
| 391 |
+
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
|
| 392 |
+
):
|
| 393 |
+
"""
|
| 394 |
+
Post-process the output of the model to decode the text.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
| 398 |
+
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
| 399 |
+
or `(sequence_length,)`.
|
| 400 |
+
skip_special_tokens (`bool`, *optional*, defaults to `True`):
|
| 401 |
+
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
|
| 402 |
+
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
| 403 |
+
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
|
| 404 |
+
**kwargs:
|
| 405 |
+
Additional arguments to be passed to the tokenizer's `batch_decode method`.
|
| 406 |
+
|
| 407 |
+
Returns:
|
| 408 |
+
`list[str]`: The decoded text.
|
| 409 |
+
"""
|
| 410 |
+
return self.tokenizer.batch_decode(
|
| 411 |
+
generated_outputs,
|
| 412 |
+
skip_special_tokens=skip_special_tokens,
|
| 413 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 414 |
+
**kwargs,
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
MolmoAct2Processor.register_for_auto_class()
|
processor_config.json
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_map": {
|
| 3 |
+
"AutoProcessor": "processing_molmoact2.MolmoAct2Processor"
|
| 4 |
+
},
|
| 5 |
+
"image_processor": {
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoImageProcessor": "image_processing_molmoact2.MolmoAct2ImageProcessor",
|
| 8 |
+
"AutoProcessor": "processing_molmoact2.MolmoAct2Processor"
|
| 9 |
+
},
|
| 10 |
+
"crop_mode": "resize",
|
| 11 |
+
"do_convert_rgb": true,
|
| 12 |
+
"image_mean": [
|
| 13 |
+
0.5,
|
| 14 |
+
0.5,
|
| 15 |
+
0.5
|
| 16 |
+
],
|
| 17 |
+
"image_processor_type": "MolmoAct2ImageProcessor",
|
| 18 |
+
"image_std": [
|
| 19 |
+
0.5,
|
| 20 |
+
0.5,
|
| 21 |
+
0.5
|
| 22 |
+
],
|
| 23 |
+
"max_crops": 8,
|
| 24 |
+
"overlap_margins": [
|
| 25 |
+
4,
|
| 26 |
+
4
|
| 27 |
+
],
|
| 28 |
+
"patch_size": 14,
|
| 29 |
+
"pooling_size": [
|
| 30 |
+
2,
|
| 31 |
+
2
|
| 32 |
+
],
|
| 33 |
+
"resample": 2,
|
| 34 |
+
"size": {
|
| 35 |
+
"height": 378,
|
| 36 |
+
"width": 378
|
| 37 |
+
}
|
| 38 |
+
},
|
| 39 |
+
"image_use_col_tokens": true,
|
| 40 |
+
"processor_class": "MolmoAct2Processor",
|
| 41 |
+
"use_frame_special_tokens": true,
|
| 42 |
+
"use_single_crop_col_tokens": false,
|
| 43 |
+
"use_single_crop_start_token": true,
|
| 44 |
+
"video_processor": {
|
| 45 |
+
"auto_map": {
|
| 46 |
+
"AutoProcessor": "processing_molmoact2.MolmoAct2Processor",
|
| 47 |
+
"AutoVideoProcessor": "video_processing_molmoact2.MolmoAct2VideoProcessor"
|
| 48 |
+
},
|
| 49 |
+
"data_format": "channels_first",
|
| 50 |
+
"default_to_square": true,
|
| 51 |
+
"do_convert_rgb": true,
|
| 52 |
+
"do_normalize": true,
|
| 53 |
+
"do_rescale": true,
|
| 54 |
+
"do_resize": true,
|
| 55 |
+
"do_sample_frames": true,
|
| 56 |
+
"frame_sample_mode": "uniform_last_frame",
|
| 57 |
+
"image_mean": [
|
| 58 |
+
0.5,
|
| 59 |
+
0.5,
|
| 60 |
+
0.5
|
| 61 |
+
],
|
| 62 |
+
"image_std": [
|
| 63 |
+
0.5,
|
| 64 |
+
0.5,
|
| 65 |
+
0.5
|
| 66 |
+
],
|
| 67 |
+
"max_fps": 2.0,
|
| 68 |
+
"num_frames": 8,
|
| 69 |
+
"patch_size": 14,
|
| 70 |
+
"pooling_size": [
|
| 71 |
+
3,
|
| 72 |
+
3
|
| 73 |
+
],
|
| 74 |
+
"resample": 2,
|
| 75 |
+
"rescale_factor": 0.00392156862745098,
|
| 76 |
+
"return_metadata": false,
|
| 77 |
+
"sampling_fps": 2,
|
| 78 |
+
"size": {
|
| 79 |
+
"height": 378,
|
| 80 |
+
"width": 378
|
| 81 |
+
},
|
| 82 |
+
"video_processor_type": "MolmoAct2VideoProcessor"
|
| 83 |
+
},
|
| 84 |
+
"video_use_col_tokens": false
|
| 85 |
+
}
|
quantization_metadata.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"source_repo": "allenai/MolmoAct2-SO100_101",
|
| 3 |
+
"source_revision": "152569fe57914d97be91055800035f54e250d009",
|
| 4 |
+
"policy_class": "transformers:AutoModelForImageTextToText",
|
| 5 |
+
"quantization": {
|
| 6 |
+
"scheme": "nf4",
|
| 7 |
+
"backend": "bitsandbytes",
|
| 8 |
+
"compute_dtype": "bfloat16",
|
| 9 |
+
"min_params_to_quantize": 4000000,
|
| 10 |
+
"rule": "Linear modules with >=4_000_000 weight elements rewritten to bnb.nn.Linear4bit; smaller heads kept in compute_dtype (bfloat16).",
|
| 11 |
+
"runtime_status": "loader-backed (install_prequantized_linears)"
|
| 12 |
+
},
|
| 13 |
+
"dropped_state_entries": []
|
| 14 |
+
}
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d5395aefc9b1b7f0385d8c86a2f1775e5af81bdfbf9f2d97827ea37921d9f862
|
| 3 |
+
size 11983605
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"auto_map": {
|
| 4 |
+
"AutoProcessor": "processing_molmoact2.MolmoAct2Processor"
|
| 5 |
+
},
|
| 6 |
+
"backend": "tokenizers",
|
| 7 |
+
"bos_token": "<|im_end|>",
|
| 8 |
+
"clean_up_tokenization_spaces": false,
|
| 9 |
+
"eos_token": "<|im_end|>",
|
| 10 |
+
"errors": "replace",
|
| 11 |
+
"extra_special_tokens": [
|
| 12 |
+
"<im_start>",
|
| 13 |
+
"<im_end>",
|
| 14 |
+
"<im_patch>",
|
| 15 |
+
"<im_col>",
|
| 16 |
+
"<low_res_im_start>",
|
| 17 |
+
"<|image|>",
|
| 18 |
+
"<im_low>",
|
| 19 |
+
"<frame_start>",
|
| 20 |
+
"<frame_end>",
|
| 21 |
+
"<|video|>",
|
| 22 |
+
"<|points|>",
|
| 23 |
+
"<|token_index|>",
|
| 24 |
+
"<|vit_index|>",
|
| 25 |
+
"<|vit_loc|>"
|
| 26 |
+
],
|
| 27 |
+
"is_local": false,
|
| 28 |
+
"model_max_length": 1010000,
|
| 29 |
+
"pad_token": "<|endoftext|>",
|
| 30 |
+
"processor_class": "MolmoAct2Processor",
|
| 31 |
+
"split_special_tokens": false,
|
| 32 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 33 |
+
"unk_token": null
|
| 34 |
+
}
|
video_processing_molmoact2.py
ADDED
|
@@ -0,0 +1,969 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Video processor class for MolmoAct2"""
|
| 2 |
+
from functools import partial
|
| 3 |
+
import os
|
| 4 |
+
import warnings
|
| 5 |
+
from contextlib import redirect_stdout
|
| 6 |
+
from io import BytesIO
|
| 7 |
+
from urllib.parse import urlparse
|
| 8 |
+
from typing import Optional, Union, Callable
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import requests
|
| 12 |
+
import einops
|
| 13 |
+
import torch
|
| 14 |
+
import torchvision.transforms
|
| 15 |
+
|
| 16 |
+
from transformers.image_utils import (
|
| 17 |
+
IMAGENET_STANDARD_MEAN,
|
| 18 |
+
IMAGENET_STANDARD_STD,
|
| 19 |
+
ImageInput,
|
| 20 |
+
PILImageResampling,
|
| 21 |
+
SizeDict,
|
| 22 |
+
validate_kwargs,
|
| 23 |
+
)
|
| 24 |
+
from transformers.video_utils import (
|
| 25 |
+
VideoInput,
|
| 26 |
+
is_valid_video,
|
| 27 |
+
make_batched_videos,
|
| 28 |
+
make_batched_metadata,
|
| 29 |
+
VideoMetadata,
|
| 30 |
+
)
|
| 31 |
+
from transformers.processing_utils import Unpack, VideosKwargs
|
| 32 |
+
from transformers.video_processing_utils import BaseVideoProcessor
|
| 33 |
+
from transformers.utils import logging
|
| 34 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 35 |
+
from transformers.utils import (
|
| 36 |
+
is_av_available,
|
| 37 |
+
is_decord_available,
|
| 38 |
+
is_torchcodec_available,
|
| 39 |
+
is_yt_dlp_available,
|
| 40 |
+
TensorType,
|
| 41 |
+
logging,
|
| 42 |
+
to_numpy,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
logger = logging.get_logger(__name__)
|
| 47 |
+
|
| 48 |
+
MAX_VIDEO_FPS = 8
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def normalize_image(
|
| 52 |
+
image: np.ndarray,
|
| 53 |
+
image_mean: list[float],
|
| 54 |
+
image_std: list[float],
|
| 55 |
+
) -> np.ndarray:
|
| 56 |
+
if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
|
| 57 |
+
return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
|
| 58 |
+
image -= np.array(image_mean, dtype=np.float32)[None, None, :]
|
| 59 |
+
image /= np.array(image_std, dtype=np.float32)[None, None, :]
|
| 60 |
+
return image
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def resize_image(
|
| 64 |
+
image: np.ndarray,
|
| 65 |
+
desired_output_size: list[int],
|
| 66 |
+
resample: PILImageResampling,
|
| 67 |
+
) -> np.ndarray:
|
| 68 |
+
if len(image.shape) == 3:
|
| 69 |
+
is_video = False
|
| 70 |
+
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
|
| 71 |
+
else:
|
| 72 |
+
is_video = True
|
| 73 |
+
image = torch.permute(torch.from_numpy(image), [0, 3, 1, 2])
|
| 74 |
+
dtype = image.dtype
|
| 75 |
+
if torch.is_floating_point(image):
|
| 76 |
+
in_min = 0.0
|
| 77 |
+
in_max = 1.0
|
| 78 |
+
resized = torchvision.transforms.Resize(
|
| 79 |
+
desired_output_size,
|
| 80 |
+
resample,
|
| 81 |
+
antialias=False,
|
| 82 |
+
)(image)
|
| 83 |
+
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
|
| 84 |
+
else:
|
| 85 |
+
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(image.dtype)
|
| 86 |
+
in_min = 0.0
|
| 87 |
+
in_max = 255.0
|
| 88 |
+
resized = torchvision.transforms.Resize(
|
| 89 |
+
desired_output_size,
|
| 90 |
+
resample,
|
| 91 |
+
antialias=False,
|
| 92 |
+
)(image)
|
| 93 |
+
resized = torch.clip(resized, 0, 255).to(dtype)
|
| 94 |
+
|
| 95 |
+
resized = resized.to(torch.float32)
|
| 96 |
+
resized = (resized - in_min) / (in_max - in_min)
|
| 97 |
+
|
| 98 |
+
if is_video:
|
| 99 |
+
resized = torch.permute(resized, [0, 2, 3, 1]).numpy()
|
| 100 |
+
else:
|
| 101 |
+
resized = torch.permute(resized, [1, 2, 0]).numpy()
|
| 102 |
+
|
| 103 |
+
return resized
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def build_resized_image(
|
| 107 |
+
image: np.ndarray,
|
| 108 |
+
base_image_input_size: list[int],
|
| 109 |
+
resample: PILImageResampling,
|
| 110 |
+
image_mean: list[float],
|
| 111 |
+
image_std: list[float],
|
| 112 |
+
image_patch_size: int,
|
| 113 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 114 |
+
resized = resize_image(
|
| 115 |
+
image, base_image_input_size, resample,
|
| 116 |
+
)
|
| 117 |
+
resized = normalize_image(resized, image_mean, image_std)
|
| 118 |
+
if len(resized.shape) == 3:
|
| 119 |
+
resized = np.expand_dims(resized, 0)
|
| 120 |
+
crop_patch_w = base_image_input_size[1] // image_patch_size
|
| 121 |
+
crop_patch_h = base_image_input_size[0] // image_patch_size
|
| 122 |
+
resize_idx = np.arange(crop_patch_w*crop_patch_h).reshape([crop_patch_h, crop_patch_w])
|
| 123 |
+
return resized, resize_idx
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
|
| 127 |
+
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
|
| 128 |
+
if len(array.shape) == 3:
|
| 129 |
+
n_crops, h, w = array.shape
|
| 130 |
+
h_patches = h//patch_size
|
| 131 |
+
w_patches = w//patch_size
|
| 132 |
+
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
|
| 133 |
+
array = np.transpose(array, [0, 1, 3, 2, 4])
|
| 134 |
+
array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size])
|
| 135 |
+
return array
|
| 136 |
+
else:
|
| 137 |
+
n_crops, h, w, c = array.shape
|
| 138 |
+
h_patches = h//patch_size
|
| 139 |
+
w_patches = w//patch_size
|
| 140 |
+
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
|
| 141 |
+
array = np.transpose(array, [0, 1, 3, 2, 4, 5])
|
| 142 |
+
array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size*c])
|
| 143 |
+
return array
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def arange_for_pooling(
|
| 147 |
+
idx_arr: np.ndarray,
|
| 148 |
+
pool_h: int,
|
| 149 |
+
pool_w: int,
|
| 150 |
+
) -> np.ndarray:
|
| 151 |
+
h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
|
| 152 |
+
w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
|
| 153 |
+
idx_arr = np.pad(idx_arr, [[h_pad//2, (h_pad+1)//2], [w_pad//2, (w_pad+1)//2]],
|
| 154 |
+
mode='constant',constant_values=-1)
|
| 155 |
+
return einops.rearrange(
|
| 156 |
+
idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def image_to_patches_and_grids(
|
| 160 |
+
image: ImageInput,
|
| 161 |
+
base_image_input_size: list[int],
|
| 162 |
+
resample: PILImageResampling,
|
| 163 |
+
image_mean: list[float],
|
| 164 |
+
image_std: list[float],
|
| 165 |
+
image_patch_size: int,
|
| 166 |
+
image_pooling_w: int,
|
| 167 |
+
image_pooling_h: int,
|
| 168 |
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 169 |
+
"""
|
| 170 |
+
:return image_grids, the shape of each image after pooling
|
| 171 |
+
:return crops, the image crops to processes with the ViT
|
| 172 |
+
:return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
|
| 173 |
+
patches in `crops` to pool for that token, masked with -1
|
| 174 |
+
"""
|
| 175 |
+
if isinstance(base_image_input_size, int):
|
| 176 |
+
base_image_input_size = (base_image_input_size, base_image_input_size)
|
| 177 |
+
|
| 178 |
+
pooling_w = image_pooling_w
|
| 179 |
+
pooling_h = image_pooling_h
|
| 180 |
+
|
| 181 |
+
resized, resize_idx = build_resized_image(
|
| 182 |
+
image,
|
| 183 |
+
base_image_input_size,
|
| 184 |
+
resample,
|
| 185 |
+
image_mean,
|
| 186 |
+
image_std,
|
| 187 |
+
image_patch_size,
|
| 188 |
+
)
|
| 189 |
+
pooling_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
|
| 190 |
+
h, w = pooling_idx.shape[:2]
|
| 191 |
+
pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w])
|
| 192 |
+
image_grid = [h, w]
|
| 193 |
+
return (
|
| 194 |
+
image_grid,
|
| 195 |
+
batch_pixels_to_patches(resized, image_patch_size),
|
| 196 |
+
pooling_idx,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def get_candidate_target_fps(
|
| 201 |
+
video_fps: Union[int, float],
|
| 202 |
+
sampling_fps: Union[int, float],
|
| 203 |
+
max_fps: Union[int, float] = MAX_VIDEO_FPS,
|
| 204 |
+
) -> list[float]:
|
| 205 |
+
"""
|
| 206 |
+
Return the subset of `video_fps` factors that remain multiples of `sampling_fps`.
|
| 207 |
+
|
| 208 |
+
Examples:
|
| 209 |
+
>>> get_candidate_target_fps(video_fps=6, sampling_fps=2)
|
| 210 |
+
[2, 6]
|
| 211 |
+
>>> get_candidate_target_fps(video_fps=5, sampling_fps=1)
|
| 212 |
+
[1, 5]
|
| 213 |
+
>>> get_candidate_target_fps(video_fps=2, sampling_fps=2)
|
| 214 |
+
[2]
|
| 215 |
+
>>> get_candidate_target_fps(video_fps=5, sampling_fps=2)
|
| 216 |
+
Traceback (most recent call last):
|
| 217 |
+
...
|
| 218 |
+
ValueError: sampling_fps=2 must divide video_fps=5 to produce consistent frame steps.
|
| 219 |
+
"""
|
| 220 |
+
video_fps = int(video_fps)
|
| 221 |
+
sampling_fps = int(sampling_fps)
|
| 222 |
+
max_fps = int(max_fps)
|
| 223 |
+
|
| 224 |
+
if sampling_fps is None:
|
| 225 |
+
raise ValueError("sampling_fps must be provided")
|
| 226 |
+
if video_fps <= 0 or sampling_fps <= 0:
|
| 227 |
+
raise ValueError(f"video_fps and sampling_fps must be positive (got {video_fps}, {sampling_fps})")
|
| 228 |
+
if video_fps % sampling_fps != 0:
|
| 229 |
+
raise ValueError(f"sampling_fps={sampling_fps} must divide video_fps={video_fps}.")
|
| 230 |
+
|
| 231 |
+
candidates = []
|
| 232 |
+
for candidate in range(sampling_fps, video_fps + 1, sampling_fps):
|
| 233 |
+
if candidate > max_fps:
|
| 234 |
+
break
|
| 235 |
+
if video_fps % candidate == 0:
|
| 236 |
+
candidates.append(float(candidate))
|
| 237 |
+
|
| 238 |
+
return candidates
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def read_video_decord(
|
| 242 |
+
video_path,
|
| 243 |
+
sample_timestamps_fn: Callable,
|
| 244 |
+
**kwargs,
|
| 245 |
+
) -> np.ndarray:
|
| 246 |
+
"""
|
| 247 |
+
Decode a video using the Decord backend.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
video_path (`str`):
|
| 251 |
+
Path to the video file.
|
| 252 |
+
sample_timestamps_fn (`Callable`):
|
| 253 |
+
A callable function that will return timestamps at which the video should be sampled.
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
| 257 |
+
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
| 258 |
+
- `VideoMetadata` object.
|
| 259 |
+
"""
|
| 260 |
+
# Lazy import from decord
|
| 261 |
+
import importlib
|
| 262 |
+
decord = importlib.import_module("decord")
|
| 263 |
+
|
| 264 |
+
vr = decord.VideoReader(uri=video_path, ctx=decord.cpu(0)) # decord has problems with gpu
|
| 265 |
+
video_fps = vr.get_avg_fps()
|
| 266 |
+
total_num_frames = len(vr)
|
| 267 |
+
time_stamps = vr.get_frame_timestamp(list(range(len(vr))))
|
| 268 |
+
duration = time_stamps[-1][1] - time_stamps[0][0]
|
| 269 |
+
|
| 270 |
+
metadata = VideoMetadata(
|
| 271 |
+
total_num_frames=int(total_num_frames),
|
| 272 |
+
fps=float(video_fps),
|
| 273 |
+
duration=float(duration),
|
| 274 |
+
video_backend="decord",
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
|
| 278 |
+
target_timestamps = np.array(target_timestamps)
|
| 279 |
+
offset = time_stamps[0, 0]
|
| 280 |
+
|
| 281 |
+
ix = np.searchsorted(time_stamps[:, 1], target_timestamps + offset, side='right')
|
| 282 |
+
ix = np.minimum(ix, len(time_stamps) - 1)
|
| 283 |
+
|
| 284 |
+
video = vr.get_batch(ix).asnumpy()
|
| 285 |
+
metadata.update(
|
| 286 |
+
{
|
| 287 |
+
"frames_indices": target_timestamps * video_fps,
|
| 288 |
+
"height": video.shape[1],
|
| 289 |
+
"width": video.shape[2],
|
| 290 |
+
}
|
| 291 |
+
)
|
| 292 |
+
return video, metadata
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def read_video_torchcodec(
|
| 296 |
+
video_path,
|
| 297 |
+
sample_timestamps_fn: Callable,
|
| 298 |
+
**kwargs,
|
| 299 |
+
) -> np.ndarray:
|
| 300 |
+
"""
|
| 301 |
+
Decode a video using torchcodec decoder.
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
video_path (`str`):
|
| 305 |
+
Path to the video file.
|
| 306 |
+
sample_timestamps_fn (`Callable`):
|
| 307 |
+
A callable function that will return timestamps at which the video should be sampled.
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
| 311 |
+
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
| 312 |
+
- `VideoMetadata` object.
|
| 313 |
+
"""
|
| 314 |
+
# Lazy import torchcodec
|
| 315 |
+
import importlib
|
| 316 |
+
torchcodec = importlib.import_module("torchcodec")
|
| 317 |
+
|
| 318 |
+
decoder = torchcodec.decoders.VideoDecoder(
|
| 319 |
+
video_path,
|
| 320 |
+
# Interestingly `exact` mode takes less than approximate when we load the whole video
|
| 321 |
+
seek_mode="exact",
|
| 322 |
+
# Allow FFmpeg decide on the number of threads for efficiency
|
| 323 |
+
num_ffmpeg_threads=0,
|
| 324 |
+
)
|
| 325 |
+
# If the first frame starts at > 0, we effectively clip the video starting at that time
|
| 326 |
+
# since (most) video players would also skip to that time
|
| 327 |
+
time_offset = decoder.metadata.begin_stream_seconds_from_content
|
| 328 |
+
# Note this duration does assume we started playing at `time_offset`
|
| 329 |
+
duration = decoder.metadata.duration_seconds
|
| 330 |
+
|
| 331 |
+
metadata = VideoMetadata(
|
| 332 |
+
total_num_frames=decoder.metadata.num_frames,
|
| 333 |
+
fps=decoder.metadata.average_fps,
|
| 334 |
+
duration=duration,
|
| 335 |
+
video_backend="torchcodec",
|
| 336 |
+
height=decoder.metadata.height,
|
| 337 |
+
width=decoder.metadata.width,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
|
| 341 |
+
|
| 342 |
+
# Floating point/rounding issues might cause `target_timestamps` to be very slightly
|
| 343 |
+
# out-of-bounds, to handle this we sanity check then clip them
|
| 344 |
+
assert all(x >= 0 for x in target_timestamps)
|
| 345 |
+
assert all(x < duration+1e-6 for x in target_timestamps)
|
| 346 |
+
# 1e-6 padding since torchcodec can throw out-of-bounds errors even if you ask for the
|
| 347 |
+
# exact boundary value, we should still get the first/last frame anyway
|
| 348 |
+
max_timestamp = decoder.metadata.end_stream_seconds_from_content - 1e-6
|
| 349 |
+
min_timestamp = decoder.metadata.begin_stream_seconds_from_content + 1e-6
|
| 350 |
+
# Note we avoid using numpy ops here to reduce floating precision issues
|
| 351 |
+
timestamps = [x + time_offset for x in target_timestamps]
|
| 352 |
+
timestamps = [max(min_timestamp, min(max_timestamp, x)) for x in timestamps]
|
| 353 |
+
|
| 354 |
+
video = decoder.get_frames_played_at(timestamps).data.numpy().transpose(0, 2, 3, 1) # Convert to THWC format
|
| 355 |
+
target_timestamps = np.array(target_timestamps)
|
| 356 |
+
metadata.frames_indices = target_timestamps * metadata.fps
|
| 357 |
+
|
| 358 |
+
return video, metadata
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def read_video_pyav(
|
| 362 |
+
video_path,
|
| 363 |
+
sample_timestamps_fn: Callable,
|
| 364 |
+
**kwargs,
|
| 365 |
+
) -> np.ndarray:
|
| 366 |
+
"""
|
| 367 |
+
Decode a video using the PyAV backend.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
video_path (`str`):
|
| 371 |
+
Path to the video file.
|
| 372 |
+
sample_timestamps_fn (`Callable`):
|
| 373 |
+
A callable function that will return timestamps at which the video should be sampled.
|
| 374 |
+
|
| 375 |
+
Returns:
|
| 376 |
+
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
| 377 |
+
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
| 378 |
+
- `VideoMetadata` object.
|
| 379 |
+
"""
|
| 380 |
+
# Lazy import torchcodec
|
| 381 |
+
import importlib
|
| 382 |
+
av = importlib.import_module("av")
|
| 383 |
+
|
| 384 |
+
with av.open(video_path) as container:
|
| 385 |
+
video_stream = container.streams.video[0]
|
| 386 |
+
fps = video_stream.average_rate or video_stream.guessed_rate
|
| 387 |
+
it = container.decode(video=0)
|
| 388 |
+
frames = list(it)
|
| 389 |
+
|
| 390 |
+
stream = container.streams.video[0]
|
| 391 |
+
start = frames[0].pts * stream.time_base
|
| 392 |
+
container_end = stream.duration
|
| 393 |
+
if container_end is not None:
|
| 394 |
+
container_end *= stream.time_base
|
| 395 |
+
if container_end is None or container_end < frames[-1].pts:
|
| 396 |
+
# Some problem with stream duration, so use the frame PTS directly
|
| 397 |
+
# and guess the duration of the last frame
|
| 398 |
+
end = frames[-1].pts * stream.time_base + 1/fps
|
| 399 |
+
else:
|
| 400 |
+
end = container_end
|
| 401 |
+
duration = float(end - start)
|
| 402 |
+
|
| 403 |
+
metadata = VideoMetadata(
|
| 404 |
+
total_num_frames=len(frames),
|
| 405 |
+
fps=float(fps),
|
| 406 |
+
duration=float(duration),
|
| 407 |
+
video_backend="pyav",
|
| 408 |
+
height=video_stream.height,
|
| 409 |
+
width=video_stream.width,
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
|
| 413 |
+
offset = float(start)
|
| 414 |
+
|
| 415 |
+
target_timestamps = np.array(target_timestamps)
|
| 416 |
+
end_time_stamps = np.array([float(frame.pts * stream.time_base) for frame in frames[1:]] + [duration])
|
| 417 |
+
indices = np.searchsorted(end_time_stamps, target_timestamps + offset, side='right')
|
| 418 |
+
indices = np.minimum(indices, len(end_time_stamps) - 1)
|
| 419 |
+
|
| 420 |
+
video = np.stack(
|
| 421 |
+
[frames[i].to_ndarray(format="rgb24", channel_last=True) for i in indices],
|
| 422 |
+
axis=0,
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
metadata.frames_indices = target_timestamps * fps
|
| 426 |
+
|
| 427 |
+
return video, metadata
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
VIDEO_DECODERS = {
|
| 431 |
+
"decord": read_video_decord,
|
| 432 |
+
"torchcodec": read_video_torchcodec,
|
| 433 |
+
"pyav": read_video_pyav,
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def load_video(
|
| 438 |
+
video: VideoInput,
|
| 439 |
+
backend: str = "decord",
|
| 440 |
+
sample_timestamps_fn: Optional[Callable] = None,
|
| 441 |
+
**kwargs,
|
| 442 |
+
):
|
| 443 |
+
"""
|
| 444 |
+
Loads `video` to a numpy array.
|
| 445 |
+
|
| 446 |
+
Args:
|
| 447 |
+
video (`VideoInput`):
|
| 448 |
+
The video to convert to the numpy array format. Can be a link to video or local path.
|
| 449 |
+
backend (`str`, *optional*, defaults to `"decord"`):
|
| 450 |
+
The backend to use when loading the video. Can be any of ["decord", "pyav", ""torchcodec"]. Defaults to "decord".
|
| 451 |
+
sample_timestamps_fn (`Callable`):
|
| 452 |
+
A callable function that will return timestamps at which the video should be sampled.
|
| 453 |
+
"""
|
| 454 |
+
|
| 455 |
+
# Early exit if provided an array or `PIL` frames
|
| 456 |
+
if not isinstance(video, str):
|
| 457 |
+
metadata = [None] * len(video)
|
| 458 |
+
return video, metadata
|
| 459 |
+
|
| 460 |
+
if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]:
|
| 461 |
+
if not is_yt_dlp_available():
|
| 462 |
+
raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
|
| 463 |
+
# Lazy import from yt_dlp
|
| 464 |
+
import importlib
|
| 465 |
+
yt_dlp = importlib.import_module("yt_dlp")
|
| 466 |
+
|
| 467 |
+
buffer = BytesIO()
|
| 468 |
+
with redirect_stdout(buffer), yt_dlp.YoutubeDL() as f:
|
| 469 |
+
f.download([video])
|
| 470 |
+
bytes_obj = buffer.getvalue()
|
| 471 |
+
file_obj = BytesIO(bytes_obj)
|
| 472 |
+
elif video.startswith("http://") or video.startswith("https://"):
|
| 473 |
+
file_obj = BytesIO(requests.get(video).content)
|
| 474 |
+
elif os.path.isfile(video):
|
| 475 |
+
file_obj = video
|
| 476 |
+
else:
|
| 477 |
+
raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.")
|
| 478 |
+
|
| 479 |
+
# can also load with decord, but not cv2/torchvision
|
| 480 |
+
# both will fail in case of url links
|
| 481 |
+
video_is_url = video.startswith("http://") or video.startswith("https://")
|
| 482 |
+
if video_is_url and backend == "opencv":
|
| 483 |
+
raise ValueError("If you are trying to load a video from URL, you cannot use 'opencv' as backend")
|
| 484 |
+
|
| 485 |
+
if (
|
| 486 |
+
(not is_decord_available() and backend == "decord")
|
| 487 |
+
or (not is_torchcodec_available() and backend == "torchcodec")
|
| 488 |
+
or (not is_av_available() and backend == "pyav")
|
| 489 |
+
):
|
| 490 |
+
raise ImportError(
|
| 491 |
+
f"You chose backend={backend} for loading the video but the required library is not found in your environment "
|
| 492 |
+
f"Make sure to install {backend} before loading the video."
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
video_decoder = VIDEO_DECODERS[backend]
|
| 496 |
+
video, metadata = video_decoder(file_obj, sample_timestamps_fn, **kwargs)
|
| 497 |
+
return video, metadata
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def get_target_fps(
|
| 501 |
+
video_fps: float,
|
| 502 |
+
max_frames: int,
|
| 503 |
+
total_frames: int,
|
| 504 |
+
frame_sample_mode: str,
|
| 505 |
+
candidate_target_fps: tuple[float],
|
| 506 |
+
) -> float:
|
| 507 |
+
"""
|
| 508 |
+
Get the target fps that best spans the video and has the most frames sampled
|
| 509 |
+
"""
|
| 510 |
+
num_frames_sampled = 0
|
| 511 |
+
selected_target_fps = None
|
| 512 |
+
for target_fps in candidate_target_fps:
|
| 513 |
+
step_size = max(int(video_fps / target_fps), 1)
|
| 514 |
+
num_frames_sampled_at_fps = int(total_frames / step_size)
|
| 515 |
+
if num_frames_sampled == 0:
|
| 516 |
+
if "uniform" in frame_sample_mode:
|
| 517 |
+
if num_frames_sampled_at_fps > max_frames:
|
| 518 |
+
break
|
| 519 |
+
selected_target_fps = target_fps
|
| 520 |
+
num_frames_sampled = num_frames_sampled_at_fps
|
| 521 |
+
|
| 522 |
+
else:
|
| 523 |
+
# the candidate sampling fps increases so frame count can't decrease
|
| 524 |
+
assert num_frames_sampled <= num_frames_sampled_at_fps
|
| 525 |
+
if num_frames_sampled_at_fps > max_frames:
|
| 526 |
+
# choose the sampling fps that spans the video
|
| 527 |
+
continue
|
| 528 |
+
|
| 529 |
+
elif num_frames_sampled_at_fps > num_frames_sampled:
|
| 530 |
+
# both are less than max_frames, choose the one with higher density of frames sampled
|
| 531 |
+
selected_target_fps = target_fps
|
| 532 |
+
num_frames_sampled = num_frames_sampled_at_fps
|
| 533 |
+
return selected_target_fps
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def get_frame_times_and_chosen_fps(
|
| 537 |
+
selected_target_fps,
|
| 538 |
+
total_frames,
|
| 539 |
+
max_frames,
|
| 540 |
+
video_fps
|
| 541 |
+
):
|
| 542 |
+
if selected_target_fps is None:
|
| 543 |
+
frame_indices = np.linspace(0, total_frames, max_frames, endpoint=False, dtype=int)
|
| 544 |
+
else:
|
| 545 |
+
step_size = max(int(video_fps / selected_target_fps), 1)
|
| 546 |
+
frame_indices = np.arange(0, total_frames, step_size)
|
| 547 |
+
if len(frame_indices) > max_frames:
|
| 548 |
+
frame_indices = frame_indices[:max_frames]
|
| 549 |
+
return selected_target_fps, frame_indices
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
class MolmoAct2VideoProcessorKwargs(VideosKwargs, total=False):
|
| 553 |
+
patch_size: Optional[int]
|
| 554 |
+
pooling_size: Optional[list[int]]
|
| 555 |
+
frame_sample_mode: Optional[str]
|
| 556 |
+
max_fps: Optional[int]
|
| 557 |
+
sampling_fps: Optional[int]
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
class MolmoAct2VideoProcessor(BaseVideoProcessor):
|
| 561 |
+
resample = PILImageResampling.BILINEAR
|
| 562 |
+
size = {"height": 378, "width": 378}
|
| 563 |
+
image_mean = IMAGENET_STANDARD_MEAN
|
| 564 |
+
image_std = IMAGENET_STANDARD_STD
|
| 565 |
+
do_resize = True
|
| 566 |
+
do_rescale = True
|
| 567 |
+
do_normalize = True
|
| 568 |
+
do_convert_rgb = True
|
| 569 |
+
patch_size = 14
|
| 570 |
+
pooling_size = [3, 3]
|
| 571 |
+
do_sample_frames = True
|
| 572 |
+
frame_sample_mode = "uniform_last_frame"
|
| 573 |
+
max_fps = 2
|
| 574 |
+
sampling_fps = 2
|
| 575 |
+
valid_kwargs = MolmoAct2VideoProcessorKwargs
|
| 576 |
+
model_input_names = ["pixel_values_videos", "video_token_pooling", "video_grids"]
|
| 577 |
+
|
| 578 |
+
def __init__(self, **kwargs: Unpack[MolmoAct2VideoProcessorKwargs]):
|
| 579 |
+
super().__init__(**kwargs)
|
| 580 |
+
if self.size is not None and (
|
| 581 |
+
self.size.get("height", None) is None or self.size.get("width", None) is None
|
| 582 |
+
):
|
| 583 |
+
raise ValueError("size must contain 'height' and 'width' keys.")
|
| 584 |
+
|
| 585 |
+
def _further_process_kwargs(
|
| 586 |
+
self,
|
| 587 |
+
size: Optional[SizeDict] = None,
|
| 588 |
+
**kwargs,
|
| 589 |
+
) -> dict:
|
| 590 |
+
"""
|
| 591 |
+
Update kwargs that need further processing before being validated
|
| 592 |
+
Can be overridden by subclasses to customize the processing of kwargs.
|
| 593 |
+
"""
|
| 594 |
+
if size is not None and ("height" not in size or "width" not in size):
|
| 595 |
+
raise ValueError("size must contain 'height' and 'width' keys.")
|
| 596 |
+
|
| 597 |
+
return super()._further_process_kwargs(size=size, **kwargs)
|
| 598 |
+
|
| 599 |
+
def sample_times(
|
| 600 |
+
self,
|
| 601 |
+
metadata: VideoMetadata,
|
| 602 |
+
frame_sample_mode: str,
|
| 603 |
+
num_frames: int,
|
| 604 |
+
max_fps: Optional[int] = None,
|
| 605 |
+
sampling_fps: Optional[int] = None,
|
| 606 |
+
**kwargs,
|
| 607 |
+
) -> np.ndarray:
|
| 608 |
+
"""
|
| 609 |
+
Time-based sampling if an array video is passed
|
| 610 |
+
Args:
|
| 611 |
+
metadata (`VideoMetadata`):
|
| 612 |
+
Metadata of the video containing information about total duration, fps and total number of frames.
|
| 613 |
+
frame_sample_mode (`str`, *optional*):
|
| 614 |
+
Mode to sample frames. Defaults to `self.frame_sample_mode`.
|
| 615 |
+
num_frames (`int`, *optional*):
|
| 616 |
+
Maximum number of frames to sample. Defaults to `self.num_frames`.
|
| 617 |
+
man_fps (`int`, *optional*):
|
| 618 |
+
Maximum frames per second to sample.
|
| 619 |
+
sampling_fps (`int`, *optional*):
|
| 620 |
+
Sampling frames per second. Defaults to `self.sampling_fps`.
|
| 621 |
+
Used when `frame_sample_mode` is `"fps"`.
|
| 622 |
+
"""
|
| 623 |
+
frame_sample_mode = frame_sample_mode or self.frame_sample_mode
|
| 624 |
+
num_frames = num_frames or self.num_frames
|
| 625 |
+
sampling_fps = sampling_fps or self.sampling_fps
|
| 626 |
+
|
| 627 |
+
duration = metadata.duration or metadata.total_num_frames / metadata.fps
|
| 628 |
+
if frame_sample_mode == "fps":
|
| 629 |
+
candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
|
| 630 |
+
# Try larger and larger FPSs until we hit one that can't span the video
|
| 631 |
+
target_fps = candidate_target_fps[0]
|
| 632 |
+
for candidate_fps in candidate_target_fps[1:]:
|
| 633 |
+
if num_frames / candidate_fps < duration:
|
| 634 |
+
break
|
| 635 |
+
target_fps = candidate_fps
|
| 636 |
+
times = np.arange(0, num_frames) / target_fps
|
| 637 |
+
times = times[times < duration]
|
| 638 |
+
return times
|
| 639 |
+
elif frame_sample_mode == "uniform_last_frame":
|
| 640 |
+
if max_fps is not None:
|
| 641 |
+
max_duration = (num_frames-1) / max_fps # -1 to include the last frame
|
| 642 |
+
if max_duration < duration:
|
| 643 |
+
times = np.linspace(
|
| 644 |
+
0, duration, num=num_frames, endpoint=True, dtype=np.float64
|
| 645 |
+
)
|
| 646 |
+
else:
|
| 647 |
+
times = np.arange(0.0, stop=duration, step=1/max_fps)
|
| 648 |
+
times = np.concatenate([times, [duration]], axis=0)
|
| 649 |
+
assert len(times) <= num_frames
|
| 650 |
+
else:
|
| 651 |
+
times = np.linspace(
|
| 652 |
+
0, duration, num=num_frames, endpoint=True, dtype=np.float64
|
| 653 |
+
)
|
| 654 |
+
return times
|
| 655 |
+
else:
|
| 656 |
+
raise NotImplementedError(frame_sample_mode)
|
| 657 |
+
|
| 658 |
+
def sample_frames(
|
| 659 |
+
self,
|
| 660 |
+
metadata: VideoMetadata,
|
| 661 |
+
frame_sample_mode: Optional[str] = None,
|
| 662 |
+
num_frames: Optional[int] = None,
|
| 663 |
+
max_fps: Optional[int] = None,
|
| 664 |
+
sampling_fps: Optional[int] = None,
|
| 665 |
+
**kwargs,
|
| 666 |
+
) -> np.ndarray:
|
| 667 |
+
"""
|
| 668 |
+
Frame-based sampling if an array video is passed
|
| 669 |
+
Args:
|
| 670 |
+
metadata (`VideoMetadata`):
|
| 671 |
+
Metadata of the video containing information about total duration, fps and total number of frames.
|
| 672 |
+
frame_sample_mode (`str`, *optional*):
|
| 673 |
+
Mode to sample frames. Defaults to `self.frame_sample_mode`.
|
| 674 |
+
num_frames (`int`, *optional*):
|
| 675 |
+
Maximum number of frames to sample. Defaults to `self.num_frames`.
|
| 676 |
+
max_fps (`int`, *optional*):
|
| 677 |
+
Maximum frames per second to sample.
|
| 678 |
+
sampling_fps (`int`, *optional*):
|
| 679 |
+
Sampling frames per second. Defaults to `self.sampling_fps`.
|
| 680 |
+
Used when `frame_sample_mode` is `"fps"`.
|
| 681 |
+
"""
|
| 682 |
+
frame_sample_mode = frame_sample_mode or self.frame_sample_mode
|
| 683 |
+
num_frames = num_frames or self.num_frames
|
| 684 |
+
sampling_fps = sampling_fps or self.sampling_fps
|
| 685 |
+
|
| 686 |
+
total_num_frames = metadata.total_num_frames
|
| 687 |
+
if frame_sample_mode == "uniform_last_frame" and max_fps is not None:
|
| 688 |
+
duration = total_num_frames / metadata.fps
|
| 689 |
+
if total_num_frames <= 2:
|
| 690 |
+
return np.arange(total_num_frames).astype(int)
|
| 691 |
+
if duration > (num_frames - 1) / max_fps: # -1 to include the last frame
|
| 692 |
+
# uniform fallback
|
| 693 |
+
indices = np.linspace(
|
| 694 |
+
0,
|
| 695 |
+
total_num_frames - 1,
|
| 696 |
+
num=min(num_frames, total_num_frames),
|
| 697 |
+
endpoint=True,
|
| 698 |
+
).astype(int)
|
| 699 |
+
return indices
|
| 700 |
+
else:
|
| 701 |
+
float_indices = np.arange(
|
| 702 |
+
0.0, stop=total_num_frames - 1, step=float(metadata.fps / max_fps),
|
| 703 |
+
)
|
| 704 |
+
if np.round(float_indices[-1]) != total_num_frames - 1:
|
| 705 |
+
float_indices = np.concatenate([float_indices, [total_num_frames - 1]], axis=0)
|
| 706 |
+
indices = np.round(float_indices).astype(int)
|
| 707 |
+
assert indices[-1] < total_num_frames
|
| 708 |
+
assert len(float_indices) <= num_frames
|
| 709 |
+
return indices
|
| 710 |
+
elif frame_sample_mode == "uniform_last_frame":
|
| 711 |
+
indices = np.linspace(
|
| 712 |
+
0, total_num_frames - 1, num=min(num_frames, total_num_frames), endpoint=True,
|
| 713 |
+
).astype(int)
|
| 714 |
+
return indices
|
| 715 |
+
elif frame_sample_mode == "fps":
|
| 716 |
+
candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
|
| 717 |
+
selected_target_fps = get_target_fps(
|
| 718 |
+
metadata.fps,
|
| 719 |
+
num_frames,
|
| 720 |
+
total_num_frames,
|
| 721 |
+
frame_sample_mode,
|
| 722 |
+
candidate_target_fps,
|
| 723 |
+
)
|
| 724 |
+
_, indices = get_frame_times_and_chosen_fps(
|
| 725 |
+
selected_target_fps,
|
| 726 |
+
total_num_frames,
|
| 727 |
+
num_frames,
|
| 728 |
+
metadata.fps,
|
| 729 |
+
)
|
| 730 |
+
return indices
|
| 731 |
+
else:
|
| 732 |
+
raise NotImplementedError(frame_sample_mode)
|
| 733 |
+
|
| 734 |
+
def fetch_videos(
|
| 735 |
+
self,
|
| 736 |
+
video_url_or_urls: Union[str, list[str], list[list[str]]],
|
| 737 |
+
sample_timestamps_fn=None
|
| 738 |
+
):
|
| 739 |
+
"""
|
| 740 |
+
Convert a single or a list of urls into the corresponding `np.array` objects.
|
| 741 |
+
|
| 742 |
+
If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
|
| 743 |
+
returned.
|
| 744 |
+
"""
|
| 745 |
+
if (
|
| 746 |
+
(not is_decord_available())
|
| 747 |
+
and (not is_torchcodec_available())
|
| 748 |
+
and (not is_av_available())
|
| 749 |
+
):
|
| 750 |
+
raise ImportError(
|
| 751 |
+
"MolmoAct2VideoProcessor requires `decord`, `torchcodec`, or `av` to be installed."
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
if is_decord_available():
|
| 755 |
+
backend = "decord"
|
| 756 |
+
elif is_torchcodec_available():
|
| 757 |
+
warnings.warn(
|
| 758 |
+
"`decord` is not installed and cannot be used to decode the video by default. "
|
| 759 |
+
"Falling back to `torchcodec`."
|
| 760 |
+
)
|
| 761 |
+
backend = "torchcodec"
|
| 762 |
+
else:
|
| 763 |
+
warnings.warn(
|
| 764 |
+
"`decord` is not installed and cannot be used to decode the video by default. "
|
| 765 |
+
"Falling back to `PyAV`."
|
| 766 |
+
)
|
| 767 |
+
backend = "pyav"
|
| 768 |
+
|
| 769 |
+
if isinstance(video_url_or_urls, list):
|
| 770 |
+
return list(zip(*[self.fetch_videos(x, sample_timestamps_fn=sample_timestamps_fn) for x in video_url_or_urls]))
|
| 771 |
+
else:
|
| 772 |
+
return load_video(video_url_or_urls, backend=backend, sample_timestamps_fn=sample_timestamps_fn)
|
| 773 |
+
|
| 774 |
+
def _decode_and_sample_videos(
|
| 775 |
+
self,
|
| 776 |
+
videos: VideoInput,
|
| 777 |
+
video_metadata: Union[VideoMetadata, dict],
|
| 778 |
+
do_sample_frames: Optional[bool] = None,
|
| 779 |
+
sample_indices_fn: Optional[Callable] = None,
|
| 780 |
+
sample_timestamps_fn: Optional[Callable] = None,
|
| 781 |
+
):
|
| 782 |
+
"""
|
| 783 |
+
Decode input videos and sample frames if needed.
|
| 784 |
+
"""
|
| 785 |
+
videos = make_batched_videos(videos)
|
| 786 |
+
video_metadata = make_batched_metadata(videos, video_metadata=video_metadata)
|
| 787 |
+
|
| 788 |
+
# Framed-based sampling if an array video is passed
|
| 789 |
+
# Otherwise, time-based sampling with decoding
|
| 790 |
+
if is_valid_video(videos[0]) and do_sample_frames:
|
| 791 |
+
assert video_metadata[0].fps is not None, "FPS must be provided for video input"
|
| 792 |
+
sampled_videos = []
|
| 793 |
+
sampled_metadata = []
|
| 794 |
+
for video, metadata in zip(videos, video_metadata):
|
| 795 |
+
indices = sample_indices_fn(metadata=metadata)
|
| 796 |
+
metadata.frames_indices = indices
|
| 797 |
+
sampled_videos.append(video[indices])
|
| 798 |
+
sampled_metadata.append(metadata)
|
| 799 |
+
videos = sampled_videos
|
| 800 |
+
video_metadata = sampled_metadata
|
| 801 |
+
elif not is_valid_video(videos[0]):
|
| 802 |
+
if sample_indices_fn is None:
|
| 803 |
+
logger.warning(
|
| 804 |
+
"do_sample_frames is False, but video array is not provided: "
|
| 805 |
+
"Will decode the video and sample frames using MolmoAct2's default sampling mode"
|
| 806 |
+
)
|
| 807 |
+
if isinstance(videos[0], list):
|
| 808 |
+
raise ValueError(
|
| 809 |
+
"A list of images is not supported for video input!"
|
| 810 |
+
)
|
| 811 |
+
else:
|
| 812 |
+
videos, video_metadata = self.fetch_videos(videos, sample_timestamps_fn=sample_timestamps_fn)
|
| 813 |
+
|
| 814 |
+
return videos, video_metadata
|
| 815 |
+
|
| 816 |
+
def _prepare_input_videos(
|
| 817 |
+
self,
|
| 818 |
+
videos: VideoInput,
|
| 819 |
+
**kwargs,
|
| 820 |
+
) -> list[np.ndarray]:
|
| 821 |
+
processed_videos = [to_numpy(video) for video in videos]
|
| 822 |
+
return processed_videos
|
| 823 |
+
|
| 824 |
+
def preprocess(
|
| 825 |
+
self,
|
| 826 |
+
videos: VideoInput,
|
| 827 |
+
**kwargs: Unpack[MolmoAct2VideoProcessorKwargs],
|
| 828 |
+
) -> BatchFeature:
|
| 829 |
+
validate_kwargs(
|
| 830 |
+
captured_kwargs=kwargs.keys(),
|
| 831 |
+
valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) + ["return_tensors"],
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
| 835 |
+
# by the user, it gets its default value from the instance, or is set to None.
|
| 836 |
+
for kwarg_name in self.valid_kwargs.__annotations__:
|
| 837 |
+
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
| 838 |
+
|
| 839 |
+
do_sample_frames = kwargs.pop("do_sample_frames")
|
| 840 |
+
video_metadata = kwargs.pop("video_metadata")
|
| 841 |
+
|
| 842 |
+
sample_indices_fn = partial(self.sample_frames, **kwargs) if do_sample_frames else None
|
| 843 |
+
sample_timestamps_fn = partial(self.sample_times, **kwargs)
|
| 844 |
+
videos, video_metadata = self._decode_and_sample_videos(
|
| 845 |
+
videos,
|
| 846 |
+
video_metadata=video_metadata,
|
| 847 |
+
do_sample_frames=do_sample_frames,
|
| 848 |
+
sample_indices_fn=sample_indices_fn,
|
| 849 |
+
sample_timestamps_fn=sample_timestamps_fn,
|
| 850 |
+
)
|
| 851 |
+
videos = self._prepare_input_videos(videos=videos)
|
| 852 |
+
|
| 853 |
+
kwargs = self._further_process_kwargs(**kwargs)
|
| 854 |
+
|
| 855 |
+
return_metadata = kwargs.pop("return_metadata")
|
| 856 |
+
preprocessed_videos = self._preprocess(videos=videos, **kwargs)
|
| 857 |
+
if return_metadata:
|
| 858 |
+
preprocessed_videos["video_metadata"] = video_metadata
|
| 859 |
+
return preprocessed_videos
|
| 860 |
+
|
| 861 |
+
def _preprocess(
|
| 862 |
+
self,
|
| 863 |
+
videos: list[np.ndarray],
|
| 864 |
+
size: Optional[SizeDict] = None,
|
| 865 |
+
resample: Optional[PILImageResampling] = None,
|
| 866 |
+
image_mean: Optional[Union[float, list[float]]] = None,
|
| 867 |
+
image_std: Optional[Union[float, list[float]]] = None,
|
| 868 |
+
do_convert_rgb: Optional[bool] = None,
|
| 869 |
+
patch_size: Optional[int] = None,
|
| 870 |
+
pooling_size: Optional[list[int]] = None,
|
| 871 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 872 |
+
**kwargs,
|
| 873 |
+
) -> BatchFeature:
|
| 874 |
+
"""
|
| 875 |
+
Preprocess a video for the model.
|
| 876 |
+
Args:
|
| 877 |
+
videos (`VideoInput`):
|
| 878 |
+
Video to preprocess.
|
| 879 |
+
size (`SizeDict`, *optional*, defaults to `self.size`):
|
| 880 |
+
Size of the image after resizing.
|
| 881 |
+
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
| 882 |
+
Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
|
| 883 |
+
has an effect if `do_resize` is set to `True`.
|
| 884 |
+
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
|
| 885 |
+
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
| 886 |
+
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
|
| 887 |
+
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
| 888 |
+
`True`.
|
| 889 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 890 |
+
Whether to convert the image to RGB.
|
| 891 |
+
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
| 892 |
+
The spatial patch size of the vision encoder.
|
| 893 |
+
pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
|
| 894 |
+
The pooling size of the vision adapter.
|
| 895 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 896 |
+
The type of tensors to return. Can be one of:
|
| 897 |
+
- Unset: Return a list of `np.ndarray`.
|
| 898 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 899 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 900 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 901 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 902 |
+
|
| 903 |
+
Returns:
|
| 904 |
+
A `BatchFeature` containing the following keys:
|
| 905 |
+
- `pixel_values_videos`: The preprocessed videos.
|
| 906 |
+
- `video_token_pooling`: The indices of the patches in `crops` to pool for each token in `video_tokens`.
|
| 907 |
+
- `video_grids`: The video grids.
|
| 908 |
+
"""
|
| 909 |
+
if size.height is None or size.width is None:
|
| 910 |
+
raise ValueError("size must contain 'height' and 'width' keys.")
|
| 911 |
+
|
| 912 |
+
base_image_input_size = [size.height, size.width]
|
| 913 |
+
|
| 914 |
+
resample = resample or self.resample
|
| 915 |
+
image_mean = image_mean or self.image_mean
|
| 916 |
+
image_std = image_std or self.image_std
|
| 917 |
+
do_convert_rgb = do_convert_rgb or self.do_convert_rgb
|
| 918 |
+
|
| 919 |
+
patch_size = patch_size or self.patch_size
|
| 920 |
+
pooling_size = pooling_size or self.pooling_size
|
| 921 |
+
|
| 922 |
+
image_pooling_h, image_pooling_w = pooling_size
|
| 923 |
+
|
| 924 |
+
batch_grids = []
|
| 925 |
+
batch_crops = []
|
| 926 |
+
batch_pooled_patches_idx = []
|
| 927 |
+
|
| 928 |
+
for video in videos:
|
| 929 |
+
all_crops = []
|
| 930 |
+
pooled_patches_idx = []
|
| 931 |
+
|
| 932 |
+
for frame in video:
|
| 933 |
+
image_grid, crops, pooled_idx = image_to_patches_and_grids(
|
| 934 |
+
frame,
|
| 935 |
+
base_image_input_size,
|
| 936 |
+
resample,
|
| 937 |
+
image_mean,
|
| 938 |
+
image_std,
|
| 939 |
+
patch_size,
|
| 940 |
+
image_pooling_w,
|
| 941 |
+
image_pooling_h,
|
| 942 |
+
)
|
| 943 |
+
offset = sum(np.prod(x.shape[:2]) for x in all_crops)
|
| 944 |
+
pooled_idx_with_offset = np.where(pooled_idx >= 0, pooled_idx + offset, pooled_idx)
|
| 945 |
+
pooled_patches_idx.append(pooled_idx_with_offset)
|
| 946 |
+
all_crops.append(crops)
|
| 947 |
+
|
| 948 |
+
video_grid = np.array([len(video), image_grid[0], image_grid[1]])
|
| 949 |
+
all_crops = np.concatenate(all_crops, 0)
|
| 950 |
+
pooled_patches_idx = np.concatenate(pooled_patches_idx, 0)
|
| 951 |
+
|
| 952 |
+
batch_grids.append(video_grid)
|
| 953 |
+
batch_crops.append(all_crops)
|
| 954 |
+
batch_pooled_patches_idx.append(pooled_patches_idx)
|
| 955 |
+
|
| 956 |
+
video_grids = np.stack(batch_grids, 0)
|
| 957 |
+
pixel_values_videos = np.concatenate(batch_crops, 0)
|
| 958 |
+
video_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
|
| 959 |
+
|
| 960 |
+
data =dict(
|
| 961 |
+
pixel_values_videos=pixel_values_videos,
|
| 962 |
+
video_token_pooling=video_token_pooling,
|
| 963 |
+
video_grids=video_grids,
|
| 964 |
+
)
|
| 965 |
+
|
| 966 |
+
return BatchFeature(data, tensor_type=return_tensors)
|
| 967 |
+
|
| 968 |
+
|
| 969 |
+
MolmoAct2VideoProcessor.register_for_auto_class()
|