Upload folder using huggingface_hub
Browse files- README.md +0 -164
- __init__.py +0 -0
- config.json +1 -1
- generation_config.json +1 -1
- modeling_molmo_point.py +3 -4
- modelling_molmo_point.py +0 -1914
- unified_demo.py +0 -334
README.md
DELETED
|
@@ -1,164 +0,0 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: apache-2.0
|
| 3 |
-
language:
|
| 4 |
-
- en
|
| 5 |
-
base_model:
|
| 6 |
-
- Qwen/Qwen3-8B
|
| 7 |
-
- google/siglip-so400m-patch14-384
|
| 8 |
-
pipeline_tag: image-text-to-text
|
| 9 |
-
library_name: transformers
|
| 10 |
-
tags:
|
| 11 |
-
- multimodal
|
| 12 |
-
- olmo
|
| 13 |
-
- molmo
|
| 14 |
-
- molmo2
|
| 15 |
-
---
|
| 16 |
-
|
| 17 |
-
# MolmoPoint-8B
|
| 18 |
-
MolmoPoint-8B is a fully-open VLM developed by the Allen Institute for AI (Ai2) that support image, video and multi-image understanding and grounding.
|
| 19 |
-
It has novel pointing mechansim that improves image pointing, video pointing, and video tracking, see our technical report for details.
|
| 20 |
-
|
| 21 |
-
Note the huggingface MolmoPoint model does not support training, see our github repo for the training code.
|
| 22 |
-
|
| 23 |
-
Quick links:
|
| 24 |
-
- 💬 [Code](https://github.com/allenai/molmo2)
|
| 25 |
-
- 📂 [All Models](https://huggingface.co/collections/allenai/molmo_point)
|
| 26 |
-
- 📃 [Paper](https://allenai.org/papers/molmo_point)
|
| 27 |
-
- 📝 [Blog](https://allenai.org/blog/molmo_point)
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
## Quick Start
|
| 31 |
-
|
| 32 |
-
### Setup Conda Environment
|
| 33 |
-
```
|
| 34 |
-
conda create --name transformers4571 python=3.11
|
| 35 |
-
conda activate transformers4571
|
| 36 |
-
pip install transformers==4.57.1
|
| 37 |
-
pip install torch pillow einops torchvision accelerate decord2
|
| 38 |
-
```
|
| 39 |
-
|
| 40 |
-
## Inference
|
| 41 |
-
We recommend running MolmoPoint with `logits_processor=model.build_logit_processor_from_inputs(model_inputs)`
|
| 42 |
-
to enforce points tokens are generated in a valid way.
|
| 43 |
-
|
| 44 |
-
In MolmoPoint, instead of coordinates points will be generated as a series of special
|
| 45 |
-
tokens, to decode the tokens back into points requires some additional
|
| 46 |
-
metadata from the preprocessor.
|
| 47 |
-
The metadata is returned by the preprocessor using the `return_pointing_metadata` flag.
|
| 48 |
-
Then `model.extract_image_points` and `model.extract_video_points` do the decoding, they
|
| 49 |
-
return a list of ({image_id|timestamps}, object_id, pixel_x, pixel_y) output points.
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
### Image Pointing Example:
|
| 53 |
-
|
| 54 |
-
```python
|
| 55 |
-
from transformers import AutoProcessor, AutoModelForImageTextToText
|
| 56 |
-
from PIL import Image
|
| 57 |
-
import requests
|
| 58 |
-
import torch
|
| 59 |
-
|
| 60 |
-
checkpoint_dir = "allenai/MolmoPoint-8B" # or path to a converted HF checkpoint
|
| 61 |
-
|
| 62 |
-
model = AutoModelForImageTextToText.from_pretrained(
|
| 63 |
-
checkpoint_dir,
|
| 64 |
-
trust_remote_code=True,
|
| 65 |
-
dtype="auto",
|
| 66 |
-
device_map="auto",
|
| 67 |
-
)
|
| 68 |
-
|
| 69 |
-
processor = AutoProcessor.from_pretrained(
|
| 70 |
-
checkpoint_dir,
|
| 71 |
-
trust_remote_code=True,
|
| 72 |
-
padding_side="left",
|
| 73 |
-
)
|
| 74 |
-
|
| 75 |
-
image_messages = [
|
| 76 |
-
{
|
| 77 |
-
"role": "user",
|
| 78 |
-
"content": [
|
| 79 |
-
{"type": "text", "text": "Point to the eyes"},
|
| 80 |
-
{"type": "image", "image": Image.open(requests.get(
|
| 81 |
-
"https://picsum.photos/id/237/536/354", stream=True
|
| 82 |
-
).raw)},
|
| 83 |
-
]
|
| 84 |
-
}
|
| 85 |
-
]
|
| 86 |
-
|
| 87 |
-
inputs = processor.apply_chat_template(
|
| 88 |
-
image_messages,
|
| 89 |
-
tokenize=True,
|
| 90 |
-
add_generation_prompt=True,
|
| 91 |
-
return_tensors="pt",
|
| 92 |
-
return_dict=True,
|
| 93 |
-
padding=True,
|
| 94 |
-
return_pointing_metadata=True
|
| 95 |
-
)
|
| 96 |
-
metadata = inputs.pop("metadata")
|
| 97 |
-
inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
| 98 |
-
|
| 99 |
-
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
| 100 |
-
output = model.generate(
|
| 101 |
-
**inputs,
|
| 102 |
-
logits_processor=model.build_logit_processor_from_inputs(inputs),
|
| 103 |
-
max_new_tokens=200
|
| 104 |
-
)
|
| 105 |
-
|
| 106 |
-
generated_tokens = output[:, inputs["input_ids"].size(1):]
|
| 107 |
-
generated_text = processor.post_process_image_text_to_text(generated_tokens, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
|
| 108 |
-
points = model.extract_image_points(
|
| 109 |
-
generated_text,
|
| 110 |
-
metadata["token_pooling"],
|
| 111 |
-
metadata["subpatch_mapping"],
|
| 112 |
-
metadata["image_sizes"]
|
| 113 |
-
)
|
| 114 |
-
print(points)
|
| 115 |
-
```
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
### Video Pointing Example:
|
| 119 |
-
```python
|
| 120 |
-
video_path = "https://storage.googleapis.com/oe-training-public/demo_videos/many_penguins.mp4"
|
| 121 |
-
video_messages = [
|
| 122 |
-
{
|
| 123 |
-
"role": "user",
|
| 124 |
-
"content": [
|
| 125 |
-
dict(type="text", text="Point to the penguins"),
|
| 126 |
-
dict(type="video", video=video_path),
|
| 127 |
-
]
|
| 128 |
-
}
|
| 129 |
-
]
|
| 130 |
-
|
| 131 |
-
inputs = processor.apply_chat_template(
|
| 132 |
-
video_messages,
|
| 133 |
-
tokenize=True,
|
| 134 |
-
add_generation_prompt=True,
|
| 135 |
-
return_tensors="pt",
|
| 136 |
-
return_dict=True,
|
| 137 |
-
padding=True,
|
| 138 |
-
return_pointing_metadata=True
|
| 139 |
-
)
|
| 140 |
-
|
| 141 |
-
inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
| 142 |
-
|
| 143 |
-
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
| 144 |
-
output = model.generate(
|
| 145 |
-
**inputs,
|
| 146 |
-
logits_processor=model.build_logit_processor_from_inputs(inputs)
|
| 147 |
-
max_new_tokens=200
|
| 148 |
-
)
|
| 149 |
-
|
| 150 |
-
generated_tokens = output[:, inputs['input_ids'].size(1):]
|
| 151 |
-
generated_text = processor.post_process_image_text_to_text(generated_tokens, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
|
| 152 |
-
points = model.extract_video_points(
|
| 153 |
-
generated_text,
|
| 154 |
-
metadata["token_pooling"],
|
| 155 |
-
metadata["subpatch_mapping"],
|
| 156 |
-
metadata["timestamps"],
|
| 157 |
-
metadata["video_size"]
|
| 158 |
-
)
|
| 159 |
-
print(points)
|
| 160 |
-
```
|
| 161 |
-
|
| 162 |
-
## License and Use
|
| 163 |
-
|
| 164 |
-
This model is licensed under Apache 2.0. It is intended for research and educational use in accordance with Ai2’s Responsible Use Guidelines. This model is trained on third party datasets that are subject to academic and non-commercial research use only. Please review the sources to determine if this model is appropriate for your use case.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__init__.py
DELETED
|
File without changes
|
config.json
CHANGED
|
@@ -83,7 +83,7 @@
|
|
| 83 |
"tie_word_embeddings": false,
|
| 84 |
"token_prediction_rotary": "one_d",
|
| 85 |
"token_prediction_rotary_theta": 50000.0,
|
| 86 |
-
"transformers_version": "4.57.
|
| 87 |
"use_cache": true,
|
| 88 |
"use_frame_special_tokens": true,
|
| 89 |
"vit_config": {
|
|
|
|
| 83 |
"tie_word_embeddings": false,
|
| 84 |
"token_prediction_rotary": "one_d",
|
| 85 |
"token_prediction_rotary_theta": 50000.0,
|
| 86 |
+
"transformers_version": "4.57.1",
|
| 87 |
"use_cache": true,
|
| 88 |
"use_frame_special_tokens": true,
|
| 89 |
"vit_config": {
|
generation_config.json
CHANGED
|
@@ -2,5 +2,5 @@
|
|
| 2 |
"bos_token_id": 151645,
|
| 3 |
"eos_token_id": 151645,
|
| 4 |
"pad_token_id": 151643,
|
| 5 |
-
"transformers_version": "4.57.
|
| 6 |
}
|
|
|
|
| 2 |
"bos_token_id": 151645,
|
| 3 |
"eos_token_id": 151645,
|
| 4 |
"pad_token_id": 151643,
|
| 5 |
+
"transformers_version": "4.57.1"
|
| 6 |
}
|
modeling_molmo_point.py
CHANGED
|
@@ -796,7 +796,7 @@ class MolmoPointConnector(nn.Module):
|
|
| 796 |
def extract_image_points(output_text, pooling, mappings, no_more_points_class, location, image_sizes):
|
| 797 |
"""Extract points from MolmoPoint image output text
|
| 798 |
|
| 799 |
-
return points: [n_points, 4] array of (
|
| 800 |
"""
|
| 801 |
if len(mappings) != len(image_sizes):
|
| 802 |
raise ValueError("Mapping and image sizes must have the same length")
|
|
@@ -831,7 +831,7 @@ def extract_video_points(output_text, pooling, mapping, timestamps, no_more_poin
|
|
| 831 |
"""
|
| 832 |
Extract points from MolmoPoint video output text
|
| 833 |
|
| 834 |
-
return points: [n_points, 4] array of (
|
| 835 |
"""
|
| 836 |
extracted_points = []
|
| 837 |
for vit_patch_id, location_id, example_id in get_subpatch_ids(output_text, pooling, no_more_points_class):
|
|
@@ -1263,7 +1263,7 @@ class MolmoPointModel(MolmoPointPreTrainedModel):
|
|
| 1263 |
**kwargs: Unpack[TransformersKwargs],
|
| 1264 |
) -> Union[tuple, MolmoPointModelOutputWithPast]:
|
| 1265 |
"""
|
| 1266 |
-
last_point_patch_id: The patch id the last
|
| 1267 |
"""
|
| 1268 |
|
| 1269 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
@@ -1289,7 +1289,6 @@ class MolmoPointModel(MolmoPointPreTrainedModel):
|
|
| 1289 |
raise NotImplementedError("Custom inputs_embeds is not implemented yet")
|
| 1290 |
|
| 1291 |
input_ids = input_ids * (input_ids != -1).to(input_ids.dtype)
|
| 1292 |
-
print(f"ON: {input_ids[0, -1]}")
|
| 1293 |
|
| 1294 |
if image_data is not None:
|
| 1295 |
can_point = True
|
|
|
|
| 796 |
def extract_image_points(output_text, pooling, mappings, no_more_points_class, location, image_sizes):
|
| 797 |
"""Extract points from MolmoPoint image output text
|
| 798 |
|
| 799 |
+
return points: [n_points, 4] array of (object_id, image_num, x, y) points
|
| 800 |
"""
|
| 801 |
if len(mappings) != len(image_sizes):
|
| 802 |
raise ValueError("Mapping and image sizes must have the same length")
|
|
|
|
| 831 |
"""
|
| 832 |
Extract points from MolmoPoint video output text
|
| 833 |
|
| 834 |
+
return points: [n_points, 4] array of (object_id, timestamp, x, y) points
|
| 835 |
"""
|
| 836 |
extracted_points = []
|
| 837 |
for vit_patch_id, location_id, example_id in get_subpatch_ids(output_text, pooling, no_more_points_class):
|
|
|
|
| 1263 |
**kwargs: Unpack[TransformersKwargs],
|
| 1264 |
) -> Union[tuple, MolmoPointModelOutputWithPast]:
|
| 1265 |
"""
|
| 1266 |
+
last_point_patch_id: The patch id the last generated point pointed to
|
| 1267 |
"""
|
| 1268 |
|
| 1269 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
|
| 1289 |
raise NotImplementedError("Custom inputs_embeds is not implemented yet")
|
| 1290 |
|
| 1291 |
input_ids = input_ids * (input_ids != -1).to(input_ids.dtype)
|
|
|
|
| 1292 |
|
| 1293 |
if image_data is not None:
|
| 1294 |
can_point = True
|
modelling_molmo_point.py
DELETED
|
@@ -1,1914 +0,0 @@
|
|
| 1 |
-
import math
|
| 2 |
-
from copy import deepcopy
|
| 3 |
-
from dataclasses import dataclass
|
| 4 |
-
from typing import Optional, Union, Callable
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
from torch import nn
|
| 8 |
-
|
| 9 |
-
from torch.nn import functional as F
|
| 10 |
-
|
| 11 |
-
from transformers.models.auto import AutoModelForImageTextToText
|
| 12 |
-
from transformers.activations import ACT2FN
|
| 13 |
-
from transformers.configuration_utils import PretrainedConfig
|
| 14 |
-
from transformers.cache_utils import Cache, DynamicCache
|
| 15 |
-
from transformers.generation import GenerationMixin
|
| 16 |
-
from transformers.masking_utils import create_causal_mask, create_masks_for_generate
|
| 17 |
-
from transformers.modeling_flash_attention_utils import (
|
| 18 |
-
_flash_attention_forward,
|
| 19 |
-
FlashAttentionKwargs,
|
| 20 |
-
flash_attn_supports_top_left_mask,
|
| 21 |
-
)
|
| 22 |
-
from transformers.modeling_layers import GradientCheckpointingLayer
|
| 23 |
-
from transformers.modeling_outputs import (
|
| 24 |
-
BaseModelOutputWithPast,
|
| 25 |
-
)
|
| 26 |
-
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 27 |
-
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 28 |
-
from transformers.processing_utils import Unpack
|
| 29 |
-
from transformers.utils import (
|
| 30 |
-
ModelOutput,
|
| 31 |
-
TransformersKwargs,
|
| 32 |
-
can_return_tuple,
|
| 33 |
-
logging,
|
| 34 |
-
)
|
| 35 |
-
|
| 36 |
-
from .configuration_molmo2 import Molmo2VitConfig, Molmo2TextConfig, Molmo2AdapterConfig
|
| 37 |
-
from .configuration_molmo_point import MolmoPointConfig, MolmoPointAdapterConfig
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
logger = logging.get_logger(__name__)
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
@dataclass
|
| 44 |
-
class MolmoPointCausalLMOutputWithPast(ModelOutput):
|
| 45 |
-
"""
|
| 46 |
-
Base class for MolmoPoint causal language model (or autoregressive) outputs.
|
| 47 |
-
|
| 48 |
-
Args:
|
| 49 |
-
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 50 |
-
Language modeling loss (for next-token prediction).
|
| 51 |
-
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 52 |
-
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 53 |
-
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 54 |
-
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
| 55 |
-
|
| 56 |
-
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
| 57 |
-
`past_key_values` input) to speed up sequential decoding.
|
| 58 |
-
image_hidden_states (`torch.FloatTensor`, *optional*):
|
| 59 |
-
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
| 60 |
-
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
| 61 |
-
"""
|
| 62 |
-
|
| 63 |
-
loss: Optional[torch.FloatTensor] = None
|
| 64 |
-
logits: Optional[torch.FloatTensor] = None
|
| 65 |
-
past_key_values: Optional[Cache] = None
|
| 66 |
-
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 67 |
-
attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 68 |
-
image_hidden_states: Optional[torch.FloatTensor] = None
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
@dataclass
|
| 72 |
-
class MolmoPointModelOutputWithPast(BaseModelOutputWithPast):
|
| 73 |
-
"""
|
| 74 |
-
Base class for Molmo2 outputs, with hidden states and attentions.
|
| 75 |
-
|
| 76 |
-
Args:
|
| 77 |
-
image_hidden_states (`torch.FloatTensor`, *optional*):
|
| 78 |
-
A `torch.FloatTensor` of size `(batch_num_patches, hidden_size)`.
|
| 79 |
-
image_hidden_states of the model produced by the vision backbone
|
| 80 |
-
"""
|
| 81 |
-
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 82 |
-
past_key_values: Optional[Cache] = None
|
| 83 |
-
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 84 |
-
attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 85 |
-
image_hidden_states: Optional[torch.FloatTensor] = None
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
class ViTMLP(nn.Module):
|
| 90 |
-
def __init__(self, dim: int, hidden_dim: int, hidden_act: str, device: Union[str, torch.device] = None):
|
| 91 |
-
super().__init__()
|
| 92 |
-
self.w1 = nn.Linear(dim, hidden_dim, bias=True, device=device)
|
| 93 |
-
self.act = ACT2FN[hidden_act]
|
| 94 |
-
self.w2 = nn.Linear(hidden_dim, dim, bias=True, device=device)
|
| 95 |
-
|
| 96 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 97 |
-
return self.w2(self.act(self.w1(x)))
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
class ViTMultiHeadDotProductAttention(nn.Module):
|
| 101 |
-
def __init__(
|
| 102 |
-
self,
|
| 103 |
-
hidden_size: int,
|
| 104 |
-
num_heads: int,
|
| 105 |
-
num_key_value_heads: int,
|
| 106 |
-
head_dim: int,
|
| 107 |
-
use_bias: bool = True,
|
| 108 |
-
input_dim: Optional[int] = None,
|
| 109 |
-
float32_attention: bool = True,
|
| 110 |
-
attention_dropout: float = 0.0,
|
| 111 |
-
residual_dropout: float = 0.0,
|
| 112 |
-
device: Union[str, torch.device] = None,
|
| 113 |
-
attn_implementation: str = "eager",
|
| 114 |
-
out_layer: bool=True
|
| 115 |
-
):
|
| 116 |
-
super().__init__()
|
| 117 |
-
|
| 118 |
-
self.hidden_size = hidden_size
|
| 119 |
-
self.num_heads = num_heads
|
| 120 |
-
self.head_dim = head_dim
|
| 121 |
-
self.num_key_value_heads = num_key_value_heads
|
| 122 |
-
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 123 |
-
self.attn_implementation = attn_implementation
|
| 124 |
-
self.is_causal = False
|
| 125 |
-
|
| 126 |
-
input_dim = input_dim or hidden_size
|
| 127 |
-
|
| 128 |
-
self.wq = nn.Linear(
|
| 129 |
-
input_dim,
|
| 130 |
-
self.num_heads * self.head_dim,
|
| 131 |
-
bias=use_bias,
|
| 132 |
-
device=device,
|
| 133 |
-
)
|
| 134 |
-
self.wk = nn.Linear(
|
| 135 |
-
input_dim,
|
| 136 |
-
self.num_key_value_heads * self.head_dim,
|
| 137 |
-
bias=use_bias,
|
| 138 |
-
device=device,
|
| 139 |
-
)
|
| 140 |
-
self.wv = nn.Linear(
|
| 141 |
-
input_dim,
|
| 142 |
-
self.num_key_value_heads * self.head_dim,
|
| 143 |
-
bias=use_bias,
|
| 144 |
-
device=device,
|
| 145 |
-
)
|
| 146 |
-
if out_layer:
|
| 147 |
-
self.wo = nn.Linear(
|
| 148 |
-
self.num_heads * self.head_dim,
|
| 149 |
-
self.hidden_size,
|
| 150 |
-
)
|
| 151 |
-
else:
|
| 152 |
-
self.w0 = None
|
| 153 |
-
self.float32_attention = float32_attention
|
| 154 |
-
self.attention_dropout = attention_dropout
|
| 155 |
-
self.residual_dropout = nn.Dropout(residual_dropout)
|
| 156 |
-
|
| 157 |
-
def _split_heads(self, hidden_states, num_heads) -> torch.Tensor:
|
| 158 |
-
return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
|
| 159 |
-
|
| 160 |
-
def _merge_heads(self, hidden_states) -> torch.Tensor:
|
| 161 |
-
return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,))
|
| 162 |
-
|
| 163 |
-
def forward(
|
| 164 |
-
self,
|
| 165 |
-
inputs_q: torch.Tensor,
|
| 166 |
-
inputs_kv: Optional[torch.Tensor] = None,
|
| 167 |
-
attn_mask: Optional[torch.Tensor] = None,
|
| 168 |
-
) -> torch.Tensor:
|
| 169 |
-
|
| 170 |
-
if inputs_kv is not None:
|
| 171 |
-
inputs_k = inputs_kv
|
| 172 |
-
inputs_v = inputs_kv
|
| 173 |
-
else:
|
| 174 |
-
inputs_k = inputs_q
|
| 175 |
-
inputs_v = inputs_q
|
| 176 |
-
|
| 177 |
-
xq, xk, xv = self.wq(inputs_q), self.wk(inputs_k), self.wv(inputs_v)
|
| 178 |
-
|
| 179 |
-
xq = self._split_heads(xq, self.num_heads)
|
| 180 |
-
xk = self._split_heads(xk, self.num_key_value_heads)
|
| 181 |
-
xv = self._split_heads(xv, self.num_key_value_heads)
|
| 182 |
-
|
| 183 |
-
if self.num_heads != self.num_key_value_heads:
|
| 184 |
-
xk = xk.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads)
|
| 185 |
-
xv = xv.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads)
|
| 186 |
-
|
| 187 |
-
og_dtype = xq.dtype
|
| 188 |
-
|
| 189 |
-
if self.float32_attention:
|
| 190 |
-
xq = xq.to(torch.float)
|
| 191 |
-
xk = xk.to(torch.float)
|
| 192 |
-
|
| 193 |
-
dropout_p = 0.0 if not self.training else self.attention_dropout
|
| 194 |
-
|
| 195 |
-
if self.attn_implementation == "eager":
|
| 196 |
-
attn_weights = torch.einsum("...qhd,...khd->...hqk", xq / math.sqrt(xq.size(-1)), xk)
|
| 197 |
-
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(xq.dtype)
|
| 198 |
-
attn_weights = F.dropout(
|
| 199 |
-
attn_weights,
|
| 200 |
-
p=dropout_p,
|
| 201 |
-
training=self.training
|
| 202 |
-
)
|
| 203 |
-
attn_output = torch.einsum("...hqk,...khd->...qhd", attn_weights.to(xv.dtype), xv)
|
| 204 |
-
|
| 205 |
-
elif self.attn_implementation == "sdpa":
|
| 206 |
-
if not torch.is_autocast_enabled():
|
| 207 |
-
xv = xv.to(torch.float)
|
| 208 |
-
|
| 209 |
-
attn_output = F.scaled_dot_product_attention(
|
| 210 |
-
xq.transpose(1, 2).contiguous(),
|
| 211 |
-
xk.transpose(1, 2).contiguous(),
|
| 212 |
-
xv.transpose(1, 2).contiguous(),
|
| 213 |
-
attn_mask=attn_mask,
|
| 214 |
-
is_causal=False,
|
| 215 |
-
dropout_p=dropout_p,
|
| 216 |
-
).transpose(1, 2)
|
| 217 |
-
|
| 218 |
-
elif self.attn_implementation == "flash_attention_2":
|
| 219 |
-
if xq.dtype == torch.float32:
|
| 220 |
-
if torch.is_autocast_enabled():
|
| 221 |
-
target_dtype = torch.get_autocast_gpu_dtype()
|
| 222 |
-
else:
|
| 223 |
-
target_dtype = self.wq.weight.dtype
|
| 224 |
-
attn_output = _flash_attention_forward(
|
| 225 |
-
xq,
|
| 226 |
-
xk,
|
| 227 |
-
xv,
|
| 228 |
-
attention_mask=attn_mask,
|
| 229 |
-
query_length=inputs_q.shape[1],
|
| 230 |
-
is_causal=False,
|
| 231 |
-
dropout=dropout_p,
|
| 232 |
-
softmax_scale=xq.shape[-1] ** -0.5,
|
| 233 |
-
use_top_left_mask=flash_attn_supports_top_left_mask(),
|
| 234 |
-
target_dtype=target_dtype,
|
| 235 |
-
implementation=self.attn_implementation,
|
| 236 |
-
)
|
| 237 |
-
else:
|
| 238 |
-
raise ValueError(f"Attention implementation {self.attn_implementation} not supported")
|
| 239 |
-
|
| 240 |
-
attn_output = attn_output.to(og_dtype)
|
| 241 |
-
attn_output = self._merge_heads(attn_output)
|
| 242 |
-
if self.wo is not None:
|
| 243 |
-
attn_output = self.wo(attn_output)
|
| 244 |
-
attn_output = self.residual_dropout(attn_output)
|
| 245 |
-
|
| 246 |
-
return attn_output
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
class Molmo2VisionBlock(nn.Module):
|
| 250 |
-
|
| 251 |
-
def __init__(self, config: Molmo2VitConfig, device: Union[str, torch.device] = None):
|
| 252 |
-
super().__init__()
|
| 253 |
-
self.attention = ViTMultiHeadDotProductAttention(
|
| 254 |
-
hidden_size=config.hidden_size,
|
| 255 |
-
num_heads=config.num_attention_heads,
|
| 256 |
-
num_key_value_heads=config.num_key_value_heads,
|
| 257 |
-
head_dim=config.head_dim,
|
| 258 |
-
float32_attention=config.float32_attention,
|
| 259 |
-
attention_dropout=config.attention_dropout,
|
| 260 |
-
residual_dropout=config.residual_dropout,
|
| 261 |
-
device=device,
|
| 262 |
-
attn_implementation=config._attn_implementation,
|
| 263 |
-
)
|
| 264 |
-
self.feed_forward = ViTMLP(config.hidden_size, config.intermediate_size, config.hidden_act, device=device)
|
| 265 |
-
self.attention_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device)
|
| 266 |
-
self.ffn_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device)
|
| 267 |
-
|
| 268 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 269 |
-
x = x + self.attention(self.attention_norm(x))
|
| 270 |
-
x = x + self.feed_forward(self.ffn_norm(x))
|
| 271 |
-
return x
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
class Molmo2VisionBlockCollection(nn.Module):
|
| 275 |
-
|
| 276 |
-
def __init__(self, config: Molmo2VitConfig, device: Union[str, torch.device] = None):
|
| 277 |
-
super().__init__()
|
| 278 |
-
self.conifg = config
|
| 279 |
-
self.resblocks = nn.ModuleList([
|
| 280 |
-
Molmo2VisionBlock(config, device) for _ in range(config.num_hidden_layers)
|
| 281 |
-
])
|
| 282 |
-
|
| 283 |
-
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
| 284 |
-
hidden_states = []
|
| 285 |
-
for r in self.resblocks:
|
| 286 |
-
x = r(x)
|
| 287 |
-
hidden_states.append(x)
|
| 288 |
-
return hidden_states
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
class Molmo2VisionTransformer(nn.Module):
|
| 292 |
-
|
| 293 |
-
def __init__(self, config: Molmo2VitConfig, device: Union[str, torch.device] = None):
|
| 294 |
-
super().__init__()
|
| 295 |
-
self.config = config
|
| 296 |
-
|
| 297 |
-
# positional embeddings
|
| 298 |
-
self.scale = config.hidden_size ** -0.5
|
| 299 |
-
self.num_prefix_tokens: int = 0 # no class embeddings
|
| 300 |
-
self.positional_embedding = nn.Parameter(
|
| 301 |
-
torch.zeros(config.image_num_pos, config.hidden_size, device=device),
|
| 302 |
-
)
|
| 303 |
-
|
| 304 |
-
image_patch_size = config.image_patch_size
|
| 305 |
-
self.patch_embedding = nn.Linear(
|
| 306 |
-
image_patch_size * image_patch_size * 3,
|
| 307 |
-
config.hidden_size,
|
| 308 |
-
bias=True,
|
| 309 |
-
device=device,
|
| 310 |
-
)
|
| 311 |
-
|
| 312 |
-
self.transformer = Molmo2VisionBlockCollection(config, device)
|
| 313 |
-
|
| 314 |
-
def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
|
| 315 |
-
pos_emb = self.positional_embedding
|
| 316 |
-
|
| 317 |
-
pos_emb = pos_emb.reshape(
|
| 318 |
-
(int(math.sqrt(pos_emb.shape[0])), int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1])
|
| 319 |
-
)
|
| 320 |
-
|
| 321 |
-
(patch_num_0, patch_num_1) = patch_num
|
| 322 |
-
|
| 323 |
-
if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1:
|
| 324 |
-
# Dervied from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
| 325 |
-
# antialias: default True in jax.image.resize
|
| 326 |
-
pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
|
| 327 |
-
pos_emb = F.interpolate(
|
| 328 |
-
pos_emb, size=(patch_num_0, patch_num_1), mode="bicubic", align_corners=False, antialias=True,
|
| 329 |
-
)
|
| 330 |
-
pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0)
|
| 331 |
-
|
| 332 |
-
pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1])
|
| 333 |
-
x = x + pos_emb[None, :, :].to(x.dtype)
|
| 334 |
-
return x
|
| 335 |
-
|
| 336 |
-
def forward(self, x: torch.Tensor, patch_num: int = None) -> list[torch.Tensor]:
|
| 337 |
-
"""
|
| 338 |
-
: param x: (batch_size, num_patch, n_pixels)
|
| 339 |
-
"""
|
| 340 |
-
if patch_num is None:
|
| 341 |
-
patch_num = self.config.image_num_patch
|
| 342 |
-
|
| 343 |
-
B, N, D = x.shape
|
| 344 |
-
|
| 345 |
-
x = self.patch_embedding(x)
|
| 346 |
-
|
| 347 |
-
# class embeddings and positional embeddings
|
| 348 |
-
x = self.add_pos_emb(x, patch_num)
|
| 349 |
-
|
| 350 |
-
hidden_states = self.transformer(x)
|
| 351 |
-
return hidden_states
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
class ImageProjectorMLP(nn.Module):
|
| 355 |
-
|
| 356 |
-
def __init__(
|
| 357 |
-
self,
|
| 358 |
-
input_dim: int,
|
| 359 |
-
hidden_dim: int,
|
| 360 |
-
output_dim: int,
|
| 361 |
-
hidden_act: str,
|
| 362 |
-
device: Union[str, torch.device] = None,
|
| 363 |
-
):
|
| 364 |
-
super().__init__()
|
| 365 |
-
self.w1 = nn.Linear(input_dim, hidden_dim, bias=False, device=device)
|
| 366 |
-
self.w2 = nn.Linear(hidden_dim, output_dim, bias=False, device=device)
|
| 367 |
-
self.w3 = nn.Linear(input_dim, hidden_dim, bias=False, device=device)
|
| 368 |
-
self.act = ACT2FN[hidden_act]
|
| 369 |
-
|
| 370 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 371 |
-
return self.w2(self.act(self.w1(x)) * self.w3(x))
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
class Molmo2VisionBackbone(nn.Module):
|
| 375 |
-
def __init__(self, vit_config: Molmo2VitConfig, adapter_config: Molmo2AdapterConfig):
|
| 376 |
-
super().__init__()
|
| 377 |
-
self.vit_config = vit_config
|
| 378 |
-
self.adapter_config = adapter_config
|
| 379 |
-
|
| 380 |
-
self.vit_layers = []
|
| 381 |
-
for layer in adapter_config.vit_layers:
|
| 382 |
-
if layer >= 0:
|
| 383 |
-
self.vit_layers.append(layer)
|
| 384 |
-
else:
|
| 385 |
-
self.vit_layers.append(layer + vit_config.num_hidden_layers)
|
| 386 |
-
|
| 387 |
-
last_layer_needed = max(self.vit_layers) + 1
|
| 388 |
-
if last_layer_needed < vit_config.num_hidden_layers:
|
| 389 |
-
new_vit_config = deepcopy(vit_config)
|
| 390 |
-
new_vit_config.num_hidden_layers = last_layer_needed
|
| 391 |
-
self.image_vit = Molmo2VisionTransformer(new_vit_config)
|
| 392 |
-
else:
|
| 393 |
-
self.image_vit = Molmo2VisionTransformer(vit_config)
|
| 394 |
-
|
| 395 |
-
self.num_prefix_tokens: int = self.image_vit.num_prefix_tokens
|
| 396 |
-
|
| 397 |
-
pool_dim = vit_config.hidden_size * len(adapter_config.vit_layers)
|
| 398 |
-
self.image_pooling_2d = ViTMultiHeadDotProductAttention(
|
| 399 |
-
hidden_size=adapter_config.hidden_size,
|
| 400 |
-
num_heads=adapter_config.num_attention_heads,
|
| 401 |
-
num_key_value_heads=adapter_config.num_key_value_heads,
|
| 402 |
-
head_dim=adapter_config.head_dim,
|
| 403 |
-
input_dim=pool_dim,
|
| 404 |
-
float32_attention=adapter_config.float32_attention,
|
| 405 |
-
attention_dropout=adapter_config.attention_dropout,
|
| 406 |
-
residual_dropout=adapter_config.residual_dropout,
|
| 407 |
-
attn_implementation=adapter_config._attn_implementation,
|
| 408 |
-
)
|
| 409 |
-
self.image_projector = ImageProjectorMLP(
|
| 410 |
-
adapter_config.hidden_size,
|
| 411 |
-
adapter_config.intermediate_size,
|
| 412 |
-
adapter_config.text_hidden_size,
|
| 413 |
-
adapter_config.hidden_act,
|
| 414 |
-
)
|
| 415 |
-
self.image_feature_dropout = nn.Dropout(adapter_config.image_feature_dropout)
|
| 416 |
-
|
| 417 |
-
def encode_image(self, images: torch.Tensor) -> torch.Tensor:
|
| 418 |
-
"""
|
| 419 |
-
: param images: (batch_size, num_crops, num_patch, n_pixels)
|
| 420 |
-
"""
|
| 421 |
-
B, T, N, D = images.shape
|
| 422 |
-
images = images.view(B * T, N, D)
|
| 423 |
-
image_features = self.image_vit(images)
|
| 424 |
-
|
| 425 |
-
features = []
|
| 426 |
-
for layer in self.vit_layers:
|
| 427 |
-
features.append(image_features[layer])
|
| 428 |
-
image_features = torch.cat(features, dim=-1)
|
| 429 |
-
|
| 430 |
-
if self.num_prefix_tokens > 0:
|
| 431 |
-
image_features = image_features[:, 1:]
|
| 432 |
-
image_features = image_features.view(B, T, N, -1)
|
| 433 |
-
return image_features
|
| 434 |
-
|
| 435 |
-
@property
|
| 436 |
-
def dtype(self) -> torch.dtype:
|
| 437 |
-
return self.image_vit.patch_embedding.weight.dtype
|
| 438 |
-
|
| 439 |
-
@property
|
| 440 |
-
def device(self) -> torch.device:
|
| 441 |
-
return self.image_vit.patch_embedding.weight.device
|
| 442 |
-
|
| 443 |
-
def forward(
|
| 444 |
-
self,
|
| 445 |
-
images: torch.Tensor,
|
| 446 |
-
pooled_patches_idx: torch.Tensor,
|
| 447 |
-
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 448 |
-
|
| 449 |
-
# image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
|
| 450 |
-
batch_size, num_image = images.shape[:2]
|
| 451 |
-
images = images.to(device=self.device, dtype=self.dtype)
|
| 452 |
-
image_features = self.encode_image(images)
|
| 453 |
-
|
| 454 |
-
image_features = self.image_feature_dropout(image_features)
|
| 455 |
-
dim = image_features.shape[-1]
|
| 456 |
-
valid = pooled_patches_idx >= 0
|
| 457 |
-
valid_token = torch.any(valid, -1)
|
| 458 |
-
|
| 459 |
-
# Use `pooled_patches_idx` to arange the features for image pooling
|
| 460 |
-
batch_idx = torch.arange(pooled_patches_idx.shape[0], dtype=torch.long, device=pooled_patches_idx.device)
|
| 461 |
-
batch_idx = torch.tile(batch_idx.view(batch_size, 1, 1), [1, pooled_patches_idx.shape[1], pooled_patches_idx.shape[2]])
|
| 462 |
-
|
| 463 |
-
# Now [batch, num_high_res_features, pool_dim, dim]
|
| 464 |
-
to_pool = image_features.reshape(batch_size, -1, dim)[batch_idx, torch.clip(pooled_patches_idx, 0)]
|
| 465 |
-
to_pool = to_pool * valid.to(self.dtype)[:, :, :, None]
|
| 466 |
-
to_pool = to_pool.reshape([-1, pooled_patches_idx.shape[-1], dim])
|
| 467 |
-
if self.adapter_config.pooling_attention_mask:
|
| 468 |
-
attn_mask = valid.reshape([-1, 1, 1, valid.shape[-1]])
|
| 469 |
-
denom = valid.view(-1, to_pool.shape[-2]).float().sum(-1)
|
| 470 |
-
denom = torch.where(denom == 0, 1, denom)
|
| 471 |
-
query = to_pool.sum(-2, keepdim=True) / denom[:, None, None].to(to_pool.dtype)
|
| 472 |
-
else:
|
| 473 |
-
attn_mask = None
|
| 474 |
-
query = to_pool.mean(-2, keepdim=True)
|
| 475 |
-
pooled_features = self.image_pooling_2d(query, to_pool, attn_mask=attn_mask)
|
| 476 |
-
pooled_features = pooled_features.reshape([batch_size, -1, pooled_features.shape[-1]])
|
| 477 |
-
|
| 478 |
-
# MLP layer to map the feature.
|
| 479 |
-
pooled_features = self.image_projector(pooled_features)
|
| 480 |
-
return pooled_features.view(-1, pooled_features.shape[-1])[valid_token.flatten()]
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
| 484 |
-
def rotate_half(x):
|
| 485 |
-
"""Rotates half the hidden dims of the input."""
|
| 486 |
-
x1 = x[..., : x.shape[-1] // 2]
|
| 487 |
-
x2 = x[..., x.shape[-1] // 2 :]
|
| 488 |
-
return torch.cat((-x2, x1), dim=-1)
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
| 492 |
-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 493 |
-
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 494 |
-
|
| 495 |
-
Args:
|
| 496 |
-
q (`torch.Tensor`): The query tensor.
|
| 497 |
-
k (`torch.Tensor`): The key tensor.
|
| 498 |
-
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 499 |
-
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 500 |
-
position_ids (`torch.Tensor`, *optional*):
|
| 501 |
-
Deprecated and unused.
|
| 502 |
-
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 503 |
-
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 504 |
-
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 505 |
-
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 506 |
-
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 507 |
-
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 508 |
-
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 509 |
-
Returns:
|
| 510 |
-
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 511 |
-
"""
|
| 512 |
-
cos = cos.unsqueeze(unsqueeze_dim)
|
| 513 |
-
sin = sin.unsqueeze(unsqueeze_dim)
|
| 514 |
-
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 515 |
-
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 516 |
-
return q_embed, k_embed
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
class Molmo2RotaryEmbedding(nn.Module):
|
| 520 |
-
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 521 |
-
|
| 522 |
-
def __init__(
|
| 523 |
-
self,
|
| 524 |
-
config: Molmo2TextConfig,
|
| 525 |
-
device: Union[str, torch.device] = None,
|
| 526 |
-
rope_type: Optional[str] = None,
|
| 527 |
-
):
|
| 528 |
-
super().__init__()
|
| 529 |
-
if rope_type is not None:
|
| 530 |
-
self.rope_type = rope_type
|
| 531 |
-
elif hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
|
| 532 |
-
# BC: "rope_type" was originally "type"
|
| 533 |
-
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 534 |
-
else:
|
| 535 |
-
self.rope_type = "default"
|
| 536 |
-
self.max_seq_len_cached = config.max_position_embeddings
|
| 537 |
-
self.original_max_seq_len = config.max_position_embeddings
|
| 538 |
-
|
| 539 |
-
self.config = config
|
| 540 |
-
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 541 |
-
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
| 542 |
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 543 |
-
self.original_inv_freq = self.inv_freq
|
| 544 |
-
|
| 545 |
-
@torch.no_grad()
|
| 546 |
-
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 547 |
-
def forward(self, x, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 548 |
-
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 549 |
-
position_ids_expanded = position_ids[:, None, :].float()
|
| 550 |
-
|
| 551 |
-
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 552 |
-
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 553 |
-
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 554 |
-
emb = torch.cat((freqs, freqs), dim=-1)
|
| 555 |
-
cos = emb.cos() * self.attention_scaling
|
| 556 |
-
sin = emb.sin() * self.attention_scaling
|
| 557 |
-
|
| 558 |
-
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
class Molmo2RMSNorm(nn.Module):
|
| 562 |
-
|
| 563 |
-
def __init__(
|
| 564 |
-
self,
|
| 565 |
-
size: int,
|
| 566 |
-
eps: float = 1e-6,
|
| 567 |
-
device: Union[str, torch.device] = None,
|
| 568 |
-
):
|
| 569 |
-
super().__init__()
|
| 570 |
-
self.weight = nn.Parameter(torch.ones(size, device=device))
|
| 571 |
-
self.eps = eps
|
| 572 |
-
|
| 573 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 574 |
-
with torch.autocast(enabled=False, device_type=x.device.type):
|
| 575 |
-
og_dtype = x.dtype
|
| 576 |
-
x = x.to(torch.float32)
|
| 577 |
-
variance = x.pow(2).mean(-1, keepdim=True)
|
| 578 |
-
x = x * torch.rsqrt(variance + self.eps)
|
| 579 |
-
x = x.to(og_dtype)
|
| 580 |
-
|
| 581 |
-
return self.weight * x
|
| 582 |
-
|
| 583 |
-
def extra_repr(self):
|
| 584 |
-
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
| 588 |
-
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 589 |
-
"""
|
| 590 |
-
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 591 |
-
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 592 |
-
"""
|
| 593 |
-
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 594 |
-
if n_rep == 1:
|
| 595 |
-
return hidden_states
|
| 596 |
-
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 597 |
-
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
def eager_attention_forward(
|
| 601 |
-
module: nn.Module,
|
| 602 |
-
query: torch.Tensor,
|
| 603 |
-
key: torch.Tensor,
|
| 604 |
-
value: torch.Tensor,
|
| 605 |
-
attention_mask: Optional[torch.Tensor],
|
| 606 |
-
scaling: float,
|
| 607 |
-
dropout: float = 0.0,
|
| 608 |
-
**kwargs,
|
| 609 |
-
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 610 |
-
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 611 |
-
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 612 |
-
|
| 613 |
-
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 614 |
-
if attention_mask is not None:
|
| 615 |
-
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 616 |
-
attn_weights = attn_weights + causal_mask
|
| 617 |
-
|
| 618 |
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 619 |
-
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 620 |
-
attn_output = torch.matmul(attn_weights, value_states)
|
| 621 |
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 622 |
-
|
| 623 |
-
return attn_output, attn_weights
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
class Molmo2Attention(nn.Module):
|
| 627 |
-
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 628 |
-
|
| 629 |
-
def __init__(self, config: Molmo2TextConfig, layer_idx: int) -> None:
|
| 630 |
-
super().__init__()
|
| 631 |
-
self.config = config
|
| 632 |
-
self.layer_idx = layer_idx
|
| 633 |
-
self.num_heads = config.num_attention_heads
|
| 634 |
-
self.num_key_value_heads = config.num_key_value_heads
|
| 635 |
-
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 636 |
-
self.head_dim = config.head_dim
|
| 637 |
-
self.scaling = self.head_dim**-0.5
|
| 638 |
-
self.is_causal = True
|
| 639 |
-
|
| 640 |
-
self.fused_dims = (
|
| 641 |
-
config.num_attention_heads * config.head_dim,
|
| 642 |
-
config.head_dim * config.num_key_value_heads,
|
| 643 |
-
config.head_dim * config.num_key_value_heads,
|
| 644 |
-
)
|
| 645 |
-
self.att_proj = nn.Linear(
|
| 646 |
-
config.hidden_size,
|
| 647 |
-
sum(self.fused_dims),
|
| 648 |
-
bias=config.qkv_bias,
|
| 649 |
-
)
|
| 650 |
-
|
| 651 |
-
# Layer norms.
|
| 652 |
-
self.k_norm: Optional[Molmo2RMSNorm] = None
|
| 653 |
-
self.q_norm: Optional[Molmo2RMSNorm] = None
|
| 654 |
-
self.qk_norm_type: Optional[str] = None
|
| 655 |
-
if config.use_qk_norm:
|
| 656 |
-
k_norm_size = (
|
| 657 |
-
config.head_dim
|
| 658 |
-
if config.qk_norm_type == "qwen3" else
|
| 659 |
-
config.num_key_value_heads * config.head_dim
|
| 660 |
-
)
|
| 661 |
-
self.k_norm = Molmo2RMSNorm(k_norm_size, eps=config.layer_norm_eps)
|
| 662 |
-
q_norm_size = (
|
| 663 |
-
config.head_dim
|
| 664 |
-
if config.qk_norm_type == "qwen3" else
|
| 665 |
-
config.num_attention_heads * config.head_dim
|
| 666 |
-
)
|
| 667 |
-
self.q_norm = Molmo2RMSNorm(q_norm_size, eps=config.layer_norm_eps)
|
| 668 |
-
self.qk_norm_type = config.qk_norm_type
|
| 669 |
-
|
| 670 |
-
self.attention_dropout = config.attention_dropout
|
| 671 |
-
|
| 672 |
-
self.attn_out = nn.Linear(
|
| 673 |
-
config.head_dim * config.num_attention_heads,
|
| 674 |
-
config.hidden_size,
|
| 675 |
-
bias=False,
|
| 676 |
-
)
|
| 677 |
-
|
| 678 |
-
def forward(
|
| 679 |
-
self,
|
| 680 |
-
hidden_states: torch.Tensor,
|
| 681 |
-
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 682 |
-
attention_mask: Optional[torch.Tensor],
|
| 683 |
-
past_key_values: Optional[Cache] = None,
|
| 684 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 685 |
-
**kwargs: Unpack[FlashAttentionKwargs],
|
| 686 |
-
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
| 687 |
-
input_shape = hidden_states.shape[:-1]
|
| 688 |
-
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 689 |
-
|
| 690 |
-
qkv = self.att_proj(hidden_states)
|
| 691 |
-
query_states, key_states, value_states = qkv.split(self.fused_dims, dim=-1)
|
| 692 |
-
value_states = value_states.view(hidden_shape)
|
| 693 |
-
|
| 694 |
-
# Optionally apply layer norm to keys and queries.
|
| 695 |
-
if self.q_norm is not None and self.k_norm is not None and self.qk_norm_type != "qwen3":
|
| 696 |
-
query_states = self.q_norm(query_states)
|
| 697 |
-
key_states = self.k_norm(key_states)
|
| 698 |
-
|
| 699 |
-
query_states = query_states.view(hidden_shape)
|
| 700 |
-
key_states = key_states.view(hidden_shape)
|
| 701 |
-
if self.q_norm is not None and self.k_norm is not None and self.qk_norm_type == "qwen3":
|
| 702 |
-
query_states = self.q_norm(query_states)
|
| 703 |
-
key_states = self.k_norm(key_states)
|
| 704 |
-
query_states = query_states.transpose(1, 2)
|
| 705 |
-
key_states = key_states.transpose(1, 2)
|
| 706 |
-
value_states = value_states.transpose(1, 2)
|
| 707 |
-
|
| 708 |
-
cos, sin = position_embeddings
|
| 709 |
-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 710 |
-
|
| 711 |
-
if past_key_values is not None:
|
| 712 |
-
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 713 |
-
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 714 |
-
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 715 |
-
|
| 716 |
-
attention_interface: Callable = eager_attention_forward
|
| 717 |
-
if self.config._attn_implementation != "eager":
|
| 718 |
-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 719 |
-
|
| 720 |
-
attn_output, attn_weights = attention_interface(
|
| 721 |
-
self,
|
| 722 |
-
query_states,
|
| 723 |
-
key_states,
|
| 724 |
-
value_states,
|
| 725 |
-
attention_mask,
|
| 726 |
-
dropout=0.0 if not self.training else self.attention_dropout,
|
| 727 |
-
scaling=self.scaling,
|
| 728 |
-
**kwargs,
|
| 729 |
-
)
|
| 730 |
-
|
| 731 |
-
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 732 |
-
attn_output = self.attn_out(attn_output)
|
| 733 |
-
return attn_output, attn_weights
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
class LanguageModelMLP(nn.Module):
|
| 737 |
-
|
| 738 |
-
def __init__(
|
| 739 |
-
self,
|
| 740 |
-
input_dim: int,
|
| 741 |
-
intermediate_size: int,
|
| 742 |
-
hidden_act: str,
|
| 743 |
-
device: Union[str, torch.device] = None,
|
| 744 |
-
):
|
| 745 |
-
super().__init__()
|
| 746 |
-
self.ff_proj = nn.Linear(input_dim, intermediate_size * 2, bias=False, device=device)
|
| 747 |
-
self.ff_out = nn.Linear(intermediate_size, input_dim, bias=False, device=device)
|
| 748 |
-
self.act = ACT2FN[hidden_act]
|
| 749 |
-
|
| 750 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 751 |
-
x = self.ff_proj(x)
|
| 752 |
-
x, gate = x.chunk(2, dim=-1)
|
| 753 |
-
x = self.act(gate) * x
|
| 754 |
-
x = self.ff_out(x)
|
| 755 |
-
return x
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
class Molmo2DecoderLayer(GradientCheckpointingLayer):
|
| 759 |
-
|
| 760 |
-
def __init__(
|
| 761 |
-
self,
|
| 762 |
-
config: Molmo2TextConfig,
|
| 763 |
-
layer_idx: Optional[int] = None,
|
| 764 |
-
device: Union[str, torch.device] = None
|
| 765 |
-
):
|
| 766 |
-
super().__init__()
|
| 767 |
-
self.config = config
|
| 768 |
-
|
| 769 |
-
self.self_attn = Molmo2Attention(config, layer_idx)
|
| 770 |
-
self.attn_norm = Molmo2RMSNorm(
|
| 771 |
-
config.hidden_size, eps=config.layer_norm_eps, device=device)
|
| 772 |
-
self.dropout = nn.Dropout(config.residual_dropout)
|
| 773 |
-
self.mlp = LanguageModelMLP(
|
| 774 |
-
config.hidden_size, config.intermediate_size, config.hidden_act, device=device)
|
| 775 |
-
self.ff_norm = Molmo2RMSNorm(
|
| 776 |
-
config.hidden_size, eps=config.layer_norm_eps, device=device)
|
| 777 |
-
|
| 778 |
-
def forward(
|
| 779 |
-
self,
|
| 780 |
-
hidden_states: torch.Tensor,
|
| 781 |
-
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 782 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 783 |
-
position_ids: Optional[torch.LongTensor] = None,
|
| 784 |
-
past_key_values: Optional[Cache] = None,
|
| 785 |
-
output_attentions: Optional[bool] = False,
|
| 786 |
-
use_cache: Optional[bool] = False,
|
| 787 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 788 |
-
**kwargs: Unpack[TransformersKwargs],
|
| 789 |
-
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 790 |
-
|
| 791 |
-
residual = hidden_states
|
| 792 |
-
hidden_states = self.attn_norm(hidden_states)
|
| 793 |
-
|
| 794 |
-
# Self Attention
|
| 795 |
-
hidden_states, self_attn_weights = self.self_attn(
|
| 796 |
-
hidden_states=hidden_states,
|
| 797 |
-
position_embeddings=position_embeddings,
|
| 798 |
-
attention_mask=attention_mask,
|
| 799 |
-
position_ids=position_ids,
|
| 800 |
-
past_key_values=past_key_values,
|
| 801 |
-
output_attentions=output_attentions,
|
| 802 |
-
use_cache=use_cache,
|
| 803 |
-
cache_position=cache_position,
|
| 804 |
-
**kwargs,
|
| 805 |
-
)
|
| 806 |
-
|
| 807 |
-
hidden_states = residual + self.dropout(hidden_states)
|
| 808 |
-
|
| 809 |
-
# Fully Connected
|
| 810 |
-
residual = hidden_states
|
| 811 |
-
hidden_states = self.ff_norm(hidden_states)
|
| 812 |
-
hidden_states = self.mlp(hidden_states)
|
| 813 |
-
|
| 814 |
-
hidden_states = residual + self.dropout(hidden_states)
|
| 815 |
-
|
| 816 |
-
outputs = (hidden_states,)
|
| 817 |
-
|
| 818 |
-
if output_attentions:
|
| 819 |
-
outputs += (self_attn_weights,)
|
| 820 |
-
|
| 821 |
-
return outputs
|
| 822 |
-
|
| 823 |
-
|
| 824 |
-
class Molmo2PostNormDecoderLayer(Molmo2DecoderLayer):
|
| 825 |
-
def forward(
|
| 826 |
-
self,
|
| 827 |
-
hidden_states: torch.Tensor,
|
| 828 |
-
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 829 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 830 |
-
position_ids: Optional[torch.LongTensor] = None,
|
| 831 |
-
past_key_values: Optional[Cache] = None,
|
| 832 |
-
output_attentions: Optional[bool] = False,
|
| 833 |
-
use_cache: Optional[bool] = False,
|
| 834 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 835 |
-
**kwargs,
|
| 836 |
-
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 837 |
-
|
| 838 |
-
residual = hidden_states
|
| 839 |
-
|
| 840 |
-
# Self Attention
|
| 841 |
-
hidden_states, self_attn_weights = self.self_attn(
|
| 842 |
-
hidden_states=hidden_states,
|
| 843 |
-
position_embeddings=position_embeddings,
|
| 844 |
-
attention_mask=attention_mask,
|
| 845 |
-
position_ids=position_ids,
|
| 846 |
-
past_key_values=past_key_values,
|
| 847 |
-
output_attentions=output_attentions,
|
| 848 |
-
use_cache=use_cache,
|
| 849 |
-
cache_position=cache_position,
|
| 850 |
-
)
|
| 851 |
-
hidden_states = self.attn_norm(hidden_states)
|
| 852 |
-
|
| 853 |
-
hidden_states = residual + self.dropout(hidden_states)
|
| 854 |
-
|
| 855 |
-
# Fully Connected
|
| 856 |
-
residual = hidden_states
|
| 857 |
-
hidden_states = self.mlp(hidden_states)
|
| 858 |
-
hidden_states = self.ff_norm(hidden_states)
|
| 859 |
-
|
| 860 |
-
hidden_states = residual + self.dropout(hidden_states)
|
| 861 |
-
|
| 862 |
-
outputs = (hidden_states,)
|
| 863 |
-
|
| 864 |
-
if output_attentions:
|
| 865 |
-
outputs += (self_attn_weights,)
|
| 866 |
-
|
| 867 |
-
return outputs
|
| 868 |
-
|
| 869 |
-
|
| 870 |
-
class Molmo2Embedding(nn.Module):
|
| 871 |
-
def __init__(
|
| 872 |
-
self,
|
| 873 |
-
num_embeddings: int,
|
| 874 |
-
num_new_embeddings: int,
|
| 875 |
-
features: int,
|
| 876 |
-
device: Union[str, torch.device] = None,
|
| 877 |
-
):
|
| 878 |
-
super().__init__()
|
| 879 |
-
self.embedding = nn.Parameter(
|
| 880 |
-
torch.zeros(num_embeddings, features, device=device),
|
| 881 |
-
)
|
| 882 |
-
self.new_embedding = nn.Parameter(
|
| 883 |
-
torch.zeros(num_new_embeddings, features, device=device),
|
| 884 |
-
)
|
| 885 |
-
|
| 886 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 887 |
-
return F.embedding(x, torch.cat([self.embedding, self.new_embedding], dim=0))
|
| 888 |
-
|
| 889 |
-
|
| 890 |
-
class MolmoPointPreTrainedModel(PreTrainedModel):
|
| 891 |
-
config: MolmoPointConfig
|
| 892 |
-
base_model_prefix = "model"
|
| 893 |
-
supports_gradient_checkpointing = True
|
| 894 |
-
_no_split_modules = [
|
| 895 |
-
"Molmo2DecoderLayer",
|
| 896 |
-
"Molmo2PostNormDecoderLayer",
|
| 897 |
-
"Molmo2VisionBlock",
|
| 898 |
-
"ViTMultiHeadDotProductAttention",
|
| 899 |
-
]
|
| 900 |
-
_skip_keys_device_placement = "past_key_values"
|
| 901 |
-
_supports_flash_attn = True
|
| 902 |
-
_supports_sdpa = True
|
| 903 |
-
|
| 904 |
-
_can_compile_fullgraph = True
|
| 905 |
-
_supports_attention_backend = True
|
| 906 |
-
_can_record_outputs = {
|
| 907 |
-
"hidden_states": Molmo2DecoderLayer,
|
| 908 |
-
"attentions": Molmo2Attention,
|
| 909 |
-
}
|
| 910 |
-
|
| 911 |
-
def _init_weights(self, module):
|
| 912 |
-
std = self.config.initializer_range
|
| 913 |
-
if isinstance(module, (nn.Linear,)):
|
| 914 |
-
module.weight.data.normal_(mean=0.0, std=std)
|
| 915 |
-
if module.bias is not None:
|
| 916 |
-
module.bias.data.zero_()
|
| 917 |
-
elif isinstance(module, Molmo2Embedding):
|
| 918 |
-
module.embedding.data.normal_(mean=0.0, std=std)
|
| 919 |
-
module.new_embedding.data.normal_(mean=0.0, std=std)
|
| 920 |
-
elif isinstance(module, nn.Embedding):
|
| 921 |
-
module.weight.data.normal_(mean=0.0, std=std)
|
| 922 |
-
if module.padding_idx is not None:
|
| 923 |
-
module.weight.data[module.padding_idx].zero_()
|
| 924 |
-
elif isinstance(module, Molmo2RMSNorm):
|
| 925 |
-
module.weight.data.fill_(1.0)
|
| 926 |
-
elif isinstance(module, nn.LayerNorm):
|
| 927 |
-
module.weight.data.fill_(1.0)
|
| 928 |
-
if module.bias is not None:
|
| 929 |
-
module.bias.data.zero_()
|
| 930 |
-
|
| 931 |
-
|
| 932 |
-
class MolmoPointTextModel(PreTrainedModel):
|
| 933 |
-
config: Molmo2TextConfig
|
| 934 |
-
_no_split_modules = ["Molmo2DecoderLayer", "Molmo2PostNormDecoderLayer"]
|
| 935 |
-
base_model_prefix = "model"
|
| 936 |
-
supports_gradient_checkpointing = True
|
| 937 |
-
_skip_keys_device_placement = "past_key_values"
|
| 938 |
-
_supports_flash_attn = True
|
| 939 |
-
_supports_sdpa = True
|
| 940 |
-
|
| 941 |
-
_can_compile_fullgraph = True
|
| 942 |
-
_supports_attention_backend = True
|
| 943 |
-
_can_record_outputs = {
|
| 944 |
-
"hidden_states": Molmo2DecoderLayer,
|
| 945 |
-
"attentions": Molmo2Attention,
|
| 946 |
-
}
|
| 947 |
-
|
| 948 |
-
def __init__(self, config: Molmo2TextConfig):
|
| 949 |
-
super().__init__(config)
|
| 950 |
-
if config.additional_vocab_size is not None:
|
| 951 |
-
self.wte = Molmo2Embedding(
|
| 952 |
-
config.vocab_size,
|
| 953 |
-
config.additional_vocab_size,
|
| 954 |
-
config.hidden_size,
|
| 955 |
-
)
|
| 956 |
-
else:
|
| 957 |
-
self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 958 |
-
self.emb_drop = nn.Dropout(config.embedding_dropout)
|
| 959 |
-
decoder_layer = Molmo2PostNormDecoderLayer if config.norm_after else Molmo2DecoderLayer
|
| 960 |
-
self.blocks = nn.ModuleList(
|
| 961 |
-
[decoder_layer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 962 |
-
)
|
| 963 |
-
self.ln_f = Molmo2RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 964 |
-
if config.rope_scaling_layers is not None:
|
| 965 |
-
self.rotary_embs = nn.ModuleDict(
|
| 966 |
-
{
|
| 967 |
-
"default": Molmo2RotaryEmbedding(config, rope_type="default"),
|
| 968 |
-
"scaling": Molmo2RotaryEmbedding(config),
|
| 969 |
-
}
|
| 970 |
-
)
|
| 971 |
-
else:
|
| 972 |
-
self.rotary_emb = Molmo2RotaryEmbedding(config)
|
| 973 |
-
self.gradient_checkpointing = False
|
| 974 |
-
|
| 975 |
-
# Initialize weights and apply final processing
|
| 976 |
-
self.post_init()
|
| 977 |
-
|
| 978 |
-
def get_input_embeddings(self) -> torch.nn.Module:
|
| 979 |
-
return self.wte
|
| 980 |
-
|
| 981 |
-
def set_input_embeddings(self, value: torch.nn.Module) -> None:
|
| 982 |
-
self.wte = value
|
| 983 |
-
|
| 984 |
-
@can_return_tuple
|
| 985 |
-
def forward(
|
| 986 |
-
self,
|
| 987 |
-
input_ids: Optional[torch.LongTensor] = None,
|
| 988 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 989 |
-
position_ids: Optional[torch.LongTensor] = None,
|
| 990 |
-
past_key_values: Optional[Cache] = None,
|
| 991 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 992 |
-
use_cache: Optional[bool] = None,
|
| 993 |
-
output_attentions: Optional[bool] = None,
|
| 994 |
-
output_hidden_states: Optional[bool] = None,
|
| 995 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 996 |
-
**kwargs: Unpack[TransformersKwargs],
|
| 997 |
-
) -> BaseModelOutputWithPast:
|
| 998 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 999 |
-
output_hidden_states = (
|
| 1000 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1001 |
-
)
|
| 1002 |
-
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 1003 |
-
|
| 1004 |
-
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 1005 |
-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 1006 |
-
|
| 1007 |
-
if self.gradient_checkpointing and self.training and use_cache:
|
| 1008 |
-
logger.warning_once(
|
| 1009 |
-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
| 1010 |
-
)
|
| 1011 |
-
use_cache = False
|
| 1012 |
-
|
| 1013 |
-
if inputs_embeds is None:
|
| 1014 |
-
input_ids = input_ids * (input_ids != -1).to(input_ids.dtype)
|
| 1015 |
-
inputs_embeds = self.wte(input_ids)
|
| 1016 |
-
|
| 1017 |
-
# torch.jit.trace() doesn't support cache objects in the output
|
| 1018 |
-
if use_cache and past_key_values is None and not torch.jit.is_tracing():
|
| 1019 |
-
past_key_values = DynamicCache(config=self.config)
|
| 1020 |
-
|
| 1021 |
-
if cache_position is None:
|
| 1022 |
-
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 1023 |
-
cache_position = torch.arange(
|
| 1024 |
-
past_seen_tokens,
|
| 1025 |
-
past_seen_tokens + inputs_embeds.shape[1],
|
| 1026 |
-
device=inputs_embeds.device,
|
| 1027 |
-
)
|
| 1028 |
-
|
| 1029 |
-
if position_ids is None:
|
| 1030 |
-
position_ids = cache_position.unsqueeze(0)
|
| 1031 |
-
|
| 1032 |
-
# It may already have been prepared by e.g. `generate`
|
| 1033 |
-
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
| 1034 |
-
# Prepare mask arguments
|
| 1035 |
-
mask_kwargs = {
|
| 1036 |
-
"config": self.config,
|
| 1037 |
-
"input_embeds": inputs_embeds,
|
| 1038 |
-
"attention_mask": attention_mask,
|
| 1039 |
-
"cache_position": cache_position,
|
| 1040 |
-
"past_key_values": past_key_values,
|
| 1041 |
-
"position_ids": position_ids,
|
| 1042 |
-
}
|
| 1043 |
-
|
| 1044 |
-
# Create the mask
|
| 1045 |
-
causal_mask_mapping = create_causal_mask(**mask_kwargs)
|
| 1046 |
-
|
| 1047 |
-
hidden_states = inputs_embeds
|
| 1048 |
-
|
| 1049 |
-
# create position embeddings to be shared across the decoder layers
|
| 1050 |
-
if self.config.rope_scaling_layers is not None:
|
| 1051 |
-
position_embeddings_mapping = {
|
| 1052 |
-
"default": self.rotary_embs["default"](hidden_states, position_ids),
|
| 1053 |
-
"scaling": self.rotary_embs["scaling"](hidden_states, position_ids),
|
| 1054 |
-
}
|
| 1055 |
-
else:
|
| 1056 |
-
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 1057 |
-
|
| 1058 |
-
# decoder layers
|
| 1059 |
-
all_hidden_states = () if output_hidden_states else None
|
| 1060 |
-
all_self_attns = () if output_attentions else None
|
| 1061 |
-
|
| 1062 |
-
for layer_idx, decoder_block in enumerate(self.blocks[: self.config.num_hidden_layers]):
|
| 1063 |
-
if output_hidden_states:
|
| 1064 |
-
all_hidden_states += (hidden_states,)
|
| 1065 |
-
|
| 1066 |
-
if self.config.rope_scaling_layers is not None:
|
| 1067 |
-
position_embeddings_i = (
|
| 1068 |
-
position_embeddings_mapping["scaling"]
|
| 1069 |
-
if layer_idx in self.config.rope_scaling_layers
|
| 1070 |
-
else position_embeddings_mapping["default"]
|
| 1071 |
-
)
|
| 1072 |
-
else:
|
| 1073 |
-
position_embeddings_i = position_embeddings
|
| 1074 |
-
|
| 1075 |
-
layer_outputs = decoder_block(
|
| 1076 |
-
hidden_states,
|
| 1077 |
-
attention_mask=causal_mask_mapping,
|
| 1078 |
-
position_ids=position_ids,
|
| 1079 |
-
past_key_values=past_key_values,
|
| 1080 |
-
output_attentions=output_attentions,
|
| 1081 |
-
use_cache=use_cache,
|
| 1082 |
-
cache_position=cache_position,
|
| 1083 |
-
position_embeddings=position_embeddings_i,
|
| 1084 |
-
**kwargs,
|
| 1085 |
-
)
|
| 1086 |
-
|
| 1087 |
-
hidden_states = layer_outputs[0]
|
| 1088 |
-
|
| 1089 |
-
if output_attentions:
|
| 1090 |
-
all_self_attns += (layer_outputs[1],)
|
| 1091 |
-
|
| 1092 |
-
hidden_states = self.ln_f(hidden_states)
|
| 1093 |
-
|
| 1094 |
-
# add hidden states from the last decoder layer
|
| 1095 |
-
if output_hidden_states:
|
| 1096 |
-
all_hidden_states += (hidden_states,)
|
| 1097 |
-
|
| 1098 |
-
return BaseModelOutputWithPast(
|
| 1099 |
-
last_hidden_state=hidden_states,
|
| 1100 |
-
past_key_values=past_key_values,
|
| 1101 |
-
hidden_states=all_hidden_states,
|
| 1102 |
-
attentions=all_self_attns,
|
| 1103 |
-
)
|
| 1104 |
-
|
| 1105 |
-
# Adapted from transformers.models.gemma3.modeling_gemma3
|
| 1106 |
-
def token_type_ids_mask_function(
|
| 1107 |
-
token_type_ids: Optional[torch.Tensor] = None,
|
| 1108 |
-
) -> Optional[Callable]:
|
| 1109 |
-
"""
|
| 1110 |
-
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
|
| 1111 |
-
not start and end indices.
|
| 1112 |
-
"""
|
| 1113 |
-
# Do not return an additional mask in this case
|
| 1114 |
-
if token_type_ids is None:
|
| 1115 |
-
return None
|
| 1116 |
-
|
| 1117 |
-
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
| 1118 |
-
# If it's 1 for both query and key/value, we are in an image block
|
| 1119 |
-
# NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
|
| 1120 |
-
# Since vmap doesn't support `if statement` we workaround it with `torch.where`
|
| 1121 |
-
safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
|
| 1122 |
-
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx]
|
| 1123 |
-
token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)
|
| 1124 |
-
|
| 1125 |
-
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1)
|
| 1126 |
-
|
| 1127 |
-
# This is bidirectional attention whenever we are dealing with image tokens
|
| 1128 |
-
return is_image_block & is_image_block
|
| 1129 |
-
|
| 1130 |
-
return inner_mask
|
| 1131 |
-
|
| 1132 |
-
|
| 1133 |
-
class MolmoPointPadWithLearnedVector(nn.Module):
|
| 1134 |
-
"""Module that pads vector
|
| 1135 |
-
|
| 1136 |
-
Used to add in the no-more-point key value
|
| 1137 |
-
"""
|
| 1138 |
-
def __init__(self, dim: int):
|
| 1139 |
-
super().__init__()
|
| 1140 |
-
self.dim = dim
|
| 1141 |
-
self.vector = nn.Parameter(torch.zeros([dim]))
|
| 1142 |
-
|
| 1143 |
-
def reset_parameters(self):
|
| 1144 |
-
torch.nn.init.zeros_(self.vector)
|
| 1145 |
-
|
| 1146 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 1147 |
-
vector = torch.tile(self.vector[None, :], [x.shape[0], 1])
|
| 1148 |
-
return torch.concatenate([x, vector[:, None, :]], dim=1)
|
| 1149 |
-
|
| 1150 |
-
|
| 1151 |
-
class AddPosEmbed(nn.Module):
|
| 1152 |
-
|
| 1153 |
-
def __init__(self, in_features: int, n_pos: int) -> None:
|
| 1154 |
-
super().__init__()
|
| 1155 |
-
self.bias = nn.Parameter(torch.zeros([n_pos, in_features]))
|
| 1156 |
-
|
| 1157 |
-
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 1158 |
-
return input + self.bias[None, :input.shape[-2], :]
|
| 1159 |
-
|
| 1160 |
-
|
| 1161 |
-
class MolmoPointConnector(nn.Module):
|
| 1162 |
-
def __init__(self, config: MolmoPointAdapterConfig, vit_config: Molmo2VitConfig):
|
| 1163 |
-
super().__init__()
|
| 1164 |
-
self.config = config
|
| 1165 |
-
self.n_vit_layers = len(config.vit_layers)
|
| 1166 |
-
pool_dim = vit_config.hidden_size * self.n_vit_layers
|
| 1167 |
-
self.norm = None
|
| 1168 |
-
self.image_projector = ImageProjectorMLP(
|
| 1169 |
-
config.hidden_size,
|
| 1170 |
-
config.intermediate_size,
|
| 1171 |
-
config.text_hidden_size,
|
| 1172 |
-
config.hidden_act,
|
| 1173 |
-
)
|
| 1174 |
-
self.act = ACT2FN[config.hidden_act]
|
| 1175 |
-
self.image_pooling_2d = ViTMultiHeadDotProductAttention(
|
| 1176 |
-
hidden_size=config.hidden_size,
|
| 1177 |
-
num_heads=config.num_attention_heads,
|
| 1178 |
-
num_key_value_heads=config.num_key_value_heads,
|
| 1179 |
-
head_dim=config.head_dim,
|
| 1180 |
-
input_dim=pool_dim,
|
| 1181 |
-
float32_attention=config.float32_attention,
|
| 1182 |
-
attention_dropout=config.attention_dropout,
|
| 1183 |
-
residual_dropout=config.residual_dropout,
|
| 1184 |
-
attn_implementation=config._attn_implementation,
|
| 1185 |
-
out_layer=False
|
| 1186 |
-
)
|
| 1187 |
-
if self.config.positional_embeddings:
|
| 1188 |
-
self.positional_embeddings = AddPosEmbed(pool_dim, self.config.positional_embeddings)
|
| 1189 |
-
else:
|
| 1190 |
-
self.positional_embeddings = None
|
| 1191 |
-
|
| 1192 |
-
def __call__(self, to_pool, to_pool_mask):
|
| 1193 |
-
"""
|
| 1194 |
-
to_pool: [n_to_pool, pooling_dim, vit_dim]
|
| 1195 |
-
to_pool_mask: [n_to_pool, pooling_dim]
|
| 1196 |
-
|
| 1197 |
-
returns:
|
| 1198 |
-
pooled_features: [n_to_pool, llm_dim]
|
| 1199 |
-
"""
|
| 1200 |
-
cfg = self.config
|
| 1201 |
-
|
| 1202 |
-
if self.config.positional_embeddings:
|
| 1203 |
-
to_pool = self.positional_embeddings(to_pool)
|
| 1204 |
-
|
| 1205 |
-
if self.config.pooling_attention_mask:
|
| 1206 |
-
attn_mask = to_pool_mask.reshape([-1, 1, 1, to_pool_mask.shape[-1]])
|
| 1207 |
-
else:
|
| 1208 |
-
attn_mask = None
|
| 1209 |
-
to_pool = to_pool * to_pool_mask.float()[:, :, None]
|
| 1210 |
-
|
| 1211 |
-
denom = to_pool_mask.view(-1, to_pool.shape[-2]).float().sum(-1)
|
| 1212 |
-
denom = torch.where(denom == 0, 1, denom)
|
| 1213 |
-
query = to_pool.sum(-2, keepdim=True) / denom[:, None, None]
|
| 1214 |
-
|
| 1215 |
-
pooled_features = self.image_pooling_2d(query, to_pool, attn_mask=attn_mask)
|
| 1216 |
-
pooled_features = self.act(pooled_features)
|
| 1217 |
-
pooled_features = self.image_projector(pooled_features)
|
| 1218 |
-
return pooled_features
|
| 1219 |
-
|
| 1220 |
-
|
| 1221 |
-
class MolmoPointModel(MolmoPointPreTrainedModel):
|
| 1222 |
-
base_model_prefix = ""
|
| 1223 |
-
_checkpoint_conversion_mapping = {}
|
| 1224 |
-
# Reference: fix gemma3 grad acc #37208
|
| 1225 |
-
accepts_loss_kwargs = False
|
| 1226 |
-
config: MolmoPointConfig
|
| 1227 |
-
|
| 1228 |
-
def __init__(self, config: MolmoPointConfig):
|
| 1229 |
-
super().__init__(config)
|
| 1230 |
-
self.transformer: MolmoPointTextModel = MolmoPointTextModel(config.text_config)
|
| 1231 |
-
|
| 1232 |
-
vit_config = config.vit_config
|
| 1233 |
-
adapter_config = config.adapter_config
|
| 1234 |
-
self.vit_layers = []
|
| 1235 |
-
for layer in adapter_config.vit_layers:
|
| 1236 |
-
if layer >= 0:
|
| 1237 |
-
self.vit_layers.append(layer)
|
| 1238 |
-
else:
|
| 1239 |
-
self.vit_layers.append(layer + vit_config.num_hidden_layers)
|
| 1240 |
-
|
| 1241 |
-
last_layer_needed = max(self.vit_layers) + 1
|
| 1242 |
-
if last_layer_needed < vit_config.num_hidden_layers:
|
| 1243 |
-
new_vit_config = deepcopy(vit_config)
|
| 1244 |
-
new_vit_config.num_hidden_layers = last_layer_needed
|
| 1245 |
-
self.vit = Molmo2VisionTransformer(new_vit_config)
|
| 1246 |
-
else:
|
| 1247 |
-
self.vit = Molmo2VisionTransformer(vit_config)
|
| 1248 |
-
|
| 1249 |
-
self.connector = MolmoPointConnector(adapter_config, vit_config)
|
| 1250 |
-
|
| 1251 |
-
vit_dim = self.config.vit_config.hidden_size * len(self.config.adapter_config.vit_layers)
|
| 1252 |
-
llm_dim = self.config.text_config.hidden_size
|
| 1253 |
-
self.patch_rotary = None
|
| 1254 |
-
self.patch_q = nn.Linear(llm_dim, config.patch_embed_dim)
|
| 1255 |
-
self.patch_k = nn.Linear(llm_dim, config.patch_embed_dim)
|
| 1256 |
-
self.subpatch_q = nn.Linear(llm_dim, config.patch_embed_dim)
|
| 1257 |
-
self.subpatch_k = nn.Linear(vit_dim, config.patch_embed_dim)
|
| 1258 |
-
self.add_no_point_class_embed = MolmoPointPadWithLearnedVector(config.patch_embed_dim)
|
| 1259 |
-
|
| 1260 |
-
if self.config.embed_selected_vit_patch == "linear":
|
| 1261 |
-
self.build_vit_embedding = nn.Linear(vit_dim, llm_dim, bias=True)
|
| 1262 |
-
else:
|
| 1263 |
-
raise NotImplementedError(f"Embedding {self.config.embed_selected_vit_patch} not implemented")
|
| 1264 |
-
|
| 1265 |
-
if self.config.patch_location == "3x3":
|
| 1266 |
-
self.subpatch_loc_k = nn.Linear(llm_dim, 9)
|
| 1267 |
-
elif self.config.patch_location is None:
|
| 1268 |
-
self.subpatch_loc_k = None
|
| 1269 |
-
else:
|
| 1270 |
-
raise NotImplementedError(f"Patch location {self.config.patch_location} not implemented")
|
| 1271 |
-
|
| 1272 |
-
if self.config.layer_norm_x:
|
| 1273 |
-
self.x_norm = Molmo2RMSNorm(llm_dim, eps=self.config.text_config.layer_norm_eps)
|
| 1274 |
-
else:
|
| 1275 |
-
self.x_norm = None
|
| 1276 |
-
|
| 1277 |
-
# Initialize weights and apply final processing
|
| 1278 |
-
self.post_init()
|
| 1279 |
-
|
| 1280 |
-
def get_input_embeddings(self) -> torch.nn.Module:
|
| 1281 |
-
return self.transformer.wte
|
| 1282 |
-
|
| 1283 |
-
def set_input_embeddings(self, value: torch.nn.Module) -> None:
|
| 1284 |
-
self.transformer.wte = value
|
| 1285 |
-
|
| 1286 |
-
def set_decoder(self, decoder):
|
| 1287 |
-
self.transformer = decoder
|
| 1288 |
-
|
| 1289 |
-
def get_decoder(self):
|
| 1290 |
-
return self.transformer
|
| 1291 |
-
|
| 1292 |
-
@property
|
| 1293 |
-
def device(self) -> torch.device:
|
| 1294 |
-
return self.transformer.ln_f.weight.device
|
| 1295 |
-
|
| 1296 |
-
def build_batched_images(
|
| 1297 |
-
self,
|
| 1298 |
-
input_ids: torch.LongTensor,
|
| 1299 |
-
pixel_values: torch.Tensor,
|
| 1300 |
-
image_token_pooling: torch.Tensor,
|
| 1301 |
-
image_grids: torch.Tensor,
|
| 1302 |
-
image_num_crops: torch.Tensor,
|
| 1303 |
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 1304 |
-
# 1) Count the number of images in each example
|
| 1305 |
-
raw_counts = (input_ids == self.config.image_end_token_id).sum(1) # [N]
|
| 1306 |
-
# Each image is represented by global view and high-res view
|
| 1307 |
-
# so we divide by 2 to get the number of images
|
| 1308 |
-
counts = raw_counts // 2
|
| 1309 |
-
N = counts.size(0)
|
| 1310 |
-
device = input_ids.device
|
| 1311 |
-
|
| 1312 |
-
# Total number of images in the batch
|
| 1313 |
-
num_images = int(counts.sum().item())
|
| 1314 |
-
|
| 1315 |
-
# Sanity check
|
| 1316 |
-
assert image_grids.size(0) == num_images, \
|
| 1317 |
-
f"Expected {num_images} image grids, but got {image_grids.size(0)}"
|
| 1318 |
-
assert image_num_crops.size(0) == num_images, \
|
| 1319 |
-
f"Expected {num_images} image num crops, but got {image_num_crops.size(0)}"
|
| 1320 |
-
|
| 1321 |
-
# 1-1) Compute per-image pooled patch count from image grids
|
| 1322 |
-
with torch.no_grad():
|
| 1323 |
-
first_prod = image_grids[:, :2].prod(dim=1) # [num_images]
|
| 1324 |
-
second_prod = image_grids[:, 2:].prod(dim=1) # [num_images]
|
| 1325 |
-
num_pooled_patches_per_image = (first_prod + second_prod).to(image_num_crops.dtype) # [num_images]
|
| 1326 |
-
|
| 1327 |
-
# pixel_values: [n_crops, n_patches, pixels_per_patch]
|
| 1328 |
-
n_crops, n_patches, pixels_per_patch = pixel_values.shape
|
| 1329 |
-
|
| 1330 |
-
# 2) Map each image index → example index
|
| 1331 |
-
# Example: if counts = [2, 1, 3], then this becomes [0,0,1,2,2,2]
|
| 1332 |
-
example_ids_for_image = torch.arange(N, device=device).repeat_interleave(counts) # [num_images]
|
| 1333 |
-
assert example_ids_for_image.numel() == num_images
|
| 1334 |
-
|
| 1335 |
-
# 2-1) Compute crops_per_example by summing per-image crop counts
|
| 1336 |
-
crops_per_example = torch.zeros(
|
| 1337 |
-
N, dtype=image_num_crops.dtype, device=image_num_crops.device
|
| 1338 |
-
)
|
| 1339 |
-
crops_per_example.index_add_(0, example_ids_for_image, image_num_crops) # [N]
|
| 1340 |
-
|
| 1341 |
-
# 2-2) Per-image number of patches = (crops per image) * n_patches
|
| 1342 |
-
patches_per_image = image_num_crops * n_patches # [num_images]
|
| 1343 |
-
|
| 1344 |
-
# 2-3) Compute per-example per-image patch offsets
|
| 1345 |
-
counts_list = counts.tolist()
|
| 1346 |
-
index_offset_per_example_list = []
|
| 1347 |
-
offset_img = 0
|
| 1348 |
-
for c in counts_list:
|
| 1349 |
-
per_img_patches = patches_per_image[offset_img:offset_img + c] # [c]
|
| 1350 |
-
# Offsets: [0, img0_total_patches, img0+img1_total_patches, ...]
|
| 1351 |
-
index_offset = [0] + per_img_patches.cumsum(0).tolist()[:-1]
|
| 1352 |
-
index_offset_per_example_list.append(index_offset)
|
| 1353 |
-
offset_img += c
|
| 1354 |
-
|
| 1355 |
-
# 2-4) Compute num_pooled_patches_per_example
|
| 1356 |
-
num_pooled_patches_per_example = torch.zeros(
|
| 1357 |
-
N, dtype=num_pooled_patches_per_image.dtype, device=num_pooled_patches_per_image.device
|
| 1358 |
-
)
|
| 1359 |
-
num_pooled_patches_per_example.index_add_(
|
| 1360 |
-
0, example_ids_for_image, num_pooled_patches_per_image
|
| 1361 |
-
)
|
| 1362 |
-
|
| 1363 |
-
# Sanity checks
|
| 1364 |
-
total_crops = int(crops_per_example.sum().item())
|
| 1365 |
-
assert total_crops == n_crops, \
|
| 1366 |
-
f"Expected {total_crops} crops, but got {n_crops}"
|
| 1367 |
-
|
| 1368 |
-
total_num_pooled_patches = int(num_pooled_patches_per_example.sum().item())
|
| 1369 |
-
assert total_num_pooled_patches == image_token_pooling.size(0), \
|
| 1370 |
-
f"Expected {total_num_pooled_patches} pooled patches, but got {image_token_pooling.size(0)}"
|
| 1371 |
-
|
| 1372 |
-
# 3) Build images tensor filled with -1
|
| 1373 |
-
M = int(crops_per_example.max().item())
|
| 1374 |
-
images = torch.full(
|
| 1375 |
-
(N, M, n_patches, pixels_per_patch),
|
| 1376 |
-
fill_value=-1,
|
| 1377 |
-
dtype=pixel_values.dtype,
|
| 1378 |
-
device=pixel_values.device,
|
| 1379 |
-
)
|
| 1380 |
-
|
| 1381 |
-
# 4) Fill images with per-example slices from pixel_values
|
| 1382 |
-
offset_crop = 0
|
| 1383 |
-
for i in range(N):
|
| 1384 |
-
num = int(crops_per_example[i].item())
|
| 1385 |
-
cur = pixel_values[offset_crop:offset_crop + num] # [num, n_patches, pixels_per_patch]
|
| 1386 |
-
images[i, :num] = cur
|
| 1387 |
-
offset_crop += num
|
| 1388 |
-
|
| 1389 |
-
# Sanity check
|
| 1390 |
-
assert offset_crop == n_crops
|
| 1391 |
-
|
| 1392 |
-
# 5) Build new_token_pooling tensor filled with -1
|
| 1393 |
-
P = int(num_pooled_patches_per_example.max().item())
|
| 1394 |
-
_, dim = image_token_pooling.shape
|
| 1395 |
-
new_token_pooling = torch.full(
|
| 1396 |
-
(N, P, dim),
|
| 1397 |
-
fill_value=-1,
|
| 1398 |
-
dtype=image_token_pooling.dtype,
|
| 1399 |
-
device=image_token_pooling.device,
|
| 1400 |
-
)
|
| 1401 |
-
|
| 1402 |
-
# 6) Fill token_pooling with per-example slices, adding per-image patch offsets
|
| 1403 |
-
patch_offset = 0
|
| 1404 |
-
img_offset = 0
|
| 1405 |
-
|
| 1406 |
-
for i, c in enumerate(counts_list):
|
| 1407 |
-
num_patches = int(num_pooled_patches_per_example[i].item())
|
| 1408 |
-
|
| 1409 |
-
# Subsequence of pooled tokens belonging to this example
|
| 1410 |
-
cur = image_token_pooling[patch_offset:patch_offset + num_patches].clone() # [num_patches, dim]
|
| 1411 |
-
|
| 1412 |
-
index_offset_per_example = index_offset_per_example_list[i] # length = c
|
| 1413 |
-
per_img_pooled = num_pooled_patches_per_image[img_offset:img_offset + c] # [c]
|
| 1414 |
-
|
| 1415 |
-
assert len(index_offset_per_example) == per_img_pooled.numel()
|
| 1416 |
-
|
| 1417 |
-
# Apply per-image offsets to the (ragged) subsequence
|
| 1418 |
-
offset = 0
|
| 1419 |
-
for j in range(c):
|
| 1420 |
-
index_offset = int(index_offset_per_example[j])
|
| 1421 |
-
n = int(per_img_pooled[j].item())
|
| 1422 |
-
cur_slice = cur[offset:offset + n]
|
| 1423 |
-
|
| 1424 |
-
# Apply offset across all columns
|
| 1425 |
-
cur[offset:offset + n] = torch.where(
|
| 1426 |
-
cur_slice >= 0,
|
| 1427 |
-
cur_slice + index_offset,
|
| 1428 |
-
cur_slice,
|
| 1429 |
-
)
|
| 1430 |
-
offset += n
|
| 1431 |
-
|
| 1432 |
-
new_token_pooling[i, :num_patches] = cur
|
| 1433 |
-
|
| 1434 |
-
patch_offset += num_patches
|
| 1435 |
-
img_offset += c
|
| 1436 |
-
|
| 1437 |
-
# Final sanity checks
|
| 1438 |
-
assert patch_offset == total_num_pooled_patches
|
| 1439 |
-
assert img_offset == num_images
|
| 1440 |
-
|
| 1441 |
-
return images, new_token_pooling
|
| 1442 |
-
|
| 1443 |
-
def build_batched_videos(
|
| 1444 |
-
self,
|
| 1445 |
-
input_ids: torch.LongTensor,
|
| 1446 |
-
pixel_values_videos: torch.Tensor,
|
| 1447 |
-
video_token_pooling: torch.Tensor,
|
| 1448 |
-
video_grids: torch.Tensor,
|
| 1449 |
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 1450 |
-
|
| 1451 |
-
# 1) Count the number of videos in each example
|
| 1452 |
-
if self.config.use_frame_special_tokens:
|
| 1453 |
-
end_token_id = self.config.frame_end_token_id
|
| 1454 |
-
else:
|
| 1455 |
-
end_token_id = self.config.image_end_token_id
|
| 1456 |
-
counts = (input_ids == end_token_id).any(dim=1).long() # [N]
|
| 1457 |
-
N = counts.size(0)
|
| 1458 |
-
device = input_ids.device
|
| 1459 |
-
|
| 1460 |
-
# Total number of videos in the batch
|
| 1461 |
-
num_videos = int(counts.sum().item())
|
| 1462 |
-
|
| 1463 |
-
# Sanity check
|
| 1464 |
-
assert video_grids.size(0) == num_videos, \
|
| 1465 |
-
f"Expected {num_videos} videos, but got {video_grids.size(0)}"
|
| 1466 |
-
|
| 1467 |
-
video_num_frames = video_grids[:, 0] # [num_videos]
|
| 1468 |
-
num_pooled_patches_per_video = video_grids.prod(dim=1) # [num_videos]
|
| 1469 |
-
|
| 1470 |
-
# pixel_values_videos: [n_frames, n_patches, pixels_per_patch]
|
| 1471 |
-
n_frames, n_patches, pixels_per_patch = pixel_values_videos.shape
|
| 1472 |
-
|
| 1473 |
-
# 2) Map each video index -> example index
|
| 1474 |
-
# Example: if counts = [2, 1, 3], then this becomes [0,0,1,2,2,2]
|
| 1475 |
-
example_ids_for_video = torch.arange(N, device=device).repeat_interleave(counts) # [num_videos]
|
| 1476 |
-
assert example_ids_for_video.numel() == num_videos
|
| 1477 |
-
|
| 1478 |
-
# 2-1) Compute frames_per_example by summing per-video frame counts
|
| 1479 |
-
frames_per_example = torch.zeros(
|
| 1480 |
-
N, dtype=video_num_frames.dtype, device=device,
|
| 1481 |
-
)
|
| 1482 |
-
frames_per_example.index_add_(0, example_ids_for_video, video_num_frames) # [N]
|
| 1483 |
-
|
| 1484 |
-
# 2-2) Compute num_pooled_patches_per_example
|
| 1485 |
-
num_pooled_patches_per_example = torch.zeros(
|
| 1486 |
-
N, dtype=num_pooled_patches_per_video.dtype, device=num_pooled_patches_per_video.device,
|
| 1487 |
-
)
|
| 1488 |
-
num_pooled_patches_per_example.index_add_(
|
| 1489 |
-
0, example_ids_for_video, num_pooled_patches_per_video,
|
| 1490 |
-
)
|
| 1491 |
-
|
| 1492 |
-
# Sanity checks
|
| 1493 |
-
total_frames = int(frames_per_example.sum().item())
|
| 1494 |
-
assert total_frames == n_frames, \
|
| 1495 |
-
f"Expected {total_frames} frames, but got {n_frames}"
|
| 1496 |
-
|
| 1497 |
-
total_num_pooled_patches = int(num_pooled_patches_per_example.sum().item())
|
| 1498 |
-
assert total_num_pooled_patches == video_token_pooling.size(0), \
|
| 1499 |
-
f"Expected {total_num_pooled_patches} pooled patches, but got {video_token_pooling.size(0)}"
|
| 1500 |
-
|
| 1501 |
-
# 3) Build videos tensor filled with -1
|
| 1502 |
-
M = int(frames_per_example.max().item())
|
| 1503 |
-
videos = torch.full(
|
| 1504 |
-
(N, M, n_patches, pixels_per_patch),
|
| 1505 |
-
fill_value=-1,
|
| 1506 |
-
dtype=pixel_values_videos.dtype,
|
| 1507 |
-
device=device,
|
| 1508 |
-
)
|
| 1509 |
-
|
| 1510 |
-
# 4) Fill videos with per-examples slices from pixel_values_videos
|
| 1511 |
-
offset_frame = 0
|
| 1512 |
-
for i in range(N):
|
| 1513 |
-
num = int(frames_per_example[i].item())
|
| 1514 |
-
cur = pixel_values_videos[offset_frame:offset_frame + num] # [num, n_patches, pixels_per_patch]
|
| 1515 |
-
videos[i, :num] = cur
|
| 1516 |
-
offset_frame += num
|
| 1517 |
-
|
| 1518 |
-
# Sanity check
|
| 1519 |
-
assert offset_frame == n_frames
|
| 1520 |
-
|
| 1521 |
-
# 5) Build new token_pooling tensor filled with -1
|
| 1522 |
-
P = int(num_pooled_patches_per_example.max().item())
|
| 1523 |
-
_, dim = video_token_pooling.shape
|
| 1524 |
-
new_token_pooling = torch.full(
|
| 1525 |
-
(N, P, dim),
|
| 1526 |
-
fill_value=-1,
|
| 1527 |
-
dtype=video_token_pooling.dtype,
|
| 1528 |
-
device=video_token_pooling.device,
|
| 1529 |
-
)
|
| 1530 |
-
|
| 1531 |
-
# 6) Fill new token_pooling with per-examples slices from video_token_pooling
|
| 1532 |
-
patch_offset = 0
|
| 1533 |
-
for i in range(N):
|
| 1534 |
-
num_patches = int(num_pooled_patches_per_example[i].item())
|
| 1535 |
-
cur = video_token_pooling[patch_offset:patch_offset + num_patches] # [num_patches, dim]
|
| 1536 |
-
new_token_pooling[i, :num_patches] = cur
|
| 1537 |
-
patch_offset += num_patches
|
| 1538 |
-
|
| 1539 |
-
# Final sanity checks
|
| 1540 |
-
assert patch_offset == total_num_pooled_patches
|
| 1541 |
-
|
| 1542 |
-
return videos, new_token_pooling
|
| 1543 |
-
|
| 1544 |
-
def merge_visual_inputs(
|
| 1545 |
-
self,
|
| 1546 |
-
input_ids: Optional[torch.LongTensor] = None,
|
| 1547 |
-
pixel_values: Optional[torch.Tensor] = None,
|
| 1548 |
-
image_token_pooling: Optional[torch.Tensor] = None,
|
| 1549 |
-
image_grids: Optional[torch.Tensor] = None,
|
| 1550 |
-
image_num_crops: Optional[torch.Tensor] = None,
|
| 1551 |
-
pixel_values_videos: Optional[torch.Tensor] = None,
|
| 1552 |
-
video_token_pooling: Optional[torch.Tensor] = None,
|
| 1553 |
-
video_grids: Optional[torch.Tensor] = None,
|
| 1554 |
-
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 1555 |
-
if pixel_values is not None and pixel_values_videos is not None:
|
| 1556 |
-
raise ValueError("pixel_values and pixel_values_videos are provided at the same time")
|
| 1557 |
-
elif pixel_values is not None:
|
| 1558 |
-
assert input_ids is not None
|
| 1559 |
-
images, token_pooling = self.build_batched_images(
|
| 1560 |
-
input_ids=input_ids,
|
| 1561 |
-
pixel_values=pixel_values,
|
| 1562 |
-
image_token_pooling=image_token_pooling,
|
| 1563 |
-
image_grids=image_grids,
|
| 1564 |
-
image_num_crops=image_num_crops,
|
| 1565 |
-
)
|
| 1566 |
-
elif pixel_values_videos is not None:
|
| 1567 |
-
assert input_ids is not None
|
| 1568 |
-
images, token_pooling = self.build_batched_videos(
|
| 1569 |
-
input_ids=input_ids,
|
| 1570 |
-
pixel_values_videos=pixel_values_videos,
|
| 1571 |
-
video_token_pooling=video_token_pooling,
|
| 1572 |
-
video_grids=video_grids,
|
| 1573 |
-
)
|
| 1574 |
-
else:
|
| 1575 |
-
images, token_pooling = None, None
|
| 1576 |
-
return images, token_pooling
|
| 1577 |
-
|
| 1578 |
-
def build_input_embeddings(
|
| 1579 |
-
self,
|
| 1580 |
-
input_ids: torch.LongTensor,
|
| 1581 |
-
images: Optional[torch.FloatTensor] = None, # image inputs
|
| 1582 |
-
token_pooling: Optional[torch.LongTensor] = None,
|
| 1583 |
-
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 1584 |
-
|
| 1585 |
-
# Get embeddings of input.
|
| 1586 |
-
# shape: (batch_size, seq_len, d_model)
|
| 1587 |
-
input_ids = input_ids * (input_ids != -1).to(input_ids.dtype)
|
| 1588 |
-
x = self.transformer.wte(input_ids)
|
| 1589 |
-
|
| 1590 |
-
image_features: Optional[torch.FloatTensor] = None
|
| 1591 |
-
if images is not None:
|
| 1592 |
-
image_features = self.vision_backbone(images, token_pooling).to(x.device)
|
| 1593 |
-
is_image_patch = input_ids.view(-1) == self.config.image_patch_id
|
| 1594 |
-
assert is_image_patch.sum() == len(image_features)
|
| 1595 |
-
x.view(-1, x.shape[-1])[is_image_patch] += image_features
|
| 1596 |
-
|
| 1597 |
-
# shape: (batch_size, seq_len, d_model)
|
| 1598 |
-
x = self.transformer.emb_drop(x) # type: ignore
|
| 1599 |
-
|
| 1600 |
-
return x, image_features
|
| 1601 |
-
|
| 1602 |
-
@can_return_tuple
|
| 1603 |
-
def forward(
|
| 1604 |
-
self,
|
| 1605 |
-
input_ids: Optional[torch.LongTensor] = None,
|
| 1606 |
-
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1607 |
-
image_token_pooling: Optional[torch.Tensor] = None,
|
| 1608 |
-
image_grids: Optional[torch.Tensor] = None,
|
| 1609 |
-
image_num_crops: Optional[torch.Tensor] = None,
|
| 1610 |
-
pixel_values_videos: Optional[torch.Tensor] = None,
|
| 1611 |
-
video_token_pooling: Optional[torch.Tensor] = None,
|
| 1612 |
-
video_grids: Optional[torch.Tensor] = None,
|
| 1613 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 1614 |
-
position_ids: Optional[torch.Tensor] = None,
|
| 1615 |
-
past_key_values: Optional[Cache] = None,
|
| 1616 |
-
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1617 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1618 |
-
use_cache: Optional[bool] = None,
|
| 1619 |
-
output_attentions: Optional[bool] = None,
|
| 1620 |
-
output_hidden_states: Optional[bool] = None,
|
| 1621 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 1622 |
-
**kwargs: Unpack[TransformersKwargs],
|
| 1623 |
-
) -> Union[tuple, MolmoPointModelOutputWithPast]:
|
| 1624 |
-
|
| 1625 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1626 |
-
output_hidden_states = (
|
| 1627 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1628 |
-
)
|
| 1629 |
-
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 1630 |
-
|
| 1631 |
-
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 1632 |
-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 1633 |
-
|
| 1634 |
-
images, token_pooling = self.merge_visual_inputs(
|
| 1635 |
-
input_ids=input_ids,
|
| 1636 |
-
pixel_values=pixel_values,
|
| 1637 |
-
image_token_pooling=image_token_pooling,
|
| 1638 |
-
image_grids=image_grids,
|
| 1639 |
-
image_num_crops=image_num_crops,
|
| 1640 |
-
pixel_values_videos=pixel_values_videos,
|
| 1641 |
-
video_token_pooling=video_token_pooling,
|
| 1642 |
-
video_grids=video_grids,
|
| 1643 |
-
)
|
| 1644 |
-
|
| 1645 |
-
if images is not None and inputs_embeds is not None:
|
| 1646 |
-
raise ValueError(
|
| 1647 |
-
"You cannot specify both images and inputs_embeds at the same time."
|
| 1648 |
-
)
|
| 1649 |
-
|
| 1650 |
-
if inputs_embeds is None:
|
| 1651 |
-
inputs_embeds, image_features = self.build_input_embeddings(
|
| 1652 |
-
input_ids, images, token_pooling,
|
| 1653 |
-
)
|
| 1654 |
-
|
| 1655 |
-
if cache_position is None:
|
| 1656 |
-
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 1657 |
-
cache_position = torch.arange(
|
| 1658 |
-
past_seen_tokens,
|
| 1659 |
-
past_seen_tokens + inputs_embeds.shape[1],
|
| 1660 |
-
device=inputs_embeds.device,
|
| 1661 |
-
)
|
| 1662 |
-
|
| 1663 |
-
# Adapted from transformers.models.gemma3.modeling_gemma3
|
| 1664 |
-
# It may already have been prepared by e.g. `generate`
|
| 1665 |
-
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
| 1666 |
-
# Prepare mask arguments
|
| 1667 |
-
mask_kwargs = {
|
| 1668 |
-
"config": self.config.get_text_config(),
|
| 1669 |
-
"input_embeds": inputs_embeds,
|
| 1670 |
-
"attention_mask": attention_mask,
|
| 1671 |
-
"cache_position": cache_position,
|
| 1672 |
-
"past_key_values": past_key_values,
|
| 1673 |
-
"position_ids": position_ids,
|
| 1674 |
-
}
|
| 1675 |
-
|
| 1676 |
-
# NOTE: this `is_prefill` logic is not flawless, it fails when we're using a cache eagerly initialized
|
| 1677 |
-
# (e.g. compiled prefill) AND `images` are not provided. Determining prefill in that case requires
|
| 1678 |
-
# checking data values, which is not compile-compatible.
|
| 1679 |
-
is_prefill = (
|
| 1680 |
-
not use_cache
|
| 1681 |
-
or past_key_values is None
|
| 1682 |
-
or not past_key_values.is_initialized
|
| 1683 |
-
or images is not None
|
| 1684 |
-
)
|
| 1685 |
-
if token_type_ids is not None and is_prefill:
|
| 1686 |
-
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
|
| 1687 |
-
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
|
| 1688 |
-
token_type_ids.to(cache_position.device)
|
| 1689 |
-
)
|
| 1690 |
-
|
| 1691 |
-
# Create the mask
|
| 1692 |
-
causal_mask_mapping = create_causal_mask(**mask_kwargs)
|
| 1693 |
-
|
| 1694 |
-
outputs = self.transformer(
|
| 1695 |
-
attention_mask=causal_mask_mapping,
|
| 1696 |
-
position_ids=position_ids,
|
| 1697 |
-
past_key_values=past_key_values,
|
| 1698 |
-
inputs_embeds=inputs_embeds,
|
| 1699 |
-
use_cache=use_cache,
|
| 1700 |
-
output_attentions=output_attentions,
|
| 1701 |
-
output_hidden_states=output_hidden_states,
|
| 1702 |
-
cache_position=cache_position,
|
| 1703 |
-
**kwargs,
|
| 1704 |
-
)
|
| 1705 |
-
|
| 1706 |
-
return MolmoPointModelOutputWithPast(
|
| 1707 |
-
last_hidden_state=outputs.last_hidden_state,
|
| 1708 |
-
past_key_values=outputs.past_key_values,
|
| 1709 |
-
hidden_states=outputs.hidden_states,
|
| 1710 |
-
attentions=outputs.attentions,
|
| 1711 |
-
image_hidden_states=image_features if images is not None else None,
|
| 1712 |
-
)
|
| 1713 |
-
|
| 1714 |
-
|
| 1715 |
-
class MolmoPointForConditionalGeneration(MolmoPointPreTrainedModel, GenerationMixin):
|
| 1716 |
-
_checkpoint_conversion_mapping = {}
|
| 1717 |
-
_tied_weights_keys = [] # Weights are not tied
|
| 1718 |
-
# Reference: fix gemma3 grad acc #37208
|
| 1719 |
-
accepts_loss_kwargs = False
|
| 1720 |
-
config: MolmoPointConfig
|
| 1721 |
-
|
| 1722 |
-
def __init__(self, config: MolmoPointConfig):
|
| 1723 |
-
super().__init__(config)
|
| 1724 |
-
|
| 1725 |
-
self.model = MolmoPointModel(config)
|
| 1726 |
-
self.output_embeddings = nn.Parameter(torch.zeros([config.vocab_size, config.hidden_size]))
|
| 1727 |
-
self.new_output_embeddings = nn.Parameter(torch.zeros([128, config.hidden_size]))
|
| 1728 |
-
self.vocab_size = config.vocab_size
|
| 1729 |
-
|
| 1730 |
-
# Initialize weights and apply final processing
|
| 1731 |
-
self.post_init()
|
| 1732 |
-
|
| 1733 |
-
def get_input_embeddings(self) -> torch.nn.Module:
|
| 1734 |
-
return self.model.transformer.wte
|
| 1735 |
-
|
| 1736 |
-
def set_input_embeddings(self, value: torch.nn.Module) -> None:
|
| 1737 |
-
self.model.transformer.wte = value
|
| 1738 |
-
|
| 1739 |
-
def set_decoder(self, decoder):
|
| 1740 |
-
self.model.set_decoder(decoder)
|
| 1741 |
-
|
| 1742 |
-
def get_decoder(self):
|
| 1743 |
-
return self.model.get_decoder()
|
| 1744 |
-
|
| 1745 |
-
# Make modules available throught conditional class for BC
|
| 1746 |
-
@property
|
| 1747 |
-
def language_model(self) -> torch.nn.Module:
|
| 1748 |
-
return self.model.transformer
|
| 1749 |
-
|
| 1750 |
-
@property
|
| 1751 |
-
def vision_backbone(self) -> torch.nn.Module:
|
| 1752 |
-
return self.model.vision_backbone
|
| 1753 |
-
|
| 1754 |
-
@can_return_tuple
|
| 1755 |
-
def forward(
|
| 1756 |
-
self,
|
| 1757 |
-
input_ids: torch.LongTensor = None,
|
| 1758 |
-
pixel_values: Optional[torch.Tensor] = None,
|
| 1759 |
-
image_token_pooling: Optional[torch.Tensor] = None,
|
| 1760 |
-
image_grids: Optional[torch.Tensor] = None,
|
| 1761 |
-
image_num_crops: Optional[torch.Tensor] = None,
|
| 1762 |
-
pixel_values_videos: Optional[torch.Tensor] = None,
|
| 1763 |
-
video_token_pooling: Optional[torch.Tensor] = None,
|
| 1764 |
-
video_grids: Optional[torch.Tensor] = None,
|
| 1765 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 1766 |
-
position_ids: Optional[torch.LongTensor] = None,
|
| 1767 |
-
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
| 1768 |
-
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1769 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1770 |
-
labels: Optional[torch.LongTensor] = None,
|
| 1771 |
-
use_cache: Optional[bool] = None,
|
| 1772 |
-
output_attentions: Optional[bool] = None,
|
| 1773 |
-
output_hidden_states: Optional[bool] = None,
|
| 1774 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 1775 |
-
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 1776 |
-
**kwargs: Unpack[TransformersKwargs],
|
| 1777 |
-
) -> Union[tuple, MolmoPointCausalLMOutputWithPast]:
|
| 1778 |
-
r"""
|
| 1779 |
-
```python
|
| 1780 |
-
>>> from PIL import Image
|
| 1781 |
-
>>> import requests
|
| 1782 |
-
>>> from transformers import AutoProcessor, MolmoPointForConditionalGeneration
|
| 1783 |
-
|
| 1784 |
-
>>> model = Molmo2ForConditionalGeneration.from_pretrained("...")
|
| 1785 |
-
>>> processor = AutoProcessor.from_pretrained("...")
|
| 1786 |
-
|
| 1787 |
-
>>> prompt = "What's the content of the image?"
|
| 1788 |
-
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
| 1789 |
-
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1790 |
-
|
| 1791 |
-
>>> messages = [{"role": "user", "content": [{"type": "text", "text": prompt}, {"type": "image", "image": image}]}]
|
| 1792 |
-
|
| 1793 |
-
>>> inputs = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True)
|
| 1794 |
-
|
| 1795 |
-
>>> # Generate
|
| 1796 |
-
>>> generated_ids = model.generate(**inputs, max_new_tokens=15)
|
| 1797 |
-
>>> generated_tokens = generated_ids[:, inputs['input_ids'].size(1):]
|
| 1798 |
-
>>> processor.post_process_image_text_to_text(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 1799 |
-
"The image shows a bustling street scene in what appears to be a Chinatown area. There's ..."
|
| 1800 |
-
```"""
|
| 1801 |
-
outputs = self.model(
|
| 1802 |
-
input_ids=input_ids,
|
| 1803 |
-
pixel_values=pixel_values,
|
| 1804 |
-
image_token_pooling=image_token_pooling,
|
| 1805 |
-
image_grids=image_grids,
|
| 1806 |
-
image_num_crops=image_num_crops,
|
| 1807 |
-
pixel_values_videos=pixel_values_videos,
|
| 1808 |
-
video_token_pooling=video_token_pooling,
|
| 1809 |
-
video_grids=video_grids,
|
| 1810 |
-
attention_mask=attention_mask,
|
| 1811 |
-
position_ids=position_ids,
|
| 1812 |
-
past_key_values=past_key_values,
|
| 1813 |
-
token_type_ids=token_type_ids,
|
| 1814 |
-
inputs_embeds=inputs_embeds,
|
| 1815 |
-
use_cache=use_cache,
|
| 1816 |
-
output_attentions=output_attentions,
|
| 1817 |
-
output_hidden_states=output_hidden_states,
|
| 1818 |
-
cache_position=cache_position,
|
| 1819 |
-
**kwargs,
|
| 1820 |
-
)
|
| 1821 |
-
|
| 1822 |
-
hidden_states = outputs.last_hidden_state
|
| 1823 |
-
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 1824 |
-
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 1825 |
-
lm_head = torch.concatenate([self.output_embeddings, self.new_output_embeddings], dim=0)
|
| 1826 |
-
logits = F.linear(hidden_states[:, slice_indices, :], lm_head)
|
| 1827 |
-
|
| 1828 |
-
loss = None
|
| 1829 |
-
if labels is not None:
|
| 1830 |
-
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size)
|
| 1831 |
-
|
| 1832 |
-
return MolmoPointCausalLMOutputWithPast(
|
| 1833 |
-
loss=loss,
|
| 1834 |
-
logits=logits,
|
| 1835 |
-
past_key_values=outputs.past_key_values,
|
| 1836 |
-
hidden_states=outputs.hidden_states,
|
| 1837 |
-
attentions=outputs.attentions,
|
| 1838 |
-
image_hidden_states=outputs.image_hidden_states,
|
| 1839 |
-
)
|
| 1840 |
-
|
| 1841 |
-
def prepare_inputs_for_generation(
|
| 1842 |
-
self,
|
| 1843 |
-
input_ids: torch.LongTensor,
|
| 1844 |
-
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
| 1845 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1846 |
-
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1847 |
-
image_token_pooling: Optional[torch.Tensor] = None,
|
| 1848 |
-
image_grids: Optional[torch.Tensor] = None,
|
| 1849 |
-
image_num_crops: Optional[torch.Tensor] = None,
|
| 1850 |
-
pixel_values_videos: Optional[torch.Tensor] = None,
|
| 1851 |
-
video_token_pooling: Optional[torch.Tensor] = None,
|
| 1852 |
-
video_grids: Optional[torch.Tensor] = None,
|
| 1853 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 1854 |
-
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1855 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 1856 |
-
logits_to_keep: Optional[Union[int, torch.Tensor]] = None,
|
| 1857 |
-
**kwargs,
|
| 1858 |
-
):
|
| 1859 |
-
|
| 1860 |
-
model_inputs = super().prepare_inputs_for_generation(
|
| 1861 |
-
input_ids,
|
| 1862 |
-
past_key_values=past_key_values,
|
| 1863 |
-
inputs_embeds=inputs_embeds,
|
| 1864 |
-
attention_mask=attention_mask,
|
| 1865 |
-
cache_position=cache_position,
|
| 1866 |
-
logits_to_keep=logits_to_keep,
|
| 1867 |
-
token_type_ids=token_type_ids,
|
| 1868 |
-
**kwargs,
|
| 1869 |
-
)
|
| 1870 |
-
|
| 1871 |
-
if cache_position[0] == 0:
|
| 1872 |
-
model_inputs["pixel_values"] = pixel_values
|
| 1873 |
-
model_inputs["image_token_pooling"] = image_token_pooling
|
| 1874 |
-
model_inputs["image_grids"] = image_grids
|
| 1875 |
-
model_inputs["image_num_crops"] = image_num_crops
|
| 1876 |
-
model_inputs["pixel_values_videos"] = pixel_values_videos
|
| 1877 |
-
model_inputs["video_token_pooling"] = video_token_pooling
|
| 1878 |
-
model_inputs["video_grids"] = video_grids
|
| 1879 |
-
|
| 1880 |
-
return model_inputs
|
| 1881 |
-
|
| 1882 |
-
# Adapted from transformers.models.gemma3.modeling_gemma3
|
| 1883 |
-
@staticmethod
|
| 1884 |
-
def create_masks_for_generate(
|
| 1885 |
-
config: PretrainedConfig,
|
| 1886 |
-
input_embeds: torch.Tensor,
|
| 1887 |
-
attention_mask: Optional[torch.Tensor],
|
| 1888 |
-
cache_position: torch.Tensor,
|
| 1889 |
-
past_key_values: Optional[Cache],
|
| 1890 |
-
position_ids: Optional[torch.Tensor],
|
| 1891 |
-
token_type_ids: Optional[torch.Tensor] = None,
|
| 1892 |
-
**kwargs,
|
| 1893 |
-
) -> dict:
|
| 1894 |
-
# Prepare mask arguments
|
| 1895 |
-
mask_kwargs = {
|
| 1896 |
-
"config": config.get_text_config(),
|
| 1897 |
-
"input_embeds": input_embeds,
|
| 1898 |
-
"attention_mask": attention_mask,
|
| 1899 |
-
"cache_position": cache_position,
|
| 1900 |
-
"past_key_values": past_key_values,
|
| 1901 |
-
"position_ids": position_ids,
|
| 1902 |
-
}
|
| 1903 |
-
# Add the token type ids mask for generate as well
|
| 1904 |
-
if token_type_ids is not None and input_embeds.shape[1] != 1:
|
| 1905 |
-
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
|
| 1906 |
-
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
|
| 1907 |
-
token_type_ids.to(cache_position.device)
|
| 1908 |
-
)
|
| 1909 |
-
|
| 1910 |
-
return create_masks_for_generate(**mask_kwargs)
|
| 1911 |
-
|
| 1912 |
-
|
| 1913 |
-
# Always register for multi-modal features
|
| 1914 |
-
AutoModelForImageTextToText.register(MolmoPointConfig, MolmoPointForConditionalGeneration)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unified_demo.py
DELETED
|
@@ -1,334 +0,0 @@
|
|
| 1 |
-
import functools
|
| 2 |
-
import os
|
| 3 |
-
import argparse
|
| 4 |
-
import logging
|
| 5 |
-
from collections import defaultdict
|
| 6 |
-
from PIL import Image, ImageFile, ImageDraw
|
| 7 |
-
import PIL
|
| 8 |
-
|
| 9 |
-
import numpy as np
|
| 10 |
-
import torch
|
| 11 |
-
from transformers import AutoProcessor, AutoModelForImageTextToText
|
| 12 |
-
|
| 13 |
-
from olmo.models.video_olmo.video_olmo import VideoOlmoConfig
|
| 14 |
-
from olmo.html_utils import postprocess_prompt
|
| 15 |
-
from olmo.util import (
|
| 16 |
-
prepare_cli_environment,
|
| 17 |
-
resource_path,
|
| 18 |
-
)
|
| 19 |
-
|
| 20 |
-
import gradio as gr
|
| 21 |
-
|
| 22 |
-
try:
|
| 23 |
-
from molmo_utils import process_vision_info
|
| 24 |
-
except ImportError:
|
| 25 |
-
# raise ImportError("molmo_utils not found. Please install it with `pip install molmo-utils`.")
|
| 26 |
-
pass
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
Image.MAX_IMAGE_PIXELS = None
|
| 30 |
-
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
CACHE = "model_cache"
|
| 34 |
-
log = logging.getLogger(__name__)
|
| 35 |
-
ALLOWED_PATH = [CACHE]
|
| 36 |
-
MAX_IMAGE_SIZE = 512
|
| 37 |
-
MAX_VIDEO_HEIGHT = 512
|
| 38 |
-
POINT_SIZE = 0.01
|
| 39 |
-
|
| 40 |
-
DEVICE = None
|
| 41 |
-
|
| 42 |
-
# load the model, processor
|
| 43 |
-
MODEL = None
|
| 44 |
-
PROCESSOR = None
|
| 45 |
-
POINT_FORMATTER = None
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
def draw_points(image, points):
|
| 49 |
-
if isinstance(image, np.ndarray):
|
| 50 |
-
annotation = PIL.Image.fromarray(image)
|
| 51 |
-
else:
|
| 52 |
-
annotation = image.copy()
|
| 53 |
-
draw = ImageDraw.Draw(annotation)
|
| 54 |
-
w, h = annotation.size
|
| 55 |
-
size = max(5, int(max(w, h) * POINT_SIZE))
|
| 56 |
-
for x, y in points:
|
| 57 |
-
draw.ellipse((x-size, y-size, x+size, y+size), fill="rgb(240, 82, 156)", outline=None)
|
| 58 |
-
return annotation
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def get_message(
|
| 62 |
-
images: list[Image.Image] | None,
|
| 63 |
-
video_path: str | None,
|
| 64 |
-
max_frames: int,
|
| 65 |
-
frame_sample_mode: str,
|
| 66 |
-
max_fps: int | None,
|
| 67 |
-
sampling_fps: int | None,
|
| 68 |
-
input_text: str,
|
| 69 |
-
style: str,
|
| 70 |
-
):
|
| 71 |
-
content = [
|
| 72 |
-
dict(type="text", text=input_text, stye=style)
|
| 73 |
-
]
|
| 74 |
-
if images:
|
| 75 |
-
image_content = [
|
| 76 |
-
dict(type="image", image=image)
|
| 77 |
-
for image in images
|
| 78 |
-
]
|
| 79 |
-
content.extend(image_content)
|
| 80 |
-
if video_path:
|
| 81 |
-
video_kwargs = {
|
| 82 |
-
"num_frames": max_frames,
|
| 83 |
-
"frame_sample_mode": frame_sample_mode,
|
| 84 |
-
}
|
| 85 |
-
if max_fps is not None:
|
| 86 |
-
video_kwargs["max_fps"] = max_fps
|
| 87 |
-
if sampling_fps is not None:
|
| 88 |
-
video_kwargs["sampling_fps"] = sampling_fps
|
| 89 |
-
video_content = dict(type="video", video=video_path, **video_kwargs)
|
| 90 |
-
content.append(video_content)
|
| 91 |
-
|
| 92 |
-
return [
|
| 93 |
-
{
|
| 94 |
-
"role": "user",
|
| 95 |
-
"content": content,
|
| 96 |
-
}
|
| 97 |
-
]
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
def cast_float_dtype(t: torch.Tensor):
|
| 101 |
-
if torch.is_floating_point(t):
|
| 102 |
-
t = t.to(torch.bfloat16)
|
| 103 |
-
return t
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
def run_single_inference(*inputs, annotations=None):
|
| 107 |
-
video_path, images, input_text, style, frame_sample_mode, max_frames, max_fps, sampling_fps, max_steps = inputs
|
| 108 |
-
assert images is not None or video_path is not None, "Either images or video_path must be provided"
|
| 109 |
-
assert images is None or video_path is None, "Both images and video_path cannot be provided at the same time"
|
| 110 |
-
nimages = 0
|
| 111 |
-
if images:
|
| 112 |
-
images = [t[0] for t in images]
|
| 113 |
-
nimages = len(images)
|
| 114 |
-
logging.info(f"# of images: {nimages}")
|
| 115 |
-
|
| 116 |
-
messages = get_message(
|
| 117 |
-
images=images,
|
| 118 |
-
video_path=video_path,
|
| 119 |
-
max_frames=max_frames,
|
| 120 |
-
frame_sample_mode=frame_sample_mode,
|
| 121 |
-
max_fps=max_fps,
|
| 122 |
-
sampling_fps=sampling_fps,
|
| 123 |
-
input_text=input_text,
|
| 124 |
-
style=style,
|
| 125 |
-
)
|
| 126 |
-
images, videos, video_kwargs = process_vision_info(messages)
|
| 127 |
-
if videos:
|
| 128 |
-
videos, video_metadatas = zip(*videos)
|
| 129 |
-
videos, video_metadatas = list(videos), list(video_metadatas)
|
| 130 |
-
logging.info(
|
| 131 |
-
f"Videos: {videos[0].shape}, frame_sample_mode: {frame_sample_mode}, "
|
| 132 |
-
f"max_frames: {max_frames}, max_fps: {max_fps}, sampling_fps: {sampling_fps}"
|
| 133 |
-
)
|
| 134 |
-
else:
|
| 135 |
-
video_metadatas = None
|
| 136 |
-
logging.info(f"Running inference for prompt: \"{input_text}\", style={style} steps={max_steps}")
|
| 137 |
-
text = PROCESSOR.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 138 |
-
|
| 139 |
-
inputs = PROCESSOR(
|
| 140 |
-
images=images,
|
| 141 |
-
videos=videos,
|
| 142 |
-
video_metadata=video_metadatas,
|
| 143 |
-
text=text,
|
| 144 |
-
padding=True,
|
| 145 |
-
return_tensors="pt",
|
| 146 |
-
**video_kwargs,
|
| 147 |
-
)
|
| 148 |
-
|
| 149 |
-
if MODEL.config.dtype == torch.bfloat16:
|
| 150 |
-
inputs = {k: cast_float_dtype(v.to(DEVICE)) for k, v in inputs.items()}
|
| 151 |
-
else:
|
| 152 |
-
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
| 153 |
-
with torch.inference_mode():
|
| 154 |
-
if MODEL.config.dtype == torch.bfloat16:
|
| 155 |
-
output = MODEL.generate(**inputs, max_new_tokens=max_steps)
|
| 156 |
-
else:
|
| 157 |
-
with torch.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
| 158 |
-
output = MODEL.generate(**inputs, max_new_tokens=max_steps)
|
| 159 |
-
prompts = output[0, :inputs['input_ids'].size(1)]
|
| 160 |
-
prompt_text = PROCESSOR.decode(prompts, skip_special_tokens=False)
|
| 161 |
-
prompt_text = postprocess_prompt(prompt_text)
|
| 162 |
-
logging.info(f"hf prompt: {prompt_text}")
|
| 163 |
-
generated_tokens = output[0, inputs['input_ids'].size(1):]
|
| 164 |
-
generated_text = PROCESSOR.decode(generated_tokens, skip_special_tokens=True)
|
| 165 |
-
logging.info(f"hf generated_text: {generated_text}")
|
| 166 |
-
if annotations:
|
| 167 |
-
if video_path is None and nimages == 1:
|
| 168 |
-
w, h = images[0].size
|
| 169 |
-
points = POINT_FORMATTER.extract_points(generated_text, w, h)
|
| 170 |
-
if points:
|
| 171 |
-
return generated_text, [draw_points(images[0], points)]
|
| 172 |
-
else:
|
| 173 |
-
return generated_text, []
|
| 174 |
-
elif video_path is None and nimages > 1:
|
| 175 |
-
w, h = [x.size[0] for x in images], [x.size[1] for x in images]
|
| 176 |
-
points = POINT_FORMATTER.extract_multi_image_points(generated_text, w, h)
|
| 177 |
-
if points:
|
| 178 |
-
group_by_index = defaultdict(list)
|
| 179 |
-
for ix, x, y in points:
|
| 180 |
-
group_by_index[ix].append((x, y))
|
| 181 |
-
out = []
|
| 182 |
-
for ix, points in group_by_index.items():
|
| 183 |
-
out.append(draw_points(images[ix-1], points))
|
| 184 |
-
return generated_text, out
|
| 185 |
-
else:
|
| 186 |
-
return generated_text, []
|
| 187 |
-
else:
|
| 188 |
-
h, w = videos[0].shape[1:3]
|
| 189 |
-
group_by_time = defaultdict(list)
|
| 190 |
-
points = POINT_FORMATTER.extract_multi_image_points(generated_text, w, h)
|
| 191 |
-
if points:
|
| 192 |
-
for ts, x, y in points:
|
| 193 |
-
group_by_time[ts].append((x, y))
|
| 194 |
-
else:
|
| 195 |
-
track = POINT_FORMATTER.extract_trajectories(generated_text, w, h, 30)
|
| 196 |
-
for ex in track:
|
| 197 |
-
group_by_time[ex["time"]] = [(x["x"], x["y"]) for x in ex["points"]]
|
| 198 |
-
grouped_by_frame = defaultdict(list)
|
| 199 |
-
for ts, points in group_by_time.items():
|
| 200 |
-
timestamps = video_metadatas[0]["frames_indices"] / video_metadatas[0]["fps"]
|
| 201 |
-
ix = int(np.argmin(np.abs(timestamps - ts)))
|
| 202 |
-
grouped_by_frame[ix] += points
|
| 203 |
-
out = []
|
| 204 |
-
for ix, points in grouped_by_frame.items():
|
| 205 |
-
out.append(draw_points(videos[0][ix], points))
|
| 206 |
-
return generated_text, out
|
| 207 |
-
return generated_text
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
def main():
|
| 211 |
-
parser = argparse.ArgumentParser()
|
| 212 |
-
parser.add_argument("ckpt_home", type=str)
|
| 213 |
-
parser.add_argument("--server_name")
|
| 214 |
-
parser.add_argument("--default_max_tokens", type=int, default=2048)
|
| 215 |
-
parser.add_argument("--cloudflare_tunnel", action="store_true")
|
| 216 |
-
parser.add_argument("--original_ckpt_home", type=str, default=None)
|
| 217 |
-
parser.add_argument("--annotations", action="store_true")
|
| 218 |
-
parser.add_argument("--no_share", action="store_true")
|
| 219 |
-
parser.add_argument("--port", type=int, default=7860)
|
| 220 |
-
args = parser.parse_args()
|
| 221 |
-
|
| 222 |
-
prepare_cli_environment()
|
| 223 |
-
|
| 224 |
-
global DEVICE, MODEL, PROCESSOR
|
| 225 |
-
if torch.cuda.is_available():
|
| 226 |
-
DEVICE = torch.device("cuda")
|
| 227 |
-
else:
|
| 228 |
-
logging.warning("No GPU available, using CPU")
|
| 229 |
-
DEVICE = torch.device("cpu")
|
| 230 |
-
if MODEL is not None:
|
| 231 |
-
MODEL.to(DEVICE)
|
| 232 |
-
|
| 233 |
-
MODEL = AutoModelForImageTextToText.from_pretrained(
|
| 234 |
-
args.ckpt_home,
|
| 235 |
-
trust_remote_code=True,
|
| 236 |
-
dtype="auto",
|
| 237 |
-
device_map="auto",
|
| 238 |
-
)
|
| 239 |
-
|
| 240 |
-
PROCESSOR = AutoProcessor.from_pretrained(
|
| 241 |
-
args.ckpt_home,
|
| 242 |
-
trust_remote_code=True,
|
| 243 |
-
dtype="auto",
|
| 244 |
-
device_map="auto",
|
| 245 |
-
padding_side="left",
|
| 246 |
-
)
|
| 247 |
-
|
| 248 |
-
if args.annotations:
|
| 249 |
-
assert args.original_ckpt_home is not None, "original_ckpt_home must be provided when annotations are enabled"
|
| 250 |
-
global POINT_FORMATTER
|
| 251 |
-
model_cfg_path = resource_path(args.original_ckpt_home, "config.yaml")
|
| 252 |
-
model_cfg = VideoOlmoConfig.load(model_cfg_path, key="model", validate_paths=False)
|
| 253 |
-
preprocessor = model_cfg.build_preprocessor(for_inference=True, is_training=False)
|
| 254 |
-
POINT_FORMATTER = preprocessor.formatter._point_formatter
|
| 255 |
-
|
| 256 |
-
CSS = """
|
| 257 |
-
#input_image image {
|
| 258 |
-
object-fit: contain !important;
|
| 259 |
-
}
|
| 260 |
-
#input_video video {
|
| 261 |
-
object-fit: contain !important;
|
| 262 |
-
}
|
| 263 |
-
"""
|
| 264 |
-
|
| 265 |
-
frame_sample_mode = PROCESSOR.video_processor.frame_sample_mode
|
| 266 |
-
max_frames = PROCESSOR.video_processor.num_frames
|
| 267 |
-
max_fps = PROCESSOR.video_processor.max_fps
|
| 268 |
-
sampling_fps = PROCESSOR.video_processor.sampling_fps
|
| 269 |
-
|
| 270 |
-
with gr.Blocks(css=CSS) as demo:
|
| 271 |
-
gr.Markdown(
|
| 272 |
-
f"""
|
| 273 |
-
## Molmo2 Demo
|
| 274 |
-
Provide either a video or images and a prompt for question answering.
|
| 275 |
-
"""
|
| 276 |
-
)
|
| 277 |
-
with gr.Row():
|
| 278 |
-
with gr.Tabs():
|
| 279 |
-
with gr.TabItem("video"):
|
| 280 |
-
video = gr.Video(label="Input Video", elem_id="input_video", height=MAX_VIDEO_HEIGHT)
|
| 281 |
-
with gr.TabItem("image(s)"):
|
| 282 |
-
images = gr.Gallery(label="Input Images", elem_id="input_image", type="pil", height=MAX_IMAGE_SIZE)
|
| 283 |
-
|
| 284 |
-
with gr.Row():
|
| 285 |
-
input_text = gr.Textbox(placeholder="Enter the prompt", label="Input text")
|
| 286 |
-
|
| 287 |
-
with gr.Row():
|
| 288 |
-
style = gr.Textbox(value="demo", label="style")
|
| 289 |
-
frame_sample_mode = gr.Textbox(value=frame_sample_mode, label="frame_sample_mode")
|
| 290 |
-
max_frames = gr.Number(value=max_frames, label="max_frames")
|
| 291 |
-
max_fps = gr.Number(value=max_fps, label="max_fps")
|
| 292 |
-
sampling_fps = gr.Number(value=sampling_fps, label="sampling_fps")
|
| 293 |
-
max_tok_slider = gr.Slider(label="max_tokens", minimum=1, maximum=4096, step=1, value=args.default_max_tokens)
|
| 294 |
-
|
| 295 |
-
with gr.Row():
|
| 296 |
-
submit_button = gr.Button("Submit", scale=3)
|
| 297 |
-
clear_all_button = gr.ClearButton(components=[video, images, input_text], value="Clear All", scale=1)
|
| 298 |
-
|
| 299 |
-
with gr.Row():
|
| 300 |
-
output_text = gr.Textbox(placeholder="Output text", label="Output text", lines=10)
|
| 301 |
-
|
| 302 |
-
if args.annotations:
|
| 303 |
-
with gr.Row():
|
| 304 |
-
output_annotations = gr.Gallery(label="Annotations", height=MAX_IMAGE_SIZE)
|
| 305 |
-
outputs = [output_text, output_annotations]
|
| 306 |
-
fn = functools.partial(run_single_inference, annotations="points")
|
| 307 |
-
else:
|
| 308 |
-
fn = run_single_inference
|
| 309 |
-
outputs = [output_text]
|
| 310 |
-
|
| 311 |
-
submit_button.click(
|
| 312 |
-
fn=fn,
|
| 313 |
-
inputs=[video, images, input_text, style, frame_sample_mode, max_frames, max_fps, sampling_fps, max_tok_slider],
|
| 314 |
-
outputs=outputs,
|
| 315 |
-
)
|
| 316 |
-
|
| 317 |
-
if args.cloudflare_tunnel:
|
| 318 |
-
import cloudflared_tunnel
|
| 319 |
-
with cloudflared_tunnel.run() as port:
|
| 320 |
-
demo.queue().launch(
|
| 321 |
-
share=False, show_error=True, max_threads=os.cpu_count() - 10, server_port=port,
|
| 322 |
-
allowed_paths=ALLOWED_PATH
|
| 323 |
-
)
|
| 324 |
-
else:
|
| 325 |
-
demo.queue().launch(
|
| 326 |
-
server_name=args.server_name,
|
| 327 |
-
share=not args.no_share, show_error=True, max_threads=os.cpu_count() - 10,
|
| 328 |
-
server_port=args.port,
|
| 329 |
-
allowed_paths=ALLOWED_PATH
|
| 330 |
-
)
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
if __name__ == "__main__":
|
| 334 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|