|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Union |
|
|
from types import SimpleNamespace |
|
|
from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import Qwen2VLImageProcessorFast |
|
|
from functools import partial, lru_cache |
|
|
from transformers.image_processing_utils import BatchFeature |
|
|
from transformers.image_utils import ( |
|
|
ChannelDimension, |
|
|
SizeDict, |
|
|
make_flat_list_of_images, |
|
|
valid_images, |
|
|
pil_torch_interpolation_mapping, |
|
|
) |
|
|
from torchvision.transforms.v2 import functional as F |
|
|
import torch |
|
|
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rescale(image, scale): |
|
|
return image * scale |
|
|
|
|
|
|
|
|
def normalize(image, mean, std): |
|
|
return F.normalize(image, mean, std) |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=10) |
|
|
def _fuse_mean_std_and_rescale_factor( |
|
|
do_normalize: Optional[bool] = None, |
|
|
image_mean: Optional[Union[float, list[float]]] = None, |
|
|
image_std: Optional[Union[float, list[float]]] = None, |
|
|
do_rescale: Optional[bool] = None, |
|
|
rescale_factor: Optional[float] = None, |
|
|
device: Optional["torch.device"] = None, |
|
|
) -> tuple: |
|
|
if do_rescale and do_normalize: |
|
|
|
|
|
image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor) |
|
|
image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor) |
|
|
do_rescale = False |
|
|
return image_mean, image_std, do_rescale |
|
|
|
|
|
|
|
|
def rescale_and_normalize( |
|
|
images: "torch.Tensor", |
|
|
do_rescale: bool, |
|
|
rescale_factor: float, |
|
|
do_normalize: bool, |
|
|
image_mean: Union[float, list[float]], |
|
|
image_std: Union[float, list[float]], |
|
|
) -> "torch.Tensor": |
|
|
""" |
|
|
Rescale and normalize images. |
|
|
""" |
|
|
image_mean, image_std, do_rescale = _fuse_mean_std_and_rescale_factor( |
|
|
do_normalize=do_normalize, |
|
|
image_mean=image_mean, |
|
|
image_std=image_std, |
|
|
do_rescale=do_rescale, |
|
|
rescale_factor=rescale_factor, |
|
|
device=images.device, |
|
|
) |
|
|
|
|
|
if do_normalize: |
|
|
images = normalize(images.to(dtype=torch.float32), image_mean, image_std) |
|
|
elif do_rescale: |
|
|
images = rescale(images, rescale_factor) |
|
|
images = images.to(OpenPanguVLImageProcessorFast.dtype) |
|
|
|
|
|
return images |
|
|
|
|
|
|
|
|
from collections import defaultdict |
|
|
def _group_images_by_shape(nested_images, is_nested: bool = False): |
|
|
"""Helper function to flatten a single level of nested image structures and group by shape.""" |
|
|
grouped_images = defaultdict(list) |
|
|
grouped_images_index = {} |
|
|
nested_images = [nested_images] if not is_nested else nested_images |
|
|
for i, sublist in enumerate(nested_images): |
|
|
for j, image in enumerate(sublist): |
|
|
key = (i, j) if is_nested else j |
|
|
shape = image.shape[1:] |
|
|
grouped_images[shape].append(image) |
|
|
grouped_images_index[key] = (shape, len(grouped_images[shape]) - 1) |
|
|
|
|
|
return grouped_images, grouped_images_index |
|
|
|
|
|
|
|
|
def _reconstruct_nested_structure(indices, processed_images): |
|
|
"""Helper function to reconstruct a single level nested structure.""" |
|
|
|
|
|
max_outer_idx = max(idx[0] for idx in indices.keys()) |
|
|
|
|
|
|
|
|
result = [None] * (max_outer_idx + 1) |
|
|
|
|
|
|
|
|
nested_indices = defaultdict(list) |
|
|
for i, j in indices.keys(): |
|
|
nested_indices[i].append(j) |
|
|
|
|
|
for i in range(max_outer_idx + 1): |
|
|
if i in nested_indices: |
|
|
inner_max_idx = max(nested_indices[i]) |
|
|
inner_list = [None] * (inner_max_idx + 1) |
|
|
for j in range(inner_max_idx + 1): |
|
|
if (i, j) in indices: |
|
|
shape, idx = indices[(i, j)] |
|
|
inner_list[j] = processed_images[shape][idx] |
|
|
result[i] = inner_list |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def group_images_by_shape( |
|
|
images: Union[list["torch.Tensor"], "torch.Tensor"], |
|
|
disable_grouping: bool, |
|
|
is_nested: bool = False, |
|
|
) -> tuple[ |
|
|
dict[tuple[int, int], list["torch.Tensor"]], dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]] |
|
|
]: |
|
|
|
|
|
if disable_grouping is None: |
|
|
device = images[0][0].device if is_nested else images[0].device |
|
|
disable_grouping = device == "cpu" |
|
|
|
|
|
if disable_grouping: |
|
|
if is_nested: |
|
|
return {(i, j): images[i][j].unsqueeze(0) for i in range(len(images)) for j in range(len(images[i]))}, { |
|
|
(i, j): ((i, j), 0) for i in range(len(images)) for j in range(len(images[i])) |
|
|
} |
|
|
else: |
|
|
return {i: images[i].unsqueeze(0) for i in range(len(images))}, {i: (i, 0) for i in range(len(images))} |
|
|
|
|
|
|
|
|
grouped_images, grouped_images_index = _group_images_by_shape(images, is_nested) |
|
|
|
|
|
|
|
|
grouped_images = {shape: torch.stack(images_list, dim=0) for shape, images_list in grouped_images.items()} |
|
|
|
|
|
return grouped_images, grouped_images_index |
|
|
|
|
|
|
|
|
def reorder_images( |
|
|
processed_images: dict[tuple[int, int], "torch.Tensor"], |
|
|
grouped_images_index: dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]], |
|
|
is_nested: bool = False, |
|
|
) -> Union[list["torch.Tensor"], "torch.Tensor"]: |
|
|
if not is_nested: |
|
|
return [ |
|
|
processed_images[grouped_images_index[i][0]][grouped_images_index[i][1]] |
|
|
for i in range(len(grouped_images_index)) |
|
|
] |
|
|
|
|
|
return _reconstruct_nested_structure(grouped_images_index, processed_images) |
|
|
|
|
|
|
|
|
class OpenPanguVLImageProcessorFast(Qwen2VLImageProcessorFast): |
|
|
temporal_patch_size = 1 |
|
|
min_pxl = 28 |
|
|
min_edge = 56 |
|
|
dtype = torch.bfloat16 |
|
|
|
|
|
def _prepare_input_images( |
|
|
self, |
|
|
images, |
|
|
do_convert_rgb, |
|
|
input_data_format, |
|
|
device, |
|
|
) -> list["torch.Tensor"]: |
|
|
""" |
|
|
Prepare the input images for processing. |
|
|
""" |
|
|
images = self._prepare_images_structure(images) |
|
|
process_image_fn = partial( |
|
|
self._process_image, |
|
|
do_convert_rgb=do_convert_rgb, |
|
|
input_data_format=input_data_format, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
processed_images = [] |
|
|
for image in images: |
|
|
if image.size[0] <= OpenPanguVLImageProcessorFast.min_pxl or image.size[1] <= OpenPanguVLImageProcessorFast.min_pxl: |
|
|
|
|
|
if image.size[0] >= image.size[1]: |
|
|
aspect_ratio = OpenPanguVLImageProcessorFast.min_edge * 1.0 / image.size[1] |
|
|
new_image_height = OpenPanguVLImageProcessorFast.min_edge |
|
|
new_image_width = int(aspect_ratio * image.size[0]) |
|
|
else: |
|
|
aspect_ratio = OpenPanguVLImageProcessorFast.min_edge * 1.0 / image.size[0] |
|
|
new_image_height = int(aspect_ratio * image.size[1]) |
|
|
new_image_width = OpenPanguVLImageProcessorFast.min_edge |
|
|
image = image.resize((new_image_width, new_image_height)) |
|
|
|
|
|
processed_images.append(process_image_fn(image)) |
|
|
return processed_images |
|
|
|
|
|
def preprocess( |
|
|
self, |
|
|
images = None, |
|
|
videos = None, |
|
|
do_resize = None, |
|
|
size = None, |
|
|
resample = None, |
|
|
do_rescale = None, |
|
|
rescale_factor = None, |
|
|
do_normalize = None, |
|
|
image_mean = None, |
|
|
image_std = None, |
|
|
min_pixels = None, |
|
|
max_pixels = None, |
|
|
patch_size = None, |
|
|
temporal_patch_size = None, |
|
|
merge_size = None, |
|
|
do_convert_rgb = None, |
|
|
return_tensors = None, |
|
|
data_format = ChannelDimension.FIRST, |
|
|
input_data_format = None, |
|
|
device = None, |
|
|
disable_grouping = False, |
|
|
**kwargs, |
|
|
): |
|
|
temporal_patch_size=OpenPanguVLImageProcessorFast.temporal_patch_size |
|
|
params = self._resolve_preprocess_params( |
|
|
do_resize=do_resize, |
|
|
size=size, |
|
|
min_pixels=min_pixels, |
|
|
max_pixels=max_pixels, |
|
|
resample=resample, |
|
|
do_rescale=do_rescale, |
|
|
rescale_factor=rescale_factor, |
|
|
do_normalize=do_normalize, |
|
|
image_mean=image_mean, |
|
|
image_std=image_std, |
|
|
patch_size=patch_size, |
|
|
temporal_patch_size=temporal_patch_size, |
|
|
merge_size=merge_size, |
|
|
do_convert_rgb=do_convert_rgb, |
|
|
) |
|
|
|
|
|
data = self._process_images( |
|
|
images, |
|
|
params, |
|
|
input_data_format, |
|
|
device, |
|
|
disable_grouping, |
|
|
return_tensors |
|
|
) |
|
|
|
|
|
return data |
|
|
|
|
|
def _resolve_preprocess_params(self, **kwargs): |
|
|
params = SimpleNamespace() |
|
|
for key, value in kwargs.items(): |
|
|
setattr(params, key, value if value is not None else getattr(self, key)) |
|
|
if params.size is None: |
|
|
params.size = {"shortest_edge": params.min_pixels, "longest_edge": params.max_pixels} |
|
|
params.size = SizeDict(**params.size) |
|
|
params.image_mean = tuple(params.image_mean) if params.image_mean else None |
|
|
params.image_std = tuple(params.image_std) if params.image_std else None |
|
|
return params |
|
|
|
|
|
def _process_images(self, images, params, input_data_format, device, disable_grouping, return_tensors): |
|
|
images = make_flat_list_of_images(images) |
|
|
if not valid_images(images): |
|
|
raise ValueError("Invalid image type.") |
|
|
|
|
|
images = self._prepare_input_images( |
|
|
images=images, |
|
|
do_convert_rgb=params.do_convert_rgb, |
|
|
input_data_format=input_data_format, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
data = self._preprocess( |
|
|
images=images, |
|
|
do_resize=params.do_resize, |
|
|
size=params.size, |
|
|
interpolation=pil_torch_interpolation_mapping.get(params.resample, params.resample), |
|
|
do_rescale=params.do_rescale, |
|
|
rescale_factor=params.rescale_factor, |
|
|
do_normalize=params.do_normalize, |
|
|
image_mean=params.image_mean, |
|
|
image_std=params.image_std, |
|
|
patch_size=params.patch_size, |
|
|
temporal_patch_size=params.temporal_patch_size, |
|
|
merge_size=params.merge_size, |
|
|
do_convert_rgb=params.do_convert_rgb, |
|
|
input_data_format=input_data_format, |
|
|
device=device, |
|
|
disable_grouping=disable_grouping, |
|
|
return_tensors=return_tensors, |
|
|
) |
|
|
|
|
|
return data |
|
|
|
|
|
def _preprocess( |
|
|
self, |
|
|
images: list["torch.Tensor"], |
|
|
do_resize: bool, |
|
|
size: SizeDict, |
|
|
interpolation: Optional["F.InterpolationMode"], |
|
|
do_rescale: bool, |
|
|
rescale_factor: float, |
|
|
do_normalize: bool, |
|
|
image_mean: Optional[Union[float, list[float]]], |
|
|
image_std: Optional[Union[float, list[float]]], |
|
|
patch_size: int, |
|
|
temporal_patch_size: int, |
|
|
merge_size: int, |
|
|
disable_grouping: Optional[bool], |
|
|
return_tensors, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) |
|
|
resized_images_grouped = {} |
|
|
for shape, stacked_images in grouped_images.items(): |
|
|
height, width = stacked_images.shape[-2:] |
|
|
if do_resize: |
|
|
resized_height, resized_width = smart_resize( |
|
|
height, |
|
|
width, |
|
|
factor=patch_size * merge_size, |
|
|
min_pixels=size["shortest_edge"], |
|
|
max_pixels=size["longest_edge"], |
|
|
) |
|
|
stacked_images = self.resize( |
|
|
image=stacked_images, |
|
|
size=SizeDict(height=resized_height, width=resized_width), |
|
|
interpolation=interpolation, |
|
|
) |
|
|
resized_images_grouped[shape] = stacked_images |
|
|
resized_images = reorder_images(resized_images_grouped, grouped_images_index) |
|
|
|
|
|
|
|
|
|
|
|
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) |
|
|
processed_images_grouped = {} |
|
|
processed_grids = {} |
|
|
for shape, stacked_images in grouped_images.items(): |
|
|
resized_height, resized_width = stacked_images.shape[-2:] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
patches = stacked_images |
|
|
if patches.ndim == 4: |
|
|
|
|
|
patches = patches.unsqueeze(1) |
|
|
if patches.shape[1] % temporal_patch_size != 0: |
|
|
repeats = patches[:, -1:].repeat(1, temporal_patch_size - 1, 1, 1, 1) |
|
|
patches = torch.cat([patches, repeats], dim=1) |
|
|
batch_size, grid_t, channel = patches.shape[:3] |
|
|
grid_t = grid_t // temporal_patch_size |
|
|
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size |
|
|
|
|
|
patches = patches.view( |
|
|
batch_size, |
|
|
grid_t, |
|
|
temporal_patch_size, |
|
|
channel, |
|
|
grid_h // merge_size, |
|
|
merge_size, |
|
|
patch_size, |
|
|
grid_w // merge_size, |
|
|
merge_size, |
|
|
patch_size, |
|
|
) |
|
|
|
|
|
|
|
|
patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9) |
|
|
flatten_patches = patches.reshape( |
|
|
batch_size, |
|
|
grid_t * grid_h * grid_w, |
|
|
channel * temporal_patch_size * patch_size * patch_size, |
|
|
) |
|
|
|
|
|
processed_images_grouped[shape] = flatten_patches |
|
|
processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size |
|
|
|
|
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index) |
|
|
processed_grids = reorder_images(processed_grids, grouped_images_index) |
|
|
pixel_values = torch.cat(processed_images, dim=0) |
|
|
image_grid_thw = torch.tensor(processed_grids) |
|
|
|
|
|
return BatchFeature( |
|
|
data={"pixel_values": pixel_values, |
|
|
"image_grid_thw": image_grid_thw}, tensor_type=return_tensors |
|
|
) |