NVILA-Lite-2B-hf-preview / processing_vila.py
AndyZijianZhang's picture
refactor: now processing_vila does not depend on any other files
4e95a1b
raw
history blame
10.4 kB
from typing import List, Optional, Tuple, Unpack, cast
import numpy as np
import transformers.image_transforms as image_transforms
import transformers.image_utils as image_utils
from numpy.typing import NDArray
from PIL.Image import Image
from torch import Tensor
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_processing_utils import BaseImageProcessor
from transformers.image_processing_utils_fast import BaseImageProcessorFast
from transformers.image_utils import ImageInput, VideoInput
from transformers.models.siglip.image_processing_siglip import SiglipImageProcessor
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, TextInput
class VILAProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {} # type: ignore
class VILAProcessorOutput(BatchFeature):
input_ids: List[List[int]] | NDArray[np.int64] | Tensor
attention_mask: List[List[int]] | NDArray[np.int64] | Tensor
pixel_values: Optional[List[NDArray[np.float32]] | NDArray[np.float32] | Tensor]
class VILAProcessor(ProcessorMixin):
attributes: List[str] = [
"image_processor",
"tokenizer",
]
image_processor_class: str = "AutoImageProcessor"
tokenizer_class: str = "AutoTokenizer"
# Attributes.
image_processor: BaseImageProcessor | BaseImageProcessorFast
tokenizer: PreTrainedTokenizerBase
# Configuration parameters.
image_pad_len: int
image_token: str
max_tiles: int
min_tiles: int
def __init__(
self,
image_processor: BaseImageProcessor,
tokenizer: PreTrainedTokenizer,
*,
image_pad_len: int,
image_token: str,
max_tiles: int,
min_tiles: int,
**kwargs,
):
super().__init__(
image_processor,
tokenizer,
**kwargs,
)
self.image_pad_len = image_pad_len
self.image_token = image_token
self.max_tiles = max_tiles
self.min_tiles = min_tiles
def __call__(
self,
images: Optional[ImageInput] = None,
text: Optional[TextInput | List[TextInput]] = None,
audio: None = None,
videos: Optional[VideoInput] = None,
**kwargs: Unpack[VILAProcessorKwargs],
) -> VILAProcessorOutput:
# Validate arguments.
assert text is not None and text != [], "text must be provided"
assert not kwargs.get(
"is_split_into_words", False
), "is_split_into_words=True is not supported"
output_kwargs = self._merge_kwargs(
VILAProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
# Process images.
if images is not None and images != []:
image_inputs, num_cropped_images = self._process_images(
images=images,
**output_kwargs["images_kwargs"],
)
else:
# If no images are provided, do not define pixel_values.
image_inputs = BatchFeature()
num_cropped_images = []
# TODO: video processing.
# Process text.
text = text if isinstance(text, list) else [text]
text = self._pad_image_tokens_by_num_crops(
text,
num_cropped_images=num_cropped_images,
)
text = self._pad_image_tokens_by_num_embeddings(
text,
)
text_inputs = self.tokenizer.__call__(
text,
**output_kwargs["text_kwargs"],
)
return VILAProcessorOutput(
data={
**text_inputs,
**image_inputs,
}
)
def _crop_image(
self,
image: Image,
) -> List[Image]:
"""Crops the image into multiple tiles.
Args:
image: The image to be cropped.
Returns:
The cropped images.
"""
# TODO: Support more image processors.
assert isinstance(self.image_processor, SiglipImageProcessor)
assert self.image_processor.size["height"] == self.image_processor.size["width"]
cropped_size = self.image_processor.size["height"]
cropped_images: List[Image] = dynamic_preprocess(
image,
min_num=self.min_tiles,
max_num=self.max_tiles,
image_size=cropped_size,
)
return cropped_images
def _pad_image_tokens_by_num_crops(
self,
text: List[TextInput],
*,
num_cropped_images: List[int],
) -> List[TextInput]:
"""Pads each <image> to num_cropped_images of "<image>\n\n".
Args:
text: The text to be padded.
num_cropped_images: The number of cropped images for each image token.
Returns:
The padded text.
"""
# Validate arguments.
num_images = len(num_cropped_images)
num_image_tokens = sum([item.count(self.image_token) for item in text])
assert num_images == num_image_tokens, (
f"Number of image tokens ({num_image_tokens}) in text does not match "
f"the number of images ({num_images})."
)
assert all(
image_pad_len > 0 for image_pad_len in num_cropped_images
), "All image padding lengths should be positive integers."
# Pad image tokens.
image_idx = 0
padded_text: List[TextInput] = []
for i in range(len(text)):
padded_text_item = ""
remaining_text = text[i]
while True:
token_pos = remaining_text.find(self.image_token)
if token_pos == -1:
padded_text_item += remaining_text
break
padded_text_item += remaining_text[:token_pos] + (
(self.image_token + "\n") * num_cropped_images[image_idx]
)
image_idx += 1
remaining_text = remaining_text[token_pos + len(self.image_token) :]
padded_text.append(padded_text_item)
return padded_text
def _pad_image_tokens_by_num_embeddings(
self,
text: List[TextInput],
) -> List[TextInput]:
"""Pads each <image> to image_pad_len times of "<image>".
Args:
text: The text to be padded.
Returns:
The padded text.
"""
padded_text: List[TextInput] = []
for i in range(len(text)):
padded_text_item = ""
remaining_text = text[i]
while True:
token_pos = remaining_text.find(self.image_token)
if token_pos == -1:
padded_text_item += remaining_text
break
padded_text_item += remaining_text[:token_pos] + (
self.image_token * self.image_pad_len
)
remaining_text = remaining_text[token_pos + len(self.image_token) :]
padded_text.append(padded_text_item)
return padded_text
def _process_images(
self,
images: ImageInput,
**kwargs: Unpack[VILAProcessorKwargs],
) -> Tuple[BatchFeature, List[int]]:
images_flatten = cast(
List[Image] | List[NDArray] | List[Tensor],
image_utils.make_flat_list_of_images(images),
)
cropped_images: List[Image] = []
num_cropped_images: List[int] = []
for image in images_flatten:
pil_image: Image = image_transforms.to_pil_image(image)
single_cropped_images = self._crop_image(pil_image)
cropped_images.extend(single_cropped_images)
num_cropped_images.append(len(single_cropped_images))
image_inputs = self.image_processor(
cropped_images,
**kwargs,
)
return image_inputs, num_cropped_images
def dynamic_preprocess(
image, min_num=1, max_num=12, image_size=384, use_thumbnail=True
):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = {
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
}
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size
)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio