JLFaller commited on
Commit
8c1b50a
·
verified ·
1 Parent(s): 2c8ec7c

Fix: explicitly cast to bool tensor before masked scatter

Browse files
Files changed (1) hide show
  1. modeling_deepseekocr2.py +4 -1
modeling_deepseekocr2.py CHANGED
@@ -497,7 +497,10 @@ class DeepseekOCR2Model(DeepseekV2Model):
497
  device = torch.device("cuda")
498
  else:
499
  device = torch.device("cpu")
500
- inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).to(device), images_in_this_batch)
 
 
 
501
 
502
  idx += 1
503
 
 
497
  device = torch.device("cuda")
498
  else:
499
  device = torch.device("cpu")
500
+
501
+ mask = images_seq_mask[idx].unsqueeze(-1).to(device=device, dtype=torch.bool)
502
+
503
+ inputs_embeds[idx].masked_scatter_(mask, images_in_this_batch)
504
 
505
  idx += 1
506