File size: 5,909 Bytes
8f993ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from networkx import to_numpy_array
import numpy as np
import torch
from PIL import Image
import math
from functools import partial, reduce
from transformers.image_transforms import (
convert_to_rgb,
normalize,
rescale,
resize,
to_channel_dimension_format,
)
from transformers.processing_utils import ImagesKwargs
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from transformers.image_utils import ImageInput, ChannelDimension, PILImageResampling, to_numpy_array
from einops import rearrange
class LlavaUHDV3ImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values", "grid_hws"]
def __init__(
self,
image_mean=(0.5, 0.5, 0.5),
image_std=(0.5, 0.5, 0.5),
size=(400, 400),
crop_size = None,
resample=PILImageResampling.BICUBIC,
rescale_factor=1 / 255,
data_format=ChannelDimension.FIRST,
scale_resolution=1580,
patch_size=10,
any_res=True,
allow_upscale=True,
**kwargs,
):
super().__init__(**kwargs)
crop_size = crop_size if crop_size is not None else {"height": 400, "width": 400}
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
self.image_mean = image_mean
self.image_std = image_std
self.size = size
self.resample = resample
self.rescale_factor = rescale_factor
self.data_format = data_format
self.crop_size = crop_size
self.scale_resolution = scale_resolution
self.patch_size = patch_size
self.any_res = any_res
self.allow_upscale = allow_upscale
def preprocess(self, images, max_resolution=None, upscale_rate=1.4, return_tensors = 'pt', **kwargs) -> BatchFeature:
if max_resolution is not None:
scale_resolution = max_resolution
else:
scale_resolution = self.scale_resolution
if images is not None:
pixel_values, grid_hws = [], []
for image in images if isinstance(images, list) else [images]:
image = self._preprocess(image, scale_resolution, self.patch_size, self.any_res, self.allow_upscale, upscale_rate=upscale_rate)
if not torch.is_tensor(image):
image = torch.tensor(image)
_,H,W = image.shape
grid_h = int(H // self.patch_size)
grid_w = int(W // self.patch_size)
grid_hw = (grid_h, grid_w)
patches = rearrange(image, "c (h p1) (w p2) -> h w c p1 p2", h=grid_h, w=grid_w)
patches = rearrange(patches, "h w c p1 p2 -> (h w) c p1 p2")
pixel_values.append(patches)
grid_hws.append(grid_hw)
pixel_values = torch.concat(pixel_values, dim=0)
grid_hws = torch.tensor(grid_hws)
data = {
"pixel_values": pixel_values,
"grid_hws": grid_hws
}
return BatchFeature(data=data, tensor_type=return_tensors)
def _preprocess(self, image, scale_resolution=1580, patch_size=10, any_res=True, allow_upscale=True, upscale_rate=1.4):
original_size = image.size
soft_patch_size = patch_size * 8
best_size = self.find_best_resize(
original_size, scale_resolution, soft_patch_size, allow_upscale=allow_upscale, upscale_rate=upscale_rate, any_res=any_res
)
source_image = image.resize(best_size, Image.Resampling.BICUBIC)
source_image = [source_image]
transforms = [
convert_to_rgb,
to_numpy_array
]
transforms.append(partial(rescale, scale=self.rescale_factor, data_format=self.data_format))
transforms.append(partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format))
transforms.append(partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format))
image = reduce(lambda x, f: [*map(f, x)], transforms, source_image)
return image[0] if len(image) == 1 else image
def ensure_divide(self, length, patch_size):
return max(math.floor(length / patch_size) * patch_size, patch_size)
def find_best_resize(self, original_size, scale_resolution, patch_size, allow_upscale=False, upscale_rate=1.4, any_res=False):
width, height = original_size
max_edge = 5120
scale_resolution_low = 512
if any_res:
if allow_upscale:
width *= upscale_rate
height *= upscale_rate
scale_resolution_low = 560
r = width / height
if (width * height > scale_resolution * scale_resolution):
height = int(scale_resolution / math.sqrt(r))
width = int(height * r)
if (width * height < scale_resolution_low * scale_resolution_low):
height = int(scale_resolution_low / math.sqrt(r))
width = int(height * r)
if max(width, height) > max_edge:
scale = max_edge / max(width, height)
width = int(width * scale)
height = int(height * scale)
else:
if (width * height > scale_resolution * scale_resolution) or allow_upscale:
r = width / height # width=672 height=448 r= 1.5
height = int(scale_resolution / math.sqrt(r)) # scale_resolution=336 / r**0.5 274.3428511917
width = int(height * r) # 411.5142767876
best_width = self.ensure_divide(width, patch_size)
best_height = self.ensure_divide(height, patch_size)
best_width = min(best_width, max_edge)
best_height = min(best_height, max_edge)
return (best_width, best_height)
__all__ = ["LlavaUHDV3ImageProcessor"] |