LLaVA-UHD-v3 / image_processing_llava_uhd_v3.py
Sishxo's picture
Upload 12 files
8f993ed verified
from networkx import to_numpy_array
import numpy as np
import torch
from PIL import Image
import math
from functools import partial, reduce
from transformers.image_transforms import (
convert_to_rgb,
normalize,
rescale,
resize,
to_channel_dimension_format,
)
from transformers.processing_utils import ImagesKwargs
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from transformers.image_utils import ImageInput, ChannelDimension, PILImageResampling, to_numpy_array
from einops import rearrange
class LlavaUHDV3ImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values", "grid_hws"]
def __init__(
self,
image_mean=(0.5, 0.5, 0.5),
image_std=(0.5, 0.5, 0.5),
size=(400, 400),
crop_size = None,
resample=PILImageResampling.BICUBIC,
rescale_factor=1 / 255,
data_format=ChannelDimension.FIRST,
scale_resolution=1580,
patch_size=10,
any_res=True,
allow_upscale=True,
**kwargs,
):
super().__init__(**kwargs)
crop_size = crop_size if crop_size is not None else {"height": 400, "width": 400}
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
self.image_mean = image_mean
self.image_std = image_std
self.size = size
self.resample = resample
self.rescale_factor = rescale_factor
self.data_format = data_format
self.crop_size = crop_size
self.scale_resolution = scale_resolution
self.patch_size = patch_size
self.any_res = any_res
self.allow_upscale = allow_upscale
def preprocess(self, images, max_resolution=None, upscale_rate=1.4, return_tensors = 'pt', **kwargs) -> BatchFeature:
if max_resolution is not None:
scale_resolution = max_resolution
else:
scale_resolution = self.scale_resolution
if images is not None:
pixel_values, grid_hws = [], []
for image in images if isinstance(images, list) else [images]:
image = self._preprocess(image, scale_resolution, self.patch_size, self.any_res, self.allow_upscale, upscale_rate=upscale_rate)
if not torch.is_tensor(image):
image = torch.tensor(image)
_,H,W = image.shape
grid_h = int(H // self.patch_size)
grid_w = int(W // self.patch_size)
grid_hw = (grid_h, grid_w)
patches = rearrange(image, "c (h p1) (w p2) -> h w c p1 p2", h=grid_h, w=grid_w)
patches = rearrange(patches, "h w c p1 p2 -> (h w) c p1 p2")
pixel_values.append(patches)
grid_hws.append(grid_hw)
pixel_values = torch.concat(pixel_values, dim=0)
grid_hws = torch.tensor(grid_hws)
data = {
"pixel_values": pixel_values,
"grid_hws": grid_hws
}
return BatchFeature(data=data, tensor_type=return_tensors)
def _preprocess(self, image, scale_resolution=1580, patch_size=10, any_res=True, allow_upscale=True, upscale_rate=1.4):
original_size = image.size
soft_patch_size = patch_size * 8
best_size = self.find_best_resize(
original_size, scale_resolution, soft_patch_size, allow_upscale=allow_upscale, upscale_rate=upscale_rate, any_res=any_res
)
source_image = image.resize(best_size, Image.Resampling.BICUBIC)
source_image = [source_image]
transforms = [
convert_to_rgb,
to_numpy_array
]
transforms.append(partial(rescale, scale=self.rescale_factor, data_format=self.data_format))
transforms.append(partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format))
transforms.append(partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format))
image = reduce(lambda x, f: [*map(f, x)], transforms, source_image)
return image[0] if len(image) == 1 else image
def ensure_divide(self, length, patch_size):
return max(math.floor(length / patch_size) * patch_size, patch_size)
def find_best_resize(self, original_size, scale_resolution, patch_size, allow_upscale=False, upscale_rate=1.4, any_res=False):
width, height = original_size
max_edge = 5120
scale_resolution_low = 512
if any_res:
if allow_upscale:
width *= upscale_rate
height *= upscale_rate
scale_resolution_low = 560
r = width / height
if (width * height > scale_resolution * scale_resolution):
height = int(scale_resolution / math.sqrt(r))
width = int(height * r)
if (width * height < scale_resolution_low * scale_resolution_low):
height = int(scale_resolution_low / math.sqrt(r))
width = int(height * r)
if max(width, height) > max_edge:
scale = max_edge / max(width, height)
width = int(width * scale)
height = int(height * scale)
else:
if (width * height > scale_resolution * scale_resolution) or allow_upscale:
r = width / height # width=672 height=448 r= 1.5
height = int(scale_resolution / math.sqrt(r)) # scale_resolution=336 / r**0.5 274.3428511917
width = int(height * r) # 411.5142767876
best_width = self.ensure_divide(width, patch_size)
best_height = self.ensure_divide(height, patch_size)
best_width = min(best_width, max_edge)
best_height = min(best_height, max_edge)
return (best_width, best_height)
__all__ = ["LlavaUHDV3ImageProcessor"]