chrisc36 commited on
Commit
448b75b
·
verified ·
1 Parent(s): c57624d

Update modeling_molmo_point.py

Browse files
Files changed (1) hide show
  1. modeling_molmo_point.py +1 -1
modeling_molmo_point.py CHANGED
@@ -1497,7 +1497,7 @@ class MolmoPointModel(MolmoPointPreTrainedModel):
1497
  assert last_predicted_patch_id is not None, "Patch should always be generated before a subpatch"
1498
  for_patches = (last_predicted_patch_id.view(batch_size) + image_token_offset)[input_subpatch_ids.view(batch_size) >= 0]
1499
  vit_features_to_embed = vit_features_flat[for_patches, input_subpatch_ids]
1500
- x.view(-1, dim)[is_subpatch.view(-1)] = self.build_vit_embedding(vit_features_to_embed).to(device=device)
1501
 
1502
  # shape: (batch_size, seq_len, d_model)
1503
  x = self.transformer.emb_drop(x) # type: ignore
 
1497
  assert last_predicted_patch_id is not None, "Patch should always be generated before a subpatch"
1498
  for_patches = (last_predicted_patch_id.view(batch_size) + image_token_offset)[input_subpatch_ids.view(batch_size) >= 0]
1499
  vit_features_to_embed = vit_features_flat[for_patches, input_subpatch_ids]
1500
+ x.view(-1, dim)[is_subpatch.view(-1)] = self.build_vit_embedding(vit_features_to_embed).to(device=device, dtype=x.dtype)
1501
 
1502
  # shape: (batch_size, seq_len, d_model)
1503
  x = self.transformer.emb_drop(x) # type: ignore