File size: 5,909 Bytes
8f993ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from networkx import to_numpy_array
import numpy as np
import torch
from PIL import Image
import math
from functools import partial, reduce
from transformers.image_transforms import (
    convert_to_rgb,
    normalize,
    rescale,
    resize,
    to_channel_dimension_format,
)
from transformers.processing_utils import ImagesKwargs
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from transformers.image_utils import ImageInput, ChannelDimension, PILImageResampling, to_numpy_array
from einops import rearrange

class LlavaUHDV3ImageProcessor(BaseImageProcessor):
    model_input_names = ["pixel_values", "grid_hws"]
    def __init__(
        self,
        image_mean=(0.5, 0.5, 0.5),
        image_std=(0.5, 0.5, 0.5),
        size=(400, 400),
        crop_size = None, 
        resample=PILImageResampling.BICUBIC, 
        rescale_factor=1 / 255, 
        data_format=ChannelDimension.FIRST,
        scale_resolution=1580,
        patch_size=10, 
        any_res=True, 
        allow_upscale=True,
        **kwargs,
    ):
        super().__init__(**kwargs)
        crop_size = crop_size if crop_size is not None else {"height": 400, "width": 400}
        crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
        self.image_mean = image_mean
        self.image_std = image_std
        self.size = size
        self.resample = resample
        self.rescale_factor = rescale_factor
        self.data_format = data_format
        self.crop_size = crop_size
        self.scale_resolution = scale_resolution
        self.patch_size = patch_size
        self.any_res = any_res
        self.allow_upscale = allow_upscale

    def preprocess(self, images, max_resolution=None, upscale_rate=1.4, return_tensors = 'pt', **kwargs) -> BatchFeature:
        if max_resolution is not None:
            scale_resolution = max_resolution
        else:
            scale_resolution = self.scale_resolution

        if images is not None:
            pixel_values, grid_hws = [], []
            for image in images if isinstance(images, list) else [images]:
                image = self._preprocess(image, scale_resolution, self.patch_size, self.any_res, self.allow_upscale, upscale_rate=upscale_rate)
                if not torch.is_tensor(image):
                    image = torch.tensor(image)
                _,H,W = image.shape
                grid_h = int(H // self.patch_size)
                grid_w = int(W // self.patch_size)
                grid_hw = (grid_h, grid_w)
                patches = rearrange(image, "c (h p1) (w p2) -> h w c p1 p2", h=grid_h, w=grid_w)
                patches = rearrange(patches, "h w c p1 p2 -> (h w) c p1 p2")
                pixel_values.append(patches)
                grid_hws.append(grid_hw)
            pixel_values = torch.concat(pixel_values, dim=0)
            grid_hws = torch.tensor(grid_hws)
            data = {
                "pixel_values": pixel_values,
                "grid_hws": grid_hws
            }
        return BatchFeature(data=data, tensor_type=return_tensors)
    
    def _preprocess(self, image, scale_resolution=1580, patch_size=10, any_res=True, allow_upscale=True, upscale_rate=1.4):
        original_size = image.size
        soft_patch_size = patch_size * 8
        best_size = self.find_best_resize(
            original_size, scale_resolution, soft_patch_size, allow_upscale=allow_upscale, upscale_rate=upscale_rate, any_res=any_res
        )

        source_image = image.resize(best_size, Image.Resampling.BICUBIC)
        source_image = [source_image]
        transforms = [
            convert_to_rgb,
            to_numpy_array
        ]
        transforms.append(partial(rescale, scale=self.rescale_factor, data_format=self.data_format))
        transforms.append(partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format))
        transforms.append(partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format))
        image = reduce(lambda x, f: [*map(f, x)], transforms, source_image)
        return image[0] if len(image) == 1 else image

    def ensure_divide(self, length, patch_size):
        return max(math.floor(length / patch_size) * patch_size, patch_size)
    
    def find_best_resize(self, original_size, scale_resolution, patch_size, allow_upscale=False, upscale_rate=1.4, any_res=False):
        width, height = original_size
        max_edge = 5120
        scale_resolution_low = 512
        if any_res:
            if allow_upscale:
                width *= upscale_rate
                height *= upscale_rate
                scale_resolution_low = 560
            r = width / height
            if (width * height > scale_resolution * scale_resolution):
                height = int(scale_resolution / math.sqrt(r))
                width = int(height * r)
            if (width * height < scale_resolution_low * scale_resolution_low):
                height = int(scale_resolution_low / math.sqrt(r))
                width = int(height * r)
            if max(width, height) > max_edge:
                scale = max_edge / max(width, height)
                width = int(width * scale)
                height = int(height * scale)
        else:
            if (width * height > scale_resolution * scale_resolution) or allow_upscale:
                r = width / height # width=672 height=448 r= 1.5
                height = int(scale_resolution / math.sqrt(r)) # scale_resolution=336 / r**0.5  274.3428511917
                width = int(height * r) # 411.5142767876
        best_width = self.ensure_divide(width, patch_size)
        best_height = self.ensure_divide(height, patch_size)
        best_width = min(best_width, max_edge)
        best_height = min(best_height, max_edge)
        return (best_width, best_height)

__all__ = ["LlavaUHDV3ImageProcessor"]