| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Optional, Union |
| from types import SimpleNamespace |
| from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import Qwen2VLImageProcessorFast |
| from functools import partial, lru_cache |
| from transformers.image_processing_utils import BatchFeature |
| from transformers.image_utils import ( |
| ChannelDimension, |
| SizeDict, |
| make_flat_list_of_images, |
| valid_images, |
| pil_torch_interpolation_mapping, |
| ) |
| from torchvision.transforms.v2 import functional as F |
| import torch |
| from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize |
| |
| |
| |
| |
|
|
|
|
| def rescale(image, scale): |
| return image * scale |
|
|
|
|
| def normalize(image, mean, std): |
| return F.normalize(image, mean, std) |
|
|
|
|
| @lru_cache(maxsize=10) |
| def _fuse_mean_std_and_rescale_factor( |
| do_normalize: Optional[bool] = None, |
| image_mean: Optional[Union[float, list[float]]] = None, |
| image_std: Optional[Union[float, list[float]]] = None, |
| do_rescale: Optional[bool] = None, |
| rescale_factor: Optional[float] = None, |
| device: Optional["torch.device"] = None, |
| ) -> tuple: |
| if do_rescale and do_normalize: |
| |
| image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor) |
| image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor) |
| do_rescale = False |
| return image_mean, image_std, do_rescale |
|
|
|
|
| def rescale_and_normalize( |
| images: "torch.Tensor", |
| do_rescale: bool, |
| rescale_factor: float, |
| do_normalize: bool, |
| image_mean: Union[float, list[float]], |
| image_std: Union[float, list[float]], |
| ) -> "torch.Tensor": |
| """ |
| Rescale and normalize images. |
| """ |
| image_mean, image_std, do_rescale = _fuse_mean_std_and_rescale_factor( |
| do_normalize=do_normalize, |
| image_mean=image_mean, |
| image_std=image_std, |
| do_rescale=do_rescale, |
| rescale_factor=rescale_factor, |
| device=images.device, |
| ) |
| |
| if do_normalize: |
| images = normalize(images.to(dtype=torch.float32), image_mean, image_std) |
| elif do_rescale: |
| images = rescale(images, rescale_factor) |
| images = images.to(OpenPanguVLImageProcessorFast.dtype) |
|
|
| return images |
|
|
| |
| from collections import defaultdict |
| def _group_images_by_shape(nested_images, is_nested: bool = False): |
| """Helper function to flatten a single level of nested image structures and group by shape.""" |
| grouped_images = defaultdict(list) |
| grouped_images_index = {} |
| nested_images = [nested_images] if not is_nested else nested_images |
| for i, sublist in enumerate(nested_images): |
| for j, image in enumerate(sublist): |
| key = (i, j) if is_nested else j |
| shape = image.shape[1:] |
| grouped_images[shape].append(image) |
| grouped_images_index[key] = (shape, len(grouped_images[shape]) - 1) |
|
|
| return grouped_images, grouped_images_index |
|
|
|
|
| def _reconstruct_nested_structure(indices, processed_images): |
| """Helper function to reconstruct a single level nested structure.""" |
| |
| max_outer_idx = max(idx[0] for idx in indices.keys()) |
|
|
| |
| result = [None] * (max_outer_idx + 1) |
|
|
| |
| nested_indices = defaultdict(list) |
| for i, j in indices.keys(): |
| nested_indices[i].append(j) |
|
|
| for i in range(max_outer_idx + 1): |
| if i in nested_indices: |
| inner_max_idx = max(nested_indices[i]) |
| inner_list = [None] * (inner_max_idx + 1) |
| for j in range(inner_max_idx + 1): |
| if (i, j) in indices: |
| shape, idx = indices[(i, j)] |
| inner_list[j] = processed_images[shape][idx] |
| result[i] = inner_list |
|
|
| return result |
|
|
|
|
| def group_images_by_shape( |
| images: Union[list["torch.Tensor"], "torch.Tensor"], |
| disable_grouping: bool, |
| is_nested: bool = False, |
| ) -> tuple[ |
| dict[tuple[int, int], list["torch.Tensor"]], dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]] |
| ]: |
| |
| if disable_grouping is None: |
| device = images[0][0].device if is_nested else images[0].device |
| disable_grouping = device == "cpu" |
|
|
| if disable_grouping: |
| if is_nested: |
| return {(i, j): images[i][j].unsqueeze(0) for i in range(len(images)) for j in range(len(images[i]))}, { |
| (i, j): ((i, j), 0) for i in range(len(images)) for j in range(len(images[i])) |
| } |
| else: |
| return {i: images[i].unsqueeze(0) for i in range(len(images))}, {i: (i, 0) for i in range(len(images))} |
|
|
| |
| grouped_images, grouped_images_index = _group_images_by_shape(images, is_nested) |
|
|
| |
| grouped_images = {shape: torch.stack(images_list, dim=0) for shape, images_list in grouped_images.items()} |
|
|
| return grouped_images, grouped_images_index |
|
|
|
|
| def reorder_images( |
| processed_images: dict[tuple[int, int], "torch.Tensor"], |
| grouped_images_index: dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]], |
| is_nested: bool = False, |
| ) -> Union[list["torch.Tensor"], "torch.Tensor"]: |
| if not is_nested: |
| return [ |
| processed_images[grouped_images_index[i][0]][grouped_images_index[i][1]] |
| for i in range(len(grouped_images_index)) |
| ] |
|
|
| return _reconstruct_nested_structure(grouped_images_index, processed_images) |
|
|
|
|
| class OpenPanguVLImageProcessorFast(Qwen2VLImageProcessorFast): |
| temporal_patch_size = 1 |
| min_pxl = 28 |
| min_edge = 56 |
| dtype = torch.bfloat16 |
|
|
| def _prepare_input_images( |
| self, |
| images, |
| do_convert_rgb, |
| input_data_format, |
| device, |
| ) -> list["torch.Tensor"]: |
| """ |
| Prepare the input images for processing. |
| """ |
| images = self._prepare_images_structure(images) |
| process_image_fn = partial( |
| self._process_image, |
| do_convert_rgb=do_convert_rgb, |
| input_data_format=input_data_format, |
| device=device, |
| ) |
|
|
| processed_images = [] |
| for image in images: |
| if image.size[0] <= OpenPanguVLImageProcessorFast.min_pxl or image.size[1] <= OpenPanguVLImageProcessorFast.min_pxl: |
| |
| if image.size[0] >= image.size[1]: |
| aspect_ratio = OpenPanguVLImageProcessorFast.min_edge * 1.0 / image.size[1] |
| new_image_height = OpenPanguVLImageProcessorFast.min_edge |
| new_image_width = int(aspect_ratio * image.size[0]) |
| else: |
| aspect_ratio = OpenPanguVLImageProcessorFast.min_edge * 1.0 / image.size[0] |
| new_image_height = int(aspect_ratio * image.size[1]) |
| new_image_width = OpenPanguVLImageProcessorFast.min_edge |
| image = image.resize((new_image_width, new_image_height)) |
|
|
| processed_images.append(process_image_fn(image)) |
| return processed_images |
|
|
| def preprocess( |
| self, |
| images = None, |
| videos = None, |
| do_resize = None, |
| size = None, |
| resample = None, |
| do_rescale = None, |
| rescale_factor = None, |
| do_normalize = None, |
| image_mean = None, |
| image_std = None, |
| min_pixels = None, |
| max_pixels = None, |
| patch_size = None, |
| temporal_patch_size = None, |
| merge_size = None, |
| do_convert_rgb = None, |
| return_tensors = None, |
| data_format = ChannelDimension.FIRST, |
| input_data_format = None, |
| device = None, |
| disable_grouping = False, |
| **kwargs, |
| ): |
| temporal_patch_size=OpenPanguVLImageProcessorFast.temporal_patch_size |
| params = self._resolve_preprocess_params( |
| do_resize=do_resize, |
| size=size, |
| min_pixels=min_pixels, |
| max_pixels=max_pixels, |
| resample=resample, |
| do_rescale=do_rescale, |
| rescale_factor=rescale_factor, |
| do_normalize=do_normalize, |
| image_mean=image_mean, |
| image_std=image_std, |
| patch_size=patch_size, |
| temporal_patch_size=temporal_patch_size, |
| merge_size=merge_size, |
| do_convert_rgb=do_convert_rgb, |
| ) |
|
|
| data = self._process_images( |
| images, |
| params, |
| input_data_format, |
| device, |
| disable_grouping, |
| return_tensors |
| ) |
|
|
| return data |
|
|
| def _resolve_preprocess_params(self, **kwargs): |
| params = SimpleNamespace() |
| for key, value in kwargs.items(): |
| setattr(params, key, value if value is not None else getattr(self, key)) |
| if params.size is None: |
| params.size = {"shortest_edge": params.min_pixels, "longest_edge": params.max_pixels} |
| params.size = SizeDict(**params.size) |
| params.image_mean = tuple(params.image_mean) if params.image_mean else None |
| params.image_std = tuple(params.image_std) if params.image_std else None |
| return params |
|
|
| def _process_images(self, images, params, input_data_format, device, disable_grouping, return_tensors): |
| images = make_flat_list_of_images(images) |
| if not valid_images(images): |
| raise ValueError("Invalid image type.") |
|
|
| images = self._prepare_input_images( |
| images=images, |
| do_convert_rgb=params.do_convert_rgb, |
| input_data_format=input_data_format, |
| device=device, |
| ) |
|
|
| data = self._preprocess( |
| images=images, |
| do_resize=params.do_resize, |
| size=params.size, |
| interpolation=pil_torch_interpolation_mapping.get(params.resample, params.resample), |
| do_rescale=params.do_rescale, |
| rescale_factor=params.rescale_factor, |
| do_normalize=params.do_normalize, |
| image_mean=params.image_mean, |
| image_std=params.image_std, |
| patch_size=params.patch_size, |
| temporal_patch_size=params.temporal_patch_size, |
| merge_size=params.merge_size, |
| do_convert_rgb=params.do_convert_rgb, |
| input_data_format=input_data_format, |
| device=device, |
| disable_grouping=disable_grouping, |
| return_tensors=return_tensors, |
| ) |
| |
| return data |
|
|
| def _preprocess( |
| self, |
| images: list["torch.Tensor"], |
| do_resize: bool, |
| size: SizeDict, |
| interpolation: Optional["F.InterpolationMode"], |
| do_rescale: bool, |
| rescale_factor: float, |
| do_normalize: bool, |
| image_mean: Optional[Union[float, list[float]]], |
| image_std: Optional[Union[float, list[float]]], |
| patch_size: int, |
| temporal_patch_size: int, |
| merge_size: int, |
| disable_grouping: Optional[bool], |
| return_tensors, |
| **kwargs, |
| ): |
| |
| grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) |
| resized_images_grouped = {} |
| for shape, stacked_images in grouped_images.items(): |
| height, width = stacked_images.shape[-2:] |
| if do_resize: |
| resized_height, resized_width = smart_resize( |
| height, |
| width, |
| factor=patch_size * merge_size, |
| min_pixels=size["shortest_edge"], |
| max_pixels=size["longest_edge"], |
| ) |
| stacked_images = self.resize( |
| image=stacked_images, |
| size=SizeDict(height=resized_height, width=resized_width), |
| interpolation=interpolation, |
| ) |
| resized_images_grouped[shape] = stacked_images |
| resized_images = reorder_images(resized_images_grouped, grouped_images_index) |
|
|
| |
| |
| grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) |
| processed_images_grouped = {} |
| processed_grids = {} |
| for shape, stacked_images in grouped_images.items(): |
| resized_height, resized_width = stacked_images.shape[-2:] |
| |
| |
| |
| |
| patches = stacked_images |
| if patches.ndim == 4: |
| |
| patches = patches.unsqueeze(1) |
| if patches.shape[1] % temporal_patch_size != 0: |
| repeats = patches[:, -1:].repeat(1, temporal_patch_size - 1, 1, 1, 1) |
| patches = torch.cat([patches, repeats], dim=1) |
| batch_size, grid_t, channel = patches.shape[:3] |
| grid_t = grid_t // temporal_patch_size |
| grid_h, grid_w = resized_height // patch_size, resized_width // patch_size |
|
|
| patches = patches.view( |
| batch_size, |
| grid_t, |
| temporal_patch_size, |
| channel, |
| grid_h // merge_size, |
| merge_size, |
| patch_size, |
| grid_w // merge_size, |
| merge_size, |
| patch_size, |
| ) |
| |
| |
| patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9) |
| flatten_patches = patches.reshape( |
| batch_size, |
| grid_t * grid_h * grid_w, |
| channel * temporal_patch_size * patch_size * patch_size, |
| ) |
|
|
| processed_images_grouped[shape] = flatten_patches |
| processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size |
|
|
| processed_images = reorder_images(processed_images_grouped, grouped_images_index) |
| processed_grids = reorder_images(processed_grids, grouped_images_index) |
| pixel_values = torch.cat(processed_images, dim=0) |
| image_grid_thw = torch.tensor(processed_grids) |
|
|
| return BatchFeature( |
| data={"pixel_values": pixel_values, |
| "image_grid_thw": image_grid_thw}, tensor_type=return_tensors |
| ) |