chrisc36 commited on
Commit
4d0968d
·
verified ·
1 Parent(s): 3f47d4a

Upload folder using huggingface_hub

Browse files
README.md DELETED
@@ -1,164 +0,0 @@
1
- ---
2
- license: apache-2.0
3
- language:
4
- - en
5
- base_model:
6
- - Qwen/Qwen3-8B
7
- - google/siglip-so400m-patch14-384
8
- pipeline_tag: image-text-to-text
9
- library_name: transformers
10
- tags:
11
- - multimodal
12
- - olmo
13
- - molmo
14
- - molmo2
15
- ---
16
-
17
- # MolmoPoint-8B
18
- MolmoPoint-8B is a fully-open VLM developed by the Allen Institute for AI (Ai2) that support image, video and multi-image understanding and grounding.
19
- It has novel pointing mechansim that improves image pointing, video pointing, and video tracking, see our technical report for details.
20
-
21
- Note the huggingface MolmoPoint model does not support training, see our github repo for the training code.
22
-
23
- Quick links:
24
- - 💬 [Code](https://github.com/allenai/molmo2)
25
- - 📂 [All Models](https://huggingface.co/collections/allenai/molmo_point)
26
- - 📃 [Paper](https://allenai.org/papers/molmo_point)
27
- - 📝 [Blog](https://allenai.org/blog/molmo_point)
28
-
29
-
30
- ## Quick Start
31
-
32
- ### Setup Conda Environment
33
- ```
34
- conda create --name transformers4571 python=3.11
35
- conda activate transformers4571
36
- pip install transformers==4.57.1
37
- pip install torch pillow einops torchvision accelerate decord2
38
- ```
39
-
40
- ## Inference
41
- We recommend running MolmoPoint with `logits_processor=model.build_logit_processor_from_inputs(model_inputs)`
42
- to enforce points tokens are generated in a valid way.
43
-
44
- In MolmoPoint, instead of coordinates points will be generated as a series of special
45
- tokens, to decode the tokens back into points requires some additional
46
- metadata from the preprocessor.
47
- The metadata is returned by the preprocessor using the `return_pointing_metadata` flag.
48
- Then `model.extract_image_points` and `model.extract_video_points` do the decoding, they
49
- return a list of ({image_id|timestamps}, object_id, pixel_x, pixel_y) output points.
50
-
51
-
52
- ### Image Pointing Example:
53
-
54
- ```python
55
- from transformers import AutoProcessor, AutoModelForImageTextToText
56
- from PIL import Image
57
- import requests
58
- import torch
59
-
60
- checkpoint_dir = "allenai/MolmoPoint-8B" # or path to a converted HF checkpoint
61
-
62
- model = AutoModelForImageTextToText.from_pretrained(
63
- checkpoint_dir,
64
- trust_remote_code=True,
65
- dtype="auto",
66
- device_map="auto",
67
- )
68
-
69
- processor = AutoProcessor.from_pretrained(
70
- checkpoint_dir,
71
- trust_remote_code=True,
72
- padding_side="left",
73
- )
74
-
75
- image_messages = [
76
- {
77
- "role": "user",
78
- "content": [
79
- {"type": "text", "text": "Point to the eyes"},
80
- {"type": "image", "image": Image.open(requests.get(
81
- "https://picsum.photos/id/237/536/354", stream=True
82
- ).raw)},
83
- ]
84
- }
85
- ]
86
-
87
- inputs = processor.apply_chat_template(
88
- image_messages,
89
- tokenize=True,
90
- add_generation_prompt=True,
91
- return_tensors="pt",
92
- return_dict=True,
93
- padding=True,
94
- return_pointing_metadata=True
95
- )
96
- metadata = inputs.pop("metadata")
97
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
98
-
99
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
100
- output = model.generate(
101
- **inputs,
102
- logits_processor=model.build_logit_processor_from_inputs(inputs),
103
- max_new_tokens=200
104
- )
105
-
106
- generated_tokens = output[:, inputs["input_ids"].size(1):]
107
- generated_text = processor.post_process_image_text_to_text(generated_tokens, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
108
- points = model.extract_image_points(
109
- generated_text,
110
- metadata["token_pooling"],
111
- metadata["subpatch_mapping"],
112
- metadata["image_sizes"]
113
- )
114
- print(points)
115
- ```
116
-
117
-
118
- ### Video Pointing Example:
119
- ```python
120
- video_path = "https://storage.googleapis.com/oe-training-public/demo_videos/many_penguins.mp4"
121
- video_messages = [
122
- {
123
- "role": "user",
124
- "content": [
125
- dict(type="text", text="Point to the penguins"),
126
- dict(type="video", video=video_path),
127
- ]
128
- }
129
- ]
130
-
131
- inputs = processor.apply_chat_template(
132
- video_messages,
133
- tokenize=True,
134
- add_generation_prompt=True,
135
- return_tensors="pt",
136
- return_dict=True,
137
- padding=True,
138
- return_pointing_metadata=True
139
- )
140
-
141
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
142
-
143
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
144
- output = model.generate(
145
- **inputs,
146
- logits_processor=model.build_logit_processor_from_inputs(inputs)
147
- max_new_tokens=200
148
- )
149
-
150
- generated_tokens = output[:, inputs['input_ids'].size(1):]
151
- generated_text = processor.post_process_image_text_to_text(generated_tokens, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
152
- points = model.extract_video_points(
153
- generated_text,
154
- metadata["token_pooling"],
155
- metadata["subpatch_mapping"],
156
- metadata["timestamps"],
157
- metadata["video_size"]
158
- )
159
- print(points)
160
- ```
161
-
162
- ## License and Use
163
-
164
- This model is licensed under Apache 2.0. It is intended for research and educational use in accordance with Ai2’s Responsible Use Guidelines. This model is trained on third party datasets that are subject to academic and non-commercial research use only. Please review the sources to determine if this model is appropriate for your use case.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__init__.py DELETED
File without changes
config.json CHANGED
@@ -83,7 +83,7 @@
83
  "tie_word_embeddings": false,
84
  "token_prediction_rotary": "one_d",
85
  "token_prediction_rotary_theta": 50000.0,
86
- "transformers_version": "4.57.6",
87
  "use_cache": true,
88
  "use_frame_special_tokens": true,
89
  "vit_config": {
 
83
  "tie_word_embeddings": false,
84
  "token_prediction_rotary": "one_d",
85
  "token_prediction_rotary_theta": 50000.0,
86
+ "transformers_version": "4.57.1",
87
  "use_cache": true,
88
  "use_frame_special_tokens": true,
89
  "vit_config": {
generation_config.json CHANGED
@@ -2,5 +2,5 @@
2
  "bos_token_id": 151645,
3
  "eos_token_id": 151645,
4
  "pad_token_id": 151643,
5
- "transformers_version": "4.57.6"
6
  }
 
2
  "bos_token_id": 151645,
3
  "eos_token_id": 151645,
4
  "pad_token_id": 151643,
5
+ "transformers_version": "4.57.1"
6
  }
modeling_molmo_point.py CHANGED
@@ -796,7 +796,7 @@ class MolmoPointConnector(nn.Module):
796
  def extract_image_points(output_text, pooling, mappings, no_more_points_class, location, image_sizes):
797
  """Extract points from MolmoPoint image output text
798
 
799
- return points: [n_points, 4] array of (image_num, object_id, x, y) points
800
  """
801
  if len(mappings) != len(image_sizes):
802
  raise ValueError("Mapping and image sizes must have the same length")
@@ -831,7 +831,7 @@ def extract_video_points(output_text, pooling, mapping, timestamps, no_more_poin
831
  """
832
  Extract points from MolmoPoint video output text
833
 
834
- return points: [n_points, 4] array of (timestamp, object_id, x, y) points
835
  """
836
  extracted_points = []
837
  for vit_patch_id, location_id, example_id in get_subpatch_ids(output_text, pooling, no_more_points_class):
@@ -1263,7 +1263,7 @@ class MolmoPointModel(MolmoPointPreTrainedModel):
1263
  **kwargs: Unpack[TransformersKwargs],
1264
  ) -> Union[tuple, MolmoPointModelOutputWithPast]:
1265
  """
1266
- last_point_patch_id: The patch id the last generatd point pointed to
1267
  """
1268
 
1269
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -1289,7 +1289,6 @@ class MolmoPointModel(MolmoPointPreTrainedModel):
1289
  raise NotImplementedError("Custom inputs_embeds is not implemented yet")
1290
 
1291
  input_ids = input_ids * (input_ids != -1).to(input_ids.dtype)
1292
- print(f"ON: {input_ids[0, -1]}")
1293
 
1294
  if image_data is not None:
1295
  can_point = True
 
796
  def extract_image_points(output_text, pooling, mappings, no_more_points_class, location, image_sizes):
797
  """Extract points from MolmoPoint image output text
798
 
799
+ return points: [n_points, 4] array of (object_id, image_num, x, y) points
800
  """
801
  if len(mappings) != len(image_sizes):
802
  raise ValueError("Mapping and image sizes must have the same length")
 
831
  """
832
  Extract points from MolmoPoint video output text
833
 
834
+ return points: [n_points, 4] array of (object_id, timestamp, x, y) points
835
  """
836
  extracted_points = []
837
  for vit_patch_id, location_id, example_id in get_subpatch_ids(output_text, pooling, no_more_points_class):
 
1263
  **kwargs: Unpack[TransformersKwargs],
1264
  ) -> Union[tuple, MolmoPointModelOutputWithPast]:
1265
  """
1266
+ last_point_patch_id: The patch id the last generated point pointed to
1267
  """
1268
 
1269
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
1289
  raise NotImplementedError("Custom inputs_embeds is not implemented yet")
1290
 
1291
  input_ids = input_ids * (input_ids != -1).to(input_ids.dtype)
 
1292
 
1293
  if image_data is not None:
1294
  can_point = True
modelling_molmo_point.py DELETED
@@ -1,1914 +0,0 @@
1
- import math
2
- from copy import deepcopy
3
- from dataclasses import dataclass
4
- from typing import Optional, Union, Callable
5
-
6
- import torch
7
- from torch import nn
8
-
9
- from torch.nn import functional as F
10
-
11
- from transformers.models.auto import AutoModelForImageTextToText
12
- from transformers.activations import ACT2FN
13
- from transformers.configuration_utils import PretrainedConfig
14
- from transformers.cache_utils import Cache, DynamicCache
15
- from transformers.generation import GenerationMixin
16
- from transformers.masking_utils import create_causal_mask, create_masks_for_generate
17
- from transformers.modeling_flash_attention_utils import (
18
- _flash_attention_forward,
19
- FlashAttentionKwargs,
20
- flash_attn_supports_top_left_mask,
21
- )
22
- from transformers.modeling_layers import GradientCheckpointingLayer
23
- from transformers.modeling_outputs import (
24
- BaseModelOutputWithPast,
25
- )
26
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
27
- from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
28
- from transformers.processing_utils import Unpack
29
- from transformers.utils import (
30
- ModelOutput,
31
- TransformersKwargs,
32
- can_return_tuple,
33
- logging,
34
- )
35
-
36
- from .configuration_molmo2 import Molmo2VitConfig, Molmo2TextConfig, Molmo2AdapterConfig
37
- from .configuration_molmo_point import MolmoPointConfig, MolmoPointAdapterConfig
38
-
39
-
40
- logger = logging.get_logger(__name__)
41
-
42
-
43
- @dataclass
44
- class MolmoPointCausalLMOutputWithPast(ModelOutput):
45
- """
46
- Base class for MolmoPoint causal language model (or autoregressive) outputs.
47
-
48
- Args:
49
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
50
- Language modeling loss (for next-token prediction).
51
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
52
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
53
- past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
54
- It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
55
-
56
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
57
- `past_key_values` input) to speed up sequential decoding.
58
- image_hidden_states (`torch.FloatTensor`, *optional*):
59
- A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
60
- image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
61
- """
62
-
63
- loss: Optional[torch.FloatTensor] = None
64
- logits: Optional[torch.FloatTensor] = None
65
- past_key_values: Optional[Cache] = None
66
- hidden_states: Optional[tuple[torch.FloatTensor]] = None
67
- attentions: Optional[tuple[torch.FloatTensor]] = None
68
- image_hidden_states: Optional[torch.FloatTensor] = None
69
-
70
-
71
- @dataclass
72
- class MolmoPointModelOutputWithPast(BaseModelOutputWithPast):
73
- """
74
- Base class for Molmo2 outputs, with hidden states and attentions.
75
-
76
- Args:
77
- image_hidden_states (`torch.FloatTensor`, *optional*):
78
- A `torch.FloatTensor` of size `(batch_num_patches, hidden_size)`.
79
- image_hidden_states of the model produced by the vision backbone
80
- """
81
- last_hidden_state: Optional[torch.FloatTensor] = None
82
- past_key_values: Optional[Cache] = None
83
- hidden_states: Optional[tuple[torch.FloatTensor]] = None
84
- attentions: Optional[tuple[torch.FloatTensor]] = None
85
- image_hidden_states: Optional[torch.FloatTensor] = None
86
-
87
-
88
-
89
- class ViTMLP(nn.Module):
90
- def __init__(self, dim: int, hidden_dim: int, hidden_act: str, device: Union[str, torch.device] = None):
91
- super().__init__()
92
- self.w1 = nn.Linear(dim, hidden_dim, bias=True, device=device)
93
- self.act = ACT2FN[hidden_act]
94
- self.w2 = nn.Linear(hidden_dim, dim, bias=True, device=device)
95
-
96
- def forward(self, x: torch.Tensor) -> torch.Tensor:
97
- return self.w2(self.act(self.w1(x)))
98
-
99
-
100
- class ViTMultiHeadDotProductAttention(nn.Module):
101
- def __init__(
102
- self,
103
- hidden_size: int,
104
- num_heads: int,
105
- num_key_value_heads: int,
106
- head_dim: int,
107
- use_bias: bool = True,
108
- input_dim: Optional[int] = None,
109
- float32_attention: bool = True,
110
- attention_dropout: float = 0.0,
111
- residual_dropout: float = 0.0,
112
- device: Union[str, torch.device] = None,
113
- attn_implementation: str = "eager",
114
- out_layer: bool=True
115
- ):
116
- super().__init__()
117
-
118
- self.hidden_size = hidden_size
119
- self.num_heads = num_heads
120
- self.head_dim = head_dim
121
- self.num_key_value_heads = num_key_value_heads
122
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
123
- self.attn_implementation = attn_implementation
124
- self.is_causal = False
125
-
126
- input_dim = input_dim or hidden_size
127
-
128
- self.wq = nn.Linear(
129
- input_dim,
130
- self.num_heads * self.head_dim,
131
- bias=use_bias,
132
- device=device,
133
- )
134
- self.wk = nn.Linear(
135
- input_dim,
136
- self.num_key_value_heads * self.head_dim,
137
- bias=use_bias,
138
- device=device,
139
- )
140
- self.wv = nn.Linear(
141
- input_dim,
142
- self.num_key_value_heads * self.head_dim,
143
- bias=use_bias,
144
- device=device,
145
- )
146
- if out_layer:
147
- self.wo = nn.Linear(
148
- self.num_heads * self.head_dim,
149
- self.hidden_size,
150
- )
151
- else:
152
- self.w0 = None
153
- self.float32_attention = float32_attention
154
- self.attention_dropout = attention_dropout
155
- self.residual_dropout = nn.Dropout(residual_dropout)
156
-
157
- def _split_heads(self, hidden_states, num_heads) -> torch.Tensor:
158
- return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
159
-
160
- def _merge_heads(self, hidden_states) -> torch.Tensor:
161
- return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,))
162
-
163
- def forward(
164
- self,
165
- inputs_q: torch.Tensor,
166
- inputs_kv: Optional[torch.Tensor] = None,
167
- attn_mask: Optional[torch.Tensor] = None,
168
- ) -> torch.Tensor:
169
-
170
- if inputs_kv is not None:
171
- inputs_k = inputs_kv
172
- inputs_v = inputs_kv
173
- else:
174
- inputs_k = inputs_q
175
- inputs_v = inputs_q
176
-
177
- xq, xk, xv = self.wq(inputs_q), self.wk(inputs_k), self.wv(inputs_v)
178
-
179
- xq = self._split_heads(xq, self.num_heads)
180
- xk = self._split_heads(xk, self.num_key_value_heads)
181
- xv = self._split_heads(xv, self.num_key_value_heads)
182
-
183
- if self.num_heads != self.num_key_value_heads:
184
- xk = xk.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads)
185
- xv = xv.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads)
186
-
187
- og_dtype = xq.dtype
188
-
189
- if self.float32_attention:
190
- xq = xq.to(torch.float)
191
- xk = xk.to(torch.float)
192
-
193
- dropout_p = 0.0 if not self.training else self.attention_dropout
194
-
195
- if self.attn_implementation == "eager":
196
- attn_weights = torch.einsum("...qhd,...khd->...hqk", xq / math.sqrt(xq.size(-1)), xk)
197
- attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(xq.dtype)
198
- attn_weights = F.dropout(
199
- attn_weights,
200
- p=dropout_p,
201
- training=self.training
202
- )
203
- attn_output = torch.einsum("...hqk,...khd->...qhd", attn_weights.to(xv.dtype), xv)
204
-
205
- elif self.attn_implementation == "sdpa":
206
- if not torch.is_autocast_enabled():
207
- xv = xv.to(torch.float)
208
-
209
- attn_output = F.scaled_dot_product_attention(
210
- xq.transpose(1, 2).contiguous(),
211
- xk.transpose(1, 2).contiguous(),
212
- xv.transpose(1, 2).contiguous(),
213
- attn_mask=attn_mask,
214
- is_causal=False,
215
- dropout_p=dropout_p,
216
- ).transpose(1, 2)
217
-
218
- elif self.attn_implementation == "flash_attention_2":
219
- if xq.dtype == torch.float32:
220
- if torch.is_autocast_enabled():
221
- target_dtype = torch.get_autocast_gpu_dtype()
222
- else:
223
- target_dtype = self.wq.weight.dtype
224
- attn_output = _flash_attention_forward(
225
- xq,
226
- xk,
227
- xv,
228
- attention_mask=attn_mask,
229
- query_length=inputs_q.shape[1],
230
- is_causal=False,
231
- dropout=dropout_p,
232
- softmax_scale=xq.shape[-1] ** -0.5,
233
- use_top_left_mask=flash_attn_supports_top_left_mask(),
234
- target_dtype=target_dtype,
235
- implementation=self.attn_implementation,
236
- )
237
- else:
238
- raise ValueError(f"Attention implementation {self.attn_implementation} not supported")
239
-
240
- attn_output = attn_output.to(og_dtype)
241
- attn_output = self._merge_heads(attn_output)
242
- if self.wo is not None:
243
- attn_output = self.wo(attn_output)
244
- attn_output = self.residual_dropout(attn_output)
245
-
246
- return attn_output
247
-
248
-
249
- class Molmo2VisionBlock(nn.Module):
250
-
251
- def __init__(self, config: Molmo2VitConfig, device: Union[str, torch.device] = None):
252
- super().__init__()
253
- self.attention = ViTMultiHeadDotProductAttention(
254
- hidden_size=config.hidden_size,
255
- num_heads=config.num_attention_heads,
256
- num_key_value_heads=config.num_key_value_heads,
257
- head_dim=config.head_dim,
258
- float32_attention=config.float32_attention,
259
- attention_dropout=config.attention_dropout,
260
- residual_dropout=config.residual_dropout,
261
- device=device,
262
- attn_implementation=config._attn_implementation,
263
- )
264
- self.feed_forward = ViTMLP(config.hidden_size, config.intermediate_size, config.hidden_act, device=device)
265
- self.attention_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device)
266
- self.ffn_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device)
267
-
268
- def forward(self, x: torch.Tensor) -> torch.Tensor:
269
- x = x + self.attention(self.attention_norm(x))
270
- x = x + self.feed_forward(self.ffn_norm(x))
271
- return x
272
-
273
-
274
- class Molmo2VisionBlockCollection(nn.Module):
275
-
276
- def __init__(self, config: Molmo2VitConfig, device: Union[str, torch.device] = None):
277
- super().__init__()
278
- self.conifg = config
279
- self.resblocks = nn.ModuleList([
280
- Molmo2VisionBlock(config, device) for _ in range(config.num_hidden_layers)
281
- ])
282
-
283
- def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
284
- hidden_states = []
285
- for r in self.resblocks:
286
- x = r(x)
287
- hidden_states.append(x)
288
- return hidden_states
289
-
290
-
291
- class Molmo2VisionTransformer(nn.Module):
292
-
293
- def __init__(self, config: Molmo2VitConfig, device: Union[str, torch.device] = None):
294
- super().__init__()
295
- self.config = config
296
-
297
- # positional embeddings
298
- self.scale = config.hidden_size ** -0.5
299
- self.num_prefix_tokens: int = 0 # no class embeddings
300
- self.positional_embedding = nn.Parameter(
301
- torch.zeros(config.image_num_pos, config.hidden_size, device=device),
302
- )
303
-
304
- image_patch_size = config.image_patch_size
305
- self.patch_embedding = nn.Linear(
306
- image_patch_size * image_patch_size * 3,
307
- config.hidden_size,
308
- bias=True,
309
- device=device,
310
- )
311
-
312
- self.transformer = Molmo2VisionBlockCollection(config, device)
313
-
314
- def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
315
- pos_emb = self.positional_embedding
316
-
317
- pos_emb = pos_emb.reshape(
318
- (int(math.sqrt(pos_emb.shape[0])), int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1])
319
- )
320
-
321
- (patch_num_0, patch_num_1) = patch_num
322
-
323
- if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1:
324
- # Dervied from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
325
- # antialias: default True in jax.image.resize
326
- pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
327
- pos_emb = F.interpolate(
328
- pos_emb, size=(patch_num_0, patch_num_1), mode="bicubic", align_corners=False, antialias=True,
329
- )
330
- pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0)
331
-
332
- pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1])
333
- x = x + pos_emb[None, :, :].to(x.dtype)
334
- return x
335
-
336
- def forward(self, x: torch.Tensor, patch_num: int = None) -> list[torch.Tensor]:
337
- """
338
- : param x: (batch_size, num_patch, n_pixels)
339
- """
340
- if patch_num is None:
341
- patch_num = self.config.image_num_patch
342
-
343
- B, N, D = x.shape
344
-
345
- x = self.patch_embedding(x)
346
-
347
- # class embeddings and positional embeddings
348
- x = self.add_pos_emb(x, patch_num)
349
-
350
- hidden_states = self.transformer(x)
351
- return hidden_states
352
-
353
-
354
- class ImageProjectorMLP(nn.Module):
355
-
356
- def __init__(
357
- self,
358
- input_dim: int,
359
- hidden_dim: int,
360
- output_dim: int,
361
- hidden_act: str,
362
- device: Union[str, torch.device] = None,
363
- ):
364
- super().__init__()
365
- self.w1 = nn.Linear(input_dim, hidden_dim, bias=False, device=device)
366
- self.w2 = nn.Linear(hidden_dim, output_dim, bias=False, device=device)
367
- self.w3 = nn.Linear(input_dim, hidden_dim, bias=False, device=device)
368
- self.act = ACT2FN[hidden_act]
369
-
370
- def forward(self, x: torch.Tensor) -> torch.Tensor:
371
- return self.w2(self.act(self.w1(x)) * self.w3(x))
372
-
373
-
374
- class Molmo2VisionBackbone(nn.Module):
375
- def __init__(self, vit_config: Molmo2VitConfig, adapter_config: Molmo2AdapterConfig):
376
- super().__init__()
377
- self.vit_config = vit_config
378
- self.adapter_config = adapter_config
379
-
380
- self.vit_layers = []
381
- for layer in adapter_config.vit_layers:
382
- if layer >= 0:
383
- self.vit_layers.append(layer)
384
- else:
385
- self.vit_layers.append(layer + vit_config.num_hidden_layers)
386
-
387
- last_layer_needed = max(self.vit_layers) + 1
388
- if last_layer_needed < vit_config.num_hidden_layers:
389
- new_vit_config = deepcopy(vit_config)
390
- new_vit_config.num_hidden_layers = last_layer_needed
391
- self.image_vit = Molmo2VisionTransformer(new_vit_config)
392
- else:
393
- self.image_vit = Molmo2VisionTransformer(vit_config)
394
-
395
- self.num_prefix_tokens: int = self.image_vit.num_prefix_tokens
396
-
397
- pool_dim = vit_config.hidden_size * len(adapter_config.vit_layers)
398
- self.image_pooling_2d = ViTMultiHeadDotProductAttention(
399
- hidden_size=adapter_config.hidden_size,
400
- num_heads=adapter_config.num_attention_heads,
401
- num_key_value_heads=adapter_config.num_key_value_heads,
402
- head_dim=adapter_config.head_dim,
403
- input_dim=pool_dim,
404
- float32_attention=adapter_config.float32_attention,
405
- attention_dropout=adapter_config.attention_dropout,
406
- residual_dropout=adapter_config.residual_dropout,
407
- attn_implementation=adapter_config._attn_implementation,
408
- )
409
- self.image_projector = ImageProjectorMLP(
410
- adapter_config.hidden_size,
411
- adapter_config.intermediate_size,
412
- adapter_config.text_hidden_size,
413
- adapter_config.hidden_act,
414
- )
415
- self.image_feature_dropout = nn.Dropout(adapter_config.image_feature_dropout)
416
-
417
- def encode_image(self, images: torch.Tensor) -> torch.Tensor:
418
- """
419
- : param images: (batch_size, num_crops, num_patch, n_pixels)
420
- """
421
- B, T, N, D = images.shape
422
- images = images.view(B * T, N, D)
423
- image_features = self.image_vit(images)
424
-
425
- features = []
426
- for layer in self.vit_layers:
427
- features.append(image_features[layer])
428
- image_features = torch.cat(features, dim=-1)
429
-
430
- if self.num_prefix_tokens > 0:
431
- image_features = image_features[:, 1:]
432
- image_features = image_features.view(B, T, N, -1)
433
- return image_features
434
-
435
- @property
436
- def dtype(self) -> torch.dtype:
437
- return self.image_vit.patch_embedding.weight.dtype
438
-
439
- @property
440
- def device(self) -> torch.device:
441
- return self.image_vit.patch_embedding.weight.device
442
-
443
- def forward(
444
- self,
445
- images: torch.Tensor,
446
- pooled_patches_idx: torch.Tensor,
447
- ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
448
-
449
- # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
450
- batch_size, num_image = images.shape[:2]
451
- images = images.to(device=self.device, dtype=self.dtype)
452
- image_features = self.encode_image(images)
453
-
454
- image_features = self.image_feature_dropout(image_features)
455
- dim = image_features.shape[-1]
456
- valid = pooled_patches_idx >= 0
457
- valid_token = torch.any(valid, -1)
458
-
459
- # Use `pooled_patches_idx` to arange the features for image pooling
460
- batch_idx = torch.arange(pooled_patches_idx.shape[0], dtype=torch.long, device=pooled_patches_idx.device)
461
- batch_idx = torch.tile(batch_idx.view(batch_size, 1, 1), [1, pooled_patches_idx.shape[1], pooled_patches_idx.shape[2]])
462
-
463
- # Now [batch, num_high_res_features, pool_dim, dim]
464
- to_pool = image_features.reshape(batch_size, -1, dim)[batch_idx, torch.clip(pooled_patches_idx, 0)]
465
- to_pool = to_pool * valid.to(self.dtype)[:, :, :, None]
466
- to_pool = to_pool.reshape([-1, pooled_patches_idx.shape[-1], dim])
467
- if self.adapter_config.pooling_attention_mask:
468
- attn_mask = valid.reshape([-1, 1, 1, valid.shape[-1]])
469
- denom = valid.view(-1, to_pool.shape[-2]).float().sum(-1)
470
- denom = torch.where(denom == 0, 1, denom)
471
- query = to_pool.sum(-2, keepdim=True) / denom[:, None, None].to(to_pool.dtype)
472
- else:
473
- attn_mask = None
474
- query = to_pool.mean(-2, keepdim=True)
475
- pooled_features = self.image_pooling_2d(query, to_pool, attn_mask=attn_mask)
476
- pooled_features = pooled_features.reshape([batch_size, -1, pooled_features.shape[-1]])
477
-
478
- # MLP layer to map the feature.
479
- pooled_features = self.image_projector(pooled_features)
480
- return pooled_features.view(-1, pooled_features.shape[-1])[valid_token.flatten()]
481
-
482
-
483
- # Copied from transformers.models.llama.modeling_llama.rotate_half
484
- def rotate_half(x):
485
- """Rotates half the hidden dims of the input."""
486
- x1 = x[..., : x.shape[-1] // 2]
487
- x2 = x[..., x.shape[-1] // 2 :]
488
- return torch.cat((-x2, x1), dim=-1)
489
-
490
-
491
- # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
492
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
493
- """Applies Rotary Position Embedding to the query and key tensors.
494
-
495
- Args:
496
- q (`torch.Tensor`): The query tensor.
497
- k (`torch.Tensor`): The key tensor.
498
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
499
- sin (`torch.Tensor`): The sine part of the rotary embedding.
500
- position_ids (`torch.Tensor`, *optional*):
501
- Deprecated and unused.
502
- unsqueeze_dim (`int`, *optional*, defaults to 1):
503
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
504
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
505
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
506
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
507
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
508
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
509
- Returns:
510
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
511
- """
512
- cos = cos.unsqueeze(unsqueeze_dim)
513
- sin = sin.unsqueeze(unsqueeze_dim)
514
- q_embed = (q * cos) + (rotate_half(q) * sin)
515
- k_embed = (k * cos) + (rotate_half(k) * sin)
516
- return q_embed, k_embed
517
-
518
-
519
- class Molmo2RotaryEmbedding(nn.Module):
520
- inv_freq: torch.Tensor # fix linting for `register_buffer`
521
-
522
- def __init__(
523
- self,
524
- config: Molmo2TextConfig,
525
- device: Union[str, torch.device] = None,
526
- rope_type: Optional[str] = None,
527
- ):
528
- super().__init__()
529
- if rope_type is not None:
530
- self.rope_type = rope_type
531
- elif hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
532
- # BC: "rope_type" was originally "type"
533
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
534
- else:
535
- self.rope_type = "default"
536
- self.max_seq_len_cached = config.max_position_embeddings
537
- self.original_max_seq_len = config.max_position_embeddings
538
-
539
- self.config = config
540
- rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
541
- inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
542
- self.register_buffer("inv_freq", inv_freq, persistent=False)
543
- self.original_inv_freq = self.inv_freq
544
-
545
- @torch.no_grad()
546
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
547
- def forward(self, x, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
548
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
549
- position_ids_expanded = position_ids[:, None, :].float()
550
-
551
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
552
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
553
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
554
- emb = torch.cat((freqs, freqs), dim=-1)
555
- cos = emb.cos() * self.attention_scaling
556
- sin = emb.sin() * self.attention_scaling
557
-
558
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
559
-
560
-
561
- class Molmo2RMSNorm(nn.Module):
562
-
563
- def __init__(
564
- self,
565
- size: int,
566
- eps: float = 1e-6,
567
- device: Union[str, torch.device] = None,
568
- ):
569
- super().__init__()
570
- self.weight = nn.Parameter(torch.ones(size, device=device))
571
- self.eps = eps
572
-
573
- def forward(self, x: torch.Tensor) -> torch.Tensor:
574
- with torch.autocast(enabled=False, device_type=x.device.type):
575
- og_dtype = x.dtype
576
- x = x.to(torch.float32)
577
- variance = x.pow(2).mean(-1, keepdim=True)
578
- x = x * torch.rsqrt(variance + self.eps)
579
- x = x.to(og_dtype)
580
-
581
- return self.weight * x
582
-
583
- def extra_repr(self):
584
- return f"{tuple(self.weight.shape)}, eps={self.eps}"
585
-
586
-
587
- # Copied from transformers.models.llama.modeling_llama.repeat_kv
588
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
589
- """
590
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
591
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
592
- """
593
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
594
- if n_rep == 1:
595
- return hidden_states
596
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
597
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
598
-
599
-
600
- def eager_attention_forward(
601
- module: nn.Module,
602
- query: torch.Tensor,
603
- key: torch.Tensor,
604
- value: torch.Tensor,
605
- attention_mask: Optional[torch.Tensor],
606
- scaling: float,
607
- dropout: float = 0.0,
608
- **kwargs,
609
- ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
610
- key_states = repeat_kv(key, module.num_key_value_groups)
611
- value_states = repeat_kv(value, module.num_key_value_groups)
612
-
613
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
614
- if attention_mask is not None:
615
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
616
- attn_weights = attn_weights + causal_mask
617
-
618
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
619
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
620
- attn_output = torch.matmul(attn_weights, value_states)
621
- attn_output = attn_output.transpose(1, 2).contiguous()
622
-
623
- return attn_output, attn_weights
624
-
625
-
626
- class Molmo2Attention(nn.Module):
627
- """Multi-headed attention from 'Attention Is All You Need' paper"""
628
-
629
- def __init__(self, config: Molmo2TextConfig, layer_idx: int) -> None:
630
- super().__init__()
631
- self.config = config
632
- self.layer_idx = layer_idx
633
- self.num_heads = config.num_attention_heads
634
- self.num_key_value_heads = config.num_key_value_heads
635
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
636
- self.head_dim = config.head_dim
637
- self.scaling = self.head_dim**-0.5
638
- self.is_causal = True
639
-
640
- self.fused_dims = (
641
- config.num_attention_heads * config.head_dim,
642
- config.head_dim * config.num_key_value_heads,
643
- config.head_dim * config.num_key_value_heads,
644
- )
645
- self.att_proj = nn.Linear(
646
- config.hidden_size,
647
- sum(self.fused_dims),
648
- bias=config.qkv_bias,
649
- )
650
-
651
- # Layer norms.
652
- self.k_norm: Optional[Molmo2RMSNorm] = None
653
- self.q_norm: Optional[Molmo2RMSNorm] = None
654
- self.qk_norm_type: Optional[str] = None
655
- if config.use_qk_norm:
656
- k_norm_size = (
657
- config.head_dim
658
- if config.qk_norm_type == "qwen3" else
659
- config.num_key_value_heads * config.head_dim
660
- )
661
- self.k_norm = Molmo2RMSNorm(k_norm_size, eps=config.layer_norm_eps)
662
- q_norm_size = (
663
- config.head_dim
664
- if config.qk_norm_type == "qwen3" else
665
- config.num_attention_heads * config.head_dim
666
- )
667
- self.q_norm = Molmo2RMSNorm(q_norm_size, eps=config.layer_norm_eps)
668
- self.qk_norm_type = config.qk_norm_type
669
-
670
- self.attention_dropout = config.attention_dropout
671
-
672
- self.attn_out = nn.Linear(
673
- config.head_dim * config.num_attention_heads,
674
- config.hidden_size,
675
- bias=False,
676
- )
677
-
678
- def forward(
679
- self,
680
- hidden_states: torch.Tensor,
681
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
682
- attention_mask: Optional[torch.Tensor],
683
- past_key_values: Optional[Cache] = None,
684
- cache_position: Optional[torch.LongTensor] = None,
685
- **kwargs: Unpack[FlashAttentionKwargs],
686
- ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
687
- input_shape = hidden_states.shape[:-1]
688
- hidden_shape = (*input_shape, -1, self.head_dim)
689
-
690
- qkv = self.att_proj(hidden_states)
691
- query_states, key_states, value_states = qkv.split(self.fused_dims, dim=-1)
692
- value_states = value_states.view(hidden_shape)
693
-
694
- # Optionally apply layer norm to keys and queries.
695
- if self.q_norm is not None and self.k_norm is not None and self.qk_norm_type != "qwen3":
696
- query_states = self.q_norm(query_states)
697
- key_states = self.k_norm(key_states)
698
-
699
- query_states = query_states.view(hidden_shape)
700
- key_states = key_states.view(hidden_shape)
701
- if self.q_norm is not None and self.k_norm is not None and self.qk_norm_type == "qwen3":
702
- query_states = self.q_norm(query_states)
703
- key_states = self.k_norm(key_states)
704
- query_states = query_states.transpose(1, 2)
705
- key_states = key_states.transpose(1, 2)
706
- value_states = value_states.transpose(1, 2)
707
-
708
- cos, sin = position_embeddings
709
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
710
-
711
- if past_key_values is not None:
712
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
713
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
714
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
715
-
716
- attention_interface: Callable = eager_attention_forward
717
- if self.config._attn_implementation != "eager":
718
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
719
-
720
- attn_output, attn_weights = attention_interface(
721
- self,
722
- query_states,
723
- key_states,
724
- value_states,
725
- attention_mask,
726
- dropout=0.0 if not self.training else self.attention_dropout,
727
- scaling=self.scaling,
728
- **kwargs,
729
- )
730
-
731
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
732
- attn_output = self.attn_out(attn_output)
733
- return attn_output, attn_weights
734
-
735
-
736
- class LanguageModelMLP(nn.Module):
737
-
738
- def __init__(
739
- self,
740
- input_dim: int,
741
- intermediate_size: int,
742
- hidden_act: str,
743
- device: Union[str, torch.device] = None,
744
- ):
745
- super().__init__()
746
- self.ff_proj = nn.Linear(input_dim, intermediate_size * 2, bias=False, device=device)
747
- self.ff_out = nn.Linear(intermediate_size, input_dim, bias=False, device=device)
748
- self.act = ACT2FN[hidden_act]
749
-
750
- def forward(self, x: torch.Tensor) -> torch.Tensor:
751
- x = self.ff_proj(x)
752
- x, gate = x.chunk(2, dim=-1)
753
- x = self.act(gate) * x
754
- x = self.ff_out(x)
755
- return x
756
-
757
-
758
- class Molmo2DecoderLayer(GradientCheckpointingLayer):
759
-
760
- def __init__(
761
- self,
762
- config: Molmo2TextConfig,
763
- layer_idx: Optional[int] = None,
764
- device: Union[str, torch.device] = None
765
- ):
766
- super().__init__()
767
- self.config = config
768
-
769
- self.self_attn = Molmo2Attention(config, layer_idx)
770
- self.attn_norm = Molmo2RMSNorm(
771
- config.hidden_size, eps=config.layer_norm_eps, device=device)
772
- self.dropout = nn.Dropout(config.residual_dropout)
773
- self.mlp = LanguageModelMLP(
774
- config.hidden_size, config.intermediate_size, config.hidden_act, device=device)
775
- self.ff_norm = Molmo2RMSNorm(
776
- config.hidden_size, eps=config.layer_norm_eps, device=device)
777
-
778
- def forward(
779
- self,
780
- hidden_states: torch.Tensor,
781
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
782
- attention_mask: Optional[torch.Tensor] = None,
783
- position_ids: Optional[torch.LongTensor] = None,
784
- past_key_values: Optional[Cache] = None,
785
- output_attentions: Optional[bool] = False,
786
- use_cache: Optional[bool] = False,
787
- cache_position: Optional[torch.LongTensor] = None,
788
- **kwargs: Unpack[TransformersKwargs],
789
- ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
790
-
791
- residual = hidden_states
792
- hidden_states = self.attn_norm(hidden_states)
793
-
794
- # Self Attention
795
- hidden_states, self_attn_weights = self.self_attn(
796
- hidden_states=hidden_states,
797
- position_embeddings=position_embeddings,
798
- attention_mask=attention_mask,
799
- position_ids=position_ids,
800
- past_key_values=past_key_values,
801
- output_attentions=output_attentions,
802
- use_cache=use_cache,
803
- cache_position=cache_position,
804
- **kwargs,
805
- )
806
-
807
- hidden_states = residual + self.dropout(hidden_states)
808
-
809
- # Fully Connected
810
- residual = hidden_states
811
- hidden_states = self.ff_norm(hidden_states)
812
- hidden_states = self.mlp(hidden_states)
813
-
814
- hidden_states = residual + self.dropout(hidden_states)
815
-
816
- outputs = (hidden_states,)
817
-
818
- if output_attentions:
819
- outputs += (self_attn_weights,)
820
-
821
- return outputs
822
-
823
-
824
- class Molmo2PostNormDecoderLayer(Molmo2DecoderLayer):
825
- def forward(
826
- self,
827
- hidden_states: torch.Tensor,
828
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
829
- attention_mask: Optional[torch.Tensor] = None,
830
- position_ids: Optional[torch.LongTensor] = None,
831
- past_key_values: Optional[Cache] = None,
832
- output_attentions: Optional[bool] = False,
833
- use_cache: Optional[bool] = False,
834
- cache_position: Optional[torch.LongTensor] = None,
835
- **kwargs,
836
- ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
837
-
838
- residual = hidden_states
839
-
840
- # Self Attention
841
- hidden_states, self_attn_weights = self.self_attn(
842
- hidden_states=hidden_states,
843
- position_embeddings=position_embeddings,
844
- attention_mask=attention_mask,
845
- position_ids=position_ids,
846
- past_key_values=past_key_values,
847
- output_attentions=output_attentions,
848
- use_cache=use_cache,
849
- cache_position=cache_position,
850
- )
851
- hidden_states = self.attn_norm(hidden_states)
852
-
853
- hidden_states = residual + self.dropout(hidden_states)
854
-
855
- # Fully Connected
856
- residual = hidden_states
857
- hidden_states = self.mlp(hidden_states)
858
- hidden_states = self.ff_norm(hidden_states)
859
-
860
- hidden_states = residual + self.dropout(hidden_states)
861
-
862
- outputs = (hidden_states,)
863
-
864
- if output_attentions:
865
- outputs += (self_attn_weights,)
866
-
867
- return outputs
868
-
869
-
870
- class Molmo2Embedding(nn.Module):
871
- def __init__(
872
- self,
873
- num_embeddings: int,
874
- num_new_embeddings: int,
875
- features: int,
876
- device: Union[str, torch.device] = None,
877
- ):
878
- super().__init__()
879
- self.embedding = nn.Parameter(
880
- torch.zeros(num_embeddings, features, device=device),
881
- )
882
- self.new_embedding = nn.Parameter(
883
- torch.zeros(num_new_embeddings, features, device=device),
884
- )
885
-
886
- def forward(self, x: torch.Tensor) -> torch.Tensor:
887
- return F.embedding(x, torch.cat([self.embedding, self.new_embedding], dim=0))
888
-
889
-
890
- class MolmoPointPreTrainedModel(PreTrainedModel):
891
- config: MolmoPointConfig
892
- base_model_prefix = "model"
893
- supports_gradient_checkpointing = True
894
- _no_split_modules = [
895
- "Molmo2DecoderLayer",
896
- "Molmo2PostNormDecoderLayer",
897
- "Molmo2VisionBlock",
898
- "ViTMultiHeadDotProductAttention",
899
- ]
900
- _skip_keys_device_placement = "past_key_values"
901
- _supports_flash_attn = True
902
- _supports_sdpa = True
903
-
904
- _can_compile_fullgraph = True
905
- _supports_attention_backend = True
906
- _can_record_outputs = {
907
- "hidden_states": Molmo2DecoderLayer,
908
- "attentions": Molmo2Attention,
909
- }
910
-
911
- def _init_weights(self, module):
912
- std = self.config.initializer_range
913
- if isinstance(module, (nn.Linear,)):
914
- module.weight.data.normal_(mean=0.0, std=std)
915
- if module.bias is not None:
916
- module.bias.data.zero_()
917
- elif isinstance(module, Molmo2Embedding):
918
- module.embedding.data.normal_(mean=0.0, std=std)
919
- module.new_embedding.data.normal_(mean=0.0, std=std)
920
- elif isinstance(module, nn.Embedding):
921
- module.weight.data.normal_(mean=0.0, std=std)
922
- if module.padding_idx is not None:
923
- module.weight.data[module.padding_idx].zero_()
924
- elif isinstance(module, Molmo2RMSNorm):
925
- module.weight.data.fill_(1.0)
926
- elif isinstance(module, nn.LayerNorm):
927
- module.weight.data.fill_(1.0)
928
- if module.bias is not None:
929
- module.bias.data.zero_()
930
-
931
-
932
- class MolmoPointTextModel(PreTrainedModel):
933
- config: Molmo2TextConfig
934
- _no_split_modules = ["Molmo2DecoderLayer", "Molmo2PostNormDecoderLayer"]
935
- base_model_prefix = "model"
936
- supports_gradient_checkpointing = True
937
- _skip_keys_device_placement = "past_key_values"
938
- _supports_flash_attn = True
939
- _supports_sdpa = True
940
-
941
- _can_compile_fullgraph = True
942
- _supports_attention_backend = True
943
- _can_record_outputs = {
944
- "hidden_states": Molmo2DecoderLayer,
945
- "attentions": Molmo2Attention,
946
- }
947
-
948
- def __init__(self, config: Molmo2TextConfig):
949
- super().__init__(config)
950
- if config.additional_vocab_size is not None:
951
- self.wte = Molmo2Embedding(
952
- config.vocab_size,
953
- config.additional_vocab_size,
954
- config.hidden_size,
955
- )
956
- else:
957
- self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
958
- self.emb_drop = nn.Dropout(config.embedding_dropout)
959
- decoder_layer = Molmo2PostNormDecoderLayer if config.norm_after else Molmo2DecoderLayer
960
- self.blocks = nn.ModuleList(
961
- [decoder_layer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
962
- )
963
- self.ln_f = Molmo2RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
964
- if config.rope_scaling_layers is not None:
965
- self.rotary_embs = nn.ModuleDict(
966
- {
967
- "default": Molmo2RotaryEmbedding(config, rope_type="default"),
968
- "scaling": Molmo2RotaryEmbedding(config),
969
- }
970
- )
971
- else:
972
- self.rotary_emb = Molmo2RotaryEmbedding(config)
973
- self.gradient_checkpointing = False
974
-
975
- # Initialize weights and apply final processing
976
- self.post_init()
977
-
978
- def get_input_embeddings(self) -> torch.nn.Module:
979
- return self.wte
980
-
981
- def set_input_embeddings(self, value: torch.nn.Module) -> None:
982
- self.wte = value
983
-
984
- @can_return_tuple
985
- def forward(
986
- self,
987
- input_ids: Optional[torch.LongTensor] = None,
988
- attention_mask: Optional[torch.Tensor] = None,
989
- position_ids: Optional[torch.LongTensor] = None,
990
- past_key_values: Optional[Cache] = None,
991
- inputs_embeds: Optional[torch.FloatTensor] = None,
992
- use_cache: Optional[bool] = None,
993
- output_attentions: Optional[bool] = None,
994
- output_hidden_states: Optional[bool] = None,
995
- cache_position: Optional[torch.LongTensor] = None,
996
- **kwargs: Unpack[TransformersKwargs],
997
- ) -> BaseModelOutputWithPast:
998
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
999
- output_hidden_states = (
1000
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1001
- )
1002
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1003
-
1004
- if (input_ids is None) ^ (inputs_embeds is not None):
1005
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1006
-
1007
- if self.gradient_checkpointing and self.training and use_cache:
1008
- logger.warning_once(
1009
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
1010
- )
1011
- use_cache = False
1012
-
1013
- if inputs_embeds is None:
1014
- input_ids = input_ids * (input_ids != -1).to(input_ids.dtype)
1015
- inputs_embeds = self.wte(input_ids)
1016
-
1017
- # torch.jit.trace() doesn't support cache objects in the output
1018
- if use_cache and past_key_values is None and not torch.jit.is_tracing():
1019
- past_key_values = DynamicCache(config=self.config)
1020
-
1021
- if cache_position is None:
1022
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1023
- cache_position = torch.arange(
1024
- past_seen_tokens,
1025
- past_seen_tokens + inputs_embeds.shape[1],
1026
- device=inputs_embeds.device,
1027
- )
1028
-
1029
- if position_ids is None:
1030
- position_ids = cache_position.unsqueeze(0)
1031
-
1032
- # It may already have been prepared by e.g. `generate`
1033
- if not isinstance(causal_mask_mapping := attention_mask, dict):
1034
- # Prepare mask arguments
1035
- mask_kwargs = {
1036
- "config": self.config,
1037
- "input_embeds": inputs_embeds,
1038
- "attention_mask": attention_mask,
1039
- "cache_position": cache_position,
1040
- "past_key_values": past_key_values,
1041
- "position_ids": position_ids,
1042
- }
1043
-
1044
- # Create the mask
1045
- causal_mask_mapping = create_causal_mask(**mask_kwargs)
1046
-
1047
- hidden_states = inputs_embeds
1048
-
1049
- # create position embeddings to be shared across the decoder layers
1050
- if self.config.rope_scaling_layers is not None:
1051
- position_embeddings_mapping = {
1052
- "default": self.rotary_embs["default"](hidden_states, position_ids),
1053
- "scaling": self.rotary_embs["scaling"](hidden_states, position_ids),
1054
- }
1055
- else:
1056
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
1057
-
1058
- # decoder layers
1059
- all_hidden_states = () if output_hidden_states else None
1060
- all_self_attns = () if output_attentions else None
1061
-
1062
- for layer_idx, decoder_block in enumerate(self.blocks[: self.config.num_hidden_layers]):
1063
- if output_hidden_states:
1064
- all_hidden_states += (hidden_states,)
1065
-
1066
- if self.config.rope_scaling_layers is not None:
1067
- position_embeddings_i = (
1068
- position_embeddings_mapping["scaling"]
1069
- if layer_idx in self.config.rope_scaling_layers
1070
- else position_embeddings_mapping["default"]
1071
- )
1072
- else:
1073
- position_embeddings_i = position_embeddings
1074
-
1075
- layer_outputs = decoder_block(
1076
- hidden_states,
1077
- attention_mask=causal_mask_mapping,
1078
- position_ids=position_ids,
1079
- past_key_values=past_key_values,
1080
- output_attentions=output_attentions,
1081
- use_cache=use_cache,
1082
- cache_position=cache_position,
1083
- position_embeddings=position_embeddings_i,
1084
- **kwargs,
1085
- )
1086
-
1087
- hidden_states = layer_outputs[0]
1088
-
1089
- if output_attentions:
1090
- all_self_attns += (layer_outputs[1],)
1091
-
1092
- hidden_states = self.ln_f(hidden_states)
1093
-
1094
- # add hidden states from the last decoder layer
1095
- if output_hidden_states:
1096
- all_hidden_states += (hidden_states,)
1097
-
1098
- return BaseModelOutputWithPast(
1099
- last_hidden_state=hidden_states,
1100
- past_key_values=past_key_values,
1101
- hidden_states=all_hidden_states,
1102
- attentions=all_self_attns,
1103
- )
1104
-
1105
- # Adapted from transformers.models.gemma3.modeling_gemma3
1106
- def token_type_ids_mask_function(
1107
- token_type_ids: Optional[torch.Tensor] = None,
1108
- ) -> Optional[Callable]:
1109
- """
1110
- This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
1111
- not start and end indices.
1112
- """
1113
- # Do not return an additional mask in this case
1114
- if token_type_ids is None:
1115
- return None
1116
-
1117
- def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
1118
- # If it's 1 for both query and key/value, we are in an image block
1119
- # NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
1120
- # Since vmap doesn't support `if statement` we workaround it with `torch.where`
1121
- safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
1122
- token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx]
1123
- token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)
1124
-
1125
- is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1)
1126
-
1127
- # This is bidirectional attention whenever we are dealing with image tokens
1128
- return is_image_block & is_image_block
1129
-
1130
- return inner_mask
1131
-
1132
-
1133
- class MolmoPointPadWithLearnedVector(nn.Module):
1134
- """Module that pads vector
1135
-
1136
- Used to add in the no-more-point key value
1137
- """
1138
- def __init__(self, dim: int):
1139
- super().__init__()
1140
- self.dim = dim
1141
- self.vector = nn.Parameter(torch.zeros([dim]))
1142
-
1143
- def reset_parameters(self):
1144
- torch.nn.init.zeros_(self.vector)
1145
-
1146
- def forward(self, x: torch.Tensor) -> torch.Tensor:
1147
- vector = torch.tile(self.vector[None, :], [x.shape[0], 1])
1148
- return torch.concatenate([x, vector[:, None, :]], dim=1)
1149
-
1150
-
1151
- class AddPosEmbed(nn.Module):
1152
-
1153
- def __init__(self, in_features: int, n_pos: int) -> None:
1154
- super().__init__()
1155
- self.bias = nn.Parameter(torch.zeros([n_pos, in_features]))
1156
-
1157
- def forward(self, input: torch.Tensor) -> torch.Tensor:
1158
- return input + self.bias[None, :input.shape[-2], :]
1159
-
1160
-
1161
- class MolmoPointConnector(nn.Module):
1162
- def __init__(self, config: MolmoPointAdapterConfig, vit_config: Molmo2VitConfig):
1163
- super().__init__()
1164
- self.config = config
1165
- self.n_vit_layers = len(config.vit_layers)
1166
- pool_dim = vit_config.hidden_size * self.n_vit_layers
1167
- self.norm = None
1168
- self.image_projector = ImageProjectorMLP(
1169
- config.hidden_size,
1170
- config.intermediate_size,
1171
- config.text_hidden_size,
1172
- config.hidden_act,
1173
- )
1174
- self.act = ACT2FN[config.hidden_act]
1175
- self.image_pooling_2d = ViTMultiHeadDotProductAttention(
1176
- hidden_size=config.hidden_size,
1177
- num_heads=config.num_attention_heads,
1178
- num_key_value_heads=config.num_key_value_heads,
1179
- head_dim=config.head_dim,
1180
- input_dim=pool_dim,
1181
- float32_attention=config.float32_attention,
1182
- attention_dropout=config.attention_dropout,
1183
- residual_dropout=config.residual_dropout,
1184
- attn_implementation=config._attn_implementation,
1185
- out_layer=False
1186
- )
1187
- if self.config.positional_embeddings:
1188
- self.positional_embeddings = AddPosEmbed(pool_dim, self.config.positional_embeddings)
1189
- else:
1190
- self.positional_embeddings = None
1191
-
1192
- def __call__(self, to_pool, to_pool_mask):
1193
- """
1194
- to_pool: [n_to_pool, pooling_dim, vit_dim]
1195
- to_pool_mask: [n_to_pool, pooling_dim]
1196
-
1197
- returns:
1198
- pooled_features: [n_to_pool, llm_dim]
1199
- """
1200
- cfg = self.config
1201
-
1202
- if self.config.positional_embeddings:
1203
- to_pool = self.positional_embeddings(to_pool)
1204
-
1205
- if self.config.pooling_attention_mask:
1206
- attn_mask = to_pool_mask.reshape([-1, 1, 1, to_pool_mask.shape[-1]])
1207
- else:
1208
- attn_mask = None
1209
- to_pool = to_pool * to_pool_mask.float()[:, :, None]
1210
-
1211
- denom = to_pool_mask.view(-1, to_pool.shape[-2]).float().sum(-1)
1212
- denom = torch.where(denom == 0, 1, denom)
1213
- query = to_pool.sum(-2, keepdim=True) / denom[:, None, None]
1214
-
1215
- pooled_features = self.image_pooling_2d(query, to_pool, attn_mask=attn_mask)
1216
- pooled_features = self.act(pooled_features)
1217
- pooled_features = self.image_projector(pooled_features)
1218
- return pooled_features
1219
-
1220
-
1221
- class MolmoPointModel(MolmoPointPreTrainedModel):
1222
- base_model_prefix = ""
1223
- _checkpoint_conversion_mapping = {}
1224
- # Reference: fix gemma3 grad acc #37208
1225
- accepts_loss_kwargs = False
1226
- config: MolmoPointConfig
1227
-
1228
- def __init__(self, config: MolmoPointConfig):
1229
- super().__init__(config)
1230
- self.transformer: MolmoPointTextModel = MolmoPointTextModel(config.text_config)
1231
-
1232
- vit_config = config.vit_config
1233
- adapter_config = config.adapter_config
1234
- self.vit_layers = []
1235
- for layer in adapter_config.vit_layers:
1236
- if layer >= 0:
1237
- self.vit_layers.append(layer)
1238
- else:
1239
- self.vit_layers.append(layer + vit_config.num_hidden_layers)
1240
-
1241
- last_layer_needed = max(self.vit_layers) + 1
1242
- if last_layer_needed < vit_config.num_hidden_layers:
1243
- new_vit_config = deepcopy(vit_config)
1244
- new_vit_config.num_hidden_layers = last_layer_needed
1245
- self.vit = Molmo2VisionTransformer(new_vit_config)
1246
- else:
1247
- self.vit = Molmo2VisionTransformer(vit_config)
1248
-
1249
- self.connector = MolmoPointConnector(adapter_config, vit_config)
1250
-
1251
- vit_dim = self.config.vit_config.hidden_size * len(self.config.adapter_config.vit_layers)
1252
- llm_dim = self.config.text_config.hidden_size
1253
- self.patch_rotary = None
1254
- self.patch_q = nn.Linear(llm_dim, config.patch_embed_dim)
1255
- self.patch_k = nn.Linear(llm_dim, config.patch_embed_dim)
1256
- self.subpatch_q = nn.Linear(llm_dim, config.patch_embed_dim)
1257
- self.subpatch_k = nn.Linear(vit_dim, config.patch_embed_dim)
1258
- self.add_no_point_class_embed = MolmoPointPadWithLearnedVector(config.patch_embed_dim)
1259
-
1260
- if self.config.embed_selected_vit_patch == "linear":
1261
- self.build_vit_embedding = nn.Linear(vit_dim, llm_dim, bias=True)
1262
- else:
1263
- raise NotImplementedError(f"Embedding {self.config.embed_selected_vit_patch} not implemented")
1264
-
1265
- if self.config.patch_location == "3x3":
1266
- self.subpatch_loc_k = nn.Linear(llm_dim, 9)
1267
- elif self.config.patch_location is None:
1268
- self.subpatch_loc_k = None
1269
- else:
1270
- raise NotImplementedError(f"Patch location {self.config.patch_location} not implemented")
1271
-
1272
- if self.config.layer_norm_x:
1273
- self.x_norm = Molmo2RMSNorm(llm_dim, eps=self.config.text_config.layer_norm_eps)
1274
- else:
1275
- self.x_norm = None
1276
-
1277
- # Initialize weights and apply final processing
1278
- self.post_init()
1279
-
1280
- def get_input_embeddings(self) -> torch.nn.Module:
1281
- return self.transformer.wte
1282
-
1283
- def set_input_embeddings(self, value: torch.nn.Module) -> None:
1284
- self.transformer.wte = value
1285
-
1286
- def set_decoder(self, decoder):
1287
- self.transformer = decoder
1288
-
1289
- def get_decoder(self):
1290
- return self.transformer
1291
-
1292
- @property
1293
- def device(self) -> torch.device:
1294
- return self.transformer.ln_f.weight.device
1295
-
1296
- def build_batched_images(
1297
- self,
1298
- input_ids: torch.LongTensor,
1299
- pixel_values: torch.Tensor,
1300
- image_token_pooling: torch.Tensor,
1301
- image_grids: torch.Tensor,
1302
- image_num_crops: torch.Tensor,
1303
- ) -> tuple[torch.Tensor, torch.Tensor]:
1304
- # 1) Count the number of images in each example
1305
- raw_counts = (input_ids == self.config.image_end_token_id).sum(1) # [N]
1306
- # Each image is represented by global view and high-res view
1307
- # so we divide by 2 to get the number of images
1308
- counts = raw_counts // 2
1309
- N = counts.size(0)
1310
- device = input_ids.device
1311
-
1312
- # Total number of images in the batch
1313
- num_images = int(counts.sum().item())
1314
-
1315
- # Sanity check
1316
- assert image_grids.size(0) == num_images, \
1317
- f"Expected {num_images} image grids, but got {image_grids.size(0)}"
1318
- assert image_num_crops.size(0) == num_images, \
1319
- f"Expected {num_images} image num crops, but got {image_num_crops.size(0)}"
1320
-
1321
- # 1-1) Compute per-image pooled patch count from image grids
1322
- with torch.no_grad():
1323
- first_prod = image_grids[:, :2].prod(dim=1) # [num_images]
1324
- second_prod = image_grids[:, 2:].prod(dim=1) # [num_images]
1325
- num_pooled_patches_per_image = (first_prod + second_prod).to(image_num_crops.dtype) # [num_images]
1326
-
1327
- # pixel_values: [n_crops, n_patches, pixels_per_patch]
1328
- n_crops, n_patches, pixels_per_patch = pixel_values.shape
1329
-
1330
- # 2) Map each image index → example index
1331
- # Example: if counts = [2, 1, 3], then this becomes [0,0,1,2,2,2]
1332
- example_ids_for_image = torch.arange(N, device=device).repeat_interleave(counts) # [num_images]
1333
- assert example_ids_for_image.numel() == num_images
1334
-
1335
- # 2-1) Compute crops_per_example by summing per-image crop counts
1336
- crops_per_example = torch.zeros(
1337
- N, dtype=image_num_crops.dtype, device=image_num_crops.device
1338
- )
1339
- crops_per_example.index_add_(0, example_ids_for_image, image_num_crops) # [N]
1340
-
1341
- # 2-2) Per-image number of patches = (crops per image) * n_patches
1342
- patches_per_image = image_num_crops * n_patches # [num_images]
1343
-
1344
- # 2-3) Compute per-example per-image patch offsets
1345
- counts_list = counts.tolist()
1346
- index_offset_per_example_list = []
1347
- offset_img = 0
1348
- for c in counts_list:
1349
- per_img_patches = patches_per_image[offset_img:offset_img + c] # [c]
1350
- # Offsets: [0, img0_total_patches, img0+img1_total_patches, ...]
1351
- index_offset = [0] + per_img_patches.cumsum(0).tolist()[:-1]
1352
- index_offset_per_example_list.append(index_offset)
1353
- offset_img += c
1354
-
1355
- # 2-4) Compute num_pooled_patches_per_example
1356
- num_pooled_patches_per_example = torch.zeros(
1357
- N, dtype=num_pooled_patches_per_image.dtype, device=num_pooled_patches_per_image.device
1358
- )
1359
- num_pooled_patches_per_example.index_add_(
1360
- 0, example_ids_for_image, num_pooled_patches_per_image
1361
- )
1362
-
1363
- # Sanity checks
1364
- total_crops = int(crops_per_example.sum().item())
1365
- assert total_crops == n_crops, \
1366
- f"Expected {total_crops} crops, but got {n_crops}"
1367
-
1368
- total_num_pooled_patches = int(num_pooled_patches_per_example.sum().item())
1369
- assert total_num_pooled_patches == image_token_pooling.size(0), \
1370
- f"Expected {total_num_pooled_patches} pooled patches, but got {image_token_pooling.size(0)}"
1371
-
1372
- # 3) Build images tensor filled with -1
1373
- M = int(crops_per_example.max().item())
1374
- images = torch.full(
1375
- (N, M, n_patches, pixels_per_patch),
1376
- fill_value=-1,
1377
- dtype=pixel_values.dtype,
1378
- device=pixel_values.device,
1379
- )
1380
-
1381
- # 4) Fill images with per-example slices from pixel_values
1382
- offset_crop = 0
1383
- for i in range(N):
1384
- num = int(crops_per_example[i].item())
1385
- cur = pixel_values[offset_crop:offset_crop + num] # [num, n_patches, pixels_per_patch]
1386
- images[i, :num] = cur
1387
- offset_crop += num
1388
-
1389
- # Sanity check
1390
- assert offset_crop == n_crops
1391
-
1392
- # 5) Build new_token_pooling tensor filled with -1
1393
- P = int(num_pooled_patches_per_example.max().item())
1394
- _, dim = image_token_pooling.shape
1395
- new_token_pooling = torch.full(
1396
- (N, P, dim),
1397
- fill_value=-1,
1398
- dtype=image_token_pooling.dtype,
1399
- device=image_token_pooling.device,
1400
- )
1401
-
1402
- # 6) Fill token_pooling with per-example slices, adding per-image patch offsets
1403
- patch_offset = 0
1404
- img_offset = 0
1405
-
1406
- for i, c in enumerate(counts_list):
1407
- num_patches = int(num_pooled_patches_per_example[i].item())
1408
-
1409
- # Subsequence of pooled tokens belonging to this example
1410
- cur = image_token_pooling[patch_offset:patch_offset + num_patches].clone() # [num_patches, dim]
1411
-
1412
- index_offset_per_example = index_offset_per_example_list[i] # length = c
1413
- per_img_pooled = num_pooled_patches_per_image[img_offset:img_offset + c] # [c]
1414
-
1415
- assert len(index_offset_per_example) == per_img_pooled.numel()
1416
-
1417
- # Apply per-image offsets to the (ragged) subsequence
1418
- offset = 0
1419
- for j in range(c):
1420
- index_offset = int(index_offset_per_example[j])
1421
- n = int(per_img_pooled[j].item())
1422
- cur_slice = cur[offset:offset + n]
1423
-
1424
- # Apply offset across all columns
1425
- cur[offset:offset + n] = torch.where(
1426
- cur_slice >= 0,
1427
- cur_slice + index_offset,
1428
- cur_slice,
1429
- )
1430
- offset += n
1431
-
1432
- new_token_pooling[i, :num_patches] = cur
1433
-
1434
- patch_offset += num_patches
1435
- img_offset += c
1436
-
1437
- # Final sanity checks
1438
- assert patch_offset == total_num_pooled_patches
1439
- assert img_offset == num_images
1440
-
1441
- return images, new_token_pooling
1442
-
1443
- def build_batched_videos(
1444
- self,
1445
- input_ids: torch.LongTensor,
1446
- pixel_values_videos: torch.Tensor,
1447
- video_token_pooling: torch.Tensor,
1448
- video_grids: torch.Tensor,
1449
- ) -> tuple[torch.Tensor, torch.Tensor]:
1450
-
1451
- # 1) Count the number of videos in each example
1452
- if self.config.use_frame_special_tokens:
1453
- end_token_id = self.config.frame_end_token_id
1454
- else:
1455
- end_token_id = self.config.image_end_token_id
1456
- counts = (input_ids == end_token_id).any(dim=1).long() # [N]
1457
- N = counts.size(0)
1458
- device = input_ids.device
1459
-
1460
- # Total number of videos in the batch
1461
- num_videos = int(counts.sum().item())
1462
-
1463
- # Sanity check
1464
- assert video_grids.size(0) == num_videos, \
1465
- f"Expected {num_videos} videos, but got {video_grids.size(0)}"
1466
-
1467
- video_num_frames = video_grids[:, 0] # [num_videos]
1468
- num_pooled_patches_per_video = video_grids.prod(dim=1) # [num_videos]
1469
-
1470
- # pixel_values_videos: [n_frames, n_patches, pixels_per_patch]
1471
- n_frames, n_patches, pixels_per_patch = pixel_values_videos.shape
1472
-
1473
- # 2) Map each video index -> example index
1474
- # Example: if counts = [2, 1, 3], then this becomes [0,0,1,2,2,2]
1475
- example_ids_for_video = torch.arange(N, device=device).repeat_interleave(counts) # [num_videos]
1476
- assert example_ids_for_video.numel() == num_videos
1477
-
1478
- # 2-1) Compute frames_per_example by summing per-video frame counts
1479
- frames_per_example = torch.zeros(
1480
- N, dtype=video_num_frames.dtype, device=device,
1481
- )
1482
- frames_per_example.index_add_(0, example_ids_for_video, video_num_frames) # [N]
1483
-
1484
- # 2-2) Compute num_pooled_patches_per_example
1485
- num_pooled_patches_per_example = torch.zeros(
1486
- N, dtype=num_pooled_patches_per_video.dtype, device=num_pooled_patches_per_video.device,
1487
- )
1488
- num_pooled_patches_per_example.index_add_(
1489
- 0, example_ids_for_video, num_pooled_patches_per_video,
1490
- )
1491
-
1492
- # Sanity checks
1493
- total_frames = int(frames_per_example.sum().item())
1494
- assert total_frames == n_frames, \
1495
- f"Expected {total_frames} frames, but got {n_frames}"
1496
-
1497
- total_num_pooled_patches = int(num_pooled_patches_per_example.sum().item())
1498
- assert total_num_pooled_patches == video_token_pooling.size(0), \
1499
- f"Expected {total_num_pooled_patches} pooled patches, but got {video_token_pooling.size(0)}"
1500
-
1501
- # 3) Build videos tensor filled with -1
1502
- M = int(frames_per_example.max().item())
1503
- videos = torch.full(
1504
- (N, M, n_patches, pixels_per_patch),
1505
- fill_value=-1,
1506
- dtype=pixel_values_videos.dtype,
1507
- device=device,
1508
- )
1509
-
1510
- # 4) Fill videos with per-examples slices from pixel_values_videos
1511
- offset_frame = 0
1512
- for i in range(N):
1513
- num = int(frames_per_example[i].item())
1514
- cur = pixel_values_videos[offset_frame:offset_frame + num] # [num, n_patches, pixels_per_patch]
1515
- videos[i, :num] = cur
1516
- offset_frame += num
1517
-
1518
- # Sanity check
1519
- assert offset_frame == n_frames
1520
-
1521
- # 5) Build new token_pooling tensor filled with -1
1522
- P = int(num_pooled_patches_per_example.max().item())
1523
- _, dim = video_token_pooling.shape
1524
- new_token_pooling = torch.full(
1525
- (N, P, dim),
1526
- fill_value=-1,
1527
- dtype=video_token_pooling.dtype,
1528
- device=video_token_pooling.device,
1529
- )
1530
-
1531
- # 6) Fill new token_pooling with per-examples slices from video_token_pooling
1532
- patch_offset = 0
1533
- for i in range(N):
1534
- num_patches = int(num_pooled_patches_per_example[i].item())
1535
- cur = video_token_pooling[patch_offset:patch_offset + num_patches] # [num_patches, dim]
1536
- new_token_pooling[i, :num_patches] = cur
1537
- patch_offset += num_patches
1538
-
1539
- # Final sanity checks
1540
- assert patch_offset == total_num_pooled_patches
1541
-
1542
- return videos, new_token_pooling
1543
-
1544
- def merge_visual_inputs(
1545
- self,
1546
- input_ids: Optional[torch.LongTensor] = None,
1547
- pixel_values: Optional[torch.Tensor] = None,
1548
- image_token_pooling: Optional[torch.Tensor] = None,
1549
- image_grids: Optional[torch.Tensor] = None,
1550
- image_num_crops: Optional[torch.Tensor] = None,
1551
- pixel_values_videos: Optional[torch.Tensor] = None,
1552
- video_token_pooling: Optional[torch.Tensor] = None,
1553
- video_grids: Optional[torch.Tensor] = None,
1554
- ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
1555
- if pixel_values is not None and pixel_values_videos is not None:
1556
- raise ValueError("pixel_values and pixel_values_videos are provided at the same time")
1557
- elif pixel_values is not None:
1558
- assert input_ids is not None
1559
- images, token_pooling = self.build_batched_images(
1560
- input_ids=input_ids,
1561
- pixel_values=pixel_values,
1562
- image_token_pooling=image_token_pooling,
1563
- image_grids=image_grids,
1564
- image_num_crops=image_num_crops,
1565
- )
1566
- elif pixel_values_videos is not None:
1567
- assert input_ids is not None
1568
- images, token_pooling = self.build_batched_videos(
1569
- input_ids=input_ids,
1570
- pixel_values_videos=pixel_values_videos,
1571
- video_token_pooling=video_token_pooling,
1572
- video_grids=video_grids,
1573
- )
1574
- else:
1575
- images, token_pooling = None, None
1576
- return images, token_pooling
1577
-
1578
- def build_input_embeddings(
1579
- self,
1580
- input_ids: torch.LongTensor,
1581
- images: Optional[torch.FloatTensor] = None, # image inputs
1582
- token_pooling: Optional[torch.LongTensor] = None,
1583
- ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
1584
-
1585
- # Get embeddings of input.
1586
- # shape: (batch_size, seq_len, d_model)
1587
- input_ids = input_ids * (input_ids != -1).to(input_ids.dtype)
1588
- x = self.transformer.wte(input_ids)
1589
-
1590
- image_features: Optional[torch.FloatTensor] = None
1591
- if images is not None:
1592
- image_features = self.vision_backbone(images, token_pooling).to(x.device)
1593
- is_image_patch = input_ids.view(-1) == self.config.image_patch_id
1594
- assert is_image_patch.sum() == len(image_features)
1595
- x.view(-1, x.shape[-1])[is_image_patch] += image_features
1596
-
1597
- # shape: (batch_size, seq_len, d_model)
1598
- x = self.transformer.emb_drop(x) # type: ignore
1599
-
1600
- return x, image_features
1601
-
1602
- @can_return_tuple
1603
- def forward(
1604
- self,
1605
- input_ids: Optional[torch.LongTensor] = None,
1606
- pixel_values: Optional[torch.FloatTensor] = None,
1607
- image_token_pooling: Optional[torch.Tensor] = None,
1608
- image_grids: Optional[torch.Tensor] = None,
1609
- image_num_crops: Optional[torch.Tensor] = None,
1610
- pixel_values_videos: Optional[torch.Tensor] = None,
1611
- video_token_pooling: Optional[torch.Tensor] = None,
1612
- video_grids: Optional[torch.Tensor] = None,
1613
- attention_mask: Optional[torch.Tensor] = None,
1614
- position_ids: Optional[torch.Tensor] = None,
1615
- past_key_values: Optional[Cache] = None,
1616
- token_type_ids: Optional[torch.LongTensor] = None,
1617
- inputs_embeds: Optional[torch.FloatTensor] = None,
1618
- use_cache: Optional[bool] = None,
1619
- output_attentions: Optional[bool] = None,
1620
- output_hidden_states: Optional[bool] = None,
1621
- cache_position: Optional[torch.LongTensor] = None,
1622
- **kwargs: Unpack[TransformersKwargs],
1623
- ) -> Union[tuple, MolmoPointModelOutputWithPast]:
1624
-
1625
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1626
- output_hidden_states = (
1627
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1628
- )
1629
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1630
-
1631
- if (input_ids is None) ^ (inputs_embeds is not None):
1632
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1633
-
1634
- images, token_pooling = self.merge_visual_inputs(
1635
- input_ids=input_ids,
1636
- pixel_values=pixel_values,
1637
- image_token_pooling=image_token_pooling,
1638
- image_grids=image_grids,
1639
- image_num_crops=image_num_crops,
1640
- pixel_values_videos=pixel_values_videos,
1641
- video_token_pooling=video_token_pooling,
1642
- video_grids=video_grids,
1643
- )
1644
-
1645
- if images is not None and inputs_embeds is not None:
1646
- raise ValueError(
1647
- "You cannot specify both images and inputs_embeds at the same time."
1648
- )
1649
-
1650
- if inputs_embeds is None:
1651
- inputs_embeds, image_features = self.build_input_embeddings(
1652
- input_ids, images, token_pooling,
1653
- )
1654
-
1655
- if cache_position is None:
1656
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1657
- cache_position = torch.arange(
1658
- past_seen_tokens,
1659
- past_seen_tokens + inputs_embeds.shape[1],
1660
- device=inputs_embeds.device,
1661
- )
1662
-
1663
- # Adapted from transformers.models.gemma3.modeling_gemma3
1664
- # It may already have been prepared by e.g. `generate`
1665
- if not isinstance(causal_mask_mapping := attention_mask, dict):
1666
- # Prepare mask arguments
1667
- mask_kwargs = {
1668
- "config": self.config.get_text_config(),
1669
- "input_embeds": inputs_embeds,
1670
- "attention_mask": attention_mask,
1671
- "cache_position": cache_position,
1672
- "past_key_values": past_key_values,
1673
- "position_ids": position_ids,
1674
- }
1675
-
1676
- # NOTE: this `is_prefill` logic is not flawless, it fails when we're using a cache eagerly initialized
1677
- # (e.g. compiled prefill) AND `images` are not provided. Determining prefill in that case requires
1678
- # checking data values, which is not compile-compatible.
1679
- is_prefill = (
1680
- not use_cache
1681
- or past_key_values is None
1682
- or not past_key_values.is_initialized
1683
- or images is not None
1684
- )
1685
- if token_type_ids is not None and is_prefill:
1686
- # We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
1687
- mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
1688
- token_type_ids.to(cache_position.device)
1689
- )
1690
-
1691
- # Create the mask
1692
- causal_mask_mapping = create_causal_mask(**mask_kwargs)
1693
-
1694
- outputs = self.transformer(
1695
- attention_mask=causal_mask_mapping,
1696
- position_ids=position_ids,
1697
- past_key_values=past_key_values,
1698
- inputs_embeds=inputs_embeds,
1699
- use_cache=use_cache,
1700
- output_attentions=output_attentions,
1701
- output_hidden_states=output_hidden_states,
1702
- cache_position=cache_position,
1703
- **kwargs,
1704
- )
1705
-
1706
- return MolmoPointModelOutputWithPast(
1707
- last_hidden_state=outputs.last_hidden_state,
1708
- past_key_values=outputs.past_key_values,
1709
- hidden_states=outputs.hidden_states,
1710
- attentions=outputs.attentions,
1711
- image_hidden_states=image_features if images is not None else None,
1712
- )
1713
-
1714
-
1715
- class MolmoPointForConditionalGeneration(MolmoPointPreTrainedModel, GenerationMixin):
1716
- _checkpoint_conversion_mapping = {}
1717
- _tied_weights_keys = [] # Weights are not tied
1718
- # Reference: fix gemma3 grad acc #37208
1719
- accepts_loss_kwargs = False
1720
- config: MolmoPointConfig
1721
-
1722
- def __init__(self, config: MolmoPointConfig):
1723
- super().__init__(config)
1724
-
1725
- self.model = MolmoPointModel(config)
1726
- self.output_embeddings = nn.Parameter(torch.zeros([config.vocab_size, config.hidden_size]))
1727
- self.new_output_embeddings = nn.Parameter(torch.zeros([128, config.hidden_size]))
1728
- self.vocab_size = config.vocab_size
1729
-
1730
- # Initialize weights and apply final processing
1731
- self.post_init()
1732
-
1733
- def get_input_embeddings(self) -> torch.nn.Module:
1734
- return self.model.transformer.wte
1735
-
1736
- def set_input_embeddings(self, value: torch.nn.Module) -> None:
1737
- self.model.transformer.wte = value
1738
-
1739
- def set_decoder(self, decoder):
1740
- self.model.set_decoder(decoder)
1741
-
1742
- def get_decoder(self):
1743
- return self.model.get_decoder()
1744
-
1745
- # Make modules available throught conditional class for BC
1746
- @property
1747
- def language_model(self) -> torch.nn.Module:
1748
- return self.model.transformer
1749
-
1750
- @property
1751
- def vision_backbone(self) -> torch.nn.Module:
1752
- return self.model.vision_backbone
1753
-
1754
- @can_return_tuple
1755
- def forward(
1756
- self,
1757
- input_ids: torch.LongTensor = None,
1758
- pixel_values: Optional[torch.Tensor] = None,
1759
- image_token_pooling: Optional[torch.Tensor] = None,
1760
- image_grids: Optional[torch.Tensor] = None,
1761
- image_num_crops: Optional[torch.Tensor] = None,
1762
- pixel_values_videos: Optional[torch.Tensor] = None,
1763
- video_token_pooling: Optional[torch.Tensor] = None,
1764
- video_grids: Optional[torch.Tensor] = None,
1765
- attention_mask: Optional[torch.Tensor] = None,
1766
- position_ids: Optional[torch.LongTensor] = None,
1767
- past_key_values: Optional[list[torch.FloatTensor]] = None,
1768
- token_type_ids: Optional[torch.LongTensor] = None,
1769
- inputs_embeds: Optional[torch.FloatTensor] = None,
1770
- labels: Optional[torch.LongTensor] = None,
1771
- use_cache: Optional[bool] = None,
1772
- output_attentions: Optional[bool] = None,
1773
- output_hidden_states: Optional[bool] = None,
1774
- cache_position: Optional[torch.LongTensor] = None,
1775
- logits_to_keep: Union[int, torch.Tensor] = 0,
1776
- **kwargs: Unpack[TransformersKwargs],
1777
- ) -> Union[tuple, MolmoPointCausalLMOutputWithPast]:
1778
- r"""
1779
- ```python
1780
- >>> from PIL import Image
1781
- >>> import requests
1782
- >>> from transformers import AutoProcessor, MolmoPointForConditionalGeneration
1783
-
1784
- >>> model = Molmo2ForConditionalGeneration.from_pretrained("...")
1785
- >>> processor = AutoProcessor.from_pretrained("...")
1786
-
1787
- >>> prompt = "What's the content of the image?"
1788
- >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
1789
- >>> image = Image.open(requests.get(url, stream=True).raw)
1790
-
1791
- >>> messages = [{"role": "user", "content": [{"type": "text", "text": prompt}, {"type": "image", "image": image}]}]
1792
-
1793
- >>> inputs = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True)
1794
-
1795
- >>> # Generate
1796
- >>> generated_ids = model.generate(**inputs, max_new_tokens=15)
1797
- >>> generated_tokens = generated_ids[:, inputs['input_ids'].size(1):]
1798
- >>> processor.post_process_image_text_to_text(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1799
- "The image shows a bustling street scene in what appears to be a Chinatown area. There's ..."
1800
- ```"""
1801
- outputs = self.model(
1802
- input_ids=input_ids,
1803
- pixel_values=pixel_values,
1804
- image_token_pooling=image_token_pooling,
1805
- image_grids=image_grids,
1806
- image_num_crops=image_num_crops,
1807
- pixel_values_videos=pixel_values_videos,
1808
- video_token_pooling=video_token_pooling,
1809
- video_grids=video_grids,
1810
- attention_mask=attention_mask,
1811
- position_ids=position_ids,
1812
- past_key_values=past_key_values,
1813
- token_type_ids=token_type_ids,
1814
- inputs_embeds=inputs_embeds,
1815
- use_cache=use_cache,
1816
- output_attentions=output_attentions,
1817
- output_hidden_states=output_hidden_states,
1818
- cache_position=cache_position,
1819
- **kwargs,
1820
- )
1821
-
1822
- hidden_states = outputs.last_hidden_state
1823
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1824
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1825
- lm_head = torch.concatenate([self.output_embeddings, self.new_output_embeddings], dim=0)
1826
- logits = F.linear(hidden_states[:, slice_indices, :], lm_head)
1827
-
1828
- loss = None
1829
- if labels is not None:
1830
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size)
1831
-
1832
- return MolmoPointCausalLMOutputWithPast(
1833
- loss=loss,
1834
- logits=logits,
1835
- past_key_values=outputs.past_key_values,
1836
- hidden_states=outputs.hidden_states,
1837
- attentions=outputs.attentions,
1838
- image_hidden_states=outputs.image_hidden_states,
1839
- )
1840
-
1841
- def prepare_inputs_for_generation(
1842
- self,
1843
- input_ids: torch.LongTensor,
1844
- past_key_values: Optional[list[torch.FloatTensor]] = None,
1845
- inputs_embeds: Optional[torch.FloatTensor] = None,
1846
- pixel_values: Optional[torch.FloatTensor] = None,
1847
- image_token_pooling: Optional[torch.Tensor] = None,
1848
- image_grids: Optional[torch.Tensor] = None,
1849
- image_num_crops: Optional[torch.Tensor] = None,
1850
- pixel_values_videos: Optional[torch.Tensor] = None,
1851
- video_token_pooling: Optional[torch.Tensor] = None,
1852
- video_grids: Optional[torch.Tensor] = None,
1853
- attention_mask: Optional[torch.Tensor] = None,
1854
- token_type_ids: Optional[torch.LongTensor] = None,
1855
- cache_position: Optional[torch.LongTensor] = None,
1856
- logits_to_keep: Optional[Union[int, torch.Tensor]] = None,
1857
- **kwargs,
1858
- ):
1859
-
1860
- model_inputs = super().prepare_inputs_for_generation(
1861
- input_ids,
1862
- past_key_values=past_key_values,
1863
- inputs_embeds=inputs_embeds,
1864
- attention_mask=attention_mask,
1865
- cache_position=cache_position,
1866
- logits_to_keep=logits_to_keep,
1867
- token_type_ids=token_type_ids,
1868
- **kwargs,
1869
- )
1870
-
1871
- if cache_position[0] == 0:
1872
- model_inputs["pixel_values"] = pixel_values
1873
- model_inputs["image_token_pooling"] = image_token_pooling
1874
- model_inputs["image_grids"] = image_grids
1875
- model_inputs["image_num_crops"] = image_num_crops
1876
- model_inputs["pixel_values_videos"] = pixel_values_videos
1877
- model_inputs["video_token_pooling"] = video_token_pooling
1878
- model_inputs["video_grids"] = video_grids
1879
-
1880
- return model_inputs
1881
-
1882
- # Adapted from transformers.models.gemma3.modeling_gemma3
1883
- @staticmethod
1884
- def create_masks_for_generate(
1885
- config: PretrainedConfig,
1886
- input_embeds: torch.Tensor,
1887
- attention_mask: Optional[torch.Tensor],
1888
- cache_position: torch.Tensor,
1889
- past_key_values: Optional[Cache],
1890
- position_ids: Optional[torch.Tensor],
1891
- token_type_ids: Optional[torch.Tensor] = None,
1892
- **kwargs,
1893
- ) -> dict:
1894
- # Prepare mask arguments
1895
- mask_kwargs = {
1896
- "config": config.get_text_config(),
1897
- "input_embeds": input_embeds,
1898
- "attention_mask": attention_mask,
1899
- "cache_position": cache_position,
1900
- "past_key_values": past_key_values,
1901
- "position_ids": position_ids,
1902
- }
1903
- # Add the token type ids mask for generate as well
1904
- if token_type_ids is not None and input_embeds.shape[1] != 1:
1905
- # We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
1906
- mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
1907
- token_type_ids.to(cache_position.device)
1908
- )
1909
-
1910
- return create_masks_for_generate(**mask_kwargs)
1911
-
1912
-
1913
- # Always register for multi-modal features
1914
- AutoModelForImageTextToText.register(MolmoPointConfig, MolmoPointForConditionalGeneration)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unified_demo.py DELETED
@@ -1,334 +0,0 @@
1
- import functools
2
- import os
3
- import argparse
4
- import logging
5
- from collections import defaultdict
6
- from PIL import Image, ImageFile, ImageDraw
7
- import PIL
8
-
9
- import numpy as np
10
- import torch
11
- from transformers import AutoProcessor, AutoModelForImageTextToText
12
-
13
- from olmo.models.video_olmo.video_olmo import VideoOlmoConfig
14
- from olmo.html_utils import postprocess_prompt
15
- from olmo.util import (
16
- prepare_cli_environment,
17
- resource_path,
18
- )
19
-
20
- import gradio as gr
21
-
22
- try:
23
- from molmo_utils import process_vision_info
24
- except ImportError:
25
- # raise ImportError("molmo_utils not found. Please install it with `pip install molmo-utils`.")
26
- pass
27
-
28
-
29
- Image.MAX_IMAGE_PIXELS = None
30
- ImageFile.LOAD_TRUNCATED_IMAGES = True
31
-
32
-
33
- CACHE = "model_cache"
34
- log = logging.getLogger(__name__)
35
- ALLOWED_PATH = [CACHE]
36
- MAX_IMAGE_SIZE = 512
37
- MAX_VIDEO_HEIGHT = 512
38
- POINT_SIZE = 0.01
39
-
40
- DEVICE = None
41
-
42
- # load the model, processor
43
- MODEL = None
44
- PROCESSOR = None
45
- POINT_FORMATTER = None
46
-
47
-
48
- def draw_points(image, points):
49
- if isinstance(image, np.ndarray):
50
- annotation = PIL.Image.fromarray(image)
51
- else:
52
- annotation = image.copy()
53
- draw = ImageDraw.Draw(annotation)
54
- w, h = annotation.size
55
- size = max(5, int(max(w, h) * POINT_SIZE))
56
- for x, y in points:
57
- draw.ellipse((x-size, y-size, x+size, y+size), fill="rgb(240, 82, 156)", outline=None)
58
- return annotation
59
-
60
-
61
- def get_message(
62
- images: list[Image.Image] | None,
63
- video_path: str | None,
64
- max_frames: int,
65
- frame_sample_mode: str,
66
- max_fps: int | None,
67
- sampling_fps: int | None,
68
- input_text: str,
69
- style: str,
70
- ):
71
- content = [
72
- dict(type="text", text=input_text, stye=style)
73
- ]
74
- if images:
75
- image_content = [
76
- dict(type="image", image=image)
77
- for image in images
78
- ]
79
- content.extend(image_content)
80
- if video_path:
81
- video_kwargs = {
82
- "num_frames": max_frames,
83
- "frame_sample_mode": frame_sample_mode,
84
- }
85
- if max_fps is not None:
86
- video_kwargs["max_fps"] = max_fps
87
- if sampling_fps is not None:
88
- video_kwargs["sampling_fps"] = sampling_fps
89
- video_content = dict(type="video", video=video_path, **video_kwargs)
90
- content.append(video_content)
91
-
92
- return [
93
- {
94
- "role": "user",
95
- "content": content,
96
- }
97
- ]
98
-
99
-
100
- def cast_float_dtype(t: torch.Tensor):
101
- if torch.is_floating_point(t):
102
- t = t.to(torch.bfloat16)
103
- return t
104
-
105
-
106
- def run_single_inference(*inputs, annotations=None):
107
- video_path, images, input_text, style, frame_sample_mode, max_frames, max_fps, sampling_fps, max_steps = inputs
108
- assert images is not None or video_path is not None, "Either images or video_path must be provided"
109
- assert images is None or video_path is None, "Both images and video_path cannot be provided at the same time"
110
- nimages = 0
111
- if images:
112
- images = [t[0] for t in images]
113
- nimages = len(images)
114
- logging.info(f"# of images: {nimages}")
115
-
116
- messages = get_message(
117
- images=images,
118
- video_path=video_path,
119
- max_frames=max_frames,
120
- frame_sample_mode=frame_sample_mode,
121
- max_fps=max_fps,
122
- sampling_fps=sampling_fps,
123
- input_text=input_text,
124
- style=style,
125
- )
126
- images, videos, video_kwargs = process_vision_info(messages)
127
- if videos:
128
- videos, video_metadatas = zip(*videos)
129
- videos, video_metadatas = list(videos), list(video_metadatas)
130
- logging.info(
131
- f"Videos: {videos[0].shape}, frame_sample_mode: {frame_sample_mode}, "
132
- f"max_frames: {max_frames}, max_fps: {max_fps}, sampling_fps: {sampling_fps}"
133
- )
134
- else:
135
- video_metadatas = None
136
- logging.info(f"Running inference for prompt: \"{input_text}\", style={style} steps={max_steps}")
137
- text = PROCESSOR.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
138
-
139
- inputs = PROCESSOR(
140
- images=images,
141
- videos=videos,
142
- video_metadata=video_metadatas,
143
- text=text,
144
- padding=True,
145
- return_tensors="pt",
146
- **video_kwargs,
147
- )
148
-
149
- if MODEL.config.dtype == torch.bfloat16:
150
- inputs = {k: cast_float_dtype(v.to(DEVICE)) for k, v in inputs.items()}
151
- else:
152
- inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
153
- with torch.inference_mode():
154
- if MODEL.config.dtype == torch.bfloat16:
155
- output = MODEL.generate(**inputs, max_new_tokens=max_steps)
156
- else:
157
- with torch.autocast("cuda", enabled=True, dtype=torch.bfloat16):
158
- output = MODEL.generate(**inputs, max_new_tokens=max_steps)
159
- prompts = output[0, :inputs['input_ids'].size(1)]
160
- prompt_text = PROCESSOR.decode(prompts, skip_special_tokens=False)
161
- prompt_text = postprocess_prompt(prompt_text)
162
- logging.info(f"hf prompt: {prompt_text}")
163
- generated_tokens = output[0, inputs['input_ids'].size(1):]
164
- generated_text = PROCESSOR.decode(generated_tokens, skip_special_tokens=True)
165
- logging.info(f"hf generated_text: {generated_text}")
166
- if annotations:
167
- if video_path is None and nimages == 1:
168
- w, h = images[0].size
169
- points = POINT_FORMATTER.extract_points(generated_text, w, h)
170
- if points:
171
- return generated_text, [draw_points(images[0], points)]
172
- else:
173
- return generated_text, []
174
- elif video_path is None and nimages > 1:
175
- w, h = [x.size[0] for x in images], [x.size[1] for x in images]
176
- points = POINT_FORMATTER.extract_multi_image_points(generated_text, w, h)
177
- if points:
178
- group_by_index = defaultdict(list)
179
- for ix, x, y in points:
180
- group_by_index[ix].append((x, y))
181
- out = []
182
- for ix, points in group_by_index.items():
183
- out.append(draw_points(images[ix-1], points))
184
- return generated_text, out
185
- else:
186
- return generated_text, []
187
- else:
188
- h, w = videos[0].shape[1:3]
189
- group_by_time = defaultdict(list)
190
- points = POINT_FORMATTER.extract_multi_image_points(generated_text, w, h)
191
- if points:
192
- for ts, x, y in points:
193
- group_by_time[ts].append((x, y))
194
- else:
195
- track = POINT_FORMATTER.extract_trajectories(generated_text, w, h, 30)
196
- for ex in track:
197
- group_by_time[ex["time"]] = [(x["x"], x["y"]) for x in ex["points"]]
198
- grouped_by_frame = defaultdict(list)
199
- for ts, points in group_by_time.items():
200
- timestamps = video_metadatas[0]["frames_indices"] / video_metadatas[0]["fps"]
201
- ix = int(np.argmin(np.abs(timestamps - ts)))
202
- grouped_by_frame[ix] += points
203
- out = []
204
- for ix, points in grouped_by_frame.items():
205
- out.append(draw_points(videos[0][ix], points))
206
- return generated_text, out
207
- return generated_text
208
-
209
-
210
- def main():
211
- parser = argparse.ArgumentParser()
212
- parser.add_argument("ckpt_home", type=str)
213
- parser.add_argument("--server_name")
214
- parser.add_argument("--default_max_tokens", type=int, default=2048)
215
- parser.add_argument("--cloudflare_tunnel", action="store_true")
216
- parser.add_argument("--original_ckpt_home", type=str, default=None)
217
- parser.add_argument("--annotations", action="store_true")
218
- parser.add_argument("--no_share", action="store_true")
219
- parser.add_argument("--port", type=int, default=7860)
220
- args = parser.parse_args()
221
-
222
- prepare_cli_environment()
223
-
224
- global DEVICE, MODEL, PROCESSOR
225
- if torch.cuda.is_available():
226
- DEVICE = torch.device("cuda")
227
- else:
228
- logging.warning("No GPU available, using CPU")
229
- DEVICE = torch.device("cpu")
230
- if MODEL is not None:
231
- MODEL.to(DEVICE)
232
-
233
- MODEL = AutoModelForImageTextToText.from_pretrained(
234
- args.ckpt_home,
235
- trust_remote_code=True,
236
- dtype="auto",
237
- device_map="auto",
238
- )
239
-
240
- PROCESSOR = AutoProcessor.from_pretrained(
241
- args.ckpt_home,
242
- trust_remote_code=True,
243
- dtype="auto",
244
- device_map="auto",
245
- padding_side="left",
246
- )
247
-
248
- if args.annotations:
249
- assert args.original_ckpt_home is not None, "original_ckpt_home must be provided when annotations are enabled"
250
- global POINT_FORMATTER
251
- model_cfg_path = resource_path(args.original_ckpt_home, "config.yaml")
252
- model_cfg = VideoOlmoConfig.load(model_cfg_path, key="model", validate_paths=False)
253
- preprocessor = model_cfg.build_preprocessor(for_inference=True, is_training=False)
254
- POINT_FORMATTER = preprocessor.formatter._point_formatter
255
-
256
- CSS = """
257
- #input_image image {
258
- object-fit: contain !important;
259
- }
260
- #input_video video {
261
- object-fit: contain !important;
262
- }
263
- """
264
-
265
- frame_sample_mode = PROCESSOR.video_processor.frame_sample_mode
266
- max_frames = PROCESSOR.video_processor.num_frames
267
- max_fps = PROCESSOR.video_processor.max_fps
268
- sampling_fps = PROCESSOR.video_processor.sampling_fps
269
-
270
- with gr.Blocks(css=CSS) as demo:
271
- gr.Markdown(
272
- f"""
273
- ## Molmo2 Demo
274
- Provide either a video or images and a prompt for question answering.
275
- """
276
- )
277
- with gr.Row():
278
- with gr.Tabs():
279
- with gr.TabItem("video"):
280
- video = gr.Video(label="Input Video", elem_id="input_video", height=MAX_VIDEO_HEIGHT)
281
- with gr.TabItem("image(s)"):
282
- images = gr.Gallery(label="Input Images", elem_id="input_image", type="pil", height=MAX_IMAGE_SIZE)
283
-
284
- with gr.Row():
285
- input_text = gr.Textbox(placeholder="Enter the prompt", label="Input text")
286
-
287
- with gr.Row():
288
- style = gr.Textbox(value="demo", label="style")
289
- frame_sample_mode = gr.Textbox(value=frame_sample_mode, label="frame_sample_mode")
290
- max_frames = gr.Number(value=max_frames, label="max_frames")
291
- max_fps = gr.Number(value=max_fps, label="max_fps")
292
- sampling_fps = gr.Number(value=sampling_fps, label="sampling_fps")
293
- max_tok_slider = gr.Slider(label="max_tokens", minimum=1, maximum=4096, step=1, value=args.default_max_tokens)
294
-
295
- with gr.Row():
296
- submit_button = gr.Button("Submit", scale=3)
297
- clear_all_button = gr.ClearButton(components=[video, images, input_text], value="Clear All", scale=1)
298
-
299
- with gr.Row():
300
- output_text = gr.Textbox(placeholder="Output text", label="Output text", lines=10)
301
-
302
- if args.annotations:
303
- with gr.Row():
304
- output_annotations = gr.Gallery(label="Annotations", height=MAX_IMAGE_SIZE)
305
- outputs = [output_text, output_annotations]
306
- fn = functools.partial(run_single_inference, annotations="points")
307
- else:
308
- fn = run_single_inference
309
- outputs = [output_text]
310
-
311
- submit_button.click(
312
- fn=fn,
313
- inputs=[video, images, input_text, style, frame_sample_mode, max_frames, max_fps, sampling_fps, max_tok_slider],
314
- outputs=outputs,
315
- )
316
-
317
- if args.cloudflare_tunnel:
318
- import cloudflared_tunnel
319
- with cloudflared_tunnel.run() as port:
320
- demo.queue().launch(
321
- share=False, show_error=True, max_threads=os.cpu_count() - 10, server_port=port,
322
- allowed_paths=ALLOWED_PATH
323
- )
324
- else:
325
- demo.queue().launch(
326
- server_name=args.server_name,
327
- share=not args.no_share, show_error=True, max_threads=os.cpu_count() - 10,
328
- server_port=args.port,
329
- allowed_paths=ALLOWED_PATH
330
- )
331
-
332
-
333
- if __name__ == "__main__":
334
- main()