| | from networkx import to_numpy_array |
| | import numpy as np |
| | import torch |
| | from PIL import Image, ImageOps |
| | import math |
| | from functools import partial, reduce |
| | from transformers.image_transforms import ( |
| | convert_to_rgb, |
| | center_crop, |
| | normalize, |
| | rescale, |
| | resize, |
| | to_channel_dimension_format, |
| | ) |
| |
|
| | from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict |
| | from transformers.image_utils import ImageInput, ChannelDimension, PILImageResampling, to_numpy_array |
| |
|
| | class UMMImageProcessor(BaseImageProcessor): |
| | model_input_names = ["pixel_values", "grid_hws"] |
| | def __init__( |
| | self, |
| | image_mean=(0.5, 0.5, 0.5), |
| | image_std=(0.5, 0.5, 0.5), |
| | size=(256, 256), |
| | crop_size = None, |
| | resample=PILImageResampling.BICUBIC, |
| | rescale_factor=1 / 255, |
| | data_format=ChannelDimension.FIRST, |
| | scale_resolution=256, |
| | patch_size=16, |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| | crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256} |
| | crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") |
| | self.image_mean = image_mean |
| | self.image_std = image_std |
| | self.size = size |
| | self.resample = resample |
| | self.rescale_factor = rescale_factor |
| | self.data_format = data_format |
| | self.crop_size = crop_size |
| | self.scale_resolution = scale_resolution |
| | self.patch_size = patch_size |
| |
|
| | def preprocess(self, image, max_resolution=None, return_tensors = 'pt', und=True, **kwargs) -> BatchFeature: |
| | if max_resolution is not None: |
| | scale_resolution = max_resolution |
| | else: |
| | scale_resolution = self.scale_resolution |
| | if image is not None: |
| | pixel_values, grid_hws = [], [] |
| | if und: |
| | image = self._preprocess_und(image, scale_resolution) |
| | else: |
| | image = self._preprocess_gen(image, scale_resolution) |
| | if not torch.is_tensor(image): |
| | image = torch.tensor(image) |
| | _,H,W = image.shape |
| | grid_h = int(H // self.patch_size) |
| | grid_w = int(W // self.patch_size) |
| | grid_hw = (grid_h, grid_w) |
| | pixel_values = torch.stack([image], dim=0) |
| | grid_hws = torch.tensor([grid_hw]) |
| | data = { |
| | "pixel_values": pixel_values, |
| | "grid_hws": grid_hws |
| | } |
| | return BatchFeature(data=data, tensor_type=return_tensors) |
| | |
| | def _preprocess_gen(self, source_image, scale_resolution): |
| | w, h = source_image.size |
| | scale = scale_resolution / min(h, w) |
| | new_h = int(round(h * scale)) |
| | new_w = int(round(w * scale)) |
| | source_image = source_image.resize((new_w, new_h), Image.Resampling.BICUBIC) |
| | source_image = [source_image] |
| | transforms = [ |
| | convert_to_rgb, |
| | to_numpy_array, |
| | ] |
| | transforms.append(partial(center_crop, size=(scale_resolution, scale_resolution))) |
| | transforms.append(partial(rescale, scale=self.rescale_factor, data_format=self.data_format)) |
| | transforms.append(partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format)) |
| | image = reduce(lambda x, f: [*map(f, x)], transforms, source_image) |
| | return image[0] if len(image) == 1 else image |
| |
|
| | def _preprocess_und(self, source_image, scale_resolution): |
| | w, h = source_image.size |
| | scale = min(scale_resolution / h, scale_resolution / w) |
| | new_h = int(round(h * scale)) |
| | new_w = int(round(w * scale)) |
| | resized_image = source_image.resize((new_w, new_h), Image.Resampling.BICUBIC) |
| |
|
| | pad_w = scale_resolution - new_w |
| | pad_h = scale_resolution - new_h |
| |
|
| | left = pad_w // 2 |
| | right = pad_w - left |
| | top = pad_h // 2 |
| | bottom = pad_h - top |
| |
|
| | new_image = ImageOps.expand(resized_image, border=(left, top, right, bottom), fill=(0,0,0)) |
| | |
| | source_image = [new_image] |
| | transforms = [ |
| | convert_to_rgb, |
| | to_numpy_array |
| | ] |
| | transforms.append(partial(rescale, scale=self.rescale_factor, data_format=self.data_format)) |
| | transforms.append(partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format)) |
| | image = reduce(lambda x, f: [*map(f, x)], transforms, source_image) |
| | return image[0] if len(image) == 1 else image |
| |
|
| | __all__ = ["UMMImageProcessor"] |
| |
|
| |
|