File size: 4,630 Bytes
bb46e5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from networkx import to_numpy_array
import numpy as np
import torch
from PIL import Image, ImageOps
import math
from functools import partial, reduce
from transformers.image_transforms import (
    convert_to_rgb,
    center_crop,
    normalize,
    rescale,
    resize,
    to_channel_dimension_format,
)

from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from transformers.image_utils import ImageInput, ChannelDimension, PILImageResampling, to_numpy_array

class UMMImageProcessor(BaseImageProcessor):
    model_input_names = ["pixel_values", "grid_hws"]
    def __init__(
        self,
        image_mean=(0.5, 0.5, 0.5),
        image_std=(0.5, 0.5, 0.5),
        size=(256, 256),
        crop_size = None, 
        resample=PILImageResampling.BICUBIC, 
        rescale_factor=1 / 255, 
        data_format=ChannelDimension.FIRST,
        scale_resolution=256,
        patch_size=16, 
        **kwargs,
    ):
        super().__init__(**kwargs)
        crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256}
        crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
        self.image_mean = image_mean
        self.image_std = image_std
        self.size = size
        self.resample = resample
        self.rescale_factor = rescale_factor
        self.data_format = data_format
        self.crop_size = crop_size
        self.scale_resolution = scale_resolution
        self.patch_size = patch_size

    def preprocess(self, image, max_resolution=None, return_tensors = 'pt', und=True, **kwargs) -> BatchFeature:
        if max_resolution is not None:
            scale_resolution = max_resolution
        else:
            scale_resolution = self.scale_resolution
        if image is not None:
            pixel_values, grid_hws = [], []
            if und:
                image = self._preprocess_und(image, scale_resolution)
            else:
                image = self._preprocess_gen(image, scale_resolution)
            if not torch.is_tensor(image):
                image = torch.tensor(image)
            _,H,W = image.shape
            grid_h = int(H // self.patch_size)
            grid_w = int(W // self.patch_size)
            grid_hw = (grid_h, grid_w)
            pixel_values = torch.stack([image], dim=0)
            grid_hws = torch.tensor([grid_hw])
            data = {
                "pixel_values": pixel_values,
                "grid_hws": grid_hws
            }
        return BatchFeature(data=data, tensor_type=return_tensors)
    
    def _preprocess_gen(self, source_image, scale_resolution):
        w, h = source_image.size
        scale = scale_resolution / min(h, w)
        new_h = int(round(h * scale))
        new_w = int(round(w * scale))
        source_image = source_image.resize((new_w, new_h), Image.Resampling.BICUBIC)
        source_image = [source_image]
        transforms = [
            convert_to_rgb,
            to_numpy_array,
        ]
        transforms.append(partial(center_crop, size=(scale_resolution, scale_resolution)))
        transforms.append(partial(rescale, scale=self.rescale_factor, data_format=self.data_format))
        transforms.append(partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format))
        image = reduce(lambda x, f: [*map(f, x)], transforms, source_image)
        return image[0] if len(image) == 1 else image

    def _preprocess_und(self, source_image, scale_resolution):
        w, h = source_image.size
        scale = min(scale_resolution / h, scale_resolution / w)
        new_h = int(round(h * scale))
        new_w = int(round(w * scale))
        resized_image = source_image.resize((new_w, new_h), Image.Resampling.BICUBIC)

        pad_w = scale_resolution - new_w
        pad_h = scale_resolution - new_h

        left = pad_w // 2
        right = pad_w - left
        top = pad_h // 2
        bottom = pad_h - top

        new_image = ImageOps.expand(resized_image, border=(left, top, right, bottom), fill=(0,0,0))
        # new_image.save("test_path")
        source_image = [new_image]
        transforms = [
            convert_to_rgb,
            to_numpy_array
        ]
        transforms.append(partial(rescale, scale=self.rescale_factor, data_format=self.data_format))
        transforms.append(partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format))
        image = reduce(lambda x, f: [*map(f, x)], transforms, source_image)
        return image[0] if len(image) == 1 else image

__all__ = ["UMMImageProcessor"]