JLFaller commited on
Commit
2c8ec7c
·
verified ·
1 Parent(s): bef6664

Update: cast to device instead of explicitly cude

Browse files
Files changed (1) hide show
  1. modeling_deepseekocr2.py +7 -2
modeling_deepseekocr2.py CHANGED
@@ -491,8 +491,13 @@ class DeepseekOCR2Model(DeepseekV2Model):
491
  if images_in_this_batch:
492
  images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
493
  # exit()
494
-
495
- inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch)
 
 
 
 
 
496
 
497
  idx += 1
498
 
 
491
  if images_in_this_batch:
492
  images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
493
  # exit()
494
+ if torch.backends.mps.is_available():
495
+ device = torch.device("mps")
496
+ elif torch.cuda.is_available():
497
+ device = torch.device("cuda")
498
+ else:
499
+ device = torch.device("cpu")
500
+ inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).to(device), images_in_this_batch)
501
 
502
  idx += 1
503