|
|
from networkx import to_numpy_array |
|
|
import numpy as np |
|
|
import torch |
|
|
from PIL import Image |
|
|
import math |
|
|
from functools import partial, reduce |
|
|
from transformers.image_transforms import ( |
|
|
convert_to_rgb, |
|
|
normalize, |
|
|
rescale, |
|
|
resize, |
|
|
to_channel_dimension_format, |
|
|
) |
|
|
from transformers.processing_utils import ImagesKwargs |
|
|
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict |
|
|
from transformers.image_utils import ImageInput, ChannelDimension, PILImageResampling, to_numpy_array |
|
|
from einops import rearrange |
|
|
|
|
|
class LlavaUHDV3ImageProcessor(BaseImageProcessor): |
|
|
model_input_names = ["pixel_values", "grid_hws"] |
|
|
def __init__( |
|
|
self, |
|
|
image_mean=(0.5, 0.5, 0.5), |
|
|
image_std=(0.5, 0.5, 0.5), |
|
|
size=(400, 400), |
|
|
crop_size = None, |
|
|
resample=PILImageResampling.BICUBIC, |
|
|
rescale_factor=1 / 255, |
|
|
data_format=ChannelDimension.FIRST, |
|
|
scale_resolution=1580, |
|
|
patch_size=10, |
|
|
any_res=True, |
|
|
allow_upscale=True, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
crop_size = crop_size if crop_size is not None else {"height": 400, "width": 400} |
|
|
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") |
|
|
self.image_mean = image_mean |
|
|
self.image_std = image_std |
|
|
self.size = size |
|
|
self.resample = resample |
|
|
self.rescale_factor = rescale_factor |
|
|
self.data_format = data_format |
|
|
self.crop_size = crop_size |
|
|
self.scale_resolution = scale_resolution |
|
|
self.patch_size = patch_size |
|
|
self.any_res = any_res |
|
|
self.allow_upscale = allow_upscale |
|
|
|
|
|
def preprocess(self, images, max_resolution=None, upscale_rate=1.4, return_tensors = 'pt', **kwargs) -> BatchFeature: |
|
|
if max_resolution is not None: |
|
|
scale_resolution = max_resolution |
|
|
else: |
|
|
scale_resolution = self.scale_resolution |
|
|
|
|
|
if images is not None: |
|
|
pixel_values, grid_hws = [], [] |
|
|
for image in images if isinstance(images, list) else [images]: |
|
|
image = self._preprocess(image, scale_resolution, self.patch_size, self.any_res, self.allow_upscale, upscale_rate=upscale_rate) |
|
|
if not torch.is_tensor(image): |
|
|
image = torch.tensor(image) |
|
|
_,H,W = image.shape |
|
|
grid_h = int(H // self.patch_size) |
|
|
grid_w = int(W // self.patch_size) |
|
|
grid_hw = (grid_h, grid_w) |
|
|
patches = rearrange(image, "c (h p1) (w p2) -> h w c p1 p2", h=grid_h, w=grid_w) |
|
|
patches = rearrange(patches, "h w c p1 p2 -> (h w) c p1 p2") |
|
|
pixel_values.append(patches) |
|
|
grid_hws.append(grid_hw) |
|
|
pixel_values = torch.concat(pixel_values, dim=0) |
|
|
grid_hws = torch.tensor(grid_hws) |
|
|
data = { |
|
|
"pixel_values": pixel_values, |
|
|
"grid_hws": grid_hws |
|
|
} |
|
|
return BatchFeature(data=data, tensor_type=return_tensors) |
|
|
|
|
|
def _preprocess(self, image, scale_resolution=1580, patch_size=10, any_res=True, allow_upscale=True, upscale_rate=1.4): |
|
|
original_size = image.size |
|
|
soft_patch_size = patch_size * 8 |
|
|
best_size = self.find_best_resize( |
|
|
original_size, scale_resolution, soft_patch_size, allow_upscale=allow_upscale, upscale_rate=upscale_rate, any_res=any_res |
|
|
) |
|
|
|
|
|
source_image = image.resize(best_size, Image.Resampling.BICUBIC) |
|
|
source_image = [source_image] |
|
|
transforms = [ |
|
|
convert_to_rgb, |
|
|
to_numpy_array |
|
|
] |
|
|
transforms.append(partial(rescale, scale=self.rescale_factor, data_format=self.data_format)) |
|
|
transforms.append(partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format)) |
|
|
transforms.append(partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format)) |
|
|
image = reduce(lambda x, f: [*map(f, x)], transforms, source_image) |
|
|
return image[0] if len(image) == 1 else image |
|
|
|
|
|
def ensure_divide(self, length, patch_size): |
|
|
return max(math.floor(length / patch_size) * patch_size, patch_size) |
|
|
|
|
|
def find_best_resize(self, original_size, scale_resolution, patch_size, allow_upscale=False, upscale_rate=1.4, any_res=False): |
|
|
width, height = original_size |
|
|
max_edge = 5120 |
|
|
scale_resolution_low = 512 |
|
|
if any_res: |
|
|
if allow_upscale: |
|
|
width *= upscale_rate |
|
|
height *= upscale_rate |
|
|
scale_resolution_low = 560 |
|
|
r = width / height |
|
|
if (width * height > scale_resolution * scale_resolution): |
|
|
height = int(scale_resolution / math.sqrt(r)) |
|
|
width = int(height * r) |
|
|
if (width * height < scale_resolution_low * scale_resolution_low): |
|
|
height = int(scale_resolution_low / math.sqrt(r)) |
|
|
width = int(height * r) |
|
|
if max(width, height) > max_edge: |
|
|
scale = max_edge / max(width, height) |
|
|
width = int(width * scale) |
|
|
height = int(height * scale) |
|
|
else: |
|
|
if (width * height > scale_resolution * scale_resolution) or allow_upscale: |
|
|
r = width / height |
|
|
height = int(scale_resolution / math.sqrt(r)) |
|
|
width = int(height * r) |
|
|
best_width = self.ensure_divide(width, patch_size) |
|
|
best_height = self.ensure_divide(height, patch_size) |
|
|
best_width = min(best_width, max_edge) |
|
|
best_height = min(best_height, max_edge) |
|
|
return (best_width, best_height) |
|
|
|
|
|
__all__ = ["LlavaUHDV3ImageProcessor"] |