Upload folder using huggingface_hub
Browse files- convert_molmo_point_to_hf.py +7 -4
- model-00007-of-00008.safetensors +2 -2
- model.safetensors.index.json +13 -13
- modeling_molmo_point.py +199 -134
- processing_molmo2.py +0 -1
- video_processing_molmo2.py +2 -1
convert_molmo_point_to_hf.py
CHANGED
|
@@ -207,7 +207,14 @@ def convert_molmo2(
|
|
| 207 |
new_key = "lm_head.new_output_embeddings"
|
| 208 |
elif key == "transformer.ff_out.weight":
|
| 209 |
new_key = "lm_head.output_embeddings"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
else:
|
|
|
|
|
|
|
| 211 |
new_key = f"{base_model_prefix}.{key}"
|
| 212 |
new_state_dict[new_key] = val
|
| 213 |
|
|
@@ -215,10 +222,6 @@ def convert_molmo2(
|
|
| 215 |
qkv_bias = config.qkv_bias if isinstance(config, Molmo2TextConfig) else config.text_config.qkv_bias
|
| 216 |
use_qk_norm = config.use_qk_norm if isinstance(config, Molmo2TextConfig) else config.text_config.use_qk_norm
|
| 217 |
|
| 218 |
-
for param in list(new_state_dict.keys()):
|
| 219 |
-
if param.startswith(f"model.connectors.0"):
|
| 220 |
-
new_state_dict[param.replace("model.connectors.0", "model.connector")] = new_state_dict.pop(param)
|
| 221 |
-
|
| 222 |
for layer_i in range(config.num_hidden_layers):
|
| 223 |
prefix = f"{model_prefix}.blocks.{layer_i}"
|
| 224 |
|
|
|
|
| 207 |
new_key = "lm_head.new_output_embeddings"
|
| 208 |
elif key == "transformer.ff_out.weight":
|
| 209 |
new_key = "lm_head.output_embeddings"
|
| 210 |
+
elif key.split(".")[0] in [
|
| 211 |
+
"subpatch_k", "subpatch_q", "patch_k", "patch_q", "add_no_point_class_embed",
|
| 212 |
+
"subpatch_loc_k", "x_norm"
|
| 213 |
+
]:
|
| 214 |
+
new_key = f"{base_model_prefix}.point_predictor.{key}"
|
| 215 |
else:
|
| 216 |
+
if key.startswith(f"connectors.0"):
|
| 217 |
+
key = key.replace("connectors.0", "connector")
|
| 218 |
new_key = f"{base_model_prefix}.{key}"
|
| 219 |
new_state_dict[new_key] = val
|
| 220 |
|
|
|
|
| 222 |
qkv_bias = config.qkv_bias if isinstance(config, Molmo2TextConfig) else config.text_config.qkv_bias
|
| 223 |
use_qk_norm = config.use_qk_norm if isinstance(config, Molmo2TextConfig) else config.text_config.use_qk_norm
|
| 224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
for layer_i in range(config.num_hidden_layers):
|
| 226 |
prefix = f"{model_prefix}.blocks.{layer_i}"
|
| 227 |
|
model-00007-of-00008.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b0af4940b1b827e995397de33114310c838b8edece415819de119d4f403d420f
|
| 3 |
+
size 4091925052
|
model.safetensors.index.json
CHANGED
|
@@ -6,7 +6,6 @@
|
|
| 6 |
"weight_map": {
|
| 7 |
"lm_head.new_output_embeddings": "model-00008-of-00008.safetensors",
|
| 8 |
"lm_head.output_embeddings": "model-00008-of-00008.safetensors",
|
| 9 |
-
"model.add_no_point_class_embed.vector": "model-00007-of-00008.safetensors",
|
| 10 |
"model.build_vit_embedding.bias": "model-00007-of-00008.safetensors",
|
| 11 |
"model.build_vit_embedding.weight": "model-00007-of-00008.safetensors",
|
| 12 |
"model.connector.image_pooling_2d.wk.bias": "model-00007-of-00008.safetensors",
|
|
@@ -18,16 +17,18 @@
|
|
| 18 |
"model.connector.image_projector.w1.weight": "model-00007-of-00008.safetensors",
|
| 19 |
"model.connector.image_projector.w2.weight": "model-00007-of-00008.safetensors",
|
| 20 |
"model.connector.image_projector.w3.weight": "model-00007-of-00008.safetensors",
|
| 21 |
-
"model.
|
| 22 |
-
"model.patch_k.
|
| 23 |
-
"model.
|
| 24 |
-
"model.patch_q.
|
| 25 |
-
"model.
|
| 26 |
-
"model.subpatch_k.
|
| 27 |
-
"model.
|
| 28 |
-
"model.subpatch_loc_k.
|
| 29 |
-
"model.
|
| 30 |
-
"model.subpatch_q.
|
|
|
|
|
|
|
| 31 |
"model.transformer.blocks.0.attn_norm.weight": "model-00001-of-00008.safetensors",
|
| 32 |
"model.transformer.blocks.0.ff_norm.weight": "model-00001-of-00008.safetensors",
|
| 33 |
"model.transformer.blocks.0.mlp.ff_out.weight": "model-00001-of-00008.safetensors",
|
|
@@ -721,7 +722,6 @@
|
|
| 721 |
"model.vit.transformer.resblocks.9.feed_forward.w2.bias": "model-00007-of-00008.safetensors",
|
| 722 |
"model.vit.transformer.resblocks.9.feed_forward.w2.weight": "model-00007-of-00008.safetensors",
|
| 723 |
"model.vit.transformer.resblocks.9.ffn_norm.bias": "model-00007-of-00008.safetensors",
|
| 724 |
-
"model.vit.transformer.resblocks.9.ffn_norm.weight": "model-00007-of-00008.safetensors"
|
| 725 |
-
"model.x_norm.weight": "model-00007-of-00008.safetensors"
|
| 726 |
}
|
| 727 |
}
|
|
|
|
| 6 |
"weight_map": {
|
| 7 |
"lm_head.new_output_embeddings": "model-00008-of-00008.safetensors",
|
| 8 |
"lm_head.output_embeddings": "model-00008-of-00008.safetensors",
|
|
|
|
| 9 |
"model.build_vit_embedding.bias": "model-00007-of-00008.safetensors",
|
| 10 |
"model.build_vit_embedding.weight": "model-00007-of-00008.safetensors",
|
| 11 |
"model.connector.image_pooling_2d.wk.bias": "model-00007-of-00008.safetensors",
|
|
|
|
| 17 |
"model.connector.image_projector.w1.weight": "model-00007-of-00008.safetensors",
|
| 18 |
"model.connector.image_projector.w2.weight": "model-00007-of-00008.safetensors",
|
| 19 |
"model.connector.image_projector.w3.weight": "model-00007-of-00008.safetensors",
|
| 20 |
+
"model.point_predictor.add_no_point_class_embed.vector": "model-00007-of-00008.safetensors",
|
| 21 |
+
"model.point_predictor.patch_k.bias": "model-00007-of-00008.safetensors",
|
| 22 |
+
"model.point_predictor.patch_k.weight": "model-00007-of-00008.safetensors",
|
| 23 |
+
"model.point_predictor.patch_q.bias": "model-00007-of-00008.safetensors",
|
| 24 |
+
"model.point_predictor.patch_q.weight": "model-00007-of-00008.safetensors",
|
| 25 |
+
"model.point_predictor.subpatch_k.bias": "model-00007-of-00008.safetensors",
|
| 26 |
+
"model.point_predictor.subpatch_k.weight": "model-00007-of-00008.safetensors",
|
| 27 |
+
"model.point_predictor.subpatch_loc_k.bias": "model-00007-of-00008.safetensors",
|
| 28 |
+
"model.point_predictor.subpatch_loc_k.weight": "model-00007-of-00008.safetensors",
|
| 29 |
+
"model.point_predictor.subpatch_q.bias": "model-00007-of-00008.safetensors",
|
| 30 |
+
"model.point_predictor.subpatch_q.weight": "model-00007-of-00008.safetensors",
|
| 31 |
+
"model.point_predictor.x_norm.weight": "model-00007-of-00008.safetensors",
|
| 32 |
"model.transformer.blocks.0.attn_norm.weight": "model-00001-of-00008.safetensors",
|
| 33 |
"model.transformer.blocks.0.ff_norm.weight": "model-00001-of-00008.safetensors",
|
| 34 |
"model.transformer.blocks.0.mlp.ff_out.weight": "model-00001-of-00008.safetensors",
|
|
|
|
| 722 |
"model.vit.transformer.resblocks.9.feed_forward.w2.bias": "model-00007-of-00008.safetensors",
|
| 723 |
"model.vit.transformer.resblocks.9.feed_forward.w2.weight": "model-00007-of-00008.safetensors",
|
| 724 |
"model.vit.transformer.resblocks.9.ffn_norm.bias": "model-00007-of-00008.safetensors",
|
| 725 |
+
"model.vit.transformer.resblocks.9.ffn_norm.weight": "model-00007-of-00008.safetensors"
|
|
|
|
| 726 |
}
|
| 727 |
}
|
modeling_molmo_point.py
CHANGED
|
@@ -9,7 +9,7 @@ import torch
|
|
| 9 |
from torch import nn
|
| 10 |
|
| 11 |
from torch.nn import functional as F
|
| 12 |
-
from transformers import LogitsProcessorList, LogitsProcessor, AutoProcessor
|
| 13 |
from transformers.image_utils import PILImageResampling
|
| 14 |
|
| 15 |
from transformers.models.auto import AutoModelForImageTextToText
|
|
@@ -347,6 +347,150 @@ class ViTMultiHeadDotProductAttention(nn.Module):
|
|
| 347 |
return attn_output
|
| 348 |
|
| 349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
class MolmoPointPreTrainedModel(PreTrainedModel):
|
| 351 |
config: MolmoPointConfig
|
| 352 |
base_model_prefix = "model"
|
|
@@ -356,6 +500,7 @@ class MolmoPointPreTrainedModel(PreTrainedModel):
|
|
| 356 |
"Molmo2PostNormDecoderLayer",
|
| 357 |
"Molmo2VisionBlock",
|
| 358 |
"ViTMultiHeadDotProductAttention",
|
|
|
|
| 359 |
]
|
| 360 |
_skip_keys_device_placement = "past_key_values"
|
| 361 |
_supports_flash_attn = True
|
|
@@ -412,6 +557,7 @@ class GeneratedTokenBounds:
|
|
| 412 |
|
| 413 |
|
| 414 |
class MolmoPointLogitProcessor(LogitsProcessor):
|
|
|
|
| 415 |
|
| 416 |
def __init__(self, bounds: GeneratedTokenBounds,
|
| 417 |
prevent_repeats, force_patch_sorted, force_subpatch_sorted):
|
|
@@ -868,6 +1014,9 @@ class MolmoPointModel(MolmoPointPreTrainedModel):
|
|
| 868 |
def __init__(self, config: MolmoPointConfig):
|
| 869 |
super().__init__(config)
|
| 870 |
self.transformer: MolmoPointTextModel = MolmoPointTextModel(config.text_config)
|
|
|
|
|
|
|
|
|
|
| 871 |
|
| 872 |
vit_config = config.vit_config
|
| 873 |
adapter_config = config.adapter_config
|
|
@@ -886,45 +1035,14 @@ class MolmoPointModel(MolmoPointPreTrainedModel):
|
|
| 886 |
else:
|
| 887 |
self.vit = Molmo2VisionTransformer(vit_config)
|
| 888 |
|
| 889 |
-
if self.config.token_prediction_rotary == "none":
|
| 890 |
-
self.patch_rotary = None
|
| 891 |
-
else:
|
| 892 |
-
theta = self.config.token_prediction_rotary_theta or self.config.llm.rope_theta
|
| 893 |
-
if self.config.token_prediction_rotary == "one_d":
|
| 894 |
-
self.patch_rotary = MolmoPointPatchRope(theta, self.config.patch_embed_dim)
|
| 895 |
-
else:
|
| 896 |
-
raise NotImplementedError()
|
| 897 |
-
|
| 898 |
self.connector = MolmoPointConnector(adapter_config, vit_config)
|
| 899 |
-
|
| 900 |
-
vit_dim = self.config.vit_config.hidden_size * len(self.config.adapter_config.vit_layers)
|
| 901 |
-
llm_dim = self.config.text_config.hidden_size
|
| 902 |
-
self.patch_q = nn.Linear(llm_dim, config.patch_embed_dim)
|
| 903 |
-
self.patch_k = nn.Linear(llm_dim, config.patch_embed_dim)
|
| 904 |
-
self.subpatch_q = nn.Linear(llm_dim, config.patch_embed_dim)
|
| 905 |
-
self.subpatch_k = nn.Linear(vit_dim, config.patch_embed_dim)
|
| 906 |
-
self.add_no_point_class_embed = MolmoPointPadWithLearnedVector(config.patch_embed_dim)
|
| 907 |
-
|
| 908 |
-
self.patch_token_id = self.config.patch_token_id
|
| 909 |
-
self.subpatch_token_id = self.config.subpatch_token_id
|
| 910 |
-
self.location_token_id = self.config.location_token_id
|
| 911 |
-
|
| 912 |
if self.config.embed_selected_vit_patch == "linear":
|
|
|
|
|
|
|
| 913 |
self.build_vit_embedding = nn.Linear(vit_dim, llm_dim, bias=True)
|
| 914 |
else:
|
| 915 |
raise NotImplementedError(f"Embedding {self.config.embed_selected_vit_patch} not implemented")
|
| 916 |
-
|
| 917 |
-
if self.config.patch_location == "3x3":
|
| 918 |
-
self.subpatch_loc_k = nn.Linear(llm_dim, 9)
|
| 919 |
-
elif self.config.patch_location is None:
|
| 920 |
-
self.subpatch_loc_k = None
|
| 921 |
-
else:
|
| 922 |
-
raise NotImplementedError(f"Patch location {self.config.patch_location} not implemented")
|
| 923 |
-
|
| 924 |
-
if self.config.layer_norm_x:
|
| 925 |
-
self.x_norm = Molmo2RMSNorm(llm_dim, eps=self.config.text_config.layer_norm_eps)
|
| 926 |
-
else:
|
| 927 |
-
self.x_norm = None
|
| 928 |
|
| 929 |
# Initialize weights and apply final processing
|
| 930 |
self.post_init()
|
|
@@ -1291,6 +1409,8 @@ class MolmoPointModel(MolmoPointPreTrainedModel):
|
|
| 1291 |
input_ids = input_ids * (input_ids != -1).to(input_ids.dtype)
|
| 1292 |
|
| 1293 |
if image_data is not None:
|
|
|
|
|
|
|
| 1294 |
can_point = True
|
| 1295 |
bounds = self.build_token_bounds(image_data.token_pooling)
|
| 1296 |
expanded_inputs = input_ids
|
|
@@ -1304,7 +1424,11 @@ class MolmoPointModel(MolmoPointPreTrainedModel):
|
|
| 1304 |
input_ids = torch.where(is_subpatch, self.subpatch_token_id, input_ids)
|
| 1305 |
input_ids = torch.where(is_location, self.location_token_id, input_ids)
|
| 1306 |
else:
|
|
|
|
|
|
|
| 1307 |
input_patch_ids = None
|
|
|
|
|
|
|
| 1308 |
can_point = False
|
| 1309 |
|
| 1310 |
device = input_ids.device
|
|
@@ -1312,7 +1436,6 @@ class MolmoPointModel(MolmoPointPreTrainedModel):
|
|
| 1312 |
batch_size, _, dim = x.shape
|
| 1313 |
batch_idx = torch.arange(batch_size, device=device)
|
| 1314 |
|
| 1315 |
-
# TODO update embeddings for patch/subpatch tokens
|
| 1316 |
vit_features_flat: Optional[torch.FloatTensor] = None
|
| 1317 |
if images is not None:
|
| 1318 |
is_indexable_image_token = input_ids == self.config.image_patch_id
|
|
@@ -1342,27 +1465,23 @@ class MolmoPointModel(MolmoPointPreTrainedModel):
|
|
| 1342 |
vit_features_flat = vit_features_flat[image_features_mask.view(-1)]
|
| 1343 |
vit_features_to_flat_mask = vit_features_mask.view(-1, token_pooling.shape[-1])[image_features_mask.view(-1)]
|
| 1344 |
|
| 1345 |
-
# Finally apply the connector and add to input embeddings
|
| 1346 |
image_features = self.connector(vit_features_flat, vit_features_to_flat_mask).to(device=device)
|
| 1347 |
x = x.clone()
|
| 1348 |
x.view(-1, dim)[is_image_token.view(-1)] += image_features.view(-1, dim)
|
| 1349 |
-
|
| 1350 |
-
|
| 1351 |
-
|
| 1352 |
-
|
| 1353 |
-
|
| 1354 |
-
|
| 1355 |
-
|
| 1356 |
-
|
| 1357 |
-
|
| 1358 |
-
|
| 1359 |
-
|
| 1360 |
-
|
| 1361 |
-
|
| 1362 |
-
image_features_mask = torch.any(vit_features_mask, -1)
|
| 1363 |
-
vit_features_flat = vit_features.reshape([-1, token_pooling.shape[-1], vit_features.shape[-1]])
|
| 1364 |
-
vit_features_flat = vit_features_flat[image_features_mask.view(-1)]
|
| 1365 |
-
vit_features_to_flat_mask = vit_features_mask.view(-1, token_pooling.shape[-1])[image_features_mask.view(-1)]
|
| 1366 |
|
| 1367 |
# Embed the points
|
| 1368 |
if can_point:
|
|
@@ -1372,6 +1491,9 @@ class MolmoPointModel(MolmoPointPreTrainedModel):
|
|
| 1372 |
x.view(-1, dim)[is_patch.view(-1)] += image_data.image_features0.view(-1, dim)[input_patch_ids_flat]
|
| 1373 |
|
| 1374 |
if torch.any(is_subpatch):
|
|
|
|
|
|
|
|
|
|
| 1375 |
assert last_predicted_patch_id is not None, "Patch should always be generated before a subpatch"
|
| 1376 |
for_patches = (last_predicted_patch_id.view(batch_size) + image_token_offset)[input_subpatch_ids.view(batch_size) >= 0]
|
| 1377 |
vit_features_to_embed = vit_features_flat[for_patches, input_subpatch_ids]
|
|
@@ -1386,7 +1508,7 @@ class MolmoPointModel(MolmoPointPreTrainedModel):
|
|
| 1386 |
past_seen_tokens,
|
| 1387 |
past_seen_tokens + inputs_embeds.shape[1],
|
| 1388 |
device=inputs_embeds.device,
|
| 1389 |
-
|
| 1390 |
|
| 1391 |
# NOTE: this `is_prefill` logic is not flawless, it fails when we're using a cache eagerly initialized
|
| 1392 |
# (e.g. compiled prefill) AND `images` are not provided. Determining prefill in that case requires
|
|
@@ -1438,84 +1560,27 @@ class MolmoPointModel(MolmoPointPreTrainedModel):
|
|
| 1438 |
location_logits = None
|
| 1439 |
|
| 1440 |
if images is not None or image_data is not None:
|
| 1441 |
-
|
| 1442 |
-
|
| 1443 |
-
|
| 1444 |
-
|
| 1445 |
-
|
| 1446 |
-
|
| 1447 |
-
|
| 1448 |
-
|
| 1449 |
-
|
| 1450 |
-
|
| 1451 |
-
|
| 1452 |
-
|
| 1453 |
-
|
| 1454 |
-
|
| 1455 |
-
|
| 1456 |
-
|
| 1457 |
-
|
| 1458 |
-
patch_k.view(-1, patch_k_flat.shape[-1])[image_features_mask.flatten()] = patch_k_flat.to(dtype=x.dtype)
|
| 1459 |
-
|
| 1460 |
-
patch_k_mask = image_features_mask.clone()
|
| 1461 |
-
patch_k_mask.view(-1)[image_features_mask.view(-1)] = (
|
| 1462 |
-
is_indexable_image_token.view(-1)[is_image_token.view(-1)])
|
| 1463 |
-
|
| 1464 |
-
if self.config.no_more_points_class:
|
| 1465 |
-
patch_k = self.add_no_point_class_embed(patch_k).to(device=device)
|
| 1466 |
-
patch_k_mask = F.pad(patch_k_mask, (0, 1), value=True)
|
| 1467 |
-
|
| 1468 |
-
subpatch_k = self.subpatch_k(vit_features).to(device=device)
|
| 1469 |
-
|
| 1470 |
-
# Predict patch locations
|
| 1471 |
-
if can_point:
|
| 1472 |
-
image_q = self.patch_q(x_norm).to(device=device)
|
| 1473 |
-
if self.patch_rotary is not None and last_predicted_patch_id is not None:
|
| 1474 |
-
rotate_by = image_pos_ids[batch_idx, last_predicted_patch_id]
|
| 1475 |
-
rotate_by = torch.where(last_predicted_patch_id >= 0, rotate_by, 0)
|
| 1476 |
-
rotate_by = rotate_by.squeeze(-1)
|
| 1477 |
-
image_q = self.patch_rotary(
|
| 1478 |
-
image_q.view(-1, image_q.shape[-1]),
|
| 1479 |
-
torch.clamp(rotate_by, min=0),
|
| 1480 |
-
).reshape(batch_size, -1, image_q.shape[-1]).to(device=device)
|
| 1481 |
-
|
| 1482 |
-
dots = torch.matmul(image_q, patch_k.transpose(1, 2)) # [batch, 1, num_images]
|
| 1483 |
-
if self.config.norm_logits:
|
| 1484 |
-
dots = dots / math.sqrt(dots.shape[-1])
|
| 1485 |
-
|
| 1486 |
-
valid = patch_k_mask[:, None, :]
|
| 1487 |
-
patch_logits = torch.where(valid, dots, -100000000)
|
| 1488 |
-
|
| 1489 |
-
if can_point and torch.any(is_patch):
|
| 1490 |
-
if x_norm.shape[1] != 1:
|
| 1491 |
-
raise NotImplementedError()
|
| 1492 |
-
subpatch_point_q = self.subpatch_q(x_norm.squeeze(1)).to(device=device)
|
| 1493 |
-
subpatch_k = subpatch_k[batch_idx, input_patch_ids.squeeze(1)]
|
| 1494 |
-
subpatch_logits = torch.einsum("pd,pcd->pc", subpatch_point_q, subpatch_k)
|
| 1495 |
-
if self.config.norm_logits:
|
| 1496 |
-
subpatch_logits = subpatch_logits / math.sqrt(patch_k.shape[-1])
|
| 1497 |
-
subpatch_mask = vit_features_mask[batch_idx, input_patch_ids.squeeze(1)]
|
| 1498 |
-
subpatch_logits = torch.where(subpatch_mask, subpatch_logits, -100000)
|
| 1499 |
-
subpatch_logits = subpatch_logits[:, None, :]
|
| 1500 |
-
|
| 1501 |
-
if can_point and torch.any(is_subpatch):
|
| 1502 |
-
location_logits = self.subpatch_loc_k(x).to(device=device)
|
| 1503 |
-
|
| 1504 |
-
if is_prefill:
|
| 1505 |
num_image_tokens = is_image_token.sum(-1)
|
| 1506 |
image_token_offset = torch.cumsum(num_image_tokens[:-1], 0)
|
| 1507 |
image_token_offset = F.pad(image_token_offset, [1, 0])
|
| 1508 |
-
|
| 1509 |
-
image_data = ImageCache(
|
| 1510 |
-
patch_k=patch_k,
|
| 1511 |
-
subpatch_k=subpatch_k,
|
| 1512 |
-
vit_features=vit_features,
|
| 1513 |
-
patch_k_mask=patch_k_mask,
|
| 1514 |
-
token_pooling=token_pooling,
|
| 1515 |
-
image_pos_ids=image_pos_ids,
|
| 1516 |
-
image_features0=image_features,
|
| 1517 |
-
flat_image_tokens_to_flat_image_features=image_token_offset
|
| 1518 |
-
)
|
| 1519 |
|
| 1520 |
if last_predicted_patch_id is not None:
|
| 1521 |
last_predicted_patch_id = torch.where(input_patch_ids == -1, last_predicted_patch_id, input_patch_ids)
|
|
@@ -1713,9 +1778,9 @@ class MolmoPointForConditionalGeneration(MolmoPointPreTrainedModel, GenerationMi
|
|
| 1713 |
patch_token_logits = torch.clone(logits[:, :, self.config.patch_token_id])
|
| 1714 |
logits[:, :, self.config.patch_token_id] = small_val
|
| 1715 |
predicted_patch = predicted_tokens == self.config.patch_token_id
|
| 1716 |
-
argmax_patch_logits = torch.full([bs, seq, n_patches], small_val, dtype=logits.dtype, device=
|
| 1717 |
if outputs.patch_logits is not None:
|
| 1718 |
-
selected_patches = torch.argmax(outputs.patch_logits, -1)
|
| 1719 |
bs, seq, n_patches = outputs.patch_logits.shape
|
| 1720 |
batch_idx = torch.arange(outputs.patch_logits.shape[0], device=device)
|
| 1721 |
seq_ix = torch.arange(outputs.patch_logits.shape[1], device=device)
|
|
@@ -1725,13 +1790,13 @@ class MolmoPointForConditionalGeneration(MolmoPointPreTrainedModel, GenerationMi
|
|
| 1725 |
if outputs.subpatch_logits is not None:
|
| 1726 |
subpatch_logits = outputs.subpatch_logits
|
| 1727 |
else:
|
| 1728 |
-
subpatch_logits = torch.full([bs, seq, n_subpatches], small_val, dtype=logits.dtype, device=
|
| 1729 |
|
| 1730 |
logits[:, :, self.config.location_token_id] = small_val
|
| 1731 |
if outputs.location_logits is not None:
|
| 1732 |
location_logits = outputs.location_logits
|
| 1733 |
else:
|
| 1734 |
-
location_logits = torch.full([bs, seq, 9], small_val, dtype=logits.dtype, device=
|
| 1735 |
|
| 1736 |
logits = torch.concatenate([
|
| 1737 |
logits,
|
|
|
|
| 9 |
from torch import nn
|
| 10 |
|
| 11 |
from torch.nn import functional as F
|
| 12 |
+
from transformers import LogitsProcessorList, LogitsProcessor, AutoProcessor, ViTConfig
|
| 13 |
from transformers.image_utils import PILImageResampling
|
| 14 |
|
| 15 |
from transformers.models.auto import AutoModelForImageTextToText
|
|
|
|
| 347 |
return attn_output
|
| 348 |
|
| 349 |
|
| 350 |
+
class PointPredictor(nn.Module):
|
| 351 |
+
"""Point predictor logic"""
|
| 352 |
+
# We separate this out so accelerate will co-locate all these parameters on the same device
|
| 353 |
+
|
| 354 |
+
def __init__(self, config):
|
| 355 |
+
super().__init__()
|
| 356 |
+
self.config = config
|
| 357 |
+
llm_dim = config.text_config.hidden_size
|
| 358 |
+
patch_embed_dim = config.patch_embed_dim
|
| 359 |
+
vit_dim = self.config.vit_config.hidden_size * len(self.config.adapter_config.vit_layers)
|
| 360 |
+
if self.config.layer_norm_x:
|
| 361 |
+
self.x_norm = Molmo2RMSNorm(llm_dim, eps=self.config.text_config.layer_norm_eps)
|
| 362 |
+
else:
|
| 363 |
+
self.x_norm = None
|
| 364 |
+
if self.config.token_prediction_rotary == "none":
|
| 365 |
+
self.patch_rotary = None
|
| 366 |
+
else:
|
| 367 |
+
theta = self.config.token_prediction_rotary_theta or self.config.llm.rope_theta
|
| 368 |
+
if self.config.token_prediction_rotary == "one_d":
|
| 369 |
+
self.patch_rotary = MolmoPointPatchRope(theta, self.config.patch_embed_dim)
|
| 370 |
+
else:
|
| 371 |
+
raise NotImplementedError()
|
| 372 |
+
self.patch_q = nn.Linear(llm_dim, patch_embed_dim)
|
| 373 |
+
self.patch_k = nn.Linear(llm_dim, patch_embed_dim)
|
| 374 |
+
self.subpatch_q = nn.Linear(llm_dim, patch_embed_dim)
|
| 375 |
+
self.subpatch_k = nn.Linear(vit_dim, patch_embed_dim)
|
| 376 |
+
self.add_no_point_class_embed = MolmoPointPadWithLearnedVector(patch_embed_dim)
|
| 377 |
+
if self.config.patch_location == "3x3":
|
| 378 |
+
self.subpatch_loc_k = nn.Linear(llm_dim, 9)
|
| 379 |
+
elif self.config.patch_location is None:
|
| 380 |
+
self.subpatch_loc_k = None
|
| 381 |
+
else:
|
| 382 |
+
raise NotImplementedError(f"Patch location {self.config.patch_location} not implemented")
|
| 383 |
+
|
| 384 |
+
def forward(
|
| 385 |
+
self,
|
| 386 |
+
x,
|
| 387 |
+
token_pooling,
|
| 388 |
+
is_image_token,
|
| 389 |
+
is_patch,
|
| 390 |
+
is_subpatch,
|
| 391 |
+
is_indexable_image_token,
|
| 392 |
+
vit_features,
|
| 393 |
+
vit_features_mask,
|
| 394 |
+
image_features_mask,
|
| 395 |
+
input_patch_ids,
|
| 396 |
+
last_predicted_patch_id,
|
| 397 |
+
image_data: ImageCache
|
| 398 |
+
):
|
| 399 |
+
dim = self.config.text_config.hidden_size
|
| 400 |
+
batch_size = x.shape[0]
|
| 401 |
+
if self.x_norm is not None:
|
| 402 |
+
x_norm = self.x_norm(x)
|
| 403 |
+
elif self.config.norm_x:
|
| 404 |
+
x_norm = x / math.sqrt(dim)
|
| 405 |
+
else:
|
| 406 |
+
x_norm = x
|
| 407 |
+
|
| 408 |
+
# Build the keys, or get them from the cache
|
| 409 |
+
if image_data is not None:
|
| 410 |
+
patch_k, subpatch_k = image_data.patch_k, image_data.subpatch_k
|
| 411 |
+
patch_k_mask = image_data.patch_k_mask
|
| 412 |
+
token_pooling = image_data.token_pooling
|
| 413 |
+
vit_features_mask = token_pooling >= 0
|
| 414 |
+
image_pos_ids = image_data.image_pos_ids
|
| 415 |
+
else:
|
| 416 |
+
# Build patch keys, this takes a bit of indexing trickery since we want the keys in
|
| 417 |
+
# shape [batch, n_image_tokens] not [batch, sequence_length]
|
| 418 |
+
n_image_tokens = token_pooling.shape[1]
|
| 419 |
+
patch_k_flat = self.patch_k(x_norm.view(-1, dim)[is_image_token.view(-1)])
|
| 420 |
+
if self.patch_rotary is not None:
|
| 421 |
+
image_token_indices = torch.cumsum(is_indexable_image_token, dim=-1) - 1
|
| 422 |
+
image_pos_ids_flat = image_token_indices.view(-1)[is_image_token.view(-1)]
|
| 423 |
+
patch_k_flat = self.patch_rotary(patch_k_flat, image_pos_ids_flat)
|
| 424 |
+
|
| 425 |
+
# Computed for use with the query vectors
|
| 426 |
+
image_pos_ids = torch.zeros([batch_size, n_image_tokens], dtype=torch.long,
|
| 427 |
+
device=image_pos_ids_flat.device)
|
| 428 |
+
image_pos_ids.view(-1)[image_features_mask.view(-1)] = image_pos_ids_flat
|
| 429 |
+
else:
|
| 430 |
+
image_pos_ids = None
|
| 431 |
+
|
| 432 |
+
patch_k = torch.zeros([batch_size, n_image_tokens, patch_k_flat.shape[-1]],
|
| 433 |
+
dtype=x.dtype, device=x.device)
|
| 434 |
+
patch_k.view(-1, patch_k_flat.shape[-1])[image_features_mask.flatten()] = patch_k_flat.to(dtype=x.dtype)
|
| 435 |
+
|
| 436 |
+
patch_k_mask = image_features_mask.clone()
|
| 437 |
+
patch_k_mask.view(-1)[image_features_mask.view(-1)] = (
|
| 438 |
+
is_indexable_image_token.view(-1)[is_image_token.view(-1)])
|
| 439 |
+
|
| 440 |
+
if self.config.no_more_points_class:
|
| 441 |
+
patch_k = self.add_no_point_class_embed(patch_k)
|
| 442 |
+
patch_k_mask = F.pad(patch_k_mask, (0, 1), value=True)
|
| 443 |
+
|
| 444 |
+
subpatch_k = self.subpatch_k(vit_features)
|
| 445 |
+
|
| 446 |
+
patch_logits, subpatch_logits, location_logits = None, None, None
|
| 447 |
+
if image_data is not None:
|
| 448 |
+
# Predict patch locations, only done after pre-filling
|
| 449 |
+
batch_idx = torch.arange(batch_size, device=x_norm.device)
|
| 450 |
+
image_q = self.patch_q(x_norm)
|
| 451 |
+
if self.patch_rotary is not None and last_predicted_patch_id is not None:
|
| 452 |
+
rotate_by = image_pos_ids[batch_idx, last_predicted_patch_id]
|
| 453 |
+
rotate_by = torch.where(last_predicted_patch_id >= 0, rotate_by, 0)
|
| 454 |
+
rotate_by = rotate_by.squeeze(-1)
|
| 455 |
+
image_q = self.patch_rotary(
|
| 456 |
+
image_q.view(-1, image_q.shape[-1]),
|
| 457 |
+
torch.clamp(rotate_by, min=0),
|
| 458 |
+
).reshape(batch_size, -1, image_q.shape[-1])
|
| 459 |
+
|
| 460 |
+
dots = torch.matmul(image_q, patch_k.transpose(1, 2)) # [batch, 1, num_images]
|
| 461 |
+
if self.config.norm_logits:
|
| 462 |
+
dots = dots / math.sqrt(dots.shape[-1])
|
| 463 |
+
|
| 464 |
+
valid = patch_k_mask[:, None, :]
|
| 465 |
+
patch_logits = torch.where(valid, dots, -100000000)
|
| 466 |
+
|
| 467 |
+
if torch.any(is_patch):
|
| 468 |
+
if x_norm.shape[1] != 1:
|
| 469 |
+
raise NotImplementedError()
|
| 470 |
+
subpatch_point_q = self.subpatch_q(x_norm.squeeze(1))
|
| 471 |
+
subpatch_k = subpatch_k[batch_idx, input_patch_ids.squeeze(1)]
|
| 472 |
+
subpatch_logits = torch.einsum("pd,pcd->pc", subpatch_point_q, subpatch_k)
|
| 473 |
+
if self.config.norm_logits:
|
| 474 |
+
subpatch_logits = subpatch_logits / math.sqrt(patch_k.shape[-1])
|
| 475 |
+
subpatch_mask = vit_features_mask[batch_idx, input_patch_ids.squeeze(1)]
|
| 476 |
+
subpatch_logits = torch.where(subpatch_mask, subpatch_logits, -100000)
|
| 477 |
+
subpatch_logits = subpatch_logits[:, None, :]
|
| 478 |
+
|
| 479 |
+
if torch.any(is_subpatch):
|
| 480 |
+
location_logits = self.subpatch_loc_k(x)
|
| 481 |
+
|
| 482 |
+
if image_data is None:
|
| 483 |
+
image_data = ImageCache(
|
| 484 |
+
patch_k=patch_k,
|
| 485 |
+
subpatch_k=subpatch_k,
|
| 486 |
+
vit_features=vit_features,
|
| 487 |
+
patch_k_mask=patch_k_mask,
|
| 488 |
+
token_pooling=token_pooling,
|
| 489 |
+
image_pos_ids=image_pos_ids,
|
| 490 |
+
)
|
| 491 |
+
return patch_logits, subpatch_logits, location_logits, image_data
|
| 492 |
+
|
| 493 |
+
|
| 494 |
class MolmoPointPreTrainedModel(PreTrainedModel):
|
| 495 |
config: MolmoPointConfig
|
| 496 |
base_model_prefix = "model"
|
|
|
|
| 500 |
"Molmo2PostNormDecoderLayer",
|
| 501 |
"Molmo2VisionBlock",
|
| 502 |
"ViTMultiHeadDotProductAttention",
|
| 503 |
+
"PointPredictor"
|
| 504 |
]
|
| 505 |
_skip_keys_device_placement = "past_key_values"
|
| 506 |
_supports_flash_attn = True
|
|
|
|
| 557 |
|
| 558 |
|
| 559 |
class MolmoPointLogitProcessor(LogitsProcessor):
|
| 560 |
+
"""Force point-special tokens to be generated in a valid order"""
|
| 561 |
|
| 562 |
def __init__(self, bounds: GeneratedTokenBounds,
|
| 563 |
prevent_repeats, force_patch_sorted, force_subpatch_sorted):
|
|
|
|
| 1014 |
def __init__(self, config: MolmoPointConfig):
|
| 1015 |
super().__init__(config)
|
| 1016 |
self.transformer: MolmoPointTextModel = MolmoPointTextModel(config.text_config)
|
| 1017 |
+
self.patch_token_id = self.config.patch_token_id
|
| 1018 |
+
self.subpatch_token_id = self.config.subpatch_token_id
|
| 1019 |
+
self.location_token_id = self.config.location_token_id
|
| 1020 |
|
| 1021 |
vit_config = config.vit_config
|
| 1022 |
adapter_config = config.adapter_config
|
|
|
|
| 1035 |
else:
|
| 1036 |
self.vit = Molmo2VisionTransformer(vit_config)
|
| 1037 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1038 |
self.connector = MolmoPointConnector(adapter_config, vit_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1039 |
if self.config.embed_selected_vit_patch == "linear":
|
| 1040 |
+
llm_dim = config.text_config.hidden_size
|
| 1041 |
+
vit_dim = self.config.vit_config.hidden_size * len(self.config.adapter_config.vit_layers)
|
| 1042 |
self.build_vit_embedding = nn.Linear(vit_dim, llm_dim, bias=True)
|
| 1043 |
else:
|
| 1044 |
raise NotImplementedError(f"Embedding {self.config.embed_selected_vit_patch} not implemented")
|
| 1045 |
+
self.point_predictor = PointPredictor(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1046 |
|
| 1047 |
# Initialize weights and apply final processing
|
| 1048 |
self.post_init()
|
|
|
|
| 1409 |
input_ids = input_ids * (input_ids != -1).to(input_ids.dtype)
|
| 1410 |
|
| 1411 |
if image_data is not None:
|
| 1412 |
+
# Figure out where the patch/subpatch/location are and their values, and then convert
|
| 1413 |
+
# the input_ids back into their original special token values
|
| 1414 |
can_point = True
|
| 1415 |
bounds = self.build_token_bounds(image_data.token_pooling)
|
| 1416 |
expanded_inputs = input_ids
|
|
|
|
| 1424 |
input_ids = torch.where(is_subpatch, self.subpatch_token_id, input_ids)
|
| 1425 |
input_ids = torch.where(is_location, self.location_token_id, input_ids)
|
| 1426 |
else:
|
| 1427 |
+
# No patch prediction during pre-filling
|
| 1428 |
+
input_subpatch_ids = None
|
| 1429 |
input_patch_ids = None
|
| 1430 |
+
is_patch = None
|
| 1431 |
+
is_subpatch = None
|
| 1432 |
can_point = False
|
| 1433 |
|
| 1434 |
device = input_ids.device
|
|
|
|
| 1436 |
batch_size, _, dim = x.shape
|
| 1437 |
batch_idx = torch.arange(batch_size, device=device)
|
| 1438 |
|
|
|
|
| 1439 |
vit_features_flat: Optional[torch.FloatTensor] = None
|
| 1440 |
if images is not None:
|
| 1441 |
is_indexable_image_token = input_ids == self.config.image_patch_id
|
|
|
|
| 1465 |
vit_features_flat = vit_features_flat[image_features_mask.view(-1)]
|
| 1466 |
vit_features_to_flat_mask = vit_features_mask.view(-1, token_pooling.shape[-1])[image_features_mask.view(-1)]
|
| 1467 |
|
| 1468 |
+
# Finally, apply the connector and add to input embeddings
|
| 1469 |
image_features = self.connector(vit_features_flat, vit_features_to_flat_mask).to(device=device)
|
| 1470 |
x = x.clone()
|
| 1471 |
x.view(-1, dim)[is_image_token.view(-1)] += image_features.view(-1, dim)
|
| 1472 |
+
else:
|
| 1473 |
+
is_image_token = None
|
| 1474 |
+
is_indexable_image_token = None
|
| 1475 |
+
if image_data is not None:
|
| 1476 |
+
# Get the features/masks from the cache
|
| 1477 |
+
token_pooling = image_data.token_pooling.to(device=device)
|
| 1478 |
+
vit_features_mask = token_pooling >= 0
|
| 1479 |
+
image_features_mask = torch.any(vit_features_mask, -1)
|
| 1480 |
+
vit_features = image_data.vit_features.to(device=device)
|
| 1481 |
+
else:
|
| 1482 |
+
vit_features = None
|
| 1483 |
+
vit_features_mask = None
|
| 1484 |
+
image_features_mask = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1485 |
|
| 1486 |
# Embed the points
|
| 1487 |
if can_point:
|
|
|
|
| 1491 |
x.view(-1, dim)[is_patch.view(-1)] += image_data.image_features0.view(-1, dim)[input_patch_ids_flat]
|
| 1492 |
|
| 1493 |
if torch.any(is_subpatch):
|
| 1494 |
+
vit_features_flat = vit_features.reshape([-1, token_pooling.shape[-1], vit_features.shape[-1]])
|
| 1495 |
+
vit_features_flat = vit_features_flat[image_features_mask.view(-1)]
|
| 1496 |
+
|
| 1497 |
assert last_predicted_patch_id is not None, "Patch should always be generated before a subpatch"
|
| 1498 |
for_patches = (last_predicted_patch_id.view(batch_size) + image_token_offset)[input_subpatch_ids.view(batch_size) >= 0]
|
| 1499 |
vit_features_to_embed = vit_features_flat[for_patches, input_subpatch_ids]
|
|
|
|
| 1508 |
past_seen_tokens,
|
| 1509 |
past_seen_tokens + inputs_embeds.shape[1],
|
| 1510 |
device=inputs_embeds.device,
|
| 1511 |
+
)
|
| 1512 |
|
| 1513 |
# NOTE: this `is_prefill` logic is not flawless, it fails when we're using a cache eagerly initialized
|
| 1514 |
# (e.g. compiled prefill) AND `images` are not provided. Determining prefill in that case requires
|
|
|
|
| 1560 |
location_logits = None
|
| 1561 |
|
| 1562 |
if images is not None or image_data is not None:
|
| 1563 |
+
patch_logits, subpatch_logits, location_logits, image_data = self.point_predictor(
|
| 1564 |
+
x,
|
| 1565 |
+
token_pooling,
|
| 1566 |
+
is_image_token,
|
| 1567 |
+
is_patch,
|
| 1568 |
+
is_subpatch,
|
| 1569 |
+
is_indexable_image_token,
|
| 1570 |
+
vit_features,
|
| 1571 |
+
vit_features_mask,
|
| 1572 |
+
image_features_mask,
|
| 1573 |
+
input_patch_ids,
|
| 1574 |
+
last_predicted_patch_id,
|
| 1575 |
+
image_data
|
| 1576 |
+
)
|
| 1577 |
+
if images is not None:
|
| 1578 |
+
# Also cache stuff we need to building the patch/subpatch token embeddings
|
| 1579 |
+
image_data.image_features0 = image_features
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1580 |
num_image_tokens = is_image_token.sum(-1)
|
| 1581 |
image_token_offset = torch.cumsum(num_image_tokens[:-1], 0)
|
| 1582 |
image_token_offset = F.pad(image_token_offset, [1, 0])
|
| 1583 |
+
image_data.flat_image_tokens_to_flat_image_features = image_token_offset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1584 |
|
| 1585 |
if last_predicted_patch_id is not None:
|
| 1586 |
last_predicted_patch_id = torch.where(input_patch_ids == -1, last_predicted_patch_id, input_patch_ids)
|
|
|
|
| 1778 |
patch_token_logits = torch.clone(logits[:, :, self.config.patch_token_id])
|
| 1779 |
logits[:, :, self.config.patch_token_id] = small_val
|
| 1780 |
predicted_patch = predicted_tokens == self.config.patch_token_id
|
| 1781 |
+
argmax_patch_logits = torch.full([bs, seq, n_patches], small_val, dtype=logits.dtype, device=device)
|
| 1782 |
if outputs.patch_logits is not None:
|
| 1783 |
+
selected_patches = torch.argmax(outputs.patch_logits, -1).to(device=device)
|
| 1784 |
bs, seq, n_patches = outputs.patch_logits.shape
|
| 1785 |
batch_idx = torch.arange(outputs.patch_logits.shape[0], device=device)
|
| 1786 |
seq_ix = torch.arange(outputs.patch_logits.shape[1], device=device)
|
|
|
|
| 1790 |
if outputs.subpatch_logits is not None:
|
| 1791 |
subpatch_logits = outputs.subpatch_logits
|
| 1792 |
else:
|
| 1793 |
+
subpatch_logits = torch.full([bs, seq, n_subpatches], small_val, dtype=logits.dtype, device=device)
|
| 1794 |
|
| 1795 |
logits[:, :, self.config.location_token_id] = small_val
|
| 1796 |
if outputs.location_logits is not None:
|
| 1797 |
location_logits = outputs.location_logits
|
| 1798 |
else:
|
| 1799 |
+
location_logits = torch.full([bs, seq, 9], small_val, dtype=logits.dtype, device=device)
|
| 1800 |
|
| 1801 |
logits = torch.concatenate([
|
| 1802 |
logits,
|
processing_molmo2.py
CHANGED
|
@@ -294,7 +294,6 @@ class Molmo2Processor(ProcessorMixin):
|
|
| 294 |
Returned when `videos` is not `None`.
|
| 295 |
- **video_grids** -- Grids of videos. Returned when `videos` is not `None`.
|
| 296 |
"""
|
| 297 |
-
|
| 298 |
output_kwargs = self._merge_kwargs(
|
| 299 |
Molmo2ProcessorKwargs,
|
| 300 |
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
|
|
|
| 294 |
Returned when `videos` is not `None`.
|
| 295 |
- **video_grids** -- Grids of videos. Returned when `videos` is not `None`.
|
| 296 |
"""
|
|
|
|
| 297 |
output_kwargs = self._merge_kwargs(
|
| 298 |
Molmo2ProcessorKwargs,
|
| 299 |
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
video_processing_molmo2.py
CHANGED
|
@@ -826,7 +826,8 @@ class Molmo2VideoProcessor(BaseVideoProcessor):
|
|
| 826 |
) -> BatchFeature:
|
| 827 |
validate_kwargs(
|
| 828 |
captured_kwargs=kwargs.keys(),
|
| 829 |
-
valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) +
|
|
|
|
| 830 |
)
|
| 831 |
|
| 832 |
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
|
|
|
| 826 |
) -> BatchFeature:
|
| 827 |
validate_kwargs(
|
| 828 |
captured_kwargs=kwargs.keys(),
|
| 829 |
+
valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) +
|
| 830 |
+
["return_tensors", "return_pointing_metadata"],
|
| 831 |
)
|
| 832 |
|
| 833 |
# Set default kwargs from self. This ensures that if a kwarg is not provided
|