from typing import List, Optional, Tuple, Unpack, cast import numpy as np import transformers.image_transforms as image_transforms import transformers.image_utils as image_utils from numpy.typing import NDArray from PIL.Image import Image from torch import Tensor from transformers.feature_extraction_utils import BatchFeature from transformers.image_processing_utils import BaseImageProcessor from transformers.image_processing_utils_fast import BaseImageProcessorFast from transformers.image_utils import ImageInput, VideoInput from transformers.models.siglip.image_processing_siglip import SiglipImageProcessor from transformers.processing_utils import ProcessingKwargs, ProcessorMixin from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils_base import PreTrainedTokenizerBase, TextInput class VILAProcessorKwargs(ProcessingKwargs, total=False): _defaults = {} # type: ignore class VILAProcessorOutput(BatchFeature): input_ids: List[List[int]] | NDArray[np.int64] | Tensor attention_mask: List[List[int]] | NDArray[np.int64] | Tensor pixel_values: Optional[List[NDArray[np.float32]] | NDArray[np.float32] | Tensor] class VILAProcessor(ProcessorMixin): attributes: List[str] = [ "image_processor", "tokenizer", ] image_processor_class: str = "AutoImageProcessor" tokenizer_class: str = "AutoTokenizer" # Attributes. image_processor: BaseImageProcessor | BaseImageProcessorFast tokenizer: PreTrainedTokenizerBase # Configuration parameters. image_pad_len: int image_token: str max_tiles: int min_tiles: int def __init__( self, image_processor: BaseImageProcessor, tokenizer: PreTrainedTokenizer, *, image_pad_len: int, image_token: str, max_tiles: int, min_tiles: int, **kwargs, ): super().__init__( image_processor, tokenizer, **kwargs, ) self.image_pad_len = image_pad_len self.image_token = image_token self.max_tiles = max_tiles self.min_tiles = min_tiles def __call__( self, images: Optional[ImageInput] = None, text: Optional[TextInput | List[TextInput]] = None, audio: None = None, videos: Optional[VideoInput] = None, **kwargs: Unpack[VILAProcessorKwargs], ) -> VILAProcessorOutput: # Validate arguments. assert text is not None and text != [], "text must be provided" assert not kwargs.get( "is_split_into_words", False ), "is_split_into_words=True is not supported" output_kwargs = self._merge_kwargs( VILAProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) # Process images. if images is not None and images != []: image_inputs, num_cropped_images = self._process_images( images=images, **output_kwargs["images_kwargs"], ) else: # If no images are provided, do not define pixel_values. image_inputs = BatchFeature() num_cropped_images = [] # TODO: video processing. # Process text. text = text if isinstance(text, list) else [text] text = self._pad_image_tokens_by_num_crops( text, num_cropped_images=num_cropped_images, ) text = self._pad_image_tokens_by_num_embeddings( text, ) text_inputs = self.tokenizer.__call__( text, **output_kwargs["text_kwargs"], ) return VILAProcessorOutput( data={ **text_inputs, **image_inputs, } ) def _crop_image( self, image: Image, ) -> List[Image]: """Crops the image into multiple tiles. Args: image: The image to be cropped. Returns: The cropped images. """ # TODO: Support more image processors. assert isinstance(self.image_processor, SiglipImageProcessor) assert self.image_processor.size["height"] == self.image_processor.size["width"] cropped_size = self.image_processor.size["height"] cropped_images: List[Image] = dynamic_preprocess( image, min_num=self.min_tiles, max_num=self.max_tiles, image_size=cropped_size, ) return cropped_images def _pad_image_tokens_by_num_crops( self, text: List[TextInput], *, num_cropped_images: List[int], ) -> List[TextInput]: """Pads each to num_cropped_images of "\n\n". Args: text: The text to be padded. num_cropped_images: The number of cropped images for each image token. Returns: The padded text. """ # Validate arguments. num_images = len(num_cropped_images) num_image_tokens = sum([item.count(self.image_token) for item in text]) assert num_images == num_image_tokens, ( f"Number of image tokens ({num_image_tokens}) in text does not match " f"the number of images ({num_images})." ) assert all( image_pad_len > 0 for image_pad_len in num_cropped_images ), "All image padding lengths should be positive integers." # Pad image tokens. image_idx = 0 padded_text: List[TextInput] = [] for i in range(len(text)): padded_text_item = "" remaining_text = text[i] while True: token_pos = remaining_text.find(self.image_token) if token_pos == -1: padded_text_item += remaining_text break padded_text_item += remaining_text[:token_pos] + ( (self.image_token + "\n") * num_cropped_images[image_idx] ) image_idx += 1 remaining_text = remaining_text[token_pos + len(self.image_token) :] padded_text.append(padded_text_item) return padded_text def _pad_image_tokens_by_num_embeddings( self, text: List[TextInput], ) -> List[TextInput]: """Pads each to image_pad_len times of "". Args: text: The text to be padded. Returns: The padded text. """ padded_text: List[TextInput] = [] for i in range(len(text)): padded_text_item = "" remaining_text = text[i] while True: token_pos = remaining_text.find(self.image_token) if token_pos == -1: padded_text_item += remaining_text break padded_text_item += remaining_text[:token_pos] + ( self.image_token * self.image_pad_len ) remaining_text = remaining_text[token_pos + len(self.image_token) :] padded_text.append(padded_text_item) return padded_text def _process_images( self, images: ImageInput, **kwargs: Unpack[VILAProcessorKwargs], ) -> Tuple[BatchFeature, List[int]]: images_flatten = cast( List[Image] | List[NDArray] | List[Tensor], image_utils.make_flat_list_of_images(images), ) cropped_images: List[Image] = [] num_cropped_images: List[int] = [] for image in images_flatten: pil_image: Image = image_transforms.to_pil_image(image) single_cropped_images = self._crop_image(pil_image) cropped_images.extend(single_cropped_images) num_cropped_images.append(len(single_cropped_images)) image_inputs = self.image_processor( cropped_images, **kwargs, ) return image_inputs, num_cropped_images def dynamic_preprocess( image, min_num=1, max_num=12, image_size=384, use_thumbnail=True ): orig_width, orig_height = image.size aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio target_ratios = { (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num } target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target target_aspect_ratio = find_closest_aspect_ratio( aspect_ratio, target_ratios, orig_width, orig_height, image_size ) # calculate the target width and height target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): box = ( (i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size, ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) assert len(processed_images) == blocks if use_thumbnail and len(processed_images) != 1: thumbnail_img = image.resize((image_size, image_size)) processed_images.append(thumbnail_img) return processed_images def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): best_ratio_diff = float("inf") best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] ratio_diff = abs(aspect_ratio - target_aspect_ratio) if ratio_diff < best_ratio_diff: best_ratio_diff = ratio_diff best_ratio = ratio elif ratio_diff == best_ratio_diff: if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: best_ratio = ratio return best_ratio