Update: cast to device instead of explicitly cude
Browse files- modeling_deepseekocr2.py +7 -2
modeling_deepseekocr2.py
CHANGED
|
@@ -491,8 +491,13 @@ class DeepseekOCR2Model(DeepseekV2Model):
|
|
| 491 |
if images_in_this_batch:
|
| 492 |
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
|
| 493 |
# exit()
|
| 494 |
-
|
| 495 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
|
| 497 |
idx += 1
|
| 498 |
|
|
|
|
| 491 |
if images_in_this_batch:
|
| 492 |
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
|
| 493 |
# exit()
|
| 494 |
+
if torch.backends.mps.is_available():
|
| 495 |
+
device = torch.device("mps")
|
| 496 |
+
elif torch.cuda.is_available():
|
| 497 |
+
device = torch.device("cuda")
|
| 498 |
+
else:
|
| 499 |
+
device = torch.device("cpu")
|
| 500 |
+
inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).to(device), images_in_this_batch)
|
| 501 |
|
| 502 |
idx += 1
|
| 503 |
|