| | import re |
| | import torch |
| | from transformers import ProcessorMixin, BatchFeature, CLIPImageProcessorFast |
| | from transformers.image_processing_utils import BaseImageProcessor |
| | from transformers.image_utils import ImageInput |
| | from typing import Any, Dict, List, Optional, Union |
| | from PIL import Image |
| |
|
| | from .llava_qwen import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN |
| |
|
| | |
| | def expand_to_square(image: torch.Tensor, background_color=0) -> torch.Tensor: |
| | """ |
| | Expands an image to a square by adding a background color. |
| | """ |
| | c, height, width = image.shape |
| | if width == height: |
| | return image |
| | elif width > height: |
| | result = torch.ones((c, width, width), dtype=image.dtype) * background_color |
| | result[:, (width - height) // 2 : (width - height) // 2 + height, :] = image |
| | return result |
| | else: |
| | result = torch.ones((c, height, height), dtype=image.dtype) * background_color |
| | result[:, :, (height - width) // 2 : (height - width) // 2 + width] = image |
| | return result |
| |
|
| |
|
| | class FastVLMImageProcessor(CLIPImageProcessorFast): |
| | def _preprocess(self, images, **kwargs): |
| | image_sizes = [image.shape[-2:][::-1] for image in images] |
| | images = [expand_to_square(image) for image in images] |
| | images = super()._preprocess(images, **kwargs) |
| | pixel_values = torch.stack(images.pixel_values, dim=0) |
| | return BatchFeature(data={"pixel_values": pixel_values, "image_sizes": image_sizes}) |
| |
|
| | class FastVLMProcessor(ProcessorMixin): |
| | attributes = ["tokenizer", "image_processor"] |
| | image_processor_class = "AutoImageProcessor" |
| | tokenizer_class = "AutoTokenizer" |
| |
|
| | def __init__( |
| | self, |
| | tokenizer, |
| | image_processor, |
| | chat_template=None, |
| | **kwargs |
| | ): |
| | super().__init__(tokenizer, image_processor, chat_template=chat_template, **kwargs) |
| |
|
| | def __call__( |
| | self, |
| | images: ImageInput = None, |
| | text: Optional[Union[str, List[str]]] = None, |
| | return_tensors: Optional[str] = "pt", |
| | **kwargs, |
| | ) -> BatchFeature: |
| | if isinstance(text, str): |
| | text = [text] |
| | elif not isinstance(text, list) and not isinstance(text[0], str): |
| | raise TypeError("Invalid input text. Please provide a string, or a list of strings") |
| |
|
| | image_inputs = {} |
| | if images is not None: |
| | image_inputs = self.image_processor(images=images) |
| | |
| | image_token = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=torch.int64) |
| | input_ids = torch.tensor([], dtype=torch.int64) |
| | attention_mask = torch.tensor([], dtype=torch.int64) |
| | for prompt in text: |
| | image_indexes = [m.start() for m in re.finditer(DEFAULT_IMAGE_TOKEN, prompt)] |
| | if len(image_indexes) > 1: |
| | raise ValueError( |
| | f"Expected up to 1 image tokens per prompt, got {len(image_indexes)} instead." |
| | ) |
| |
|
| | |
| | pre, _, post = prompt.partition(DEFAULT_IMAGE_TOKEN) |
| | pre_ids = self.tokenizer(pre, return_tensors="pt", add_special_tokens=False).input_ids |
| | post_ids = self.tokenizer(post, return_tensors="pt", add_special_tokens=False).input_ids |
| |
|
| | sample_ids = torch.cat([pre_ids, image_token, post_ids], dim=1).to(dtype=torch.int64) |
| | sample_mask = torch.ones_like(sample_ids) |
| |
|
| | input_ids = torch.cat([input_ids, sample_ids], dim=0) |
| | attention_mask = torch.cat([attention_mask, sample_mask], dim=0) |
| |
|
| | return BatchFeature(data={"input_ids": input_ids, "attention_mask": attention_mask, **image_inputs}, tensor_type=return_tensors) |
| |
|