|
|
from typing import List, Optional, Tuple, Unpack, cast |
|
|
|
|
|
import numpy as np |
|
|
import transformers.image_transforms as image_transforms |
|
|
import transformers.image_utils as image_utils |
|
|
from numpy.typing import NDArray |
|
|
from PIL.Image import Image |
|
|
from torch import Tensor |
|
|
from transformers.feature_extraction_utils import BatchFeature |
|
|
from transformers.image_processing_utils import BaseImageProcessor |
|
|
from transformers.image_processing_utils_fast import BaseImageProcessorFast |
|
|
from transformers.image_utils import ImageInput, VideoInput |
|
|
from transformers.models.siglip.image_processing_siglip import SiglipImageProcessor |
|
|
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin |
|
|
from transformers.tokenization_utils import PreTrainedTokenizer |
|
|
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, TextInput |
|
|
|
|
|
|
|
|
class VILAProcessorKwargs(ProcessingKwargs, total=False): |
|
|
_defaults = {} |
|
|
|
|
|
|
|
|
class VILAProcessorOutput(BatchFeature): |
|
|
input_ids: List[List[int]] | NDArray[np.int64] | Tensor |
|
|
attention_mask: List[List[int]] | NDArray[np.int64] | Tensor |
|
|
pixel_values: Optional[List[NDArray[np.float32]] | NDArray[np.float32] | Tensor] |
|
|
|
|
|
|
|
|
class VILAProcessor(ProcessorMixin): |
|
|
attributes: List[str] = [ |
|
|
"image_processor", |
|
|
"tokenizer", |
|
|
] |
|
|
image_processor_class: str = "AutoImageProcessor" |
|
|
tokenizer_class: str = "AutoTokenizer" |
|
|
|
|
|
|
|
|
image_processor: BaseImageProcessor | BaseImageProcessorFast |
|
|
tokenizer: PreTrainedTokenizerBase |
|
|
|
|
|
|
|
|
image_pad_len: int |
|
|
image_token: str |
|
|
max_tiles: int |
|
|
min_tiles: int |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
image_processor: BaseImageProcessor, |
|
|
tokenizer: PreTrainedTokenizer, |
|
|
*, |
|
|
image_pad_len: int, |
|
|
image_token: str, |
|
|
max_tiles: int, |
|
|
min_tiles: int, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__( |
|
|
image_processor, |
|
|
tokenizer, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
self.image_pad_len = image_pad_len |
|
|
self.image_token = image_token |
|
|
self.max_tiles = max_tiles |
|
|
self.min_tiles = min_tiles |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
images: Optional[ImageInput] = None, |
|
|
text: Optional[TextInput | List[TextInput]] = None, |
|
|
audio: None = None, |
|
|
videos: Optional[VideoInput] = None, |
|
|
**kwargs: Unpack[VILAProcessorKwargs], |
|
|
) -> VILAProcessorOutput: |
|
|
|
|
|
assert text is not None and text != [], "text must be provided" |
|
|
assert not kwargs.get( |
|
|
"is_split_into_words", False |
|
|
), "is_split_into_words=True is not supported" |
|
|
|
|
|
output_kwargs = self._merge_kwargs( |
|
|
VILAProcessorKwargs, |
|
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
if images is not None and images != []: |
|
|
image_inputs, num_cropped_images = self._process_images( |
|
|
images=images, |
|
|
**output_kwargs["images_kwargs"], |
|
|
) |
|
|
else: |
|
|
|
|
|
image_inputs = BatchFeature() |
|
|
num_cropped_images = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text = text if isinstance(text, list) else [text] |
|
|
|
|
|
text = self._pad_image_tokens_by_num_crops( |
|
|
text, |
|
|
num_cropped_images=num_cropped_images, |
|
|
) |
|
|
|
|
|
text = self._pad_image_tokens_by_num_embeddings( |
|
|
text, |
|
|
) |
|
|
|
|
|
text_inputs = self.tokenizer.__call__( |
|
|
text, |
|
|
**output_kwargs["text_kwargs"], |
|
|
) |
|
|
|
|
|
return VILAProcessorOutput( |
|
|
data={ |
|
|
**text_inputs, |
|
|
**image_inputs, |
|
|
} |
|
|
) |
|
|
|
|
|
def _crop_image( |
|
|
self, |
|
|
image: Image, |
|
|
) -> List[Image]: |
|
|
"""Crops the image into multiple tiles. |
|
|
|
|
|
Args: |
|
|
image: The image to be cropped. |
|
|
|
|
|
Returns: |
|
|
The cropped images. |
|
|
""" |
|
|
|
|
|
|
|
|
assert isinstance(self.image_processor, SiglipImageProcessor) |
|
|
|
|
|
assert self.image_processor.size["height"] == self.image_processor.size["width"] |
|
|
cropped_size = self.image_processor.size["height"] |
|
|
|
|
|
cropped_images: List[Image] = dynamic_preprocess( |
|
|
image, |
|
|
min_num=self.min_tiles, |
|
|
max_num=self.max_tiles, |
|
|
image_size=cropped_size, |
|
|
) |
|
|
|
|
|
return cropped_images |
|
|
|
|
|
def _pad_image_tokens_by_num_crops( |
|
|
self, |
|
|
text: List[TextInput], |
|
|
*, |
|
|
num_cropped_images: List[int], |
|
|
) -> List[TextInput]: |
|
|
"""Pads each <image> to num_cropped_images of "<image>\n\n". |
|
|
|
|
|
Args: |
|
|
text: The text to be padded. |
|
|
num_cropped_images: The number of cropped images for each image token. |
|
|
|
|
|
Returns: |
|
|
The padded text. |
|
|
""" |
|
|
|
|
|
num_images = len(num_cropped_images) |
|
|
num_image_tokens = sum([item.count(self.image_token) for item in text]) |
|
|
assert num_images == num_image_tokens, ( |
|
|
f"Number of image tokens ({num_image_tokens}) in text does not match " |
|
|
f"the number of images ({num_images})." |
|
|
) |
|
|
|
|
|
assert all( |
|
|
image_pad_len > 0 for image_pad_len in num_cropped_images |
|
|
), "All image padding lengths should be positive integers." |
|
|
|
|
|
|
|
|
image_idx = 0 |
|
|
padded_text: List[TextInput] = [] |
|
|
|
|
|
for i in range(len(text)): |
|
|
padded_text_item = "" |
|
|
remaining_text = text[i] |
|
|
|
|
|
while True: |
|
|
token_pos = remaining_text.find(self.image_token) |
|
|
if token_pos == -1: |
|
|
padded_text_item += remaining_text |
|
|
break |
|
|
|
|
|
padded_text_item += remaining_text[:token_pos] + ( |
|
|
(self.image_token + "\n") * num_cropped_images[image_idx] |
|
|
) |
|
|
|
|
|
image_idx += 1 |
|
|
remaining_text = remaining_text[token_pos + len(self.image_token) :] |
|
|
|
|
|
padded_text.append(padded_text_item) |
|
|
|
|
|
return padded_text |
|
|
|
|
|
def _pad_image_tokens_by_num_embeddings( |
|
|
self, |
|
|
text: List[TextInput], |
|
|
) -> List[TextInput]: |
|
|
"""Pads each <image> to image_pad_len times of "<image>". |
|
|
|
|
|
Args: |
|
|
text: The text to be padded. |
|
|
|
|
|
Returns: |
|
|
The padded text. |
|
|
""" |
|
|
padded_text: List[TextInput] = [] |
|
|
|
|
|
for i in range(len(text)): |
|
|
padded_text_item = "" |
|
|
remaining_text = text[i] |
|
|
|
|
|
while True: |
|
|
token_pos = remaining_text.find(self.image_token) |
|
|
if token_pos == -1: |
|
|
padded_text_item += remaining_text |
|
|
break |
|
|
|
|
|
padded_text_item += remaining_text[:token_pos] + ( |
|
|
self.image_token * self.image_pad_len |
|
|
) |
|
|
|
|
|
remaining_text = remaining_text[token_pos + len(self.image_token) :] |
|
|
|
|
|
padded_text.append(padded_text_item) |
|
|
|
|
|
return padded_text |
|
|
|
|
|
def _process_images( |
|
|
self, |
|
|
images: ImageInput, |
|
|
**kwargs: Unpack[VILAProcessorKwargs], |
|
|
) -> Tuple[BatchFeature, List[int]]: |
|
|
images_flatten = cast( |
|
|
List[Image] | List[NDArray] | List[Tensor], |
|
|
image_utils.make_flat_list_of_images(images), |
|
|
) |
|
|
|
|
|
cropped_images: List[Image] = [] |
|
|
num_cropped_images: List[int] = [] |
|
|
for image in images_flatten: |
|
|
pil_image: Image = image_transforms.to_pil_image(image) |
|
|
single_cropped_images = self._crop_image(pil_image) |
|
|
|
|
|
cropped_images.extend(single_cropped_images) |
|
|
num_cropped_images.append(len(single_cropped_images)) |
|
|
|
|
|
image_inputs = self.image_processor( |
|
|
cropped_images, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
return image_inputs, num_cropped_images |
|
|
|
|
|
|
|
|
def dynamic_preprocess( |
|
|
image, min_num=1, max_num=12, image_size=384, use_thumbnail=True |
|
|
): |
|
|
orig_width, orig_height = image.size |
|
|
aspect_ratio = orig_width / orig_height |
|
|
|
|
|
|
|
|
target_ratios = { |
|
|
(i, j) |
|
|
for n in range(min_num, max_num + 1) |
|
|
for i in range(1, n + 1) |
|
|
for j in range(1, n + 1) |
|
|
if i * j <= max_num and i * j >= min_num |
|
|
} |
|
|
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
|
|
|
|
|
|
|
|
target_aspect_ratio = find_closest_aspect_ratio( |
|
|
aspect_ratio, target_ratios, orig_width, orig_height, image_size |
|
|
) |
|
|
|
|
|
|
|
|
target_width = image_size * target_aspect_ratio[0] |
|
|
target_height = image_size * target_aspect_ratio[1] |
|
|
blocks = target_aspect_ratio[0] * target_aspect_ratio[1] |
|
|
|
|
|
|
|
|
resized_img = image.resize((target_width, target_height)) |
|
|
processed_images = [] |
|
|
for i in range(blocks): |
|
|
box = ( |
|
|
(i % (target_width // image_size)) * image_size, |
|
|
(i // (target_width // image_size)) * image_size, |
|
|
((i % (target_width // image_size)) + 1) * image_size, |
|
|
((i // (target_width // image_size)) + 1) * image_size, |
|
|
) |
|
|
|
|
|
split_img = resized_img.crop(box) |
|
|
processed_images.append(split_img) |
|
|
assert len(processed_images) == blocks |
|
|
if use_thumbnail and len(processed_images) != 1: |
|
|
thumbnail_img = image.resize((image_size, image_size)) |
|
|
processed_images.append(thumbnail_img) |
|
|
return processed_images |
|
|
|
|
|
|
|
|
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): |
|
|
best_ratio_diff = float("inf") |
|
|
best_ratio = (1, 1) |
|
|
area = width * height |
|
|
for ratio in target_ratios: |
|
|
target_aspect_ratio = ratio[0] / ratio[1] |
|
|
ratio_diff = abs(aspect_ratio - target_aspect_ratio) |
|
|
if ratio_diff < best_ratio_diff: |
|
|
best_ratio_diff = ratio_diff |
|
|
best_ratio = ratio |
|
|
elif ratio_diff == best_ratio_diff: |
|
|
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: |
|
|
best_ratio = ratio |
|
|
return best_ratio |
|
|
|