chrisc36 commited on
Commit
74a5036
·
verified ·
1 Parent(s): ccfc55c

Upload folder using huggingface_hub

Browse files
convert_molmo_point_to_hf.py CHANGED
@@ -207,7 +207,14 @@ def convert_molmo2(
207
  new_key = "lm_head.new_output_embeddings"
208
  elif key == "transformer.ff_out.weight":
209
  new_key = "lm_head.output_embeddings"
 
 
 
 
 
210
  else:
 
 
211
  new_key = f"{base_model_prefix}.{key}"
212
  new_state_dict[new_key] = val
213
 
@@ -215,10 +222,6 @@ def convert_molmo2(
215
  qkv_bias = config.qkv_bias if isinstance(config, Molmo2TextConfig) else config.text_config.qkv_bias
216
  use_qk_norm = config.use_qk_norm if isinstance(config, Molmo2TextConfig) else config.text_config.use_qk_norm
217
 
218
- for param in list(new_state_dict.keys()):
219
- if param.startswith(f"model.connectors.0"):
220
- new_state_dict[param.replace("model.connectors.0", "model.connector")] = new_state_dict.pop(param)
221
-
222
  for layer_i in range(config.num_hidden_layers):
223
  prefix = f"{model_prefix}.blocks.{layer_i}"
224
 
 
207
  new_key = "lm_head.new_output_embeddings"
208
  elif key == "transformer.ff_out.weight":
209
  new_key = "lm_head.output_embeddings"
210
+ elif key.split(".")[0] in [
211
+ "subpatch_k", "subpatch_q", "patch_k", "patch_q", "add_no_point_class_embed",
212
+ "subpatch_loc_k", "x_norm"
213
+ ]:
214
+ new_key = f"{base_model_prefix}.point_predictor.{key}"
215
  else:
216
+ if key.startswith(f"connectors.0"):
217
+ key = key.replace("connectors.0", "connector")
218
  new_key = f"{base_model_prefix}.{key}"
219
  new_state_dict[new_key] = val
220
 
 
222
  qkv_bias = config.qkv_bias if isinstance(config, Molmo2TextConfig) else config.text_config.qkv_bias
223
  use_qk_norm = config.use_qk_norm if isinstance(config, Molmo2TextConfig) else config.text_config.use_qk_norm
224
 
 
 
 
 
225
  for layer_i in range(config.num_hidden_layers):
226
  prefix = f"{model_prefix}.blocks.{layer_i}"
227
 
model-00007-of-00008.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:30485f5c86cefbfccb5eef21be09adf1149e4ec6f9cc75072c162900e6972226
3
- size 4091924852
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b0af4940b1b827e995397de33114310c838b8edece415819de119d4f403d420f
3
+ size 4091925052
model.safetensors.index.json CHANGED
@@ -6,7 +6,6 @@
6
  "weight_map": {
7
  "lm_head.new_output_embeddings": "model-00008-of-00008.safetensors",
8
  "lm_head.output_embeddings": "model-00008-of-00008.safetensors",
9
- "model.add_no_point_class_embed.vector": "model-00007-of-00008.safetensors",
10
  "model.build_vit_embedding.bias": "model-00007-of-00008.safetensors",
11
  "model.build_vit_embedding.weight": "model-00007-of-00008.safetensors",
12
  "model.connector.image_pooling_2d.wk.bias": "model-00007-of-00008.safetensors",
@@ -18,16 +17,18 @@
18
  "model.connector.image_projector.w1.weight": "model-00007-of-00008.safetensors",
19
  "model.connector.image_projector.w2.weight": "model-00007-of-00008.safetensors",
20
  "model.connector.image_projector.w3.weight": "model-00007-of-00008.safetensors",
21
- "model.patch_k.bias": "model-00007-of-00008.safetensors",
22
- "model.patch_k.weight": "model-00007-of-00008.safetensors",
23
- "model.patch_q.bias": "model-00007-of-00008.safetensors",
24
- "model.patch_q.weight": "model-00007-of-00008.safetensors",
25
- "model.subpatch_k.bias": "model-00007-of-00008.safetensors",
26
- "model.subpatch_k.weight": "model-00007-of-00008.safetensors",
27
- "model.subpatch_loc_k.bias": "model-00007-of-00008.safetensors",
28
- "model.subpatch_loc_k.weight": "model-00007-of-00008.safetensors",
29
- "model.subpatch_q.bias": "model-00007-of-00008.safetensors",
30
- "model.subpatch_q.weight": "model-00007-of-00008.safetensors",
 
 
31
  "model.transformer.blocks.0.attn_norm.weight": "model-00001-of-00008.safetensors",
32
  "model.transformer.blocks.0.ff_norm.weight": "model-00001-of-00008.safetensors",
33
  "model.transformer.blocks.0.mlp.ff_out.weight": "model-00001-of-00008.safetensors",
@@ -721,7 +722,6 @@
721
  "model.vit.transformer.resblocks.9.feed_forward.w2.bias": "model-00007-of-00008.safetensors",
722
  "model.vit.transformer.resblocks.9.feed_forward.w2.weight": "model-00007-of-00008.safetensors",
723
  "model.vit.transformer.resblocks.9.ffn_norm.bias": "model-00007-of-00008.safetensors",
724
- "model.vit.transformer.resblocks.9.ffn_norm.weight": "model-00007-of-00008.safetensors",
725
- "model.x_norm.weight": "model-00007-of-00008.safetensors"
726
  }
727
  }
 
6
  "weight_map": {
7
  "lm_head.new_output_embeddings": "model-00008-of-00008.safetensors",
8
  "lm_head.output_embeddings": "model-00008-of-00008.safetensors",
 
9
  "model.build_vit_embedding.bias": "model-00007-of-00008.safetensors",
10
  "model.build_vit_embedding.weight": "model-00007-of-00008.safetensors",
11
  "model.connector.image_pooling_2d.wk.bias": "model-00007-of-00008.safetensors",
 
17
  "model.connector.image_projector.w1.weight": "model-00007-of-00008.safetensors",
18
  "model.connector.image_projector.w2.weight": "model-00007-of-00008.safetensors",
19
  "model.connector.image_projector.w3.weight": "model-00007-of-00008.safetensors",
20
+ "model.point_predictor.add_no_point_class_embed.vector": "model-00007-of-00008.safetensors",
21
+ "model.point_predictor.patch_k.bias": "model-00007-of-00008.safetensors",
22
+ "model.point_predictor.patch_k.weight": "model-00007-of-00008.safetensors",
23
+ "model.point_predictor.patch_q.bias": "model-00007-of-00008.safetensors",
24
+ "model.point_predictor.patch_q.weight": "model-00007-of-00008.safetensors",
25
+ "model.point_predictor.subpatch_k.bias": "model-00007-of-00008.safetensors",
26
+ "model.point_predictor.subpatch_k.weight": "model-00007-of-00008.safetensors",
27
+ "model.point_predictor.subpatch_loc_k.bias": "model-00007-of-00008.safetensors",
28
+ "model.point_predictor.subpatch_loc_k.weight": "model-00007-of-00008.safetensors",
29
+ "model.point_predictor.subpatch_q.bias": "model-00007-of-00008.safetensors",
30
+ "model.point_predictor.subpatch_q.weight": "model-00007-of-00008.safetensors",
31
+ "model.point_predictor.x_norm.weight": "model-00007-of-00008.safetensors",
32
  "model.transformer.blocks.0.attn_norm.weight": "model-00001-of-00008.safetensors",
33
  "model.transformer.blocks.0.ff_norm.weight": "model-00001-of-00008.safetensors",
34
  "model.transformer.blocks.0.mlp.ff_out.weight": "model-00001-of-00008.safetensors",
 
722
  "model.vit.transformer.resblocks.9.feed_forward.w2.bias": "model-00007-of-00008.safetensors",
723
  "model.vit.transformer.resblocks.9.feed_forward.w2.weight": "model-00007-of-00008.safetensors",
724
  "model.vit.transformer.resblocks.9.ffn_norm.bias": "model-00007-of-00008.safetensors",
725
+ "model.vit.transformer.resblocks.9.ffn_norm.weight": "model-00007-of-00008.safetensors"
 
726
  }
727
  }
modeling_molmo_point.py CHANGED
@@ -9,7 +9,7 @@ import torch
9
  from torch import nn
10
 
11
  from torch.nn import functional as F
12
- from transformers import LogitsProcessorList, LogitsProcessor, AutoProcessor
13
  from transformers.image_utils import PILImageResampling
14
 
15
  from transformers.models.auto import AutoModelForImageTextToText
@@ -347,6 +347,150 @@ class ViTMultiHeadDotProductAttention(nn.Module):
347
  return attn_output
348
 
349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  class MolmoPointPreTrainedModel(PreTrainedModel):
351
  config: MolmoPointConfig
352
  base_model_prefix = "model"
@@ -356,6 +500,7 @@ class MolmoPointPreTrainedModel(PreTrainedModel):
356
  "Molmo2PostNormDecoderLayer",
357
  "Molmo2VisionBlock",
358
  "ViTMultiHeadDotProductAttention",
 
359
  ]
360
  _skip_keys_device_placement = "past_key_values"
361
  _supports_flash_attn = True
@@ -412,6 +557,7 @@ class GeneratedTokenBounds:
412
 
413
 
414
  class MolmoPointLogitProcessor(LogitsProcessor):
 
415
 
416
  def __init__(self, bounds: GeneratedTokenBounds,
417
  prevent_repeats, force_patch_sorted, force_subpatch_sorted):
@@ -868,6 +1014,9 @@ class MolmoPointModel(MolmoPointPreTrainedModel):
868
  def __init__(self, config: MolmoPointConfig):
869
  super().__init__(config)
870
  self.transformer: MolmoPointTextModel = MolmoPointTextModel(config.text_config)
 
 
 
871
 
872
  vit_config = config.vit_config
873
  adapter_config = config.adapter_config
@@ -886,45 +1035,14 @@ class MolmoPointModel(MolmoPointPreTrainedModel):
886
  else:
887
  self.vit = Molmo2VisionTransformer(vit_config)
888
 
889
- if self.config.token_prediction_rotary == "none":
890
- self.patch_rotary = None
891
- else:
892
- theta = self.config.token_prediction_rotary_theta or self.config.llm.rope_theta
893
- if self.config.token_prediction_rotary == "one_d":
894
- self.patch_rotary = MolmoPointPatchRope(theta, self.config.patch_embed_dim)
895
- else:
896
- raise NotImplementedError()
897
-
898
  self.connector = MolmoPointConnector(adapter_config, vit_config)
899
-
900
- vit_dim = self.config.vit_config.hidden_size * len(self.config.adapter_config.vit_layers)
901
- llm_dim = self.config.text_config.hidden_size
902
- self.patch_q = nn.Linear(llm_dim, config.patch_embed_dim)
903
- self.patch_k = nn.Linear(llm_dim, config.patch_embed_dim)
904
- self.subpatch_q = nn.Linear(llm_dim, config.patch_embed_dim)
905
- self.subpatch_k = nn.Linear(vit_dim, config.patch_embed_dim)
906
- self.add_no_point_class_embed = MolmoPointPadWithLearnedVector(config.patch_embed_dim)
907
-
908
- self.patch_token_id = self.config.patch_token_id
909
- self.subpatch_token_id = self.config.subpatch_token_id
910
- self.location_token_id = self.config.location_token_id
911
-
912
  if self.config.embed_selected_vit_patch == "linear":
 
 
913
  self.build_vit_embedding = nn.Linear(vit_dim, llm_dim, bias=True)
914
  else:
915
  raise NotImplementedError(f"Embedding {self.config.embed_selected_vit_patch} not implemented")
916
-
917
- if self.config.patch_location == "3x3":
918
- self.subpatch_loc_k = nn.Linear(llm_dim, 9)
919
- elif self.config.patch_location is None:
920
- self.subpatch_loc_k = None
921
- else:
922
- raise NotImplementedError(f"Patch location {self.config.patch_location} not implemented")
923
-
924
- if self.config.layer_norm_x:
925
- self.x_norm = Molmo2RMSNorm(llm_dim, eps=self.config.text_config.layer_norm_eps)
926
- else:
927
- self.x_norm = None
928
 
929
  # Initialize weights and apply final processing
930
  self.post_init()
@@ -1291,6 +1409,8 @@ class MolmoPointModel(MolmoPointPreTrainedModel):
1291
  input_ids = input_ids * (input_ids != -1).to(input_ids.dtype)
1292
 
1293
  if image_data is not None:
 
 
1294
  can_point = True
1295
  bounds = self.build_token_bounds(image_data.token_pooling)
1296
  expanded_inputs = input_ids
@@ -1304,7 +1424,11 @@ class MolmoPointModel(MolmoPointPreTrainedModel):
1304
  input_ids = torch.where(is_subpatch, self.subpatch_token_id, input_ids)
1305
  input_ids = torch.where(is_location, self.location_token_id, input_ids)
1306
  else:
 
 
1307
  input_patch_ids = None
 
 
1308
  can_point = False
1309
 
1310
  device = input_ids.device
@@ -1312,7 +1436,6 @@ class MolmoPointModel(MolmoPointPreTrainedModel):
1312
  batch_size, _, dim = x.shape
1313
  batch_idx = torch.arange(batch_size, device=device)
1314
 
1315
- # TODO update embeddings for patch/subpatch tokens
1316
  vit_features_flat: Optional[torch.FloatTensor] = None
1317
  if images is not None:
1318
  is_indexable_image_token = input_ids == self.config.image_patch_id
@@ -1342,27 +1465,23 @@ class MolmoPointModel(MolmoPointPreTrainedModel):
1342
  vit_features_flat = vit_features_flat[image_features_mask.view(-1)]
1343
  vit_features_to_flat_mask = vit_features_mask.view(-1, token_pooling.shape[-1])[image_features_mask.view(-1)]
1344
 
1345
- # Finally apply the connector and add to input embeddings
1346
  image_features = self.connector(vit_features_flat, vit_features_to_flat_mask).to(device=device)
1347
  x = x.clone()
1348
  x.view(-1, dim)[is_image_token.view(-1)] += image_features.view(-1, dim)
1349
-
1350
- # Build position ids for the image features, which we might need for rotary
1351
- # embeddings
1352
- image_token_indices = torch.cumsum(is_indexable_image_token, dim=-1) - 1
1353
- image_pos_ids_flat = image_token_indices.view(-1)[is_image_token.view(-1)]
1354
- image_pos_ids = torch.zeros([batch_size, token_pooling.shape[1]], dtype=torch.long, device=device)
1355
- image_pos_ids.view(-1)[image_features_mask.view(-1)] = image_pos_ids_flat
1356
- max_image_pos_id = image_pos_ids_flat.max() + 1
1357
- elif image_data is not None:
1358
- token_pooling = image_data.token_pooling
1359
- image_pos_ids = image_data.image_pos_ids
1360
- vit_features = image_data.vit_features
1361
- vit_features_mask = token_pooling >= 0
1362
- image_features_mask = torch.any(vit_features_mask, -1)
1363
- vit_features_flat = vit_features.reshape([-1, token_pooling.shape[-1], vit_features.shape[-1]])
1364
- vit_features_flat = vit_features_flat[image_features_mask.view(-1)]
1365
- vit_features_to_flat_mask = vit_features_mask.view(-1, token_pooling.shape[-1])[image_features_mask.view(-1)]
1366
 
1367
  # Embed the points
1368
  if can_point:
@@ -1372,6 +1491,9 @@ class MolmoPointModel(MolmoPointPreTrainedModel):
1372
  x.view(-1, dim)[is_patch.view(-1)] += image_data.image_features0.view(-1, dim)[input_patch_ids_flat]
1373
 
1374
  if torch.any(is_subpatch):
 
 
 
1375
  assert last_predicted_patch_id is not None, "Patch should always be generated before a subpatch"
1376
  for_patches = (last_predicted_patch_id.view(batch_size) + image_token_offset)[input_subpatch_ids.view(batch_size) >= 0]
1377
  vit_features_to_embed = vit_features_flat[for_patches, input_subpatch_ids]
@@ -1386,7 +1508,7 @@ class MolmoPointModel(MolmoPointPreTrainedModel):
1386
  past_seen_tokens,
1387
  past_seen_tokens + inputs_embeds.shape[1],
1388
  device=inputs_embeds.device,
1389
- )
1390
 
1391
  # NOTE: this `is_prefill` logic is not flawless, it fails when we're using a cache eagerly initialized
1392
  # (e.g. compiled prefill) AND `images` are not provided. Determining prefill in that case requires
@@ -1438,84 +1560,27 @@ class MolmoPointModel(MolmoPointPreTrainedModel):
1438
  location_logits = None
1439
 
1440
  if images is not None or image_data is not None:
1441
- if self.x_norm:
1442
- x_norm = self.x_norm(x).to(device=device)
1443
- elif self.config.norm_x:
1444
- x_norm = x / math.sqrt(dim)
1445
- else:
1446
- x_norm = x
1447
-
1448
- # Build the keys, or get them from the cache
1449
- if image_data is not None:
1450
- patch_k, subpatch_k = image_data.patch_k, image_data.subpatch_k
1451
- patch_k_mask = image_data.patch_k_mask
1452
- else:
1453
- patch_k_flat = self.patch_k(x_norm.view(-1, dim)[is_image_token.view(-1)])
1454
- if self.patch_rotary is not None:
1455
- patch_k_flat = self.patch_rotary(patch_k_flat, image_pos_ids_flat)
1456
- patch_k_flat = patch_k_flat.to(device=device)
1457
- patch_k = torch.zeros([batch_size, image_features_mask.shape[1], patch_k_flat.shape[-1]], dtype=x.dtype, device=device)
1458
- patch_k.view(-1, patch_k_flat.shape[-1])[image_features_mask.flatten()] = patch_k_flat.to(dtype=x.dtype)
1459
-
1460
- patch_k_mask = image_features_mask.clone()
1461
- patch_k_mask.view(-1)[image_features_mask.view(-1)] = (
1462
- is_indexable_image_token.view(-1)[is_image_token.view(-1)])
1463
-
1464
- if self.config.no_more_points_class:
1465
- patch_k = self.add_no_point_class_embed(patch_k).to(device=device)
1466
- patch_k_mask = F.pad(patch_k_mask, (0, 1), value=True)
1467
-
1468
- subpatch_k = self.subpatch_k(vit_features).to(device=device)
1469
-
1470
- # Predict patch locations
1471
- if can_point:
1472
- image_q = self.patch_q(x_norm).to(device=device)
1473
- if self.patch_rotary is not None and last_predicted_patch_id is not None:
1474
- rotate_by = image_pos_ids[batch_idx, last_predicted_patch_id]
1475
- rotate_by = torch.where(last_predicted_patch_id >= 0, rotate_by, 0)
1476
- rotate_by = rotate_by.squeeze(-1)
1477
- image_q = self.patch_rotary(
1478
- image_q.view(-1, image_q.shape[-1]),
1479
- torch.clamp(rotate_by, min=0),
1480
- ).reshape(batch_size, -1, image_q.shape[-1]).to(device=device)
1481
-
1482
- dots = torch.matmul(image_q, patch_k.transpose(1, 2)) # [batch, 1, num_images]
1483
- if self.config.norm_logits:
1484
- dots = dots / math.sqrt(dots.shape[-1])
1485
-
1486
- valid = patch_k_mask[:, None, :]
1487
- patch_logits = torch.where(valid, dots, -100000000)
1488
-
1489
- if can_point and torch.any(is_patch):
1490
- if x_norm.shape[1] != 1:
1491
- raise NotImplementedError()
1492
- subpatch_point_q = self.subpatch_q(x_norm.squeeze(1)).to(device=device)
1493
- subpatch_k = subpatch_k[batch_idx, input_patch_ids.squeeze(1)]
1494
- subpatch_logits = torch.einsum("pd,pcd->pc", subpatch_point_q, subpatch_k)
1495
- if self.config.norm_logits:
1496
- subpatch_logits = subpatch_logits / math.sqrt(patch_k.shape[-1])
1497
- subpatch_mask = vit_features_mask[batch_idx, input_patch_ids.squeeze(1)]
1498
- subpatch_logits = torch.where(subpatch_mask, subpatch_logits, -100000)
1499
- subpatch_logits = subpatch_logits[:, None, :]
1500
-
1501
- if can_point and torch.any(is_subpatch):
1502
- location_logits = self.subpatch_loc_k(x).to(device=device)
1503
-
1504
- if is_prefill:
1505
  num_image_tokens = is_image_token.sum(-1)
1506
  image_token_offset = torch.cumsum(num_image_tokens[:-1], 0)
1507
  image_token_offset = F.pad(image_token_offset, [1, 0])
1508
- # Return the cache for image keys/features
1509
- image_data = ImageCache(
1510
- patch_k=patch_k,
1511
- subpatch_k=subpatch_k,
1512
- vit_features=vit_features,
1513
- patch_k_mask=patch_k_mask,
1514
- token_pooling=token_pooling,
1515
- image_pos_ids=image_pos_ids,
1516
- image_features0=image_features,
1517
- flat_image_tokens_to_flat_image_features=image_token_offset
1518
- )
1519
 
1520
  if last_predicted_patch_id is not None:
1521
  last_predicted_patch_id = torch.where(input_patch_ids == -1, last_predicted_patch_id, input_patch_ids)
@@ -1713,9 +1778,9 @@ class MolmoPointForConditionalGeneration(MolmoPointPreTrainedModel, GenerationMi
1713
  patch_token_logits = torch.clone(logits[:, :, self.config.patch_token_id])
1714
  logits[:, :, self.config.patch_token_id] = small_val
1715
  predicted_patch = predicted_tokens == self.config.patch_token_id
1716
- argmax_patch_logits = torch.full([bs, seq, n_patches], small_val, dtype=logits.dtype, device=logits.device)
1717
  if outputs.patch_logits is not None:
1718
- selected_patches = torch.argmax(outputs.patch_logits, -1)
1719
  bs, seq, n_patches = outputs.patch_logits.shape
1720
  batch_idx = torch.arange(outputs.patch_logits.shape[0], device=device)
1721
  seq_ix = torch.arange(outputs.patch_logits.shape[1], device=device)
@@ -1725,13 +1790,13 @@ class MolmoPointForConditionalGeneration(MolmoPointPreTrainedModel, GenerationMi
1725
  if outputs.subpatch_logits is not None:
1726
  subpatch_logits = outputs.subpatch_logits
1727
  else:
1728
- subpatch_logits = torch.full([bs, seq, n_subpatches], small_val, dtype=logits.dtype, device=logits.device)
1729
 
1730
  logits[:, :, self.config.location_token_id] = small_val
1731
  if outputs.location_logits is not None:
1732
  location_logits = outputs.location_logits
1733
  else:
1734
- location_logits = torch.full([bs, seq, 9], small_val, dtype=logits.dtype, device=logits.device)
1735
 
1736
  logits = torch.concatenate([
1737
  logits,
 
9
  from torch import nn
10
 
11
  from torch.nn import functional as F
12
+ from transformers import LogitsProcessorList, LogitsProcessor, AutoProcessor, ViTConfig
13
  from transformers.image_utils import PILImageResampling
14
 
15
  from transformers.models.auto import AutoModelForImageTextToText
 
347
  return attn_output
348
 
349
 
350
+ class PointPredictor(nn.Module):
351
+ """Point predictor logic"""
352
+ # We separate this out so accelerate will co-locate all these parameters on the same device
353
+
354
+ def __init__(self, config):
355
+ super().__init__()
356
+ self.config = config
357
+ llm_dim = config.text_config.hidden_size
358
+ patch_embed_dim = config.patch_embed_dim
359
+ vit_dim = self.config.vit_config.hidden_size * len(self.config.adapter_config.vit_layers)
360
+ if self.config.layer_norm_x:
361
+ self.x_norm = Molmo2RMSNorm(llm_dim, eps=self.config.text_config.layer_norm_eps)
362
+ else:
363
+ self.x_norm = None
364
+ if self.config.token_prediction_rotary == "none":
365
+ self.patch_rotary = None
366
+ else:
367
+ theta = self.config.token_prediction_rotary_theta or self.config.llm.rope_theta
368
+ if self.config.token_prediction_rotary == "one_d":
369
+ self.patch_rotary = MolmoPointPatchRope(theta, self.config.patch_embed_dim)
370
+ else:
371
+ raise NotImplementedError()
372
+ self.patch_q = nn.Linear(llm_dim, patch_embed_dim)
373
+ self.patch_k = nn.Linear(llm_dim, patch_embed_dim)
374
+ self.subpatch_q = nn.Linear(llm_dim, patch_embed_dim)
375
+ self.subpatch_k = nn.Linear(vit_dim, patch_embed_dim)
376
+ self.add_no_point_class_embed = MolmoPointPadWithLearnedVector(patch_embed_dim)
377
+ if self.config.patch_location == "3x3":
378
+ self.subpatch_loc_k = nn.Linear(llm_dim, 9)
379
+ elif self.config.patch_location is None:
380
+ self.subpatch_loc_k = None
381
+ else:
382
+ raise NotImplementedError(f"Patch location {self.config.patch_location} not implemented")
383
+
384
+ def forward(
385
+ self,
386
+ x,
387
+ token_pooling,
388
+ is_image_token,
389
+ is_patch,
390
+ is_subpatch,
391
+ is_indexable_image_token,
392
+ vit_features,
393
+ vit_features_mask,
394
+ image_features_mask,
395
+ input_patch_ids,
396
+ last_predicted_patch_id,
397
+ image_data: ImageCache
398
+ ):
399
+ dim = self.config.text_config.hidden_size
400
+ batch_size = x.shape[0]
401
+ if self.x_norm is not None:
402
+ x_norm = self.x_norm(x)
403
+ elif self.config.norm_x:
404
+ x_norm = x / math.sqrt(dim)
405
+ else:
406
+ x_norm = x
407
+
408
+ # Build the keys, or get them from the cache
409
+ if image_data is not None:
410
+ patch_k, subpatch_k = image_data.patch_k, image_data.subpatch_k
411
+ patch_k_mask = image_data.patch_k_mask
412
+ token_pooling = image_data.token_pooling
413
+ vit_features_mask = token_pooling >= 0
414
+ image_pos_ids = image_data.image_pos_ids
415
+ else:
416
+ # Build patch keys, this takes a bit of indexing trickery since we want the keys in
417
+ # shape [batch, n_image_tokens] not [batch, sequence_length]
418
+ n_image_tokens = token_pooling.shape[1]
419
+ patch_k_flat = self.patch_k(x_norm.view(-1, dim)[is_image_token.view(-1)])
420
+ if self.patch_rotary is not None:
421
+ image_token_indices = torch.cumsum(is_indexable_image_token, dim=-1) - 1
422
+ image_pos_ids_flat = image_token_indices.view(-1)[is_image_token.view(-1)]
423
+ patch_k_flat = self.patch_rotary(patch_k_flat, image_pos_ids_flat)
424
+
425
+ # Computed for use with the query vectors
426
+ image_pos_ids = torch.zeros([batch_size, n_image_tokens], dtype=torch.long,
427
+ device=image_pos_ids_flat.device)
428
+ image_pos_ids.view(-1)[image_features_mask.view(-1)] = image_pos_ids_flat
429
+ else:
430
+ image_pos_ids = None
431
+
432
+ patch_k = torch.zeros([batch_size, n_image_tokens, patch_k_flat.shape[-1]],
433
+ dtype=x.dtype, device=x.device)
434
+ patch_k.view(-1, patch_k_flat.shape[-1])[image_features_mask.flatten()] = patch_k_flat.to(dtype=x.dtype)
435
+
436
+ patch_k_mask = image_features_mask.clone()
437
+ patch_k_mask.view(-1)[image_features_mask.view(-1)] = (
438
+ is_indexable_image_token.view(-1)[is_image_token.view(-1)])
439
+
440
+ if self.config.no_more_points_class:
441
+ patch_k = self.add_no_point_class_embed(patch_k)
442
+ patch_k_mask = F.pad(patch_k_mask, (0, 1), value=True)
443
+
444
+ subpatch_k = self.subpatch_k(vit_features)
445
+
446
+ patch_logits, subpatch_logits, location_logits = None, None, None
447
+ if image_data is not None:
448
+ # Predict patch locations, only done after pre-filling
449
+ batch_idx = torch.arange(batch_size, device=x_norm.device)
450
+ image_q = self.patch_q(x_norm)
451
+ if self.patch_rotary is not None and last_predicted_patch_id is not None:
452
+ rotate_by = image_pos_ids[batch_idx, last_predicted_patch_id]
453
+ rotate_by = torch.where(last_predicted_patch_id >= 0, rotate_by, 0)
454
+ rotate_by = rotate_by.squeeze(-1)
455
+ image_q = self.patch_rotary(
456
+ image_q.view(-1, image_q.shape[-1]),
457
+ torch.clamp(rotate_by, min=0),
458
+ ).reshape(batch_size, -1, image_q.shape[-1])
459
+
460
+ dots = torch.matmul(image_q, patch_k.transpose(1, 2)) # [batch, 1, num_images]
461
+ if self.config.norm_logits:
462
+ dots = dots / math.sqrt(dots.shape[-1])
463
+
464
+ valid = patch_k_mask[:, None, :]
465
+ patch_logits = torch.where(valid, dots, -100000000)
466
+
467
+ if torch.any(is_patch):
468
+ if x_norm.shape[1] != 1:
469
+ raise NotImplementedError()
470
+ subpatch_point_q = self.subpatch_q(x_norm.squeeze(1))
471
+ subpatch_k = subpatch_k[batch_idx, input_patch_ids.squeeze(1)]
472
+ subpatch_logits = torch.einsum("pd,pcd->pc", subpatch_point_q, subpatch_k)
473
+ if self.config.norm_logits:
474
+ subpatch_logits = subpatch_logits / math.sqrt(patch_k.shape[-1])
475
+ subpatch_mask = vit_features_mask[batch_idx, input_patch_ids.squeeze(1)]
476
+ subpatch_logits = torch.where(subpatch_mask, subpatch_logits, -100000)
477
+ subpatch_logits = subpatch_logits[:, None, :]
478
+
479
+ if torch.any(is_subpatch):
480
+ location_logits = self.subpatch_loc_k(x)
481
+
482
+ if image_data is None:
483
+ image_data = ImageCache(
484
+ patch_k=patch_k,
485
+ subpatch_k=subpatch_k,
486
+ vit_features=vit_features,
487
+ patch_k_mask=patch_k_mask,
488
+ token_pooling=token_pooling,
489
+ image_pos_ids=image_pos_ids,
490
+ )
491
+ return patch_logits, subpatch_logits, location_logits, image_data
492
+
493
+
494
  class MolmoPointPreTrainedModel(PreTrainedModel):
495
  config: MolmoPointConfig
496
  base_model_prefix = "model"
 
500
  "Molmo2PostNormDecoderLayer",
501
  "Molmo2VisionBlock",
502
  "ViTMultiHeadDotProductAttention",
503
+ "PointPredictor"
504
  ]
505
  _skip_keys_device_placement = "past_key_values"
506
  _supports_flash_attn = True
 
557
 
558
 
559
  class MolmoPointLogitProcessor(LogitsProcessor):
560
+ """Force point-special tokens to be generated in a valid order"""
561
 
562
  def __init__(self, bounds: GeneratedTokenBounds,
563
  prevent_repeats, force_patch_sorted, force_subpatch_sorted):
 
1014
  def __init__(self, config: MolmoPointConfig):
1015
  super().__init__(config)
1016
  self.transformer: MolmoPointTextModel = MolmoPointTextModel(config.text_config)
1017
+ self.patch_token_id = self.config.patch_token_id
1018
+ self.subpatch_token_id = self.config.subpatch_token_id
1019
+ self.location_token_id = self.config.location_token_id
1020
 
1021
  vit_config = config.vit_config
1022
  adapter_config = config.adapter_config
 
1035
  else:
1036
  self.vit = Molmo2VisionTransformer(vit_config)
1037
 
 
 
 
 
 
 
 
 
 
1038
  self.connector = MolmoPointConnector(adapter_config, vit_config)
 
 
 
 
 
 
 
 
 
 
 
 
 
1039
  if self.config.embed_selected_vit_patch == "linear":
1040
+ llm_dim = config.text_config.hidden_size
1041
+ vit_dim = self.config.vit_config.hidden_size * len(self.config.adapter_config.vit_layers)
1042
  self.build_vit_embedding = nn.Linear(vit_dim, llm_dim, bias=True)
1043
  else:
1044
  raise NotImplementedError(f"Embedding {self.config.embed_selected_vit_patch} not implemented")
1045
+ self.point_predictor = PointPredictor(config)
 
 
 
 
 
 
 
 
 
 
 
1046
 
1047
  # Initialize weights and apply final processing
1048
  self.post_init()
 
1409
  input_ids = input_ids * (input_ids != -1).to(input_ids.dtype)
1410
 
1411
  if image_data is not None:
1412
+ # Figure out where the patch/subpatch/location are and their values, and then convert
1413
+ # the input_ids back into their original special token values
1414
  can_point = True
1415
  bounds = self.build_token_bounds(image_data.token_pooling)
1416
  expanded_inputs = input_ids
 
1424
  input_ids = torch.where(is_subpatch, self.subpatch_token_id, input_ids)
1425
  input_ids = torch.where(is_location, self.location_token_id, input_ids)
1426
  else:
1427
+ # No patch prediction during pre-filling
1428
+ input_subpatch_ids = None
1429
  input_patch_ids = None
1430
+ is_patch = None
1431
+ is_subpatch = None
1432
  can_point = False
1433
 
1434
  device = input_ids.device
 
1436
  batch_size, _, dim = x.shape
1437
  batch_idx = torch.arange(batch_size, device=device)
1438
 
 
1439
  vit_features_flat: Optional[torch.FloatTensor] = None
1440
  if images is not None:
1441
  is_indexable_image_token = input_ids == self.config.image_patch_id
 
1465
  vit_features_flat = vit_features_flat[image_features_mask.view(-1)]
1466
  vit_features_to_flat_mask = vit_features_mask.view(-1, token_pooling.shape[-1])[image_features_mask.view(-1)]
1467
 
1468
+ # Finally, apply the connector and add to input embeddings
1469
  image_features = self.connector(vit_features_flat, vit_features_to_flat_mask).to(device=device)
1470
  x = x.clone()
1471
  x.view(-1, dim)[is_image_token.view(-1)] += image_features.view(-1, dim)
1472
+ else:
1473
+ is_image_token = None
1474
+ is_indexable_image_token = None
1475
+ if image_data is not None:
1476
+ # Get the features/masks from the cache
1477
+ token_pooling = image_data.token_pooling.to(device=device)
1478
+ vit_features_mask = token_pooling >= 0
1479
+ image_features_mask = torch.any(vit_features_mask, -1)
1480
+ vit_features = image_data.vit_features.to(device=device)
1481
+ else:
1482
+ vit_features = None
1483
+ vit_features_mask = None
1484
+ image_features_mask = None
 
 
 
 
1485
 
1486
  # Embed the points
1487
  if can_point:
 
1491
  x.view(-1, dim)[is_patch.view(-1)] += image_data.image_features0.view(-1, dim)[input_patch_ids_flat]
1492
 
1493
  if torch.any(is_subpatch):
1494
+ vit_features_flat = vit_features.reshape([-1, token_pooling.shape[-1], vit_features.shape[-1]])
1495
+ vit_features_flat = vit_features_flat[image_features_mask.view(-1)]
1496
+
1497
  assert last_predicted_patch_id is not None, "Patch should always be generated before a subpatch"
1498
  for_patches = (last_predicted_patch_id.view(batch_size) + image_token_offset)[input_subpatch_ids.view(batch_size) >= 0]
1499
  vit_features_to_embed = vit_features_flat[for_patches, input_subpatch_ids]
 
1508
  past_seen_tokens,
1509
  past_seen_tokens + inputs_embeds.shape[1],
1510
  device=inputs_embeds.device,
1511
+ )
1512
 
1513
  # NOTE: this `is_prefill` logic is not flawless, it fails when we're using a cache eagerly initialized
1514
  # (e.g. compiled prefill) AND `images` are not provided. Determining prefill in that case requires
 
1560
  location_logits = None
1561
 
1562
  if images is not None or image_data is not None:
1563
+ patch_logits, subpatch_logits, location_logits, image_data = self.point_predictor(
1564
+ x,
1565
+ token_pooling,
1566
+ is_image_token,
1567
+ is_patch,
1568
+ is_subpatch,
1569
+ is_indexable_image_token,
1570
+ vit_features,
1571
+ vit_features_mask,
1572
+ image_features_mask,
1573
+ input_patch_ids,
1574
+ last_predicted_patch_id,
1575
+ image_data
1576
+ )
1577
+ if images is not None:
1578
+ # Also cache stuff we need to building the patch/subpatch token embeddings
1579
+ image_data.image_features0 = image_features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1580
  num_image_tokens = is_image_token.sum(-1)
1581
  image_token_offset = torch.cumsum(num_image_tokens[:-1], 0)
1582
  image_token_offset = F.pad(image_token_offset, [1, 0])
1583
+ image_data.flat_image_tokens_to_flat_image_features = image_token_offset
 
 
 
 
 
 
 
 
 
 
1584
 
1585
  if last_predicted_patch_id is not None:
1586
  last_predicted_patch_id = torch.where(input_patch_ids == -1, last_predicted_patch_id, input_patch_ids)
 
1778
  patch_token_logits = torch.clone(logits[:, :, self.config.patch_token_id])
1779
  logits[:, :, self.config.patch_token_id] = small_val
1780
  predicted_patch = predicted_tokens == self.config.patch_token_id
1781
+ argmax_patch_logits = torch.full([bs, seq, n_patches], small_val, dtype=logits.dtype, device=device)
1782
  if outputs.patch_logits is not None:
1783
+ selected_patches = torch.argmax(outputs.patch_logits, -1).to(device=device)
1784
  bs, seq, n_patches = outputs.patch_logits.shape
1785
  batch_idx = torch.arange(outputs.patch_logits.shape[0], device=device)
1786
  seq_ix = torch.arange(outputs.patch_logits.shape[1], device=device)
 
1790
  if outputs.subpatch_logits is not None:
1791
  subpatch_logits = outputs.subpatch_logits
1792
  else:
1793
+ subpatch_logits = torch.full([bs, seq, n_subpatches], small_val, dtype=logits.dtype, device=device)
1794
 
1795
  logits[:, :, self.config.location_token_id] = small_val
1796
  if outputs.location_logits is not None:
1797
  location_logits = outputs.location_logits
1798
  else:
1799
+ location_logits = torch.full([bs, seq, 9], small_val, dtype=logits.dtype, device=device)
1800
 
1801
  logits = torch.concatenate([
1802
  logits,
processing_molmo2.py CHANGED
@@ -294,7 +294,6 @@ class Molmo2Processor(ProcessorMixin):
294
  Returned when `videos` is not `None`.
295
  - **video_grids** -- Grids of videos. Returned when `videos` is not `None`.
296
  """
297
-
298
  output_kwargs = self._merge_kwargs(
299
  Molmo2ProcessorKwargs,
300
  tokenizer_init_kwargs=self.tokenizer.init_kwargs,
 
294
  Returned when `videos` is not `None`.
295
  - **video_grids** -- Grids of videos. Returned when `videos` is not `None`.
296
  """
 
297
  output_kwargs = self._merge_kwargs(
298
  Molmo2ProcessorKwargs,
299
  tokenizer_init_kwargs=self.tokenizer.init_kwargs,
video_processing_molmo2.py CHANGED
@@ -826,7 +826,8 @@ class Molmo2VideoProcessor(BaseVideoProcessor):
826
  ) -> BatchFeature:
827
  validate_kwargs(
828
  captured_kwargs=kwargs.keys(),
829
- valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) + ["return_tensors"],
 
830
  )
831
 
832
  # Set default kwargs from self. This ensures that if a kwarg is not provided
 
826
  ) -> BatchFeature:
827
  validate_kwargs(
828
  captured_kwargs=kwargs.keys(),
829
+ valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) +
830
+ ["return_tensors", "return_pointing_metadata"],
831
  )
832
 
833
  # Set default kwargs from self. This ensures that if a kwarg is not provided