Update processing_markupdm.py
Browse files- processing_markupdm.py +2 -1
processing_markupdm.py
CHANGED
|
@@ -117,7 +117,8 @@ class MarkupDMProcessor(ProcessorMixin): # type: ignore
|
|
| 117 |
example = self.preprocess_images(example["images"])
|
| 118 |
|
| 119 |
assert vision_model is not None, "Vision model must be provided."
|
| 120 |
-
image = example.pop("image")
|
|
|
|
| 121 |
with torch.inference_mode():
|
| 122 |
_, _, (_, _, image_ids) = vision_model.model.encode(image)
|
| 123 |
example["image_ids"] = list(image_ids.view(image.size(0), -1).cpu())
|
|
|
|
| 117 |
example = self.preprocess_images(example["images"])
|
| 118 |
|
| 119 |
assert vision_model is not None, "Vision model must be provided."
|
| 120 |
+
image = example.pop("image")
|
| 121 |
+
image = image.to(dtype=vision_model.dtype, device=vision_model.device)
|
| 122 |
with torch.inference_mode():
|
| 123 |
_, _, (_, _, image_ids) = vision_model.model.encode(image)
|
| 124 |
example["image_ids"] = list(image_ids.view(image.size(0), -1).cpu())
|