Update modelling_magiv2.py
Browse files- modelling_magiv2.py +4 -2
modelling_magiv2.py
CHANGED
|
@@ -29,13 +29,15 @@ class Magiv2Model(PreTrainedModel):
|
|
| 29 |
def move_to_device(self, input):
|
| 30 |
return move_to_device(input, self.device)
|
| 31 |
|
| 32 |
-
def forward(self, images, move_to_device_fn=None, mask_ratio=0.0, batch_size=256):
|
| 33 |
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
|
| 34 |
if len(images) == 0:
|
| 35 |
return move_to_device_fn(torch.zeros(len(images), self.config.crop_embedding_model_config.hidden_size))
|
| 36 |
|
| 37 |
assert all(isinstance(image, PIL.Image.Image) for image in images), "please provide a list of PIL images"
|
| 38 |
-
|
|
|
|
|
|
|
| 39 |
images = self.processor(images, return_tensors="pt").pixel_values
|
| 40 |
images = move_to_device_fn(images)
|
| 41 |
|
|
|
|
| 29 |
def move_to_device(self, input):
|
| 30 |
return move_to_device(input, self.device)
|
| 31 |
|
| 32 |
+
def forward(self, images, move_to_device_fn=None, mask_ratio=0.0, batch_size=256, convert_to_grayscale=True):
|
| 33 |
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
|
| 34 |
if len(images) == 0:
|
| 35 |
return move_to_device_fn(torch.zeros(len(images), self.config.crop_embedding_model_config.hidden_size))
|
| 36 |
|
| 37 |
assert all(isinstance(image, PIL.Image.Image) for image in images), "please provide a list of PIL images"
|
| 38 |
+
if convert_to_grayscale:
|
| 39 |
+
images = [x.convert("L") for x in images]
|
| 40 |
+
images = [np.array(image.convert("RGB")) for image in images]
|
| 41 |
images = self.processor(images, return_tensors="pt").pixel_values
|
| 42 |
images = move_to_device_fn(images)
|
| 43 |
|