AdrianLlopart commited on
Commit
e01c114
·
0 Parent(s):

Duplicate from AdrianLlopart/rskill-molmoact2-so101-nf4

Browse files
.gitattributes ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/sample_realsense_top_rgb.png filter=lfs diff=lfs merge=lfs -text
37
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
38
+ assets/sample_realsense_side_rgb.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - molmoact2
5
+ - robotics
6
+ - image-text-to-text
7
+ - so100
8
+ - so101
9
+ ---
10
+
11
+ <img src="assets/MolmoAct2.svg" alt="MolmoAct Logo" height="50">
12
+
13
+ # **MolmoAct2-SO100_101**
14
+
15
+ MolmoAct2 is an open vision-language-action model for robot control. It builds on Molmo2-ER and attaches a flow-matching continuous action expert that conditions on the VLM key-value cache through a per-layer connection.
16
+
17
+ This checkpoint is fine-tuned on the SO-100/101 mixture with absolute joint-pose control and annotated language instructions. It is intended for both further fine-tuning and SO-100/101 policy inference.
18
+
19
+ ## Quick Links
20
+
21
+ - 📂 Models: [Models](https://huggingface.co/collections/allenai/molmoact2-models), [Finetuned Models](https://huggingface.co/collections/allenai/molmoact2-finetuned-models)
22
+ - 📂 Datasets: [MolmoAct2-BimanualYAM Dataset](https://huggingface.co/collections/allenai/molmoact2-datasets), [MolmoAct2 Datasets](https://huggingface.co/collections/allenai/molmoact2-datasets), [Molmo2-ER Datasets](https://huggingface.co/collections/allenai/molmo2-er-datasets)
23
+ - 📄 Paper: [arXiv:2605.02881](https://arxiv.org/abs/2605.02881)
24
+ - 💻 Code: [allenai/molmoact2](https://github.com/allenai/molmoact2)
25
+ - 🎥 Blog Post: [MolmoAct2](https://allenai.org/blog/molmoact2)
26
+
27
+ ## Intended Use
28
+
29
+ Use this checkpoint for SO-100/101 inference or for further fine-tuning. Dataset normalization metadata is stored in `norm_stats.json`. pass `norm_tag="so100_so101_molmoact2"` at inference time.
30
+
31
+ Continuous action prediction is the intended and recommended inference mode. Discrete action prediction is exposed for parity and debugging, but we use continuous actions by default.
32
+
33
+ ## Install
34
+
35
+ ```bash
36
+ pip install torch transformers pillow numpy huggingface_hub
37
+ ```
38
+
39
+ ## Sample Input
40
+
41
+ This sample comes from `Beegbrain/pick_lemon_and_drop_in_bowl`, episode 0, frame 0. Camera order for this checkpoint does not matter. random camera order is acceptable.
42
+
43
+ | Realsense Top RGB | Realsense Side RGB |
44
+ | --- | --- |
45
+ | ![Sample realsense top RGB](assets/sample_realsense_top_rgb.png) | ![Sample realsense side RGB](assets/sample_realsense_side_rgb.png) |
46
+
47
+ ```python
48
+ from huggingface_hub import hf_hub_download
49
+ from PIL import Image
50
+ import numpy as np
51
+
52
+ repo_id = "allenai/MolmoAct2-SO100_101"
53
+
54
+ top_rgb = Image.open(
55
+ hf_hub_download(repo_id, "assets/sample_realsense_top_rgb.png")
56
+ ).convert("RGB")
57
+ side_rgb = Image.open(
58
+ hf_hub_download(repo_id, "assets/sample_realsense_side_rgb.png")
59
+ ).convert("RGB")
60
+
61
+ task = "Move the arm towards the lemon, grasp it, lift it up, and drop it into the red bowl."
62
+ robot_state = np.array(
63
+ [
64
+ -0.52734375,
65
+ 189.140625,
66
+ 181.40625,
67
+ 60.64453125,
68
+ -3.603515625,
69
+ 1.0971786975860596,
70
+ ],
71
+ dtype=np.float32,
72
+ )
73
+ ```
74
+
75
+ ## Continuous Actions
76
+
77
+ ```python
78
+ import numpy as np
79
+ import torch
80
+ from huggingface_hub import hf_hub_download
81
+ from PIL import Image
82
+ from transformers import AutoModelForImageTextToText, AutoProcessor
83
+
84
+ repo_id = "allenai/MolmoAct2-SO100_101"
85
+
86
+ top_rgb = Image.open(
87
+ hf_hub_download(repo_id, "assets/sample_realsense_top_rgb.png")
88
+ ).convert("RGB")
89
+ side_rgb = Image.open(
90
+ hf_hub_download(repo_id, "assets/sample_realsense_side_rgb.png")
91
+ ).convert("RGB")
92
+ task = "Move the arm towards the lemon, grasp it, lift it up, and drop it into the red bowl."
93
+ robot_state = np.array(
94
+ [
95
+ -0.52734375,
96
+ 189.140625,
97
+ 181.40625,
98
+ 60.64453125,
99
+ -3.603515625,
100
+ 1.0971786975860596,
101
+ ],
102
+ dtype=np.float32,
103
+ )
104
+
105
+ processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
106
+ model = AutoModelForImageTextToText.from_pretrained(
107
+ repo_id,
108
+ trust_remote_code=True,
109
+ dtype=torch.float32,
110
+ ).to("cuda").eval()
111
+
112
+ out = model.predict_action(
113
+ processor=processor,
114
+ images=[top_rgb, side_rgb],
115
+ task=task,
116
+ state=robot_state,
117
+ norm_tag="so100_so101_molmoact2",
118
+ inference_action_mode="continuous",
119
+ enable_depth_reasoning=False,
120
+ num_steps=10,
121
+ normalize_language=True,
122
+ enable_cuda_graph=True,
123
+ )
124
+
125
+ actions = out.actions
126
+ ```
127
+
128
+ MolmoAct2 was trained with mixed precision. For our reported experiments, we ran inference in `float32`. This path uses the most GPU memory: roughly 26GB with CUDA graph enabled, or around 24GB without CUDA graph.
129
+
130
+ If you have a GPU with less memory, you can run inference with `bfloat16` instead:
131
+
132
+ ```python
133
+ model = AutoModelForImageTextToText.from_pretrained(
134
+ repo_id,
135
+ trust_remote_code=True,
136
+ dtype=torch.bfloat16,
137
+ ).to("cuda").eval()
138
+
139
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
140
+ out = model.predict_action(...)
141
+ ```
142
+
143
+ Using `bfloat16` is much more memory efficient and can run under 16GB of GPU memory in our tests. It usually does not hurt performance much.
144
+
145
+
146
+ Images may be PIL images or RGB arrays. Camera order does not need to be fixed for this checkpoint. random camera order is acceptable. `state` is the raw robot state, and actions are returned in robot scale.
147
+
148
+ `normalize_language=True` is the default. It lowercases the task string and removes trailing sentence punctuation to match training preprocessing. Set it to `False` if you need to preserve the task text exactly.
149
+
150
+ `enable_cuda_graph=True` is the default. The first few calls can be slow because the model warms up and captures CUDA graphs. run several random warm-up calls before measuring deployment latency. `num_steps` controls the continuous flow solver and defaults to the checkpoint config value, 10.
151
+
152
+ Depth reasoning is disabled for this checkpoint. Calling `enable_depth_reasoning=True` will raise an error.
153
+
154
+ ## Discrete Actions
155
+
156
+ Discrete action inference requires a caller-provided action tokenizer. It is not saved in this repository. Discrete mode decodes action tokens directly. the continuous action expert is not used.
157
+
158
+ ```python
159
+ action_tokenizer = AutoProcessor.from_pretrained(
160
+ "allenai/MolmoAct2-FAST-Tokenizer",
161
+ trust_remote_code=True,
162
+ )
163
+
164
+ out = model.predict_action(
165
+ processor=processor,
166
+ images=[top_rgb, side_rgb],
167
+ task=task,
168
+ state=robot_state,
169
+ norm_tag="so100_so101_molmoact2",
170
+ inference_action_mode="discrete",
171
+ action_tokenizer=action_tokenizer,
172
+ enable_depth_reasoning=False,
173
+ )
174
+ ```
175
+
176
+ ## Model and Hardware Safety
177
+
178
+ MolmoAct2 generate robot actions from visual observations and language instructions, but their behavior may vary across embodiments, environments, and hardware configurations. Users should carefully validate model outputs before deployment, especially when operating physical robots or other actuated systems. Where possible, actions should be monitored through interpretable intermediate outputs (adaptive depth map), simulation rollouts, action limits, or other safety checks before execution on hardware. The model’s action space should be bounded by the training data, robot controller limits, and task-specific safety constraints, including limits on speed, workspace, torque, and contact force. Users should follow the hardware manufacturer’s safety guidelines, use appropriate emergency-stop mechanisms, and operate the system only in a safely configured environment with human supervision.
179
+
180
+ ## Citation
181
+
182
+ ```bibtex
183
+ @misc{fang2026molmoact2actionreasoningmodels,
184
+ title={MolmoAct2: Action Reasoning Models for Real-world Deployment},
185
+ author={Haoquan Fang and Jiafei Duan and Donovan Clay and Sam Wang and Shuo Liu and Weikai Huang and Xiang Fan and Wei-Chuan Tsai and Shirui Chen and Yi Ru Wang and Shanli Xing and Jaemin Cho and Jae Sung Park and Ainaz Eftekhar and Peter Sushko and Karen Farley and Angad Wadhwa and Cole Harrison and Winson Han and Ying-Chun Lee and Eli VanderBilt and Rose Hendrix and Suveen Ellawela and Lucas Ngoo and Joyce Chai and Zhongzheng Ren and Ali Farhadi and Dieter Fox and Ranjay Krishna},
186
+ year={2026},
187
+ eprint={2605.02881},
188
+ archivePrefix={arXiv},
189
+ primaryClass={cs.RO},
190
+ url={https://arxiv.org/abs/2605.02881},
191
+ }
192
+ ```
config.json ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "action_end_token_id": 151933,
3
+ "action_expert_config": {
4
+ "attn_dropout": 0.0,
5
+ "causal_attn": false,
6
+ "context_layer_norm": true,
7
+ "dropout": 0.0,
8
+ "ffn_multiple_of": 256,
9
+ "hidden_size": 768,
10
+ "mlp_ratio": 4.0,
11
+ "model_type": "molmoact2_action_expert",
12
+ "num_heads": 8,
13
+ "num_layers": 36,
14
+ "qk_norm": true,
15
+ "qk_norm_eps": 1e-06,
16
+ "rope": true,
17
+ "timestep_embed_dim": 256
18
+ },
19
+ "action_expert_depth_gate": false,
20
+ "action_expert_depth_gate_init_bias": -4.0,
21
+ "action_expert_depth_gate_per_layer": false,
22
+ "action_mode": "both",
23
+ "max_action_horizon": 30,
24
+ "action_output_token_id": 151931,
25
+ "action_start_token_id": 151932,
26
+ "action_token_start_id": 151934,
27
+ "adapter_config": {
28
+ "attention_dropout": 0.0,
29
+ "attn_implementation": "sdpa",
30
+ "float32_attention": true,
31
+ "head_dim": 72,
32
+ "hidden_act": "silu",
33
+ "hidden_size": 1152,
34
+ "image_feature_dropout": 0.0,
35
+ "initializer_range": 0.02,
36
+ "intermediate_size": 9728,
37
+ "model_type": "molmoact2",
38
+ "num_attention_heads": 16,
39
+ "num_key_value_heads": 16,
40
+ "pooling_attention_mask": true,
41
+ "residual_dropout": 0.0,
42
+ "text_hidden_size": 2560,
43
+ "vit_layers": [
44
+ -3,
45
+ -9
46
+ ]
47
+ },
48
+ "add_action_expert": true,
49
+ "add_control_tokens": true,
50
+ "add_setup_tokens": true,
51
+ "architectures": [
52
+ "MolmoAct2ForConditionalGeneration"
53
+ ],
54
+ "auto_map": {
55
+ "AutoConfig": "configuration_molmoact2.MolmoAct2Config",
56
+ "AutoModelForImageTextToText": "modeling_molmoact2.MolmoAct2ForConditionalGeneration"
57
+ },
58
+ "depth_end_token_id": null,
59
+ "depth_mode": 2,
60
+ "depth_output_token_id": null,
61
+ "depth_start_token_id": null,
62
+ "depth_token_start_id": null,
63
+ "dtype": "float32",
64
+ "enable_depth_reasoning": false,
65
+ "flow_matching_beta_alpha": 1.0,
66
+ "flow_matching_beta_beta": 1.5,
67
+ "flow_matching_cutoff": 1.0,
68
+ "flow_matching_num_steps": 10,
69
+ "flow_matching_time_offset": 0.001,
70
+ "flow_matching_time_scale": 0.999,
71
+ "frame_end_token_id": 154632,
72
+ "frame_start_token_id": 154631,
73
+ "image_col_id": 154627,
74
+ "image_end_token_id": 154625,
75
+ "image_high_res_id": 154626,
76
+ "image_low_res_id": 154630,
77
+ "image_patch_id": 154626,
78
+ "image_start_token_id": 154624,
79
+ "initializer_range": 0.02,
80
+ "low_res_image_start_token_id": 154628,
81
+ "mask_action_dim_padding": true,
82
+ "max_action_dim": 32,
83
+ "model_type": "molmoact2",
84
+ "n_obs_steps": 1,
85
+ "norm_stats_filename": "norm_stats.json",
86
+ "num_action_tokens": 2048,
87
+ "num_depth_codes": 100,
88
+ "num_depth_tokens": 0,
89
+ "num_state_tokens": 256,
90
+ "state_end_token_id": 151674,
91
+ "state_format": "discrete",
92
+ "state_start_token_id": 151673,
93
+ "state_token_start_id": 151675,
94
+ "text_config": {
95
+ "additional_vocab_size": 128,
96
+ "attention_dropout": 0.0,
97
+ "attn_implementation": "sdpa",
98
+ "embedding_dropout": 0.0,
99
+ "head_dim": 128,
100
+ "hidden_act": "silu",
101
+ "hidden_size": 2560,
102
+ "initializer_range": 0.02,
103
+ "intermediate_size": 9728,
104
+ "layer_norm_eps": 1e-06,
105
+ "max_position_embeddings": 16384,
106
+ "model_type": "molmoact2_text",
107
+ "norm_after": false,
108
+ "num_attention_heads": 32,
109
+ "num_hidden_layers": 36,
110
+ "num_key_value_heads": 8,
111
+ "qk_norm_type": "qwen3",
112
+ "qkv_bias": false,
113
+ "residual_dropout": 0.0,
114
+ "rope_parameters": {
115
+ "rope_theta": 5000000.0,
116
+ "rope_type": "default"
117
+ },
118
+ "rope_scaling_layers": null,
119
+ "rope_theta": 5000000.0,
120
+ "tie_word_embeddings": false,
121
+ "use_cache": true,
122
+ "use_qk_norm": true,
123
+ "vocab_size": 154624
124
+ },
125
+ "tie_word_embeddings": false,
126
+ "transformers_version": "5.3.0",
127
+ "use_frame_special_tokens": true,
128
+ "vit_config": {
129
+ "attention_dropout": 0.0,
130
+ "attn_implementation": "sdpa",
131
+ "float32_attention": true,
132
+ "head_dim": 72,
133
+ "hidden_act": "gelu_pytorch_tanh",
134
+ "hidden_size": 1152,
135
+ "image_default_input_size": [
136
+ 378,
137
+ 378
138
+ ],
139
+ "image_num_pos": 729,
140
+ "image_patch_size": 14,
141
+ "initializer_range": 0.02,
142
+ "intermediate_size": 4304,
143
+ "layer_norm_eps": 1e-06,
144
+ "model_type": "molmoact2",
145
+ "num_attention_heads": 16,
146
+ "num_hidden_layers": 27,
147
+ "num_key_value_heads": 16,
148
+ "residual_dropout": 0.0
149
+ },
150
+ "bos_token_id": 151645,
151
+ "eos_token_id": 151645,
152
+ "pad_token_id": 151643
153
+ }
configuration_molmoact2.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MolmoAct2 configuration
3
+ """
4
+
5
+ from typing import Optional, Any
6
+
7
+ from transformers import PretrainedConfig
8
+ from transformers.modeling_rope_utils import rope_config_validation
9
+ from transformers.utils import logging
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+
14
+ class MolmoAct2VitConfig(PretrainedConfig):
15
+ r"""
16
+ This is the configuration class to store the configuration of a [`MolmoAct2VisionTransformer`].
17
+ It is used to instantiate a `MolmoAct2VisionTransformer` according to the specified arguments,
18
+ defining the model architecture.
19
+
20
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
21
+ documentation from [`PretrainedConfig`] for more information.
22
+
23
+ Example:
24
+ ```python
25
+ >>> from transformers import MolmoAct2VitConfig, MolmoAct2VisionTransformer
26
+
27
+ >>> # Initializing a MolmoAct2VitConfig
28
+ >>> configuration = MolmoAct2VitConfig()
29
+
30
+ >>> # Initializing a MolmoAct2VisionTransformer (with random weights)
31
+ >>> model = MolmoAct2VisionTransformer(configuration)
32
+
33
+ >>> # Accessing the model configuration
34
+ >>> configuration = model.config
35
+ ```"""
36
+
37
+ model_type = "molmoact2"
38
+ base_config_key = "vit_config"
39
+
40
+ def __init__(
41
+ self,
42
+ hidden_size: int = 1152,
43
+ intermediate_size: int = 4304,
44
+ num_hidden_layers: int = 27,
45
+ num_attention_heads: int = 16,
46
+ num_key_value_heads: int = 16,
47
+ head_dim: int = 72,
48
+ hidden_act: str = "gelu_pytorch_tanh",
49
+ layer_norm_eps: float = 1e-6,
50
+ image_default_input_size: tuple[int, int] = (378, 378),
51
+ image_patch_size: int = 14,
52
+ image_num_pos: int = 577,
53
+ attention_dropout: float = 0.0,
54
+ residual_dropout: float = 0.0,
55
+ initializer_range: float = 0.02,
56
+ float32_attention: bool = True,
57
+ attn_implementation: str = "eager",
58
+ **kwargs,
59
+ ):
60
+ self.attn_implementation = attn_implementation
61
+ super().__init__(
62
+ attn_implementation=attn_implementation,
63
+ **kwargs
64
+ )
65
+ self.hidden_size = hidden_size
66
+ self.intermediate_size = intermediate_size
67
+ self.num_hidden_layers = num_hidden_layers
68
+ self.num_attention_heads = num_attention_heads
69
+ self.num_key_value_heads = num_key_value_heads
70
+ self.head_dim = head_dim
71
+ self.hidden_act = hidden_act
72
+ self.layer_norm_eps = layer_norm_eps
73
+ self.image_default_input_size = image_default_input_size
74
+ self.image_patch_size = image_patch_size
75
+ self.image_num_pos = image_num_pos
76
+ self.attention_dropout = attention_dropout
77
+ self.residual_dropout = residual_dropout
78
+ self.initializer_range = initializer_range
79
+ self.float32_attention = float32_attention
80
+
81
+ @property
82
+ def image_num_patch(self):
83
+ h, w = self.image_default_input_size
84
+ return h // self.image_patch_size, w // self.image_patch_size
85
+
86
+
87
+ class MolmoAct2AdapterConfig(PretrainedConfig):
88
+ r"""
89
+ This is the configuration class to store the configuration of MolmoAct2Adapter. With MolmoAct2VitConfig,
90
+ It is used to instantiate an MolmoAct2VisionBackbone according to the specified arguments,
91
+ defining the model architecture.
92
+
93
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
94
+ documentation from [`PretrainedConfig`] for more information.
95
+
96
+ Example:
97
+
98
+ ```python
99
+ >>> from transformers import MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2VisionBackbone
100
+
101
+ >>> # Initializing a MolmoAct2VitConfig and a MolmoAct2AdapterConfig
102
+ >>> vit_config = MolmoAct2VitConfig()
103
+ >>> adapter_config = MolmoPoolingConfig()
104
+
105
+ >>> # Initializing a MolmoAct2VisionBackbone (with random weights)
106
+ >>> model = MolmoAct2VisionBackbone(vit_config, adapter_config)
107
+
108
+ >>> # Accessing the model configuration
109
+ >>> vit_configuration = model.vit_config
110
+ >>> adapter_configuration = model.adapter_config
111
+ ```"""
112
+
113
+ model_type = "molmoact2"
114
+ base_config_key = "adapter_config"
115
+
116
+ def __init__(
117
+ self,
118
+ vit_layers: tuple = (-3, -9),
119
+ pooling_attention_mask: bool = False,
120
+ hidden_size: int = 1152,
121
+ num_attention_heads: int = 16,
122
+ num_key_value_heads: int = 16,
123
+ head_dim: int = 72,
124
+ float32_attention: bool = True,
125
+ attention_dropout: float = 0.0,
126
+ residual_dropout: float = 0.0,
127
+ hidden_act: str = "silu",
128
+ intermediate_size: int = 18944,
129
+ text_hidden_size: int = 3584,
130
+ image_feature_dropout: float = 0.0,
131
+ initializer_range: float = 0.02,
132
+ attn_implementation: str = "eager",
133
+ **kwargs,
134
+ ):
135
+ self.attn_implementation = attn_implementation
136
+ super().__init__(
137
+ attn_implementation=attn_implementation,
138
+ **kwargs
139
+ )
140
+ self.vit_layers = vit_layers
141
+ self.pooling_attention_mask = pooling_attention_mask
142
+ self.hidden_size = hidden_size
143
+ self.num_attention_heads = num_attention_heads
144
+ self.num_key_value_heads = num_key_value_heads
145
+ self.head_dim = head_dim
146
+ self.float32_attention = float32_attention
147
+ self.attention_dropout = attention_dropout
148
+ self.residual_dropout = residual_dropout
149
+ self.hidden_act = hidden_act
150
+ self.intermediate_size = intermediate_size
151
+ self.text_hidden_size = text_hidden_size
152
+ self.image_feature_dropout = image_feature_dropout
153
+ self.initializer_range = initializer_range
154
+
155
+
156
+ class MolmoAct2TextConfig(PretrainedConfig):
157
+ r"""
158
+ This is the configuration class to store the configuration of a [`MolmoAct2TextModel`]. It is used to instantiate a
159
+ `MolmoAct2TextModel` according to the specified arguments, defining the model architecture.
160
+
161
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
162
+ documentation from [`PretrainedConfig`] for more information.
163
+
164
+ Example:
165
+ ```python
166
+ >>> from transformers import MolmoAct2TextConfig, MolmoAct2TextModel
167
+
168
+ >>> # Initializing a MolmoAct2TextConfig
169
+ >>> configuration = MolmoAct2TextConfig()
170
+
171
+ >>> # Initializing a MolmoAct2TextModel (with random weights)
172
+ >>> model = MolmoAct2TextModel(configuration)
173
+
174
+ >>> # Accessing the model configuration
175
+ >>> configuration = model.config
176
+ ```"""
177
+
178
+ model_type = "molmoact2_text"
179
+ base_config_key = "text_config"
180
+ keys_to_ignore_at_inference = ["past_key_values"]
181
+ base_model_tp_plan = {
182
+ "blocks.*.self_attn.att_proj": "colwise",
183
+ "blocks.*.self_attn.attn_out": "rowwise",
184
+ "blocks.*.mlp.ff_proj": "colwise",
185
+ "blocks.*.mlp.ff_out": "rowwise",
186
+ }
187
+ base_model_pp_plan = {
188
+ "wte": (["input_ids"], ["inputs_embeds"]),
189
+ "blocks": (["hidden_states", "attention_mask"], ["hidden_states"]),
190
+ "ln_f": (["hidden_states"], ["hidden_states"]),
191
+ }
192
+
193
+ def __init__(
194
+ self,
195
+ hidden_size: int = 3584,
196
+ num_attention_heads: int = 28,
197
+ num_key_value_heads: Optional[int] = 4,
198
+ head_dim: int = 128,
199
+ vocab_size: int = 152064,
200
+ additional_vocab_size: int = 128,
201
+ qkv_bias: bool = True,
202
+ num_hidden_layers: int = 48,
203
+ intermediate_size: int = 18944,
204
+ hidden_act: str = "silu",
205
+ embedding_dropout: float=0.0,
206
+ attention_dropout: float=0.0,
207
+ residual_dropout: float = 0.0,
208
+ max_position_embeddings: int = 4096,
209
+ rope_theta: float = 1000000.0,
210
+ rope_scaling: dict[str, Any] = None,
211
+ rope_scaling_layers: Optional[list[int]] = None,
212
+ use_qk_norm: bool = False,
213
+ qk_norm_type: str = "olmo",
214
+ layer_norm_eps: int = 1e-6,
215
+ norm_after: bool = False,
216
+ initializer_range: float = 0.02,
217
+ use_cache=True,
218
+ tie_word_embeddings=False,
219
+ attn_implementation: str = "eager",
220
+ **kwargs,
221
+ ):
222
+ self.attn_implementation = attn_implementation
223
+ super().__init__(
224
+ tie_word_embeddings=tie_word_embeddings,
225
+ attn_implementation=attn_implementation,
226
+ **kwargs
227
+ )
228
+ self.hidden_size = hidden_size
229
+ self.num_attention_heads = num_attention_heads
230
+ if num_key_value_heads is None:
231
+ num_key_value_heads = num_attention_heads
232
+ self.num_key_value_heads = num_key_value_heads
233
+ self.head_dim = head_dim
234
+ self.vocab_size = vocab_size
235
+ self.additional_vocab_size = additional_vocab_size
236
+ self.qkv_bias = qkv_bias
237
+ self.num_hidden_layers = num_hidden_layers
238
+ self.intermediate_size = intermediate_size
239
+ self.hidden_act = hidden_act
240
+ self.embedding_dropout = embedding_dropout
241
+ self.attention_dropout = attention_dropout
242
+ self.residual_dropout = residual_dropout
243
+ self.max_position_embeddings = max_position_embeddings
244
+ self.rope_theta = rope_theta
245
+ self.rope_scaling = rope_scaling
246
+ self.rope_scaling_layers = rope_scaling_layers
247
+ self.use_qk_norm = use_qk_norm
248
+ self.qk_norm_type = qk_norm_type
249
+ self.layer_norm_eps = layer_norm_eps
250
+ self.norm_after = norm_after
251
+ self.initializer_range = initializer_range
252
+ self.use_cache = use_cache
253
+
254
+ # Validate the correctness of rotary position embeddings parameters
255
+ rope_config_validation(self)
256
+
257
+
258
+ class MolmoAct2ActionExpertConfig(PretrainedConfig):
259
+ r"""Configuration for the MolmoAct2 modern action expert."""
260
+
261
+ model_type = "molmoact2_action_expert"
262
+ base_config_key = "action_expert_config"
263
+
264
+ def __init__(
265
+ self,
266
+ max_action_horizon: int = 32,
267
+ max_action_dim: int = 32,
268
+ hidden_size: int = 1024,
269
+ num_layers: int = 32,
270
+ num_heads: int = 16,
271
+ mlp_ratio: float = 8.0 / 3.0,
272
+ ffn_multiple_of: int = 256,
273
+ timestep_embed_dim: int = 256,
274
+ dropout: float = 0.0,
275
+ attn_dropout: float = 0.0,
276
+ context_layer_norm: bool = True,
277
+ qk_norm: bool = True,
278
+ qk_norm_eps: float = 1e-6,
279
+ rope: bool = True,
280
+ causal_attn: bool = False,
281
+ **kwargs,
282
+ ):
283
+ super().__init__(**kwargs)
284
+ self.max_action_horizon = max_action_horizon
285
+ self.max_action_dim = max_action_dim
286
+ self.hidden_size = hidden_size
287
+ self.num_layers = num_layers
288
+ self.num_heads = num_heads
289
+ self.mlp_ratio = mlp_ratio
290
+ self.ffn_multiple_of = ffn_multiple_of
291
+ self.timestep_embed_dim = timestep_embed_dim
292
+ self.dropout = dropout
293
+ self.attn_dropout = attn_dropout
294
+ self.context_layer_norm = context_layer_norm
295
+ self.qk_norm = qk_norm
296
+ self.qk_norm_eps = qk_norm_eps
297
+ self.rope = rope
298
+ self.causal_attn = causal_attn
299
+
300
+ def to_dict(self):
301
+ output = super().to_dict()
302
+ # These are derived from the parent MolmoAct2Config for HF exports. Keeping
303
+ # them out of the public nested config avoids duplicated sources of truth.
304
+ output.pop("max_action_horizon", None)
305
+ output.pop("max_action_dim", None)
306
+ return output
307
+
308
+
309
+ class MolmoAct2Config(PretrainedConfig):
310
+ r"""
311
+ This is the configuration class to store the configuration of a [`MolmoAct2ForConditionalGeneration`].
312
+ It is used to instantiate an MolmoAct2 model according to the specified arguments, defining the model architecture.
313
+
314
+ Example:
315
+
316
+ ```python
317
+ >>> from transformers import MolmoAct2Config, MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2TextConfig
318
+
319
+ >>> # Initializing a MolmoAct2VitConfig
320
+ >>> vit_config = MolmoAct2VitConfig()
321
+
322
+ >>> # Initializing a MolmoAct2AdapterConfig
323
+ >>> adapter_config = MolmoAct2AdapterConfig()
324
+
325
+ >>> # Initializing a MolmoAct2TextConfig
326
+ >>> text_config = MolmoAct2TextConfig()
327
+
328
+ >>> # Initializing a MolmoAct2Config
329
+ >>> configuration = MolmoAct2Config(
330
+ >>> vit_config=vit_config,
331
+ >>> adapter_config=adapter_config,
332
+ >>> text_config=text_config,
333
+ >>> image_start_token_id=151936,
334
+ >>> image_end_token_id=151937,
335
+ >>> image_patch_id=151938,
336
+ >>> image_col_id=151939,
337
+ >>> low_res_image_start_token_id=151940,
338
+ >>> image_low_res_id=151942,
339
+ >>> frame_start_token_id=151943,
340
+ >>> frame_end_token_id=151944,
341
+ >>> )
342
+
343
+ >>> # Initializing a model
344
+ >>> model = MolmoAct2ForConditionalGeneration(configuration)
345
+
346
+ >>> # Accessing the model configuration
347
+ >>> configuration = model.config
348
+ ```"""
349
+
350
+ model_type = "molmoact2"
351
+ sub_configs = {
352
+ "text_config": MolmoAct2TextConfig,
353
+ "vit_config": MolmoAct2VitConfig,
354
+ "adapter_config": MolmoAct2AdapterConfig,
355
+ "action_expert_config": MolmoAct2ActionExpertConfig,
356
+ }
357
+
358
+ def __init__(
359
+ self,
360
+ vit_config: MolmoAct2VitConfig = None,
361
+ adapter_config: MolmoAct2AdapterConfig = None,
362
+ text_config: MolmoAct2TextConfig = None,
363
+ action_expert_config: MolmoAct2ActionExpertConfig = None,
364
+ image_start_token_id: int = None,
365
+ low_res_image_start_token_id: int = None,
366
+ image_end_token_id: int = None,
367
+ image_low_res_id: int = None,
368
+ image_patch_id: int = None,
369
+ image_col_id: int = None,
370
+ frame_start_token_id: int = None,
371
+ frame_end_token_id: int = None,
372
+ use_frame_special_tokens: bool = True,
373
+ initializer_range: float = 0.02,
374
+ add_action_expert: bool = True,
375
+ max_action_dim: int = 32,
376
+ max_action_horizon: int = 30,
377
+ n_obs_steps: int = 30,
378
+ action_mode: str = "both",
379
+ state_format: str = "discrete",
380
+ flow_matching_num_steps: int = 10,
381
+ flow_matching_cutoff: float = 1.0,
382
+ flow_matching_time_offset: float = 0.001,
383
+ flow_matching_time_scale: float = 0.999,
384
+ flow_matching_beta_alpha: float = 1.0,
385
+ flow_matching_beta_beta: float = 1.5,
386
+ mask_action_dim_padding: bool = True,
387
+ enable_depth_reasoning: bool = False,
388
+ depth_mode: int = 2,
389
+ num_depth_codes: int = 100,
390
+ action_expert_depth_gate: bool = False,
391
+ action_expert_depth_gate_per_layer: bool = False,
392
+ action_expert_depth_gate_init_bias: float = -4.0,
393
+ action_output_token_id: int = None,
394
+ action_start_token_id: int = None,
395
+ action_end_token_id: int = None,
396
+ action_token_start_id: int = None,
397
+ num_action_tokens: int = 0,
398
+ depth_output_token_id: int = None,
399
+ depth_start_token_id: int = None,
400
+ depth_end_token_id: int = None,
401
+ depth_token_start_id: int = None,
402
+ num_depth_tokens: int = 0,
403
+ state_start_token_id: int = None,
404
+ state_end_token_id: int = None,
405
+ state_token_start_id: int = None,
406
+ num_state_tokens: int = 0,
407
+ add_setup_tokens: bool = True,
408
+ add_control_tokens: bool = True,
409
+ norm_stats_filename: str = "norm_stats.json",
410
+ **kwargs,
411
+ ):
412
+ super().__init__(**kwargs)
413
+ if vit_config is None:
414
+ self.vit_config = MolmoAct2VitConfig()
415
+ elif isinstance(vit_config, dict):
416
+ self.vit_config = MolmoAct2VitConfig(**vit_config)
417
+ else:
418
+ self.vit_config = vit_config
419
+ if adapter_config is None:
420
+ self.adapter_config = MolmoAct2AdapterConfig()
421
+ elif isinstance(adapter_config, dict):
422
+ self.adapter_config = MolmoAct2AdapterConfig(**adapter_config)
423
+ else:
424
+ self.adapter_config = adapter_config
425
+ if text_config is None:
426
+ self.text_config = MolmoAct2TextConfig()
427
+ elif isinstance(text_config, dict):
428
+ self.text_config = MolmoAct2TextConfig(**text_config)
429
+ else:
430
+ self.text_config = text_config
431
+ self.add_action_expert = bool(add_action_expert)
432
+ if not self.add_action_expert:
433
+ self.action_expert_config = None
434
+ elif action_expert_config is None:
435
+ self.action_expert_config = MolmoAct2ActionExpertConfig(
436
+ max_action_horizon=max_action_horizon,
437
+ max_action_dim=max_action_dim,
438
+ num_layers=self.text_config.num_hidden_layers,
439
+ )
440
+ elif isinstance(action_expert_config, dict):
441
+ self.action_expert_config = MolmoAct2ActionExpertConfig(**action_expert_config)
442
+ else:
443
+ self.action_expert_config = action_expert_config
444
+ if self.add_action_expert:
445
+ self.action_expert_config.max_action_dim = int(max_action_dim)
446
+ self.action_expert_config.max_action_horizon = int(max_action_horizon)
447
+ self._validate_release_action_config(
448
+ state_format=state_format,
449
+ )
450
+ self.image_start_token_id = image_start_token_id
451
+ self.low_res_image_start_token_id = low_res_image_start_token_id
452
+ self.image_end_token_id = image_end_token_id
453
+ self.image_low_res_id = image_low_res_id
454
+ self.image_high_res_id = image_patch_id
455
+ self.image_patch_id = image_patch_id
456
+ self.image_col_id = image_col_id
457
+ self.frame_start_token_id = frame_start_token_id
458
+ self.frame_end_token_id = frame_end_token_id
459
+ self.use_frame_special_tokens = use_frame_special_tokens
460
+ self.initializer_range = initializer_range
461
+ self.max_action_dim = max_action_dim
462
+ self.max_action_horizon = max_action_horizon
463
+ self.n_obs_steps = n_obs_steps
464
+ self.action_mode = action_mode
465
+ self.state_format = state_format
466
+ self.flow_matching_num_steps = flow_matching_num_steps
467
+ self.flow_matching_cutoff = flow_matching_cutoff
468
+ self.flow_matching_time_offset = flow_matching_time_offset
469
+ self.flow_matching_time_scale = flow_matching_time_scale
470
+ self.flow_matching_beta_alpha = flow_matching_beta_alpha
471
+ self.flow_matching_beta_beta = flow_matching_beta_beta
472
+ self.mask_action_dim_padding = mask_action_dim_padding
473
+ self.enable_depth_reasoning = enable_depth_reasoning
474
+ self.depth_mode = depth_mode
475
+ self.num_depth_codes = num_depth_codes
476
+ self.action_expert_depth_gate = action_expert_depth_gate
477
+ self.action_expert_depth_gate_per_layer = action_expert_depth_gate_per_layer
478
+ self.action_expert_depth_gate_init_bias = action_expert_depth_gate_init_bias
479
+ self.action_output_token_id = action_output_token_id
480
+ self.action_start_token_id = action_start_token_id
481
+ self.action_end_token_id = action_end_token_id
482
+ self.action_token_start_id = action_token_start_id
483
+ self.num_action_tokens = num_action_tokens
484
+ self.depth_output_token_id = depth_output_token_id
485
+ self.depth_start_token_id = depth_start_token_id
486
+ self.depth_end_token_id = depth_end_token_id
487
+ self.depth_token_start_id = depth_token_start_id
488
+ self.num_depth_tokens = num_depth_tokens
489
+ self.state_start_token_id = state_start_token_id
490
+ self.state_end_token_id = state_end_token_id
491
+ self.state_token_start_id = state_token_start_id
492
+ self.num_state_tokens = num_state_tokens
493
+ self.add_setup_tokens = add_setup_tokens
494
+ self.add_control_tokens = add_control_tokens
495
+ self.norm_stats_filename = norm_stats_filename
496
+
497
+ @staticmethod
498
+ def _validate_release_action_config(
499
+ *,
500
+ state_format: str,
501
+ ) -> None:
502
+ if state_format != "discrete":
503
+ raise ValueError("MolmoAct2 HF export supports only state_format='discrete'.")
504
+
505
+ @property
506
+ def image_num_patch(self):
507
+ assert self.vit_config is not None
508
+ return self.vit_config.image_num_patch
509
+
510
+ @property
511
+ def num_attention_heads(self):
512
+ return self.text_config.num_attention_heads
513
+
514
+ @property
515
+ def num_key_value_heads(self):
516
+ return self.text_config.num_key_value_heads
517
+
518
+ @property
519
+ def head_dim(self):
520
+ return self.text_config.head_dim
521
+
522
+ @property
523
+ def num_hidden_layers(self):
524
+ return self.text_config.num_hidden_layers
525
+
526
+ @property
527
+ def hidden_size(self):
528
+ return self.text_config.hidden_size
529
+
530
+ @property
531
+ def vocab_size(self):
532
+ return self.text_config.vocab_size
533
+
534
+ @property
535
+ def max_position_embeddings(self):
536
+ return self.text_config.max_position_embeddings
537
+
538
+
539
+ MolmoAct2VitConfig.register_for_auto_class()
540
+ MolmoAct2AdapterConfig.register_for_auto_class()
541
+ MolmoAct2TextConfig.register_for_auto_class()
542
+ MolmoAct2ActionExpertConfig.register_for_auto_class()
543
+ MolmoAct2Config.register_for_auto_class()
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151645,
3
+ "eos_token_id": 151645,
4
+ "pad_token_id": 151643,
5
+ "transformers_version": "5.3.0"
6
+ }
image_processing_molmoact2.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image processor class for MolmoAct2"""
2
+ from typing import Optional, Union
3
+ import numpy as np
4
+ import einops
5
+ import torch
6
+ import torchvision.transforms
7
+
8
+ from transformers.image_utils import (
9
+ IMAGENET_STANDARD_MEAN,
10
+ IMAGENET_STANDARD_STD,
11
+ ImageInput,
12
+ PILImageResampling,
13
+ make_flat_list_of_images,
14
+ valid_images,
15
+ to_numpy_array,
16
+ )
17
+ from transformers.image_transforms import convert_to_rgb
18
+ from transformers.processing_utils import ImagesKwargs
19
+ from transformers.image_processing_utils import BaseImageProcessor, get_size_dict
20
+ from transformers.utils import logging
21
+ from transformers.feature_extraction_utils import BatchFeature
22
+ from transformers.utils import TensorType, logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ def normalize_image(
29
+ image: np.ndarray,
30
+ image_mean: list[float],
31
+ image_std: list[float],
32
+ ) -> np.ndarray:
33
+ if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
34
+ return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
35
+ image -= np.array(image_mean, dtype=np.float32)[None, None, :]
36
+ image /= np.array(image_std, dtype=np.float32)[None, None, :]
37
+ return image
38
+
39
+
40
+ def resize_image(
41
+ image: np.ndarray,
42
+ desired_output_size: list[int],
43
+ resample: PILImageResampling,
44
+ ) -> np.ndarray:
45
+ image = torch.permute(torch.from_numpy(image), [2, 0, 1])
46
+ dtype = image.dtype
47
+ if torch.is_floating_point(image):
48
+ in_min = 0.0
49
+ in_max = 1.0
50
+ resized = torchvision.transforms.Resize(
51
+ desired_output_size,
52
+ resample,
53
+ antialias=False,
54
+ )(image)
55
+ resized = torch.clip(resized, 0.0, 1.0).to(dtype)
56
+ else:
57
+ assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(image.dtype)
58
+ in_min = 0.0
59
+ in_max = 255.0
60
+ resized = torchvision.transforms.Resize(
61
+ desired_output_size,
62
+ resample,
63
+ antialias=False,
64
+ )(image)
65
+ resized = torch.clip(resized, 0, 255).to(dtype)
66
+
67
+ resized = resized.to(torch.float32)
68
+ resized = (resized - in_min) / (in_max - in_min)
69
+
70
+ resized = torch.permute(resized, [1, 2, 0]).numpy()
71
+
72
+ return resized
73
+
74
+
75
+ def select_tiling(h, w, patch_size, max_num_crops):
76
+ """Divide in image of size [w, h] in up to max_num_patches of size patch_size"""
77
+ original_size = np.stack([h, w]) # [1, 2]
78
+ original_res = h * w
79
+ tilings = []
80
+ for i in range(1, max_num_crops + 1):
81
+ for j in range(1, max_num_crops + 1):
82
+ if i*j <= max_num_crops:
83
+ tilings.append((i, j))
84
+ # sort so argmin and argmax favour smaller tilings in the event of a tie
85
+ tilings.sort(key=lambda x: (x[0]*x[1], x[0]))
86
+ candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
87
+ candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
88
+
89
+ # How much we would need to scale the image to fit exactly in each tiling
90
+ original_size = np.stack([h, w], dtype=np.float32) # [1, 2]
91
+
92
+ # The original size can be zero in rare cases if the image is smaller than the margin
93
+ # In those cases letting the scale become infinite means the tiling is based on the
94
+ # other side, or falls back to the smallest tiling
95
+ with np.errstate(divide='ignore'):
96
+ required_scale_d = candidate_resolutions.astype(np.float32) / original_size,
97
+ required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
98
+ if np.all(required_scale < 1):
99
+ # We are forced to downscale, so try to minimize the amount of downscaling
100
+ ix = np.argmax(required_scale)
101
+ else:
102
+ # Pick the resolution that required the least upscaling so that it most closely fits the image
103
+ required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
104
+ ix = np.argmin(required_scale)
105
+ return candidate_tilings[ix]
106
+
107
+
108
+ def build_resized_image(
109
+ image: np.ndarray,
110
+ base_image_input_size: list[int],
111
+ resample: PILImageResampling,
112
+ image_mean: list[float],
113
+ image_std: list[float],
114
+ image_patch_size: int,
115
+ ) -> tuple[np.ndarray, np.ndarray]:
116
+ resized = resize_image(
117
+ image, base_image_input_size, resample,
118
+ )
119
+ resized = normalize_image(resized, image_mean, image_std)
120
+ if len(resized.shape) == 3:
121
+ resized = np.expand_dims(resized, 0)
122
+ crop_patch_w = base_image_input_size[1] // image_patch_size
123
+ crop_patch_h = base_image_input_size[0] // image_patch_size
124
+ resize_idx = np.arange(crop_patch_w*crop_patch_h).reshape([crop_patch_h, crop_patch_w])
125
+ return resized, resize_idx
126
+
127
+
128
+ def build_overlapping_crops(
129
+ image: np.ndarray,
130
+ max_crops: int,
131
+ overlap_margins: list[int],
132
+ base_image_input_size: list[int],
133
+ resample: PILImageResampling,
134
+ image_mean: list[float],
135
+ image_std: list[float],
136
+ image_patch_size: int,
137
+ ) -> tuple[np.ndarray, np.ndarray]:
138
+ """Decompose an image into a set of overlapping crops
139
+
140
+ :return crop_arr: [n_crops, h, w, 3] The crops
141
+ :return patch_idx: [overlap_patch_h, overlap_patch_w] For each patch in the resized image
142
+ the crops were extracted from, what patch in `crop_arr` it corresponds to
143
+ """
144
+ original_image_h, original_image_w = image.shape[:2]
145
+ crop_size = base_image_input_size[0]
146
+ assert base_image_input_size[0] == base_image_input_size[1]
147
+
148
+ left_margin, right_margin = overlap_margins
149
+ total_margin_pixels = image_patch_size * (right_margin + left_margin) # pixels removed per dim
150
+ crop_patches = base_image_input_size[0] // image_patch_size # patches per crop dim
151
+ crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
152
+ crop_window_size = crop_window_patches * image_patch_size
153
+ crop_patch_w = base_image_input_size[1] // image_patch_size
154
+ crop_patch_h = base_image_input_size[0] // image_patch_size
155
+ original_image_h, original_image_w = image.shape[:2]
156
+ crop_size = base_image_input_size[0]
157
+
158
+ # Decide how to tile the image, to account for the overlap margins we compute the tiling
159
+ # as if we had an image without the margins and were using a crop size without the margins
160
+ tiling = select_tiling(
161
+ original_image_h - total_margin_pixels,
162
+ original_image_w - total_margin_pixels,
163
+ crop_window_size,
164
+ max_crops,
165
+ )
166
+
167
+ src = resize_image(
168
+ image,
169
+ [tiling[0]*crop_window_size+total_margin_pixels, tiling[1]*crop_window_size+total_margin_pixels],
170
+ resample,
171
+ )
172
+ src = normalize_image(src, image_mean, image_std)
173
+
174
+ # Now we have to split the image into crops, and track what patches came from
175
+ # where in `patch_idx_arr`
176
+ n_crops = tiling[0] * tiling[1]
177
+ crop_arr = np.zeros([n_crops, crop_size, crop_size, 3], dtype=src.dtype)
178
+ patch_idx_arr = np.zeros([n_crops, crop_patch_h, crop_patch_w], dtype=np.int32)
179
+ on_crop = 0
180
+ for i in range(tiling[0]):
181
+ # Slide over `src` by `crop_window_size` steps, but extract crops of size `crops_size`
182
+ # which results in overlapping crop windows
183
+ y0 = i*crop_window_size
184
+ for j in range(tiling[1]):
185
+ x0 = j*crop_window_size
186
+ crop_arr[on_crop] = src[y0:y0+crop_size, x0:x0+crop_size]
187
+ patch_idx = np.arange(crop_patch_w*crop_patch_h).reshape(crop_patch_h, crop_patch_w)
188
+ patch_idx += on_crop * crop_patch_h * crop_patch_w
189
+
190
+ # Mask out idx that are in the overlap region
191
+ if i != 0:
192
+ patch_idx[:left_margin, :] = -1
193
+ if j != 0:
194
+ patch_idx[:, :left_margin] = -1
195
+ if i != tiling[0]-1:
196
+ patch_idx[-right_margin:, :] = -1
197
+ if j != tiling[1]-1:
198
+ patch_idx[:, -right_margin:] = -1
199
+ patch_idx_arr[on_crop] = patch_idx
200
+ on_crop += 1
201
+
202
+ # `patch_idx_arr` is ordered crop-by-crop, here we transpose `patch_idx_arr`
203
+ # so it is ordered left-to-right order
204
+ patch_idx_arr = np.reshape(
205
+ patch_idx_arr,
206
+ [tiling[0], tiling[1], crop_patch_h, crop_patch_w]
207
+ )
208
+ patch_idx_arr = np.transpose(patch_idx_arr, [0, 2, 1, 3])
209
+ patch_idx_arr = np.reshape(patch_idx_arr, [-1])
210
+
211
+ # Now get the parts not in the overlap region, so it should map each patch in `src`
212
+ # to the correct patch it should come from in `crop_arr`
213
+ patch_idx_arr = patch_idx_arr[patch_idx_arr >= 0].reshape(
214
+ src.shape[0]//image_patch_size,
215
+ src.shape[1]//image_patch_size,
216
+ )
217
+ return crop_arr, patch_idx_arr
218
+
219
+
220
+ def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
221
+ """Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
222
+ if len(array.shape) == 3:
223
+ n_crops, h, w = array.shape
224
+ h_patches = h//patch_size
225
+ w_patches = w//patch_size
226
+ array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
227
+ array = np.transpose(array, [0, 1, 3, 2, 4])
228
+ array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size])
229
+ return array
230
+ else:
231
+ n_crops, h, w, c = array.shape
232
+ h_patches = h//patch_size
233
+ w_patches = w//patch_size
234
+ array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
235
+ array = np.transpose(array, [0, 1, 3, 2, 4, 5])
236
+ array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size*c])
237
+ return array
238
+
239
+
240
+ def arange_for_pooling(
241
+ idx_arr: np.ndarray,
242
+ pool_h: int,
243
+ pool_w: int,
244
+ ) -> np.ndarray:
245
+ h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
246
+ w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
247
+ idx_arr = np.pad(idx_arr, [[h_pad//2, (h_pad+1)//2], [w_pad//2, (w_pad+1)//2]],
248
+ mode='constant',constant_values=-1)
249
+ return einops.rearrange(
250
+ idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
251
+
252
+
253
+ def image_to_patches_and_grids(
254
+ image: np.ndarray,
255
+ max_crops: int,
256
+ overlap_margins: list[int],
257
+ base_image_input_size: list[int],
258
+ resample: PILImageResampling,
259
+ image_mean: list[float],
260
+ image_std: list[float],
261
+ image_patch_size: int,
262
+ image_pooling_w: int,
263
+ image_pooling_h: int,
264
+ crop_mode: str = "overlap-and-resize-c2",
265
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
266
+ """
267
+ :return image_grids, the shape of each (low-res, high-res) image after pooling
268
+ :return crops, the image crops to processes with the ViT
269
+ :return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
270
+ patches in `crops` to pool for that token, masked with -1
271
+ """
272
+ if isinstance(base_image_input_size, int):
273
+ base_image_input_size = (base_image_input_size, base_image_input_size)
274
+
275
+ base_image_input_d = image_patch_size
276
+ pooling_w = image_pooling_w
277
+ pooling_h = image_pooling_h
278
+ crop_patch_w = base_image_input_size[1] // base_image_input_d
279
+ crop_patch_h = base_image_input_size[0] // base_image_input_d
280
+
281
+ if crop_mode == "resize":
282
+ resized, resize_idx = build_resized_image(
283
+ image,
284
+ base_image_input_size,
285
+ resample,
286
+ image_mean,
287
+ image_std,
288
+ image_patch_size,
289
+ )
290
+ resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
291
+ resized_h, resized_w = resize_idx.shape[:2]
292
+ resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w])
293
+ image_grid = [np.array([resized_h, resized_w, 0, 0])]
294
+ return (
295
+ np.stack(image_grid, 0),
296
+ batch_pixels_to_patches(resized, image_patch_size),
297
+ resize_idx,
298
+ )
299
+
300
+ if crop_mode not in {"overlap-and-resize-c2", "overlap-and-resize"}:
301
+ raise ValueError(f"Unsupported MolmoAct2 image crop_mode {crop_mode!r}.")
302
+
303
+ crop_arr, patch_idx_arr = build_overlapping_crops(
304
+ image,
305
+ max_crops,
306
+ overlap_margins,
307
+ base_image_input_size,
308
+ resample,
309
+ image_mean,
310
+ image_std,
311
+ image_patch_size,
312
+ )
313
+ pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w)
314
+ h, w = pooling_idx.shape[:2]
315
+ pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w])
316
+
317
+ # Finally do the same for the global image
318
+ resized, resize_idx = build_resized_image(
319
+ image,
320
+ base_image_input_size,
321
+ resample,
322
+ image_mean,
323
+ image_std,
324
+ image_patch_size,
325
+ )
326
+ crop_arr = np.concatenate([resized, crop_arr], 0)
327
+
328
+ resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
329
+ resized_h, resized_w = resize_idx.shape[:2]
330
+ resize_idx = resize_idx.reshape([-1, pooling_h*pooling_w])
331
+
332
+ # Global image goes first, so the order of patches in previous crops gets increased
333
+ pooling_idx = np.where(
334
+ pooling_idx >= 0,
335
+ pooling_idx + crop_patch_h*crop_patch_w,
336
+ -1
337
+ )
338
+ pooling_idx = np.concatenate([resize_idx, pooling_idx])
339
+ image_grid = [np.array([resized_h, resized_w, h, w])]
340
+
341
+ return (
342
+ np.stack(image_grid, 0),
343
+ batch_pixels_to_patches(crop_arr, image_patch_size),
344
+ pooling_idx
345
+ )
346
+
347
+
348
+ class MolmoAct2ImagesKwargs(ImagesKwargs, total=False):
349
+ max_crops: Optional[int]
350
+ overlap_margins: Optional[list[int]]
351
+ crop_mode: Optional[str]
352
+ patch_size: Optional[int]
353
+ pooling_size: Optional[list[int]]
354
+
355
+
356
+ class MolmoAct2ImageProcessor(BaseImageProcessor):
357
+ r"""
358
+ Constructs a MolmoAct2 image processor that preprocesses images for the model.
359
+
360
+ Args:
361
+ size (`dict[str, int]` *optional*, defaults to `{"height": 378, "width": 378}`):
362
+ Size of the image after resizing.
363
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
364
+ Resampling filter to use when resizing the image.
365
+ image_mean (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
366
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
367
+ image_std (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
368
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
369
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
370
+ Whether to convert the image to RGB.
371
+ max_crops (`int`, *optional*, defaults to `8`):
372
+ Maximum number of crops to use per image.
373
+ overlap_margins (`list[int]`, *optional*, defaults to `[4, 4]`):
374
+ Overlap margins to use.
375
+ patch_size (`int`, *optional*, defaults to 14):
376
+ The spatial patch size of the vision encoder.
377
+ pooling_size (`list[int]`, *optional*, defaults to `[2, 2]`):
378
+ The pooling size of the vision adapter.
379
+ """
380
+
381
+ model_input_names = ["pixel_values", "image_token_pooling", "image_grids", "image_num_crops"]
382
+
383
+ def __init__(
384
+ self,
385
+ size: Optional[dict[str, int]] = None,
386
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
387
+ image_mean: Optional[Union[float, list[float]]] = None,
388
+ image_std: Optional[Union[float, list[float]]] = None,
389
+ do_convert_rgb: bool = True,
390
+ max_crops: int = 8,
391
+ overlap_margins: list[int] = [4, 4],
392
+ crop_mode: str = "overlap-and-resize-c2",
393
+ patch_size: int = 14,
394
+ pooling_size: list[int] = [2, 2],
395
+ **kwargs,
396
+ ) -> None:
397
+ super().__init__(**kwargs)
398
+ size = size if size is not None else {"height": 378, "width": 378}
399
+ size = get_size_dict(size, default_to_square=True)
400
+ self.size = size
401
+
402
+ self.resample = resample
403
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
404
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
405
+ self.do_convert_rgb = do_convert_rgb
406
+
407
+ self.max_crops = max_crops
408
+ self.overlap_margins = overlap_margins
409
+ self.crop_mode = crop_mode
410
+ self.patch_size = patch_size
411
+ self.pooling_size = pooling_size
412
+
413
+ def preprocess(
414
+ self,
415
+ images: ImageInput,
416
+ size: Optional[dict[str, int]] = None,
417
+ resample: Optional[PILImageResampling] = None,
418
+ image_mean: Optional[Union[float, list[float]]] = None,
419
+ image_std: Optional[Union[float, list[float]]] = None,
420
+ do_convert_rgb: Optional[bool] = None,
421
+ max_crops: Optional[int] = None,
422
+ overlap_margins: Optional[list[int]] = None,
423
+ crop_mode: Optional[str] = None,
424
+ patch_size: Optional[int] = None,
425
+ pooling_size: Optional[list[int]] = None,
426
+ return_tensors: Optional[Union[str, TensorType]] = None,
427
+ **kwargs,
428
+ ) -> BatchFeature:
429
+ """
430
+ Args:
431
+ images (`ImageInput`):
432
+ Image to preprocess.
433
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
434
+ Size of the image after resizing.
435
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
436
+ Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
437
+ has an effect if `do_resize` is set to `True`.
438
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
439
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
440
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
441
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
442
+ `True`.
443
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
444
+ Whether to convert the image to RGB.
445
+ max_crops (`int`, *optional*, defaults to `self.max_crops`):
446
+ Maximum number of crops to use per image.
447
+ overlap_margins (`list[int]`, *optional*, defaults to `self.overlap_margins`):
448
+ Overlap margins to use.
449
+ patch_size (`int`, *optional*, defaults to `self.patch_size`):
450
+ The spatial patch size of the vision encoder.
451
+ pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
452
+ The pooling size of the vision adapter.
453
+ return_tensors (`str` or `TensorType`, *optional*):
454
+ The type of tensors to return. Can be one of:
455
+ - Unset: Return a list of `np.ndarray`.
456
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
457
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
458
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
459
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
460
+
461
+ Returns:
462
+ A `BatchFeature` containing the following keys:
463
+ - `pixel_values`: The preprocessed images.
464
+ - `image_token_pooling`: The indices of the patches in `crops` to pool for each token in `image_tokens`.
465
+ - `image_grids`: The image grids.
466
+ - `image_num_crops`: The number of crops for each image.
467
+ """
468
+ if size is not None:
469
+ if "height" not in size or "width" not in size:
470
+ raise ValueError("size must contain 'height' and 'width' keys.")
471
+ else:
472
+ size = {**self.size}
473
+
474
+ base_image_input_size = [size["height"], size["width"]]
475
+
476
+ resample = resample or self.resample
477
+ image_mean = image_mean or self.image_mean
478
+ image_std = image_std or self.image_std
479
+ do_convert_rgb = do_convert_rgb or self.do_convert_rgb
480
+
481
+ max_crops = max_crops or self.max_crops
482
+ overlap_margins = overlap_margins or self.overlap_margins
483
+ crop_mode = crop_mode or self.crop_mode
484
+ patch_size = patch_size or self.patch_size
485
+ pooling_size = pooling_size or self.pooling_size
486
+
487
+ image_pooling_h, image_pooling_w = pooling_size
488
+
489
+ if images is not None:
490
+ images = self.fetch_images(images)
491
+ images = make_flat_list_of_images(images)
492
+
493
+ if images is not None and not valid_images(images):
494
+ raise ValueError(
495
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
496
+ "torch.Tensor, tf.Tensor or jax.ndarray."
497
+ )
498
+
499
+ if do_convert_rgb:
500
+ images = [convert_to_rgb(image) for image in images]
501
+
502
+ # All transformations expect numpy arrays.
503
+ images = [to_numpy_array(image) for image in images]
504
+
505
+ data = {}
506
+ if images is not None:
507
+ batch_grids = []
508
+ batch_crops = []
509
+ batch_pooled_patches_idx = []
510
+ batch_num_crops = []
511
+
512
+ for image in images:
513
+ image_grid, crops, pooled_idx = image_to_patches_and_grids(
514
+ image,
515
+ max_crops,
516
+ overlap_margins,
517
+ base_image_input_size,
518
+ resample,
519
+ image_mean,
520
+ image_std,
521
+ patch_size,
522
+ image_pooling_w,
523
+ image_pooling_h,
524
+ crop_mode,
525
+ )
526
+ batch_grids.append(image_grid)
527
+ batch_crops.append(crops)
528
+ batch_pooled_patches_idx.append(pooled_idx)
529
+ batch_num_crops.append(crops.shape[0])
530
+
531
+ pixel_values = np.concatenate(batch_crops, 0)
532
+ image_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
533
+ image_grids = np.concatenate(batch_grids, 0)
534
+ image_num_crops = np.array(batch_num_crops)
535
+
536
+ data.update(
537
+ pixel_values=pixel_values,
538
+ image_token_pooling=image_token_pooling,
539
+ image_grids=image_grids,
540
+ image_num_crops=image_num_crops,
541
+ )
542
+
543
+ return BatchFeature(data, tensor_type=return_tensors)
544
+
545
+
546
+ MolmoAct2ImageProcessor.register_for_auto_class()
inference.py ADDED
@@ -0,0 +1,768 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference utilities for MolmoAct2"""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Iterable, Optional, Sequence, Tuple
5
+
6
+ import torch
7
+ from torch.nn import functional as F
8
+ from transformers.cache_utils import Cache
9
+ from transformers.configuration_utils import PretrainedConfig
10
+
11
+
12
+ @dataclass
13
+ class _ActionFlowInputs:
14
+ trajectory: torch.Tensor
15
+ context: Any
16
+ modulations: Sequence[Any]
17
+ action_dim_is_pad: Optional[torch.Tensor]
18
+
19
+
20
+ @dataclass
21
+ class _ActionFlowCudaGraph:
22
+ key: Tuple[Any, ...]
23
+ graph: torch.cuda.CUDAGraph
24
+ static_inputs: _ActionFlowInputs
25
+ output: torch.Tensor
26
+
27
+
28
+ @dataclass
29
+ class _DepthDecodeCudaGraphLayerStage:
30
+ residual: torch.Tensor
31
+ query: torch.Tensor
32
+ key: torch.Tensor
33
+ value: torch.Tensor
34
+
35
+
36
+ @dataclass
37
+ class _DepthDecodeCudaGraphPostStage:
38
+ graph: torch.cuda.CUDAGraph
39
+ attn_context: torch.Tensor
40
+
41
+
42
+ @dataclass
43
+ class _DepthDecodeCudaGraph:
44
+ cache_key: Tuple[Any, ...]
45
+ pre_graph: torch.cuda.CUDAGraph
46
+ token_ids: torch.Tensor
47
+ cos: torch.Tensor
48
+ sin: torch.Tensor
49
+ positions: torch.Tensor
50
+ stages: Sequence[_DepthDecodeCudaGraphLayerStage]
51
+ post_graphs: Sequence[_DepthDecodeCudaGraphPostStage]
52
+ output: torch.Tensor
53
+
54
+
55
+ @dataclass
56
+ class _DepthDecodeCudaGraphSpec:
57
+ eligible: bool
58
+ cache_key_prefix: Tuple[Any, ...]
59
+ num_hidden_layers: int
60
+ head_dim: int
61
+ num_attention_heads: int
62
+
63
+
64
+ def _cache_seq_len_int(past_key_values: Optional[Cache]) -> int:
65
+ if past_key_values is None:
66
+ return 0
67
+ seq_len = past_key_values.get_seq_length()
68
+ if torch.is_tensor(seq_len):
69
+ return int(seq_len.item())
70
+ return int(seq_len)
71
+
72
+
73
+ def _cache_max_len_int(past_key_values: Optional[Cache]) -> int:
74
+ if past_key_values is None:
75
+ return -1
76
+ max_len = past_key_values.get_max_cache_shape()
77
+ if torch.is_tensor(max_len):
78
+ return int(max_len.item())
79
+ return int(max_len)
80
+
81
+
82
+ def _iter_cache_key_values(
83
+ past_key_values: Cache,
84
+ ) -> Iterable[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]:
85
+ layers = getattr(past_key_values, "layers", None)
86
+ if layers is not None:
87
+ for layer in layers:
88
+ yield getattr(layer, "keys", None), getattr(layer, "values", None)
89
+ return
90
+ for layer in past_key_values:
91
+ yield layer[0], layer[1]
92
+
93
+
94
+ class _DepthDecodeStaticLayerCache:
95
+ is_compileable = False
96
+ is_sliding = False
97
+
98
+ def __init__(self, max_cache_len: int) -> None:
99
+ self.max_cache_len = int(max_cache_len)
100
+ self.cumulative_length = 0
101
+ self.keys: Optional[torch.Tensor] = None
102
+ self.values: Optional[torch.Tensor] = None
103
+
104
+ def _allocate(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
105
+ bsz, n_heads = key_states.shape[:2]
106
+ self.keys = torch.empty(
107
+ (bsz, n_heads, self.max_cache_len, key_states.shape[-1]),
108
+ dtype=key_states.dtype,
109
+ device=key_states.device,
110
+ )
111
+ self.values = torch.empty(
112
+ (bsz, n_heads, self.max_cache_len, value_states.shape[-1]),
113
+ dtype=value_states.dtype,
114
+ device=value_states.device,
115
+ )
116
+
117
+ def update(
118
+ self,
119
+ key_states: torch.Tensor,
120
+ value_states: torch.Tensor,
121
+ *args,
122
+ **kwargs,
123
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
124
+ if self.keys is None:
125
+ self._allocate(key_states, value_states)
126
+ start = self.cumulative_length
127
+ end = start + key_states.shape[-2]
128
+ if end > self.max_cache_len:
129
+ raise RuntimeError(
130
+ f"KV cache length {end} exceeds max_cache_len={self.max_cache_len}."
131
+ )
132
+ self.keys[:, :, start:end, :].copy_(key_states)
133
+ self.values[:, :, start:end, :].copy_(value_states)
134
+ self.cumulative_length = end
135
+ return self.keys[:, :, :end, :], self.values[:, :, :end, :]
136
+
137
+ def get_seq_length(self) -> int:
138
+ return self.cumulative_length
139
+
140
+ def get_max_cache_shape(self) -> int:
141
+ return -1
142
+
143
+ def reset(self) -> None:
144
+ self.cumulative_length = 0
145
+
146
+
147
+ class _DepthDecodeStaticCache(Cache):
148
+ def __init__(self, config: PretrainedConfig, max_cache_len: int) -> None:
149
+ text_config = config.get_text_config(decoder=True)
150
+ super().__init__(
151
+ layers=[
152
+ _DepthDecodeStaticLayerCache(max_cache_len=max_cache_len)
153
+ for _ in range(text_config.num_hidden_layers)
154
+ ]
155
+ )
156
+
157
+ def get_seq_length(self, layer_idx: int = 0) -> int:
158
+ return self.layers[layer_idx].get_seq_length()
159
+
160
+ def get_max_cache_shape(self, layer_idx: int = 0) -> int:
161
+ return self.layers[layer_idx].get_max_cache_shape()
162
+
163
+ def reset(self) -> None:
164
+ for layer in self.layers:
165
+ layer.reset()
166
+
167
+
168
+ class ActionCudaGraphManager:
169
+ def __init__(self, model: Any) -> None:
170
+ self.model = model
171
+ self.enabled = True
172
+ self.action_flow_graph: Optional[_ActionFlowCudaGraph] = None
173
+
174
+ def set_enabled(self, enabled: bool) -> None:
175
+ self.enabled = bool(enabled)
176
+
177
+ def can_use_action_flow(self, inputs: _ActionFlowInputs) -> bool:
178
+ action_model = self.model
179
+ if not self.enabled:
180
+ return False
181
+ if action_model.training or action_model._require_action_expert().training:
182
+ return False
183
+ if inputs.trajectory.device.type != "cuda":
184
+ return False
185
+
186
+ def all_on_cuda():
187
+ yield inputs.trajectory
188
+ for k, v in inputs.context.kv_contexts:
189
+ yield k
190
+ yield v
191
+ for t in (
192
+ inputs.context.cross_mask,
193
+ inputs.context.self_mask,
194
+ inputs.context.valid_action,
195
+ inputs.action_dim_is_pad,
196
+ ):
197
+ if t is not None:
198
+ yield t
199
+ if inputs.context.rope_cache is not None:
200
+ yield from inputs.context.rope_cache
201
+ for step in inputs.modulations:
202
+ yield step.conditioning
203
+ for block_modulation in step.block_modulations:
204
+ yield from block_modulation
205
+ yield from step.final_modulation
206
+
207
+ return all(t.device.type == "cuda" for t in all_on_cuda())
208
+
209
+ def run_action_flow(
210
+ self,
211
+ inputs: _ActionFlowInputs,
212
+ steps: int,
213
+ run_loop,
214
+ ) -> torch.Tensor:
215
+ key = _cuda_graph_key(inputs, steps)
216
+ cache = self.action_flow_graph
217
+ if cache is None or cache.key != key:
218
+ static_inputs = _clone_static_inputs(inputs)
219
+ graph, output = _capture_cuda_graph(
220
+ lambda: run_loop(static_inputs, steps),
221
+ inputs.trajectory.device,
222
+ after_warmup=lambda: static_inputs.trajectory.copy_(inputs.trajectory),
223
+ )
224
+ cache = _ActionFlowCudaGraph(
225
+ key=key,
226
+ graph=graph,
227
+ static_inputs=static_inputs,
228
+ output=output,
229
+ )
230
+ self.action_flow_graph = cache
231
+ else:
232
+ _copy_inputs_(cache.static_inputs, inputs)
233
+
234
+ cache.graph.replay()
235
+ return cache.output.clone()
236
+
237
+
238
+ class DepthDecodeCudaGraphManager:
239
+ def __init__(self, model: Any) -> None:
240
+ self.model = model
241
+ self.backbone = model.model
242
+ self.enabled = True
243
+ self.graph: Optional[_DepthDecodeCudaGraph] = None
244
+ self.graph_spec: Optional[_DepthDecodeCudaGraphSpec] = None
245
+
246
+ def set_enabled(self, enabled: bool) -> None:
247
+ self.enabled = bool(enabled)
248
+
249
+ def make_static_cache(self, max_cache_len: int) -> _DepthDecodeStaticCache:
250
+ return _DepthDecodeStaticCache(
251
+ config=self.model.config.text_config,
252
+ max_cache_len=max_cache_len,
253
+ )
254
+
255
+ def _depth_decode_spec(self) -> _DepthDecodeCudaGraphSpec:
256
+ static = self.graph_spec
257
+ if static is None:
258
+ cfg = self.backbone.transformer.config
259
+ rotary_emb = getattr(self.backbone.transformer, "rotary_emb", None)
260
+ static = _DepthDecodeCudaGraphSpec(
261
+ eligible=(
262
+ not cfg.norm_after
263
+ and cfg.rope_scaling_layers is None
264
+ and getattr(rotary_emb, "rope_type", None) == "default"
265
+ and cfg._attn_implementation == "sdpa"
266
+ ),
267
+ cache_key_prefix=(
268
+ cfg.hidden_size,
269
+ cfg.num_attention_heads,
270
+ cfg.num_key_value_heads,
271
+ cfg.head_dim,
272
+ cfg.num_hidden_layers,
273
+ cfg.use_qk_norm,
274
+ cfg.qk_norm_type,
275
+ cfg._attn_implementation,
276
+ ),
277
+ num_hidden_layers=cfg.num_hidden_layers,
278
+ head_dim=cfg.head_dim,
279
+ num_attention_heads=cfg.num_attention_heads,
280
+ )
281
+ self.graph_spec = static
282
+ return static
283
+
284
+ def can_use(
285
+ self,
286
+ next_input_ids: torch.Tensor,
287
+ *,
288
+ past_key_values: Cache,
289
+ attention_bias: torch.Tensor,
290
+ ) -> bool:
291
+ if (
292
+ not self.enabled
293
+ or self.model.training
294
+ or self.backbone.transformer.training
295
+ ):
296
+ return False
297
+ if next_input_ids.device.type != "cuda":
298
+ return False
299
+ if (
300
+ next_input_ids.ndim != 2
301
+ or next_input_ids.shape[0] != 1
302
+ or next_input_ids.shape[1] != 1
303
+ ):
304
+ return False
305
+ if not isinstance(past_key_values, _DepthDecodeStaticCache):
306
+ return False
307
+ if (
308
+ not torch.is_tensor(attention_bias)
309
+ or attention_bias.device != next_input_ids.device
310
+ ):
311
+ return False
312
+ return self._depth_decode_spec().eligible
313
+
314
+ def _depth_decode_key(
315
+ self,
316
+ next_input_ids: torch.Tensor,
317
+ attention_bias: torch.Tensor,
318
+ ) -> Tuple[Any, ...]:
319
+ device = next_input_ids.device
320
+ return (
321
+ self._depth_decode_spec().cache_key_prefix,
322
+ device.type,
323
+ device.index,
324
+ self.model.lm_head.weight.dtype,
325
+ attention_bias.shape[-1],
326
+ )
327
+
328
+ def _select_depth_decode_rope(
329
+ self, cos: torch.Tensor, sin: torch.Tensor, *, past_length: int
330
+ ) -> None:
331
+ emb = self.backbone.transformer.rotary_emb
332
+ cos.copy_(emb._pos_cos_cache[0, :, past_length : past_length + 1, :])
333
+ sin.copy_(emb._pos_sin_cache[0, :, past_length : past_length + 1, :])
334
+
335
+ def _depth_decode_pre_layer(
336
+ self,
337
+ layer_idx: int,
338
+ hidden_states: torch.Tensor,
339
+ cos: torch.Tensor,
340
+ sin: torch.Tensor,
341
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
342
+ block = self.backbone.transformer.blocks[layer_idx]
343
+ attention = block.self_attn
344
+ residual = hidden_states
345
+ hidden_states = block.attn_norm(hidden_states)
346
+
347
+ input_shape = hidden_states.shape[:-1]
348
+ hidden_shape = (*input_shape, -1, attention.head_dim)
349
+ qkv = attention.att_proj(hidden_states)
350
+ query_states, key_states, value_states = qkv.split(attention.fused_dims, dim=-1)
351
+ value_states = value_states.view(hidden_shape)
352
+
353
+ apply_qk_norm = attention.q_norm is not None and attention.k_norm is not None
354
+ norm_after_view = apply_qk_norm and attention.qk_norm_type == "qwen3"
355
+
356
+ if apply_qk_norm and not norm_after_view:
357
+ query_states = attention.q_norm(query_states)
358
+ key_states = attention.k_norm(key_states)
359
+
360
+ query_states = query_states.view(hidden_shape)
361
+ key_states = key_states.view(hidden_shape)
362
+
363
+ if norm_after_view:
364
+ query_states = attention.q_norm(query_states)
365
+ key_states = attention.k_norm(key_states)
366
+
367
+ query_states = query_states.transpose(1, 2)
368
+ key_states = key_states.transpose(1, 2)
369
+ value_states = value_states.transpose(1, 2)
370
+ query_states, key_states = _apply_rotary_pos_emb(
371
+ query_states, key_states, cos, sin
372
+ )
373
+ return residual, query_states, key_states, value_states
374
+
375
+ def _depth_decode_pre0(
376
+ self,
377
+ token_ids: torch.Tensor,
378
+ cos: torch.Tensor,
379
+ sin: torch.Tensor,
380
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
381
+ inputs_embeds = self.model._embed_base_tokens(token_ids)
382
+ return self._depth_decode_pre_layer(0, inputs_embeds, cos, sin)
383
+
384
+ def _depth_decode_post_layer(
385
+ self,
386
+ layer_idx: int,
387
+ residual: torch.Tensor,
388
+ attn_context: torch.Tensor,
389
+ ) -> torch.Tensor:
390
+ block = self.backbone.transformer.blocks[layer_idx]
391
+ attention = block.self_attn
392
+ input_shape = residual.shape[:-1]
393
+ attn_output = attn_context.reshape(*input_shape, -1).contiguous()
394
+ attn_output = attention.attn_out(attn_output)
395
+ hidden_states = residual + block.dropout(attn_output)
396
+
397
+ residual = hidden_states
398
+ hidden_states = block.ff_norm(hidden_states)
399
+ hidden_states = block.mlp(hidden_states)
400
+ hidden_states = residual + block.dropout(hidden_states)
401
+ return hidden_states
402
+
403
+ def _depth_decode_post_and_pre_next(
404
+ self,
405
+ layer_idx: int,
406
+ residual: torch.Tensor,
407
+ attn_context: torch.Tensor,
408
+ cos: torch.Tensor,
409
+ sin: torch.Tensor,
410
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
411
+ hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
412
+ return self._depth_decode_pre_layer(layer_idx + 1, hidden_states, cos, sin)
413
+
414
+ def _depth_decode_last_post(
415
+ self,
416
+ layer_idx: int,
417
+ residual: torch.Tensor,
418
+ attn_context: torch.Tensor,
419
+ ) -> torch.Tensor:
420
+ hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
421
+ return self.backbone.transformer.ln_f(hidden_states)
422
+
423
+ def _build_depth_decode_graph(
424
+ self,
425
+ next_input_ids: torch.Tensor,
426
+ *,
427
+ past_length: int,
428
+ attention_bias: torch.Tensor,
429
+ ) -> _DepthDecodeCudaGraph:
430
+ text_config = self.backbone.transformer.config
431
+ device = next_input_ids.device
432
+ dtype = self.model.lm_head.weight.dtype
433
+ static = self._depth_decode_spec()
434
+ num_layers = static.num_hidden_layers
435
+ head_dim = static.head_dim
436
+ max_cache_len = int(attention_bias.shape[-1])
437
+ max_rope_len = max(int(text_config.max_position_embeddings or 0), max_cache_len)
438
+ self.backbone.transformer.prepare_rope_cache(
439
+ device=device, max_seq_len=max_rope_len
440
+ )
441
+
442
+ token_ids = torch.empty((1, 1), device=device, dtype=torch.long)
443
+ cos = torch.empty((1, 1, head_dim), device=device, dtype=dtype)
444
+ sin = torch.empty_like(cos)
445
+ positions = torch.arange(max_cache_len, device=device, dtype=torch.long)
446
+ context_shape = (1, 1, static.num_attention_heads, head_dim)
447
+
448
+ token_ids.copy_(next_input_ids)
449
+ self._select_depth_decode_rope(cos, sin, past_length=past_length)
450
+
451
+ pre_graph, pre_output = _capture_cuda_graph(
452
+ lambda: self._depth_decode_pre0(token_ids, cos, sin),
453
+ device,
454
+ )
455
+ stages = [_DepthDecodeCudaGraphLayerStage(*pre_output)]
456
+ post_graphs = []
457
+ for layer_idx in range(num_layers - 1):
458
+ stage = stages[-1]
459
+ attn_context = torch.empty(context_shape, device=device, dtype=dtype)
460
+ graph, output = _capture_cuda_graph(
461
+ lambda layer_idx=layer_idx, stage=stage, attn_context=attn_context: (
462
+ self._depth_decode_post_and_pre_next(
463
+ layer_idx,
464
+ stage.residual,
465
+ attn_context,
466
+ cos,
467
+ sin,
468
+ )
469
+ ),
470
+ device,
471
+ )
472
+ post_graphs.append(
473
+ _DepthDecodeCudaGraphPostStage(graph=graph, attn_context=attn_context)
474
+ )
475
+ stages.append(_DepthDecodeCudaGraphLayerStage(*output))
476
+
477
+ last_stage = stages[-1]
478
+ last_attn_context = torch.empty(context_shape, device=device, dtype=dtype)
479
+ last_graph, last_output = _capture_cuda_graph(
480
+ lambda: self._depth_decode_last_post(
481
+ num_layers - 1,
482
+ last_stage.residual,
483
+ last_attn_context,
484
+ ),
485
+ device,
486
+ )
487
+ post_graphs.append(
488
+ _DepthDecodeCudaGraphPostStage(
489
+ graph=last_graph, attn_context=last_attn_context
490
+ )
491
+ )
492
+ return _DepthDecodeCudaGraph(
493
+ cache_key=self._depth_decode_key(next_input_ids, attention_bias),
494
+ pre_graph=pre_graph,
495
+ token_ids=token_ids,
496
+ cos=cos,
497
+ sin=sin,
498
+ positions=positions,
499
+ stages=tuple(stages),
500
+ post_graphs=tuple(post_graphs),
501
+ output=last_output,
502
+ )
503
+
504
+ def _get_depth_decode_graph(
505
+ self,
506
+ next_input_ids: torch.Tensor,
507
+ *,
508
+ past_length: int,
509
+ attention_bias: torch.Tensor,
510
+ ) -> _DepthDecodeCudaGraph:
511
+ key = self._depth_decode_key(next_input_ids, attention_bias)
512
+ decode_graph = self.graph
513
+ if decode_graph is None or decode_graph.cache_key != key:
514
+ decode_graph = self._build_depth_decode_graph(
515
+ next_input_ids,
516
+ past_length=past_length,
517
+ attention_bias=attention_bias,
518
+ )
519
+ self.graph = decode_graph
520
+ else:
521
+ decode_graph.token_ids.copy_(next_input_ids)
522
+ self._select_depth_decode_rope(
523
+ decode_graph.cos, decode_graph.sin, past_length=past_length
524
+ )
525
+ return decode_graph
526
+
527
+ def _run_depth_decode_attention_core(
528
+ self,
529
+ layer_idx: int,
530
+ stage: _DepthDecodeCudaGraphLayerStage,
531
+ *,
532
+ past_key_values: Cache,
533
+ attention_bias: torch.Tensor,
534
+ cache_position: torch.Tensor,
535
+ cos: torch.Tensor,
536
+ sin: torch.Tensor,
537
+ ) -> torch.Tensor:
538
+ attention = self.backbone.transformer.blocks[layer_idx].self_attn
539
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
540
+ key_states, value_states = past_key_values.update(
541
+ stage.key,
542
+ stage.value,
543
+ layer_idx,
544
+ cache_kwargs,
545
+ )
546
+ key_states = _repeat_kv(key_states, attention.num_key_value_groups)
547
+ value_states = _repeat_kv(value_states, attention.num_key_value_groups)
548
+ attn_output = F.scaled_dot_product_attention(
549
+ stage.query,
550
+ key_states,
551
+ value_states,
552
+ attn_mask=attention_bias,
553
+ dropout_p=0.0,
554
+ is_causal=False,
555
+ )
556
+ return attn_output.transpose(1, 2)
557
+
558
+ def run(
559
+ self,
560
+ next_input_ids: torch.Tensor,
561
+ *,
562
+ past_key_values: Cache,
563
+ attention_bias: torch.Tensor,
564
+ past_length: int,
565
+ ) -> Tuple[torch.Tensor, Cache]:
566
+ end = past_length + 1
567
+ decode_graph = self._get_depth_decode_graph(
568
+ next_input_ids,
569
+ past_length=past_length,
570
+ attention_bias=attention_bias,
571
+ )
572
+ cache_position = decode_graph.positions[past_length:end]
573
+ attention_bias_q = attention_bias[:, :, past_length:end, :end]
574
+
575
+ decode_graph.pre_graph.replay()
576
+
577
+ for layer_idx, post_graph in enumerate(decode_graph.post_graphs):
578
+ attn_context = self._run_depth_decode_attention_core(
579
+ layer_idx,
580
+ decode_graph.stages[layer_idx],
581
+ past_key_values=past_key_values,
582
+ attention_bias=attention_bias_q,
583
+ cache_position=cache_position,
584
+ cos=decode_graph.cos,
585
+ sin=decode_graph.sin,
586
+ )
587
+ post_graph.attn_context.copy_(attn_context)
588
+ post_graph.graph.replay()
589
+
590
+ return decode_graph.output, past_key_values
591
+
592
+
593
+ def _cuda_graph_tensor_signature(
594
+ tensor: Optional[torch.Tensor],
595
+ ) -> Optional[Tuple[Any, ...]]:
596
+ if tensor is None:
597
+ return None
598
+ return (
599
+ tuple(tensor.shape),
600
+ tuple(tensor.stride()),
601
+ str(tensor.dtype),
602
+ str(tensor.device),
603
+ )
604
+
605
+
606
+ def _cuda_graph_context_signature(context: Any) -> Tuple[Any, ...]:
607
+ sig = _cuda_graph_tensor_signature
608
+ return (
609
+ tuple((sig(k), sig(v)) for k, v in context.kv_contexts),
610
+ sig(context.cross_mask),
611
+ sig(context.self_mask),
612
+ sig(context.valid_action),
613
+ None
614
+ if context.rope_cache is None
615
+ else tuple(sig(t) for t in context.rope_cache),
616
+ )
617
+
618
+
619
+ def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> Tuple[Any, ...]:
620
+ sig = _cuda_graph_tensor_signature
621
+ return tuple(
622
+ (
623
+ sig(step.conditioning),
624
+ tuple(
625
+ tuple(sig(t) for t in block_modulation)
626
+ for block_modulation in step.block_modulations
627
+ ),
628
+ tuple(sig(t) for t in step.final_modulation),
629
+ )
630
+ for step in modulations
631
+ )
632
+
633
+
634
+ def _cuda_graph_key(inputs: _ActionFlowInputs, steps: int) -> Tuple[Any, ...]:
635
+ sig = _cuda_graph_tensor_signature
636
+ return (
637
+ sig(inputs.trajectory),
638
+ _cuda_graph_context_signature(inputs.context),
639
+ _cuda_graph_modulation_signature(inputs.modulations),
640
+ sig(inputs.action_dim_is_pad),
641
+ int(steps),
642
+ )
643
+
644
+
645
+ def _clone_static_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
646
+ if tensor is None:
647
+ return None
648
+ static = torch.empty_strided(
649
+ tuple(tensor.shape),
650
+ tuple(tensor.stride()),
651
+ device=tensor.device,
652
+ dtype=tensor.dtype,
653
+ )
654
+ static.copy_(tensor)
655
+ return static
656
+
657
+
658
+ def _clone_static_context(context: Any) -> Any:
659
+ rope_cache = None
660
+ if context.rope_cache is not None:
661
+ rope_cache = tuple(_clone_static_tensor(t) for t in context.rope_cache)
662
+ return context.__class__(
663
+ kv_contexts=tuple(
664
+ (_clone_static_tensor(k), _clone_static_tensor(v))
665
+ for k, v in context.kv_contexts
666
+ ),
667
+ cross_mask=_clone_static_tensor(context.cross_mask),
668
+ self_mask=_clone_static_tensor(context.self_mask),
669
+ valid_action=_clone_static_tensor(context.valid_action),
670
+ rope_cache=rope_cache,
671
+ )
672
+
673
+
674
+ def _clone_static_modulations(modulations: Sequence[Any]) -> Sequence[Any]:
675
+ return tuple(
676
+ step.__class__(
677
+ conditioning=_clone_static_tensor(step.conditioning),
678
+ block_modulations=tuple(
679
+ tuple(_clone_static_tensor(t) for t in block_modulation)
680
+ for block_modulation in step.block_modulations
681
+ ),
682
+ final_modulation=tuple(
683
+ _clone_static_tensor(t) for t in step.final_modulation
684
+ ),
685
+ )
686
+ for step in modulations
687
+ )
688
+
689
+
690
+ def _clone_static_inputs(inputs: _ActionFlowInputs) -> _ActionFlowInputs:
691
+ return _ActionFlowInputs(
692
+ trajectory=_clone_static_tensor(inputs.trajectory),
693
+ context=_clone_static_context(inputs.context),
694
+ modulations=_clone_static_modulations(inputs.modulations),
695
+ action_dim_is_pad=_clone_static_tensor(inputs.action_dim_is_pad),
696
+ )
697
+
698
+
699
+ def _copy_context_(dst: Any, src: Any) -> None:
700
+ for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts):
701
+ dst_k.copy_(src_k)
702
+ dst_v.copy_(src_v)
703
+ if src.cross_mask is not None:
704
+ dst.cross_mask.copy_(src.cross_mask)
705
+ if src.self_mask is not None:
706
+ dst.self_mask.copy_(src.self_mask)
707
+ if src.valid_action is not None:
708
+ dst.valid_action.copy_(src.valid_action)
709
+ if src.rope_cache is not None:
710
+ for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache):
711
+ dst_tensor.copy_(src_tensor)
712
+
713
+
714
+ def _copy_inputs_(dst: _ActionFlowInputs, src: _ActionFlowInputs) -> None:
715
+ dst.trajectory.copy_(src.trajectory)
716
+ _copy_context_(dst.context, src.context)
717
+ if src.action_dim_is_pad is not None:
718
+ dst.action_dim_is_pad.copy_(src.action_dim_is_pad)
719
+
720
+
721
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
722
+ x1 = x[..., : x.shape[-1] // 2]
723
+ x2 = x[..., x.shape[-1] // 2 :]
724
+ return torch.cat((-x2, x1), dim=-1)
725
+
726
+
727
+ def _apply_rotary_pos_emb(
728
+ q: torch.Tensor,
729
+ k: torch.Tensor,
730
+ cos: torch.Tensor,
731
+ sin: torch.Tensor,
732
+ unsqueeze_dim: int = 1,
733
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
734
+ cos = cos.unsqueeze(unsqueeze_dim)
735
+ sin = sin.unsqueeze(unsqueeze_dim)
736
+ q_embed = (q * cos) + (_rotate_half(q) * sin)
737
+ k_embed = (k * cos) + (_rotate_half(k) * sin)
738
+ return q_embed, k_embed
739
+
740
+
741
+ def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
742
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
743
+ if n_rep == 1:
744
+ return hidden_states
745
+ hidden_states = hidden_states[:, :, None, :, :].expand(
746
+ batch, num_key_value_heads, n_rep, slen, head_dim
747
+ )
748
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
749
+
750
+
751
+ def _capture_cuda_graph(
752
+ fn,
753
+ device: torch.device,
754
+ *,
755
+ after_warmup=None,
756
+ ) -> Tuple[torch.cuda.CUDAGraph, Any]:
757
+ warmup_stream = torch.cuda.Stream(device=device)
758
+ warmup_stream.wait_stream(torch.cuda.current_stream(device))
759
+ with torch.cuda.stream(warmup_stream):
760
+ fn()
761
+ torch.cuda.current_stream(device).wait_stream(warmup_stream)
762
+ if after_warmup is not None:
763
+ after_warmup()
764
+
765
+ graph = torch.cuda.CUDAGraph()
766
+ with torch.cuda.graph(graph):
767
+ output = fn()
768
+ return graph, output
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a62d97555bec82618aa1e7c3f143a41067ba7e8bed074d825b9b1e88391d516a
3
+ size 4183452369
modeling_molmoact2.py ADDED
The diff for this file is too large to render. See raw diff
 
norm_stats.json ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "format": "molmoact2_norm_stats.v1",
3
+ "norm_mode": "q01_q99",
4
+ "metadata_by_tag": {
5
+ "so100_so101_molmoact2": {
6
+ "action_key": "action",
7
+ "state_key": "observation.state",
8
+ "camera_keys": [],
9
+ "normalize_gripper": true,
10
+ "action_horizon": 30,
11
+ "n_action_steps": 30,
12
+ "setup_type": "single so100/so101 robotic arm in molmoact2",
13
+ "control_mode": "absolute joint pose",
14
+ "action_stats": {
15
+ "min": [
16
+ -122.607421875,
17
+ -270.0,
18
+ -269.208984375,
19
+ -125.771484375,
20
+ -269.912109375,
21
+ -31.57327651977539
22
+ ],
23
+ "max": [
24
+ 179.208984375,
25
+ 219.638671875,
26
+ 195.380859375,
27
+ 178.9453125,
28
+ 269.82421875,
29
+ 119.40789031982422
30
+ ],
31
+ "mean": [
32
+ 3.343996486826433,
33
+ 125.7905980370996,
34
+ 120.20220128113388,
35
+ 55.88144220174933,
36
+ -11.543010633027725,
37
+ 11.25886240824774
38
+ ],
39
+ "std": [
40
+ 28.909870406169997,
41
+ 52.25069634659296,
42
+ 47.94432906599221,
43
+ 36.01019142727721,
44
+ 69.35504013212369,
45
+ 17.116239869449775
46
+ ],
47
+ "count": [
48
+ 19619650.0
49
+ ],
50
+ "q01": [
51
+ -42.1300246338976,
52
+ 45.18258358164995,
53
+ 35.40059182962813,
54
+ 4.929781836327758,
55
+ -65.57568617645342,
56
+ -0.3016556932619033
57
+ ],
58
+ "q10": [
59
+ -25.040070398997557,
60
+ 68.27827215165794,
61
+ 65.76540485606242,
62
+ 26.58811186925123,
63
+ -39.81868441470048,
64
+ 0.26123181871944706
65
+ ],
66
+ "q50": [
67
+ 3.0828094324713105,
68
+ 124.5495736487354,
69
+ 122.75175717637279,
70
+ 57.77960070056314,
71
+ -11.094802886190045,
72
+ 4.866634607477139
73
+ ],
74
+ "q90": [
75
+ 31.591544866079253,
76
+ 181.76986724267596,
77
+ 168.5741215400282,
78
+ 82.4353358815596,
79
+ 16.05609349144359,
80
+ 32.12324970648343
81
+ ],
82
+ "q99": [
83
+ 48.55349563198916,
84
+ 186.10646680077767,
85
+ 173.6076722013997,
86
+ 93.41056417929472,
87
+ 43.53107398260694,
88
+ 44.74649336930881
89
+ ],
90
+ "names": [
91
+ "shoulder_pan",
92
+ "shoulder_lift",
93
+ "elbow_flex",
94
+ "wrist_flex",
95
+ "wrist_roll",
96
+ "gripper"
97
+ ],
98
+ "mask": [
99
+ true,
100
+ true,
101
+ true,
102
+ true,
103
+ true,
104
+ true
105
+ ]
106
+ },
107
+ "state_stats": {
108
+ "min": [
109
+ -115.048828125,
110
+ -270.0,
111
+ -235.8984375,
112
+ -113.818359375,
113
+ -268.9453125,
114
+ -8.521058082580566
115
+ ],
116
+ "max": [
117
+ 178.505859375,
118
+ 218.49609375,
119
+ 192.041015625,
120
+ 207.861328125,
121
+ 250.048828125,
122
+ 118.2519302368164
123
+ ],
124
+ "mean": [
125
+ 3.3225097946752244,
126
+ 124.40594064960378,
127
+ 121.59550610749059,
128
+ 55.903039878016074,
129
+ -11.41740021122887,
130
+ 13.358497334686597
131
+ ],
132
+ "std": [
133
+ 28.79265204113751,
134
+ 52.702867303079756,
135
+ 47.00596021941705,
136
+ 35.53803566355756,
137
+ 69.12836626047817,
138
+ 16.333280282904557
139
+ ],
140
+ "count": [
141
+ 19619650.0
142
+ ],
143
+ "q01": [
144
+ -41.90962240941357,
145
+ 43.66791235922949,
146
+ 38.38770483255723,
147
+ 5.711740446834044,
148
+ -63.44539045209019,
149
+ 0.9435577790191543
150
+ ],
151
+ "q10": [
152
+ -24.949315993050774,
153
+ 66.30007546431412,
154
+ 68.16816985859437,
155
+ 27.120731646136054,
156
+ -39.50255020332888,
157
+ 1.6190225837869365
158
+ ],
159
+ "q50": [
160
+ 3.066375725640164,
161
+ 123.16482094240277,
162
+ 124.39930058290133,
163
+ 57.88605464633133,
164
+ -11.037436711677765,
165
+ 9.241478261568748
166
+ ],
167
+ "q90": [
168
+ 31.472920732960127,
169
+ 180.87158401301218,
170
+ 168.5699720215359,
171
+ 81.64709150074712,
172
+ 15.887605114617852,
173
+ 31.887861734718296
174
+ ],
175
+ "q99": [
176
+ 48.29435703371732,
177
+ 185.2611055842669,
178
+ 173.13578487933165,
179
+ 91.78122415137209,
180
+ 42.94491979114059,
181
+ 44.13755601580974
182
+ ],
183
+ "names": [
184
+ "shoulder_pan",
185
+ "shoulder_lift",
186
+ "elbow_flex",
187
+ "wrist_flex",
188
+ "wrist_roll",
189
+ "gripper"
190
+ ],
191
+ "mask": [
192
+ true,
193
+ true,
194
+ true,
195
+ true,
196
+ true,
197
+ true
198
+ ]
199
+ }
200
+ }
201
+ }
202
+ }
processing_molmoact2.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processor class for MolmoAct2.
3
+ """
4
+ from typing import Optional, Union
5
+ import dataclasses
6
+
7
+ import numpy as np
8
+
9
+ from transformers.image_utils import ImageInput
10
+ from transformers.video_utils import VideoInput
11
+ from transformers.processing_utils import (
12
+ Unpack,
13
+ ProcessingKwargs,
14
+ ProcessorMixin,
15
+ )
16
+ from transformers.feature_extraction_utils import BatchFeature
17
+ from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
18
+ from transformers.utils import logging
19
+
20
+ from transformers import AutoTokenizer
21
+ from .image_processing_molmoact2 import MolmoAct2ImagesKwargs, MolmoAct2ImageProcessor
22
+ from .video_processing_molmoact2 import MolmoAct2VideoProcessorKwargs, MolmoAct2VideoProcessor
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ # Special tokens, these should be present in any tokenizer we use since the preprocessor uses them
29
+ IMAGE_PATCH_TOKEN = f"<im_patch>" # Where to insert high-res tokens
30
+ IMAGE_LOW_RES_TOKEN = f"<im_low>" # Where to insert low-res tokens
31
+ IM_START_TOKEN = f"<im_start>"
32
+ LOW_RES_IMAGE_START_TOKEN = f"<low_res_im_start>"
33
+ FRAME_START_TOKEN = f"<frame_start>"
34
+ IM_END_TOKEN = f"<im_end>"
35
+ FRAME_END_TOKEN= f"<frame_end>"
36
+ IM_COL_TOKEN = f"<im_col>"
37
+ IMAGE_PROMPT = "<|image|>"
38
+ VIDEO_PROMPT = "<|video|>"
39
+
40
+ IMAGE_TOKENS = [
41
+ IMAGE_PATCH_TOKEN,
42
+ IM_COL_TOKEN,
43
+ IM_START_TOKEN,
44
+ LOW_RES_IMAGE_START_TOKEN,
45
+ FRAME_START_TOKEN,
46
+ IM_END_TOKEN,
47
+ FRAME_END_TOKEN,
48
+ IMAGE_LOW_RES_TOKEN,
49
+ ]
50
+
51
+
52
+ class MolmoAct2ProcessorKwargs(ProcessingKwargs, total=False):
53
+ """MolmoAct2 processor kwargs"""
54
+ images_kwargs: MolmoAct2ImagesKwargs
55
+ videos_kwargs: MolmoAct2VideoProcessorKwargs
56
+ _defaults = {
57
+ "text_kwargs": {
58
+ "padding": False,
59
+ "return_mm_token_type_ids": True,
60
+ },
61
+ "videos_kwargs": {"return_metadata": True},
62
+ }
63
+
64
+
65
+ class MolmoAct2Processor(ProcessorMixin):
66
+ attributes = ["image_processor", "video_processor", "tokenizer"]
67
+ optional_attributes = [
68
+ "chat_template",
69
+ "time_mode",
70
+ "image_use_col_tokens",
71
+ "use_single_crop_col_tokens",
72
+ "use_single_crop_start_token",
73
+ "video_use_col_tokens",
74
+ "use_frame_special_tokens",
75
+ ]
76
+ image_processor_class = "AutoImageProcessor"
77
+ video_processor_class = "AutoVideoProcessor"
78
+ tokenizer_class = "AutoTokenizer"
79
+
80
+ def __init__(
81
+ self,
82
+ image_processor: MolmoAct2ImageProcessor = None,
83
+ video_processor: MolmoAct2VideoProcessor = None,
84
+ tokenizer: AutoTokenizer = None,
85
+ chat_template: Optional[str] = None,
86
+ image_use_col_tokens: Optional[bool] = True,
87
+ use_single_crop_col_tokens: Optional[bool] = None,
88
+ use_single_crop_start_token: Optional[bool] = True,
89
+ video_use_col_tokens: Optional[bool] = False,
90
+ use_frame_special_tokens: Optional[bool] = True,
91
+ **kwargs
92
+ ) -> None:
93
+ super().__init__(
94
+ image_processor,
95
+ video_processor,
96
+ tokenizer,
97
+ chat_template=chat_template,
98
+ )
99
+ self.image_use_col_tokens = image_use_col_tokens
100
+ self.use_single_crop_col_tokens = use_single_crop_col_tokens
101
+ self.use_single_crop_start_token = use_single_crop_start_token
102
+ self.video_use_col_tokens = video_use_col_tokens
103
+ self.use_frame_special_tokens = use_frame_special_tokens
104
+
105
+ self.image_placeholder_token = IMAGE_PROMPT
106
+ self.video_placeholder_token = VIDEO_PROMPT
107
+ self.image_token_ids = [
108
+ tokenizer.convert_tokens_to_ids(token)
109
+ for token in IMAGE_TOKENS
110
+ ]
111
+
112
+ def get_image_tokens(self, image_grid: np.ndarray):
113
+ resized_h, resized_w, height, width = image_grid
114
+ if int(height) == 0 or int(width) == 0:
115
+ per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
116
+ use_single_crop_col_tokens = (
117
+ self.image_use_col_tokens
118
+ if self.use_single_crop_col_tokens is None
119
+ else self.use_single_crop_col_tokens
120
+ )
121
+ if use_single_crop_col_tokens:
122
+ per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
123
+ joint = [
124
+ [IM_START_TOKEN],
125
+ np.tile(per_row, [resized_h]),
126
+ [IM_END_TOKEN],
127
+ ]
128
+ return np.concatenate(joint)
129
+ per_row = np.full(width, IMAGE_PATCH_TOKEN)
130
+ if self.image_use_col_tokens:
131
+ per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
132
+ joint = [
133
+ [IM_START_TOKEN],
134
+ np.tile(per_row, [height]),
135
+ [IM_END_TOKEN],
136
+ ]
137
+ per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
138
+ use_single_crop_col_tokens = (
139
+ self.image_use_col_tokens
140
+ if self.use_single_crop_col_tokens is None
141
+ else self.use_single_crop_col_tokens
142
+ )
143
+ image_start_token = (
144
+ LOW_RES_IMAGE_START_TOKEN
145
+ if self.use_single_crop_start_token
146
+ else IM_START_TOKEN
147
+ )
148
+ if use_single_crop_col_tokens:
149
+ per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
150
+ joint = [
151
+ [image_start_token],
152
+ np.tile(per_row, [resized_h]),
153
+ [IM_END_TOKEN],
154
+ ] + joint
155
+
156
+ return np.concatenate(joint)
157
+
158
+ def get_video_string(
159
+ self,
160
+ video_grid: np.ndarray,
161
+ timestamps: np.ndarray,
162
+ ):
163
+ if self.use_frame_special_tokens:
164
+ start_token_id = FRAME_START_TOKEN
165
+ end_token_id = FRAME_END_TOKEN
166
+ else:
167
+ start_token_id = IM_START_TOKEN
168
+ end_token_id = IM_END_TOKEN
169
+
170
+ num_frames, h, w = video_grid
171
+ video_string: str = ""
172
+ for frame_idx, frame_time in enumerate(timestamps):
173
+ # `per-frame-compact` time mode
174
+ prev_space = " " if frame_idx > 0 else ""
175
+ frame_prefix = prev_space + f"{frame_time:.1f} " # explicit whitespace before/after image tokens
176
+
177
+ video_string += frame_prefix
178
+ per_row = np.full(w, IMAGE_PATCH_TOKEN)
179
+ if self.video_use_col_tokens:
180
+ per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
181
+ extra_tokens = np.tile(per_row, [h])
182
+ video_tokens = [
183
+ [start_token_id],
184
+ extra_tokens,
185
+ [end_token_id],
186
+ ]
187
+ video_string += "".join(np.concatenate(video_tokens, 0))
188
+
189
+ return video_string
190
+
191
+ def insert_bos(
192
+ self,
193
+ input_ids: np.ndarray,
194
+ attention_mask: np.ndarray,
195
+ bos_token_id: int,
196
+ pad_token_id: int,
197
+ ):
198
+ """
199
+ Args:
200
+ input_ids: [B, S] array with left padding
201
+ attention_mask: [B, S] array (0 for pad, 1 for valid)
202
+ bos_token_id: int
203
+ pad_token_id: int
204
+ Returns:
205
+ input_ids_out: [B, S] or [B, S+1] array with bos inserted if needed
206
+ attention_mask_out: same shape as input_ids_out
207
+ """
208
+
209
+ need_to_expand = len(input_ids.shape) == 1
210
+ if need_to_expand:
211
+ input_ids = input_ids[None, :]
212
+ attention_mask = attention_mask[None, :]
213
+
214
+ B, S = input_ids.shape
215
+
216
+ # Handle zero-length sequence
217
+ if S == 0:
218
+ new_input_ids = np.full((B, 1), bos_token_id, dtype=input_ids.dtype)
219
+ new_attention_mask = np.ones((B, 1), dtype=attention_mask.dtype)
220
+ if need_to_expand:
221
+ new_input_ids = new_input_ids[0]
222
+ new_attention_mask = new_attention_mask[0]
223
+ return new_input_ids, new_attention_mask
224
+
225
+ first_valid_index = (attention_mask == 1).argmax(axis=-1) # [B]
226
+ bos_already_present = np.all(input_ids[np.arange(B), first_valid_index] == bos_token_id)
227
+
228
+ if bos_already_present:
229
+ if need_to_expand:
230
+ input_ids = input_ids[0]
231
+ attention_mask = attention_mask[0]
232
+ return input_ids, attention_mask
233
+ else:
234
+ new_input_ids = np.full((B, S+1), pad_token_id, dtype=input_ids.dtype)
235
+ new_attention_mask = np.zeros((B, S+1), dtype=attention_mask.dtype)
236
+
237
+ src_idx = np.tile(np.arange(S), (B, 1)) # [B, S]
238
+ valid_mask = src_idx >= first_valid_index[:, None] # [B, S]
239
+ tgt_idx = src_idx + 1 # shit right
240
+ batch_idx = np.tile(np.arange(B)[:, None], (1, S)) # [B, S]
241
+
242
+ # flatten valid_positions
243
+ flat_vals = input_ids[valid_mask]
244
+ flat_batch = batch_idx[valid_mask]
245
+ flat_tgt = tgt_idx[valid_mask]
246
+
247
+ new_input_ids[flat_batch, flat_tgt] = flat_vals
248
+ new_attention_mask[flat_batch, flat_tgt] = 1
249
+
250
+ insert_pos = first_valid_index
251
+ new_input_ids[np.arange(B), insert_pos] = bos_token_id
252
+ new_attention_mask[np.arange(B), insert_pos] = 1
253
+
254
+ if need_to_expand:
255
+ new_input_ids = new_input_ids[0]
256
+ new_attention_mask = new_attention_mask[0]
257
+
258
+ return new_input_ids, new_attention_mask
259
+
260
+ def __call__(
261
+ self,
262
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
263
+ images: ImageInput = None,
264
+ videos: VideoInput = None,
265
+ **kwargs: Unpack[MolmoAct2ProcessorKwargs],
266
+ ) -> BatchFeature:
267
+ """
268
+
269
+ Args:
270
+ text (`str`, `list[str]`, `list[list[str]]`):
271
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
272
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
273
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
274
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
275
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
276
+ tensor. Both channels-first and channels-last formats are supported.
277
+ videos (`dict[str, Any]` or `list[dict[str, Any]]`):
278
+ The video or batch of videos to be prepared. Each video can be a dictionary with the following keys:
279
+ - `"frames"`: `np.ndarray` of shape (T, H, W, 3)
280
+ - `"timestamps"`: `np.ndarray` of shape (T,)
281
+ - `"sampled_fps"`: `float` (optional)
282
+ - `"sampling_augmentation"`: `str` (optional)
283
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
284
+ If set, will return tensors of a particular framework. Acceptable values are:
285
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
286
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
287
+ - `'np'`: Return NumPy `np.ndarray` objects.
288
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
289
+
290
+ Returns:
291
+ `BatchFeature`: A [`BatchFeature`] with the following fields:
292
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
293
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
294
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`).
295
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
296
+ - **image_token_pooling** -- Indices of the patches in `image_grids` to pool for each token in `image_tokens`.
297
+ Returned when `images` is not `None`.
298
+ - **image_grids** -- Grids of images. Returned when `images` is not `None`.
299
+ - **image_num_crops** -- Number of crops for each image. Returned when `images` is not `None`.
300
+ - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
301
+ - **video_token_pooling** -- Indices of the patches in `video_grids` to pool for each token in `video_tokens`.
302
+ Returned when `videos` is not `None`.
303
+ - **video_grids** -- Grids of videos. Returned when `videos` is not `None`.
304
+ """
305
+
306
+ output_kwargs = self._merge_kwargs(
307
+ MolmoAct2ProcessorKwargs,
308
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
309
+ **kwargs,
310
+ )
311
+
312
+ if images is not None:
313
+ image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
314
+ image_grids = image_inputs["image_grids"]
315
+ else:
316
+ image_inputs = {}
317
+ image_grids = None
318
+
319
+ if videos is not None:
320
+ videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
321
+ video_grids = videos_inputs["video_grids"]
322
+ # If user has not requested video metadata, pop it
323
+ if "return_metadata" not in kwargs:
324
+ video_metadata = videos_inputs.pop("video_metadata")
325
+ else:
326
+ video_metadata = videos_inputs["video_metadata"]
327
+ else:
328
+ videos_inputs = {}
329
+ video_grids = None
330
+
331
+ if not isinstance(text, list):
332
+ text = [text]
333
+
334
+ text = text.copy() # below lines change text in-place
335
+
336
+ if image_grids is not None:
337
+ index = 0
338
+ for i in range(len(text)):
339
+ num_images = text[i].count(self.image_placeholder_token)
340
+ image_grids_i = image_grids[index:index+num_images]
341
+ for image_grid in image_grids_i:
342
+ image_tokens = self.get_image_tokens(image_grid)
343
+ image_string = "".join(image_tokens)
344
+ text[i] = text[i].replace(self.image_placeholder_token, image_string, 1)
345
+ index += num_images
346
+
347
+ if video_grids is not None:
348
+ index = 0
349
+ for i in range(len(text)):
350
+ num_videos = text[i].count(self.video_placeholder_token)
351
+ assert num_videos in {0, 1}, "At most one video is supported for now"
352
+ video_grids_i = video_grids[index:index+num_videos]
353
+ metadata_i = video_metadata[index:index+num_videos]
354
+ for video_grid, metadata in zip(video_grids_i, metadata_i):
355
+ video_string = self.get_video_string(
356
+ video_grid,
357
+ metadata.timestamps,
358
+ )
359
+ text[i] = text[i].replace(self.video_placeholder_token, video_string, 1)
360
+ index += num_videos
361
+
362
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
363
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
364
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
365
+
366
+ input_ids = text_inputs["input_ids"]
367
+ attention_mask = text_inputs["attention_mask"]
368
+
369
+ input_ids = np.array(input_ids)
370
+ attention_mask = np.array(attention_mask)
371
+
372
+ bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
373
+ input_ids, attention_mask = self.insert_bos(
374
+ input_ids, attention_mask, bos, self.tokenizer.pad_token_id
375
+ )
376
+
377
+ if return_mm_token_type_ids:
378
+ image_tokens = np.array(self.image_token_ids).astype(input_ids.dtype)
379
+ token_type_ids = np.any(input_ids[:, :, None] == image_tokens[None, None, :], axis=-1)
380
+ text_inputs["token_type_ids"] = token_type_ids.tolist()
381
+
382
+ text_inputs["input_ids"] = input_ids.tolist()
383
+ text_inputs["attention_mask"] = attention_mask.tolist()
384
+
385
+ return BatchFeature(
386
+ data={**text_inputs, **image_inputs, **videos_inputs},
387
+ tensor_type=return_tensors,
388
+ )
389
+
390
+ def post_process_image_text_to_text(
391
+ self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
392
+ ):
393
+ """
394
+ Post-process the output of the model to decode the text.
395
+
396
+ Args:
397
+ generated_outputs (`torch.Tensor` or `np.ndarray`):
398
+ The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
399
+ or `(sequence_length,)`.
400
+ skip_special_tokens (`bool`, *optional*, defaults to `True`):
401
+ Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
402
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
403
+ Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
404
+ **kwargs:
405
+ Additional arguments to be passed to the tokenizer's `batch_decode method`.
406
+
407
+ Returns:
408
+ `list[str]`: The decoded text.
409
+ """
410
+ return self.tokenizer.batch_decode(
411
+ generated_outputs,
412
+ skip_special_tokens=skip_special_tokens,
413
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
414
+ **kwargs,
415
+ )
416
+
417
+
418
+ MolmoAct2Processor.register_for_auto_class()
processor_config.json ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_molmoact2.MolmoAct2Processor"
4
+ },
5
+ "image_processor": {
6
+ "auto_map": {
7
+ "AutoImageProcessor": "image_processing_molmoact2.MolmoAct2ImageProcessor",
8
+ "AutoProcessor": "processing_molmoact2.MolmoAct2Processor"
9
+ },
10
+ "crop_mode": "resize",
11
+ "do_convert_rgb": true,
12
+ "image_mean": [
13
+ 0.5,
14
+ 0.5,
15
+ 0.5
16
+ ],
17
+ "image_processor_type": "MolmoAct2ImageProcessor",
18
+ "image_std": [
19
+ 0.5,
20
+ 0.5,
21
+ 0.5
22
+ ],
23
+ "max_crops": 8,
24
+ "overlap_margins": [
25
+ 4,
26
+ 4
27
+ ],
28
+ "patch_size": 14,
29
+ "pooling_size": [
30
+ 2,
31
+ 2
32
+ ],
33
+ "resample": 2,
34
+ "size": {
35
+ "height": 378,
36
+ "width": 378
37
+ }
38
+ },
39
+ "image_use_col_tokens": true,
40
+ "processor_class": "MolmoAct2Processor",
41
+ "use_frame_special_tokens": true,
42
+ "use_single_crop_col_tokens": false,
43
+ "use_single_crop_start_token": true,
44
+ "video_processor": {
45
+ "auto_map": {
46
+ "AutoProcessor": "processing_molmoact2.MolmoAct2Processor",
47
+ "AutoVideoProcessor": "video_processing_molmoact2.MolmoAct2VideoProcessor"
48
+ },
49
+ "data_format": "channels_first",
50
+ "default_to_square": true,
51
+ "do_convert_rgb": true,
52
+ "do_normalize": true,
53
+ "do_rescale": true,
54
+ "do_resize": true,
55
+ "do_sample_frames": true,
56
+ "frame_sample_mode": "uniform_last_frame",
57
+ "image_mean": [
58
+ 0.5,
59
+ 0.5,
60
+ 0.5
61
+ ],
62
+ "image_std": [
63
+ 0.5,
64
+ 0.5,
65
+ 0.5
66
+ ],
67
+ "max_fps": 2.0,
68
+ "num_frames": 8,
69
+ "patch_size": 14,
70
+ "pooling_size": [
71
+ 3,
72
+ 3
73
+ ],
74
+ "resample": 2,
75
+ "rescale_factor": 0.00392156862745098,
76
+ "return_metadata": false,
77
+ "sampling_fps": 2,
78
+ "size": {
79
+ "height": 378,
80
+ "width": 378
81
+ },
82
+ "video_processor_type": "MolmoAct2VideoProcessor"
83
+ },
84
+ "video_use_col_tokens": false
85
+ }
quantization_metadata.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "source_repo": "allenai/MolmoAct2-SO100_101",
3
+ "source_revision": "152569fe57914d97be91055800035f54e250d009",
4
+ "policy_class": "transformers:AutoModelForImageTextToText",
5
+ "quantization": {
6
+ "scheme": "nf4",
7
+ "backend": "bitsandbytes",
8
+ "compute_dtype": "bfloat16",
9
+ "min_params_to_quantize": 4000000,
10
+ "rule": "Linear modules with >=4_000_000 weight elements rewritten to bnb.nn.Linear4bit; smaller heads kept in compute_dtype (bfloat16).",
11
+ "runtime_status": "loader-backed (install_prequantized_linears)"
12
+ },
13
+ "dropped_state_entries": []
14
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5395aefc9b1b7f0385d8c86a2f1775e5af81bdfbf9f2d97827ea37921d9f862
3
+ size 11983605
tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "auto_map": {
4
+ "AutoProcessor": "processing_molmoact2.MolmoAct2Processor"
5
+ },
6
+ "backend": "tokenizers",
7
+ "bos_token": "<|im_end|>",
8
+ "clean_up_tokenization_spaces": false,
9
+ "eos_token": "<|im_end|>",
10
+ "errors": "replace",
11
+ "extra_special_tokens": [
12
+ "<im_start>",
13
+ "<im_end>",
14
+ "<im_patch>",
15
+ "<im_col>",
16
+ "<low_res_im_start>",
17
+ "<|image|>",
18
+ "<im_low>",
19
+ "<frame_start>",
20
+ "<frame_end>",
21
+ "<|video|>",
22
+ "<|points|>",
23
+ "<|token_index|>",
24
+ "<|vit_index|>",
25
+ "<|vit_loc|>"
26
+ ],
27
+ "is_local": false,
28
+ "model_max_length": 1010000,
29
+ "pad_token": "<|endoftext|>",
30
+ "processor_class": "MolmoAct2Processor",
31
+ "split_special_tokens": false,
32
+ "tokenizer_class": "Qwen2Tokenizer",
33
+ "unk_token": null
34
+ }
video_processing_molmoact2.py ADDED
@@ -0,0 +1,969 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Video processor class for MolmoAct2"""
2
+ from functools import partial
3
+ import os
4
+ import warnings
5
+ from contextlib import redirect_stdout
6
+ from io import BytesIO
7
+ from urllib.parse import urlparse
8
+ from typing import Optional, Union, Callable
9
+
10
+ import numpy as np
11
+ import requests
12
+ import einops
13
+ import torch
14
+ import torchvision.transforms
15
+
16
+ from transformers.image_utils import (
17
+ IMAGENET_STANDARD_MEAN,
18
+ IMAGENET_STANDARD_STD,
19
+ ImageInput,
20
+ PILImageResampling,
21
+ SizeDict,
22
+ validate_kwargs,
23
+ )
24
+ from transformers.video_utils import (
25
+ VideoInput,
26
+ is_valid_video,
27
+ make_batched_videos,
28
+ make_batched_metadata,
29
+ VideoMetadata,
30
+ )
31
+ from transformers.processing_utils import Unpack, VideosKwargs
32
+ from transformers.video_processing_utils import BaseVideoProcessor
33
+ from transformers.utils import logging
34
+ from transformers.feature_extraction_utils import BatchFeature
35
+ from transformers.utils import (
36
+ is_av_available,
37
+ is_decord_available,
38
+ is_torchcodec_available,
39
+ is_yt_dlp_available,
40
+ TensorType,
41
+ logging,
42
+ to_numpy,
43
+ )
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+ MAX_VIDEO_FPS = 8
49
+
50
+
51
+ def normalize_image(
52
+ image: np.ndarray,
53
+ image_mean: list[float],
54
+ image_std: list[float],
55
+ ) -> np.ndarray:
56
+ if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
57
+ return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
58
+ image -= np.array(image_mean, dtype=np.float32)[None, None, :]
59
+ image /= np.array(image_std, dtype=np.float32)[None, None, :]
60
+ return image
61
+
62
+
63
+ def resize_image(
64
+ image: np.ndarray,
65
+ desired_output_size: list[int],
66
+ resample: PILImageResampling,
67
+ ) -> np.ndarray:
68
+ if len(image.shape) == 3:
69
+ is_video = False
70
+ image = torch.permute(torch.from_numpy(image), [2, 0, 1])
71
+ else:
72
+ is_video = True
73
+ image = torch.permute(torch.from_numpy(image), [0, 3, 1, 2])
74
+ dtype = image.dtype
75
+ if torch.is_floating_point(image):
76
+ in_min = 0.0
77
+ in_max = 1.0
78
+ resized = torchvision.transforms.Resize(
79
+ desired_output_size,
80
+ resample,
81
+ antialias=False,
82
+ )(image)
83
+ resized = torch.clip(resized, 0.0, 1.0).to(dtype)
84
+ else:
85
+ assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(image.dtype)
86
+ in_min = 0.0
87
+ in_max = 255.0
88
+ resized = torchvision.transforms.Resize(
89
+ desired_output_size,
90
+ resample,
91
+ antialias=False,
92
+ )(image)
93
+ resized = torch.clip(resized, 0, 255).to(dtype)
94
+
95
+ resized = resized.to(torch.float32)
96
+ resized = (resized - in_min) / (in_max - in_min)
97
+
98
+ if is_video:
99
+ resized = torch.permute(resized, [0, 2, 3, 1]).numpy()
100
+ else:
101
+ resized = torch.permute(resized, [1, 2, 0]).numpy()
102
+
103
+ return resized
104
+
105
+
106
+ def build_resized_image(
107
+ image: np.ndarray,
108
+ base_image_input_size: list[int],
109
+ resample: PILImageResampling,
110
+ image_mean: list[float],
111
+ image_std: list[float],
112
+ image_patch_size: int,
113
+ ) -> tuple[np.ndarray, np.ndarray]:
114
+ resized = resize_image(
115
+ image, base_image_input_size, resample,
116
+ )
117
+ resized = normalize_image(resized, image_mean, image_std)
118
+ if len(resized.shape) == 3:
119
+ resized = np.expand_dims(resized, 0)
120
+ crop_patch_w = base_image_input_size[1] // image_patch_size
121
+ crop_patch_h = base_image_input_size[0] // image_patch_size
122
+ resize_idx = np.arange(crop_patch_w*crop_patch_h).reshape([crop_patch_h, crop_patch_w])
123
+ return resized, resize_idx
124
+
125
+
126
+ def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
127
+ """Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
128
+ if len(array.shape) == 3:
129
+ n_crops, h, w = array.shape
130
+ h_patches = h//patch_size
131
+ w_patches = w//patch_size
132
+ array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
133
+ array = np.transpose(array, [0, 1, 3, 2, 4])
134
+ array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size])
135
+ return array
136
+ else:
137
+ n_crops, h, w, c = array.shape
138
+ h_patches = h//patch_size
139
+ w_patches = w//patch_size
140
+ array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
141
+ array = np.transpose(array, [0, 1, 3, 2, 4, 5])
142
+ array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size*c])
143
+ return array
144
+
145
+
146
+ def arange_for_pooling(
147
+ idx_arr: np.ndarray,
148
+ pool_h: int,
149
+ pool_w: int,
150
+ ) -> np.ndarray:
151
+ h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
152
+ w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
153
+ idx_arr = np.pad(idx_arr, [[h_pad//2, (h_pad+1)//2], [w_pad//2, (w_pad+1)//2]],
154
+ mode='constant',constant_values=-1)
155
+ return einops.rearrange(
156
+ idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
157
+
158
+
159
+ def image_to_patches_and_grids(
160
+ image: ImageInput,
161
+ base_image_input_size: list[int],
162
+ resample: PILImageResampling,
163
+ image_mean: list[float],
164
+ image_std: list[float],
165
+ image_patch_size: int,
166
+ image_pooling_w: int,
167
+ image_pooling_h: int,
168
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
169
+ """
170
+ :return image_grids, the shape of each image after pooling
171
+ :return crops, the image crops to processes with the ViT
172
+ :return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
173
+ patches in `crops` to pool for that token, masked with -1
174
+ """
175
+ if isinstance(base_image_input_size, int):
176
+ base_image_input_size = (base_image_input_size, base_image_input_size)
177
+
178
+ pooling_w = image_pooling_w
179
+ pooling_h = image_pooling_h
180
+
181
+ resized, resize_idx = build_resized_image(
182
+ image,
183
+ base_image_input_size,
184
+ resample,
185
+ image_mean,
186
+ image_std,
187
+ image_patch_size,
188
+ )
189
+ pooling_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
190
+ h, w = pooling_idx.shape[:2]
191
+ pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w])
192
+ image_grid = [h, w]
193
+ return (
194
+ image_grid,
195
+ batch_pixels_to_patches(resized, image_patch_size),
196
+ pooling_idx,
197
+ )
198
+
199
+
200
+ def get_candidate_target_fps(
201
+ video_fps: Union[int, float],
202
+ sampling_fps: Union[int, float],
203
+ max_fps: Union[int, float] = MAX_VIDEO_FPS,
204
+ ) -> list[float]:
205
+ """
206
+ Return the subset of `video_fps` factors that remain multiples of `sampling_fps`.
207
+
208
+ Examples:
209
+ >>> get_candidate_target_fps(video_fps=6, sampling_fps=2)
210
+ [2, 6]
211
+ >>> get_candidate_target_fps(video_fps=5, sampling_fps=1)
212
+ [1, 5]
213
+ >>> get_candidate_target_fps(video_fps=2, sampling_fps=2)
214
+ [2]
215
+ >>> get_candidate_target_fps(video_fps=5, sampling_fps=2)
216
+ Traceback (most recent call last):
217
+ ...
218
+ ValueError: sampling_fps=2 must divide video_fps=5 to produce consistent frame steps.
219
+ """
220
+ video_fps = int(video_fps)
221
+ sampling_fps = int(sampling_fps)
222
+ max_fps = int(max_fps)
223
+
224
+ if sampling_fps is None:
225
+ raise ValueError("sampling_fps must be provided")
226
+ if video_fps <= 0 or sampling_fps <= 0:
227
+ raise ValueError(f"video_fps and sampling_fps must be positive (got {video_fps}, {sampling_fps})")
228
+ if video_fps % sampling_fps != 0:
229
+ raise ValueError(f"sampling_fps={sampling_fps} must divide video_fps={video_fps}.")
230
+
231
+ candidates = []
232
+ for candidate in range(sampling_fps, video_fps + 1, sampling_fps):
233
+ if candidate > max_fps:
234
+ break
235
+ if video_fps % candidate == 0:
236
+ candidates.append(float(candidate))
237
+
238
+ return candidates
239
+
240
+
241
+ def read_video_decord(
242
+ video_path,
243
+ sample_timestamps_fn: Callable,
244
+ **kwargs,
245
+ ) -> np.ndarray:
246
+ """
247
+ Decode a video using the Decord backend.
248
+
249
+ Args:
250
+ video_path (`str`):
251
+ Path to the video file.
252
+ sample_timestamps_fn (`Callable`):
253
+ A callable function that will return timestamps at which the video should be sampled.
254
+
255
+ Returns:
256
+ tuple[`np.array`, `VideoMetadata`]: A tuple containing:
257
+ - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
258
+ - `VideoMetadata` object.
259
+ """
260
+ # Lazy import from decord
261
+ import importlib
262
+ decord = importlib.import_module("decord")
263
+
264
+ vr = decord.VideoReader(uri=video_path, ctx=decord.cpu(0)) # decord has problems with gpu
265
+ video_fps = vr.get_avg_fps()
266
+ total_num_frames = len(vr)
267
+ time_stamps = vr.get_frame_timestamp(list(range(len(vr))))
268
+ duration = time_stamps[-1][1] - time_stamps[0][0]
269
+
270
+ metadata = VideoMetadata(
271
+ total_num_frames=int(total_num_frames),
272
+ fps=float(video_fps),
273
+ duration=float(duration),
274
+ video_backend="decord",
275
+ )
276
+
277
+ target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
278
+ target_timestamps = np.array(target_timestamps)
279
+ offset = time_stamps[0, 0]
280
+
281
+ ix = np.searchsorted(time_stamps[:, 1], target_timestamps + offset, side='right')
282
+ ix = np.minimum(ix, len(time_stamps) - 1)
283
+
284
+ video = vr.get_batch(ix).asnumpy()
285
+ metadata.update(
286
+ {
287
+ "frames_indices": target_timestamps * video_fps,
288
+ "height": video.shape[1],
289
+ "width": video.shape[2],
290
+ }
291
+ )
292
+ return video, metadata
293
+
294
+
295
+ def read_video_torchcodec(
296
+ video_path,
297
+ sample_timestamps_fn: Callable,
298
+ **kwargs,
299
+ ) -> np.ndarray:
300
+ """
301
+ Decode a video using torchcodec decoder.
302
+
303
+ Args:
304
+ video_path (`str`):
305
+ Path to the video file.
306
+ sample_timestamps_fn (`Callable`):
307
+ A callable function that will return timestamps at which the video should be sampled.
308
+
309
+ Returns:
310
+ tuple[`np.array`, `VideoMetadata`]: A tuple containing:
311
+ - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
312
+ - `VideoMetadata` object.
313
+ """
314
+ # Lazy import torchcodec
315
+ import importlib
316
+ torchcodec = importlib.import_module("torchcodec")
317
+
318
+ decoder = torchcodec.decoders.VideoDecoder(
319
+ video_path,
320
+ # Interestingly `exact` mode takes less than approximate when we load the whole video
321
+ seek_mode="exact",
322
+ # Allow FFmpeg decide on the number of threads for efficiency
323
+ num_ffmpeg_threads=0,
324
+ )
325
+ # If the first frame starts at > 0, we effectively clip the video starting at that time
326
+ # since (most) video players would also skip to that time
327
+ time_offset = decoder.metadata.begin_stream_seconds_from_content
328
+ # Note this duration does assume we started playing at `time_offset`
329
+ duration = decoder.metadata.duration_seconds
330
+
331
+ metadata = VideoMetadata(
332
+ total_num_frames=decoder.metadata.num_frames,
333
+ fps=decoder.metadata.average_fps,
334
+ duration=duration,
335
+ video_backend="torchcodec",
336
+ height=decoder.metadata.height,
337
+ width=decoder.metadata.width,
338
+ )
339
+
340
+ target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
341
+
342
+ # Floating point/rounding issues might cause `target_timestamps` to be very slightly
343
+ # out-of-bounds, to handle this we sanity check then clip them
344
+ assert all(x >= 0 for x in target_timestamps)
345
+ assert all(x < duration+1e-6 for x in target_timestamps)
346
+ # 1e-6 padding since torchcodec can throw out-of-bounds errors even if you ask for the
347
+ # exact boundary value, we should still get the first/last frame anyway
348
+ max_timestamp = decoder.metadata.end_stream_seconds_from_content - 1e-6
349
+ min_timestamp = decoder.metadata.begin_stream_seconds_from_content + 1e-6
350
+ # Note we avoid using numpy ops here to reduce floating precision issues
351
+ timestamps = [x + time_offset for x in target_timestamps]
352
+ timestamps = [max(min_timestamp, min(max_timestamp, x)) for x in timestamps]
353
+
354
+ video = decoder.get_frames_played_at(timestamps).data.numpy().transpose(0, 2, 3, 1) # Convert to THWC format
355
+ target_timestamps = np.array(target_timestamps)
356
+ metadata.frames_indices = target_timestamps * metadata.fps
357
+
358
+ return video, metadata
359
+
360
+
361
+ def read_video_pyav(
362
+ video_path,
363
+ sample_timestamps_fn: Callable,
364
+ **kwargs,
365
+ ) -> np.ndarray:
366
+ """
367
+ Decode a video using the PyAV backend.
368
+
369
+ Args:
370
+ video_path (`str`):
371
+ Path to the video file.
372
+ sample_timestamps_fn (`Callable`):
373
+ A callable function that will return timestamps at which the video should be sampled.
374
+
375
+ Returns:
376
+ tuple[`np.array`, `VideoMetadata`]: A tuple containing:
377
+ - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
378
+ - `VideoMetadata` object.
379
+ """
380
+ # Lazy import torchcodec
381
+ import importlib
382
+ av = importlib.import_module("av")
383
+
384
+ with av.open(video_path) as container:
385
+ video_stream = container.streams.video[0]
386
+ fps = video_stream.average_rate or video_stream.guessed_rate
387
+ it = container.decode(video=0)
388
+ frames = list(it)
389
+
390
+ stream = container.streams.video[0]
391
+ start = frames[0].pts * stream.time_base
392
+ container_end = stream.duration
393
+ if container_end is not None:
394
+ container_end *= stream.time_base
395
+ if container_end is None or container_end < frames[-1].pts:
396
+ # Some problem with stream duration, so use the frame PTS directly
397
+ # and guess the duration of the last frame
398
+ end = frames[-1].pts * stream.time_base + 1/fps
399
+ else:
400
+ end = container_end
401
+ duration = float(end - start)
402
+
403
+ metadata = VideoMetadata(
404
+ total_num_frames=len(frames),
405
+ fps=float(fps),
406
+ duration=float(duration),
407
+ video_backend="pyav",
408
+ height=video_stream.height,
409
+ width=video_stream.width,
410
+ )
411
+
412
+ target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
413
+ offset = float(start)
414
+
415
+ target_timestamps = np.array(target_timestamps)
416
+ end_time_stamps = np.array([float(frame.pts * stream.time_base) for frame in frames[1:]] + [duration])
417
+ indices = np.searchsorted(end_time_stamps, target_timestamps + offset, side='right')
418
+ indices = np.minimum(indices, len(end_time_stamps) - 1)
419
+
420
+ video = np.stack(
421
+ [frames[i].to_ndarray(format="rgb24", channel_last=True) for i in indices],
422
+ axis=0,
423
+ )
424
+
425
+ metadata.frames_indices = target_timestamps * fps
426
+
427
+ return video, metadata
428
+
429
+
430
+ VIDEO_DECODERS = {
431
+ "decord": read_video_decord,
432
+ "torchcodec": read_video_torchcodec,
433
+ "pyav": read_video_pyav,
434
+ }
435
+
436
+
437
+ def load_video(
438
+ video: VideoInput,
439
+ backend: str = "decord",
440
+ sample_timestamps_fn: Optional[Callable] = None,
441
+ **kwargs,
442
+ ):
443
+ """
444
+ Loads `video` to a numpy array.
445
+
446
+ Args:
447
+ video (`VideoInput`):
448
+ The video to convert to the numpy array format. Can be a link to video or local path.
449
+ backend (`str`, *optional*, defaults to `"decord"`):
450
+ The backend to use when loading the video. Can be any of ["decord", "pyav", ""torchcodec"]. Defaults to "decord".
451
+ sample_timestamps_fn (`Callable`):
452
+ A callable function that will return timestamps at which the video should be sampled.
453
+ """
454
+
455
+ # Early exit if provided an array or `PIL` frames
456
+ if not isinstance(video, str):
457
+ metadata = [None] * len(video)
458
+ return video, metadata
459
+
460
+ if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]:
461
+ if not is_yt_dlp_available():
462
+ raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
463
+ # Lazy import from yt_dlp
464
+ import importlib
465
+ yt_dlp = importlib.import_module("yt_dlp")
466
+
467
+ buffer = BytesIO()
468
+ with redirect_stdout(buffer), yt_dlp.YoutubeDL() as f:
469
+ f.download([video])
470
+ bytes_obj = buffer.getvalue()
471
+ file_obj = BytesIO(bytes_obj)
472
+ elif video.startswith("http://") or video.startswith("https://"):
473
+ file_obj = BytesIO(requests.get(video).content)
474
+ elif os.path.isfile(video):
475
+ file_obj = video
476
+ else:
477
+ raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.")
478
+
479
+ # can also load with decord, but not cv2/torchvision
480
+ # both will fail in case of url links
481
+ video_is_url = video.startswith("http://") or video.startswith("https://")
482
+ if video_is_url and backend == "opencv":
483
+ raise ValueError("If you are trying to load a video from URL, you cannot use 'opencv' as backend")
484
+
485
+ if (
486
+ (not is_decord_available() and backend == "decord")
487
+ or (not is_torchcodec_available() and backend == "torchcodec")
488
+ or (not is_av_available() and backend == "pyav")
489
+ ):
490
+ raise ImportError(
491
+ f"You chose backend={backend} for loading the video but the required library is not found in your environment "
492
+ f"Make sure to install {backend} before loading the video."
493
+ )
494
+
495
+ video_decoder = VIDEO_DECODERS[backend]
496
+ video, metadata = video_decoder(file_obj, sample_timestamps_fn, **kwargs)
497
+ return video, metadata
498
+
499
+
500
+ def get_target_fps(
501
+ video_fps: float,
502
+ max_frames: int,
503
+ total_frames: int,
504
+ frame_sample_mode: str,
505
+ candidate_target_fps: tuple[float],
506
+ ) -> float:
507
+ """
508
+ Get the target fps that best spans the video and has the most frames sampled
509
+ """
510
+ num_frames_sampled = 0
511
+ selected_target_fps = None
512
+ for target_fps in candidate_target_fps:
513
+ step_size = max(int(video_fps / target_fps), 1)
514
+ num_frames_sampled_at_fps = int(total_frames / step_size)
515
+ if num_frames_sampled == 0:
516
+ if "uniform" in frame_sample_mode:
517
+ if num_frames_sampled_at_fps > max_frames:
518
+ break
519
+ selected_target_fps = target_fps
520
+ num_frames_sampled = num_frames_sampled_at_fps
521
+
522
+ else:
523
+ # the candidate sampling fps increases so frame count can't decrease
524
+ assert num_frames_sampled <= num_frames_sampled_at_fps
525
+ if num_frames_sampled_at_fps > max_frames:
526
+ # choose the sampling fps that spans the video
527
+ continue
528
+
529
+ elif num_frames_sampled_at_fps > num_frames_sampled:
530
+ # both are less than max_frames, choose the one with higher density of frames sampled
531
+ selected_target_fps = target_fps
532
+ num_frames_sampled = num_frames_sampled_at_fps
533
+ return selected_target_fps
534
+
535
+
536
+ def get_frame_times_and_chosen_fps(
537
+ selected_target_fps,
538
+ total_frames,
539
+ max_frames,
540
+ video_fps
541
+ ):
542
+ if selected_target_fps is None:
543
+ frame_indices = np.linspace(0, total_frames, max_frames, endpoint=False, dtype=int)
544
+ else:
545
+ step_size = max(int(video_fps / selected_target_fps), 1)
546
+ frame_indices = np.arange(0, total_frames, step_size)
547
+ if len(frame_indices) > max_frames:
548
+ frame_indices = frame_indices[:max_frames]
549
+ return selected_target_fps, frame_indices
550
+
551
+
552
+ class MolmoAct2VideoProcessorKwargs(VideosKwargs, total=False):
553
+ patch_size: Optional[int]
554
+ pooling_size: Optional[list[int]]
555
+ frame_sample_mode: Optional[str]
556
+ max_fps: Optional[int]
557
+ sampling_fps: Optional[int]
558
+
559
+
560
+ class MolmoAct2VideoProcessor(BaseVideoProcessor):
561
+ resample = PILImageResampling.BILINEAR
562
+ size = {"height": 378, "width": 378}
563
+ image_mean = IMAGENET_STANDARD_MEAN
564
+ image_std = IMAGENET_STANDARD_STD
565
+ do_resize = True
566
+ do_rescale = True
567
+ do_normalize = True
568
+ do_convert_rgb = True
569
+ patch_size = 14
570
+ pooling_size = [3, 3]
571
+ do_sample_frames = True
572
+ frame_sample_mode = "uniform_last_frame"
573
+ max_fps = 2
574
+ sampling_fps = 2
575
+ valid_kwargs = MolmoAct2VideoProcessorKwargs
576
+ model_input_names = ["pixel_values_videos", "video_token_pooling", "video_grids"]
577
+
578
+ def __init__(self, **kwargs: Unpack[MolmoAct2VideoProcessorKwargs]):
579
+ super().__init__(**kwargs)
580
+ if self.size is not None and (
581
+ self.size.get("height", None) is None or self.size.get("width", None) is None
582
+ ):
583
+ raise ValueError("size must contain 'height' and 'width' keys.")
584
+
585
+ def _further_process_kwargs(
586
+ self,
587
+ size: Optional[SizeDict] = None,
588
+ **kwargs,
589
+ ) -> dict:
590
+ """
591
+ Update kwargs that need further processing before being validated
592
+ Can be overridden by subclasses to customize the processing of kwargs.
593
+ """
594
+ if size is not None and ("height" not in size or "width" not in size):
595
+ raise ValueError("size must contain 'height' and 'width' keys.")
596
+
597
+ return super()._further_process_kwargs(size=size, **kwargs)
598
+
599
+ def sample_times(
600
+ self,
601
+ metadata: VideoMetadata,
602
+ frame_sample_mode: str,
603
+ num_frames: int,
604
+ max_fps: Optional[int] = None,
605
+ sampling_fps: Optional[int] = None,
606
+ **kwargs,
607
+ ) -> np.ndarray:
608
+ """
609
+ Time-based sampling if an array video is passed
610
+ Args:
611
+ metadata (`VideoMetadata`):
612
+ Metadata of the video containing information about total duration, fps and total number of frames.
613
+ frame_sample_mode (`str`, *optional*):
614
+ Mode to sample frames. Defaults to `self.frame_sample_mode`.
615
+ num_frames (`int`, *optional*):
616
+ Maximum number of frames to sample. Defaults to `self.num_frames`.
617
+ man_fps (`int`, *optional*):
618
+ Maximum frames per second to sample.
619
+ sampling_fps (`int`, *optional*):
620
+ Sampling frames per second. Defaults to `self.sampling_fps`.
621
+ Used when `frame_sample_mode` is `"fps"`.
622
+ """
623
+ frame_sample_mode = frame_sample_mode or self.frame_sample_mode
624
+ num_frames = num_frames or self.num_frames
625
+ sampling_fps = sampling_fps or self.sampling_fps
626
+
627
+ duration = metadata.duration or metadata.total_num_frames / metadata.fps
628
+ if frame_sample_mode == "fps":
629
+ candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
630
+ # Try larger and larger FPSs until we hit one that can't span the video
631
+ target_fps = candidate_target_fps[0]
632
+ for candidate_fps in candidate_target_fps[1:]:
633
+ if num_frames / candidate_fps < duration:
634
+ break
635
+ target_fps = candidate_fps
636
+ times = np.arange(0, num_frames) / target_fps
637
+ times = times[times < duration]
638
+ return times
639
+ elif frame_sample_mode == "uniform_last_frame":
640
+ if max_fps is not None:
641
+ max_duration = (num_frames-1) / max_fps # -1 to include the last frame
642
+ if max_duration < duration:
643
+ times = np.linspace(
644
+ 0, duration, num=num_frames, endpoint=True, dtype=np.float64
645
+ )
646
+ else:
647
+ times = np.arange(0.0, stop=duration, step=1/max_fps)
648
+ times = np.concatenate([times, [duration]], axis=0)
649
+ assert len(times) <= num_frames
650
+ else:
651
+ times = np.linspace(
652
+ 0, duration, num=num_frames, endpoint=True, dtype=np.float64
653
+ )
654
+ return times
655
+ else:
656
+ raise NotImplementedError(frame_sample_mode)
657
+
658
+ def sample_frames(
659
+ self,
660
+ metadata: VideoMetadata,
661
+ frame_sample_mode: Optional[str] = None,
662
+ num_frames: Optional[int] = None,
663
+ max_fps: Optional[int] = None,
664
+ sampling_fps: Optional[int] = None,
665
+ **kwargs,
666
+ ) -> np.ndarray:
667
+ """
668
+ Frame-based sampling if an array video is passed
669
+ Args:
670
+ metadata (`VideoMetadata`):
671
+ Metadata of the video containing information about total duration, fps and total number of frames.
672
+ frame_sample_mode (`str`, *optional*):
673
+ Mode to sample frames. Defaults to `self.frame_sample_mode`.
674
+ num_frames (`int`, *optional*):
675
+ Maximum number of frames to sample. Defaults to `self.num_frames`.
676
+ max_fps (`int`, *optional*):
677
+ Maximum frames per second to sample.
678
+ sampling_fps (`int`, *optional*):
679
+ Sampling frames per second. Defaults to `self.sampling_fps`.
680
+ Used when `frame_sample_mode` is `"fps"`.
681
+ """
682
+ frame_sample_mode = frame_sample_mode or self.frame_sample_mode
683
+ num_frames = num_frames or self.num_frames
684
+ sampling_fps = sampling_fps or self.sampling_fps
685
+
686
+ total_num_frames = metadata.total_num_frames
687
+ if frame_sample_mode == "uniform_last_frame" and max_fps is not None:
688
+ duration = total_num_frames / metadata.fps
689
+ if total_num_frames <= 2:
690
+ return np.arange(total_num_frames).astype(int)
691
+ if duration > (num_frames - 1) / max_fps: # -1 to include the last frame
692
+ # uniform fallback
693
+ indices = np.linspace(
694
+ 0,
695
+ total_num_frames - 1,
696
+ num=min(num_frames, total_num_frames),
697
+ endpoint=True,
698
+ ).astype(int)
699
+ return indices
700
+ else:
701
+ float_indices = np.arange(
702
+ 0.0, stop=total_num_frames - 1, step=float(metadata.fps / max_fps),
703
+ )
704
+ if np.round(float_indices[-1]) != total_num_frames - 1:
705
+ float_indices = np.concatenate([float_indices, [total_num_frames - 1]], axis=0)
706
+ indices = np.round(float_indices).astype(int)
707
+ assert indices[-1] < total_num_frames
708
+ assert len(float_indices) <= num_frames
709
+ return indices
710
+ elif frame_sample_mode == "uniform_last_frame":
711
+ indices = np.linspace(
712
+ 0, total_num_frames - 1, num=min(num_frames, total_num_frames), endpoint=True,
713
+ ).astype(int)
714
+ return indices
715
+ elif frame_sample_mode == "fps":
716
+ candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
717
+ selected_target_fps = get_target_fps(
718
+ metadata.fps,
719
+ num_frames,
720
+ total_num_frames,
721
+ frame_sample_mode,
722
+ candidate_target_fps,
723
+ )
724
+ _, indices = get_frame_times_and_chosen_fps(
725
+ selected_target_fps,
726
+ total_num_frames,
727
+ num_frames,
728
+ metadata.fps,
729
+ )
730
+ return indices
731
+ else:
732
+ raise NotImplementedError(frame_sample_mode)
733
+
734
+ def fetch_videos(
735
+ self,
736
+ video_url_or_urls: Union[str, list[str], list[list[str]]],
737
+ sample_timestamps_fn=None
738
+ ):
739
+ """
740
+ Convert a single or a list of urls into the corresponding `np.array` objects.
741
+
742
+ If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
743
+ returned.
744
+ """
745
+ if (
746
+ (not is_decord_available())
747
+ and (not is_torchcodec_available())
748
+ and (not is_av_available())
749
+ ):
750
+ raise ImportError(
751
+ "MolmoAct2VideoProcessor requires `decord`, `torchcodec`, or `av` to be installed."
752
+ )
753
+
754
+ if is_decord_available():
755
+ backend = "decord"
756
+ elif is_torchcodec_available():
757
+ warnings.warn(
758
+ "`decord` is not installed and cannot be used to decode the video by default. "
759
+ "Falling back to `torchcodec`."
760
+ )
761
+ backend = "torchcodec"
762
+ else:
763
+ warnings.warn(
764
+ "`decord` is not installed and cannot be used to decode the video by default. "
765
+ "Falling back to `PyAV`."
766
+ )
767
+ backend = "pyav"
768
+
769
+ if isinstance(video_url_or_urls, list):
770
+ return list(zip(*[self.fetch_videos(x, sample_timestamps_fn=sample_timestamps_fn) for x in video_url_or_urls]))
771
+ else:
772
+ return load_video(video_url_or_urls, backend=backend, sample_timestamps_fn=sample_timestamps_fn)
773
+
774
+ def _decode_and_sample_videos(
775
+ self,
776
+ videos: VideoInput,
777
+ video_metadata: Union[VideoMetadata, dict],
778
+ do_sample_frames: Optional[bool] = None,
779
+ sample_indices_fn: Optional[Callable] = None,
780
+ sample_timestamps_fn: Optional[Callable] = None,
781
+ ):
782
+ """
783
+ Decode input videos and sample frames if needed.
784
+ """
785
+ videos = make_batched_videos(videos)
786
+ video_metadata = make_batched_metadata(videos, video_metadata=video_metadata)
787
+
788
+ # Framed-based sampling if an array video is passed
789
+ # Otherwise, time-based sampling with decoding
790
+ if is_valid_video(videos[0]) and do_sample_frames:
791
+ assert video_metadata[0].fps is not None, "FPS must be provided for video input"
792
+ sampled_videos = []
793
+ sampled_metadata = []
794
+ for video, metadata in zip(videos, video_metadata):
795
+ indices = sample_indices_fn(metadata=metadata)
796
+ metadata.frames_indices = indices
797
+ sampled_videos.append(video[indices])
798
+ sampled_metadata.append(metadata)
799
+ videos = sampled_videos
800
+ video_metadata = sampled_metadata
801
+ elif not is_valid_video(videos[0]):
802
+ if sample_indices_fn is None:
803
+ logger.warning(
804
+ "do_sample_frames is False, but video array is not provided: "
805
+ "Will decode the video and sample frames using MolmoAct2's default sampling mode"
806
+ )
807
+ if isinstance(videos[0], list):
808
+ raise ValueError(
809
+ "A list of images is not supported for video input!"
810
+ )
811
+ else:
812
+ videos, video_metadata = self.fetch_videos(videos, sample_timestamps_fn=sample_timestamps_fn)
813
+
814
+ return videos, video_metadata
815
+
816
+ def _prepare_input_videos(
817
+ self,
818
+ videos: VideoInput,
819
+ **kwargs,
820
+ ) -> list[np.ndarray]:
821
+ processed_videos = [to_numpy(video) for video in videos]
822
+ return processed_videos
823
+
824
+ def preprocess(
825
+ self,
826
+ videos: VideoInput,
827
+ **kwargs: Unpack[MolmoAct2VideoProcessorKwargs],
828
+ ) -> BatchFeature:
829
+ validate_kwargs(
830
+ captured_kwargs=kwargs.keys(),
831
+ valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) + ["return_tensors"],
832
+ )
833
+
834
+ # Set default kwargs from self. This ensures that if a kwarg is not provided
835
+ # by the user, it gets its default value from the instance, or is set to None.
836
+ for kwarg_name in self.valid_kwargs.__annotations__:
837
+ kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
838
+
839
+ do_sample_frames = kwargs.pop("do_sample_frames")
840
+ video_metadata = kwargs.pop("video_metadata")
841
+
842
+ sample_indices_fn = partial(self.sample_frames, **kwargs) if do_sample_frames else None
843
+ sample_timestamps_fn = partial(self.sample_times, **kwargs)
844
+ videos, video_metadata = self._decode_and_sample_videos(
845
+ videos,
846
+ video_metadata=video_metadata,
847
+ do_sample_frames=do_sample_frames,
848
+ sample_indices_fn=sample_indices_fn,
849
+ sample_timestamps_fn=sample_timestamps_fn,
850
+ )
851
+ videos = self._prepare_input_videos(videos=videos)
852
+
853
+ kwargs = self._further_process_kwargs(**kwargs)
854
+
855
+ return_metadata = kwargs.pop("return_metadata")
856
+ preprocessed_videos = self._preprocess(videos=videos, **kwargs)
857
+ if return_metadata:
858
+ preprocessed_videos["video_metadata"] = video_metadata
859
+ return preprocessed_videos
860
+
861
+ def _preprocess(
862
+ self,
863
+ videos: list[np.ndarray],
864
+ size: Optional[SizeDict] = None,
865
+ resample: Optional[PILImageResampling] = None,
866
+ image_mean: Optional[Union[float, list[float]]] = None,
867
+ image_std: Optional[Union[float, list[float]]] = None,
868
+ do_convert_rgb: Optional[bool] = None,
869
+ patch_size: Optional[int] = None,
870
+ pooling_size: Optional[list[int]] = None,
871
+ return_tensors: Optional[Union[str, TensorType]] = None,
872
+ **kwargs,
873
+ ) -> BatchFeature:
874
+ """
875
+ Preprocess a video for the model.
876
+ Args:
877
+ videos (`VideoInput`):
878
+ Video to preprocess.
879
+ size (`SizeDict`, *optional*, defaults to `self.size`):
880
+ Size of the image after resizing.
881
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
882
+ Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
883
+ has an effect if `do_resize` is set to `True`.
884
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
885
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
886
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
887
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
888
+ `True`.
889
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
890
+ Whether to convert the image to RGB.
891
+ patch_size (`int`, *optional*, defaults to `self.patch_size`):
892
+ The spatial patch size of the vision encoder.
893
+ pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
894
+ The pooling size of the vision adapter.
895
+ return_tensors (`str` or `TensorType`, *optional*):
896
+ The type of tensors to return. Can be one of:
897
+ - Unset: Return a list of `np.ndarray`.
898
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
899
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
900
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
901
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
902
+
903
+ Returns:
904
+ A `BatchFeature` containing the following keys:
905
+ - `pixel_values_videos`: The preprocessed videos.
906
+ - `video_token_pooling`: The indices of the patches in `crops` to pool for each token in `video_tokens`.
907
+ - `video_grids`: The video grids.
908
+ """
909
+ if size.height is None or size.width is None:
910
+ raise ValueError("size must contain 'height' and 'width' keys.")
911
+
912
+ base_image_input_size = [size.height, size.width]
913
+
914
+ resample = resample or self.resample
915
+ image_mean = image_mean or self.image_mean
916
+ image_std = image_std or self.image_std
917
+ do_convert_rgb = do_convert_rgb or self.do_convert_rgb
918
+
919
+ patch_size = patch_size or self.patch_size
920
+ pooling_size = pooling_size or self.pooling_size
921
+
922
+ image_pooling_h, image_pooling_w = pooling_size
923
+
924
+ batch_grids = []
925
+ batch_crops = []
926
+ batch_pooled_patches_idx = []
927
+
928
+ for video in videos:
929
+ all_crops = []
930
+ pooled_patches_idx = []
931
+
932
+ for frame in video:
933
+ image_grid, crops, pooled_idx = image_to_patches_and_grids(
934
+ frame,
935
+ base_image_input_size,
936
+ resample,
937
+ image_mean,
938
+ image_std,
939
+ patch_size,
940
+ image_pooling_w,
941
+ image_pooling_h,
942
+ )
943
+ offset = sum(np.prod(x.shape[:2]) for x in all_crops)
944
+ pooled_idx_with_offset = np.where(pooled_idx >= 0, pooled_idx + offset, pooled_idx)
945
+ pooled_patches_idx.append(pooled_idx_with_offset)
946
+ all_crops.append(crops)
947
+
948
+ video_grid = np.array([len(video), image_grid[0], image_grid[1]])
949
+ all_crops = np.concatenate(all_crops, 0)
950
+ pooled_patches_idx = np.concatenate(pooled_patches_idx, 0)
951
+
952
+ batch_grids.append(video_grid)
953
+ batch_crops.append(all_crops)
954
+ batch_pooled_patches_idx.append(pooled_patches_idx)
955
+
956
+ video_grids = np.stack(batch_grids, 0)
957
+ pixel_values_videos = np.concatenate(batch_crops, 0)
958
+ video_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
959
+
960
+ data =dict(
961
+ pixel_values_videos=pixel_values_videos,
962
+ video_token_pooling=video_token_pooling,
963
+ video_grids=video_grids,
964
+ )
965
+
966
+ return BatchFeature(data, tensor_type=return_tensors)
967
+
968
+
969
+ MolmoAct2VideoProcessor.register_for_auto_class()