| """Image processor for packed TIPSv2 vision inputs.""" |
|
|
| import math |
| from typing import Any, Optional |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from transformers import BatchFeature |
| from transformers.image_processing_utils import BaseImageProcessor |
|
|
| try: |
| from PIL import Image |
| except ImportError: |
| Image = None |
|
|
|
|
| PATCH_TOKEN_ID = 0 |
| CLS_TOKEN_ID = 1 |
| REGISTER_TOKEN_ID = 2 |
|
|
|
|
| def smart_resize( |
| height: int, |
| width: int, |
| factor: int = 28, |
| min_pixels: int = 56 * 56, |
| max_pixels: int = 14 * 14 * 4 * 1280, |
| ) -> tuple[int, int]: |
| """Resize while preserving aspect ratio and divisibility by ``factor``.""" |
| if height <= 0 or width <= 0: |
| raise ValueError(f"height and width must be positive, got {(height, width)}") |
| if max(height, width) / min(height, width) > 200: |
| raise ValueError( |
| "absolute aspect ratio must be smaller than 200, got " |
| f"{max(height, width) / min(height, width)}" |
| ) |
|
|
| h_bar = round(height / factor) * factor |
| w_bar = round(width / factor) * factor |
| if h_bar * w_bar > max_pixels: |
| beta = math.sqrt((height * width) / max_pixels) |
| h_bar = max(factor, math.floor(height / beta / factor) * factor) |
| w_bar = max(factor, math.floor(width / beta / factor) * factor) |
| elif h_bar * w_bar < min_pixels: |
| beta = math.sqrt(min_pixels / (height * width)) |
| h_bar = math.ceil(height * beta / factor) * factor |
| w_bar = math.ceil(width * beta / factor) * factor |
| return h_bar, w_bar |
|
|
|
|
| class TIPSv2ImageProcessor(BaseImageProcessor): |
| """Build packed patch sequences for TIPSv2 image encoder inputs.""" |
|
|
| model_input_names = [ |
| "pixel_values", |
| "input_ids", |
| "position_ids", |
| "grid_sizes", |
| "document_ids", |
| ] |
|
|
| def __init__( |
| self, |
| patch_size: int = 14, |
| num_register_tokens: int = 1, |
| min_pixels: int = 56 * 56, |
| max_pixels: int = 14 * 14 * 4 * 1280, |
| factor: Optional[int] = None, |
| **kwargs: Any, |
| ) -> None: |
| super().__init__(**kwargs) |
| self.patch_size = patch_size |
| self.num_register_tokens = num_register_tokens |
| self.min_pixels = min_pixels |
| self.max_pixels = max_pixels |
| self.factor = factor |
|
|
| @staticmethod |
| def _is_batched(images: Any) -> bool: |
| return isinstance(images, (list, tuple)) |
|
|
| def _to_tensor(self, image: Any) -> torch.Tensor: |
| if Image is not None and isinstance(image, Image.Image): |
| image = image.convert("RGB") |
| array = np.asarray(image, dtype=np.float32).copy() |
| return torch.from_numpy(array).permute(2, 0, 1).div_(255.0) |
|
|
| if isinstance(image, np.ndarray): |
| tensor = torch.from_numpy(image) |
| elif isinstance(image, torch.Tensor): |
| tensor = image.detach().clone() |
| else: |
| raise TypeError( |
| "images must contain PIL.Image.Image, numpy.ndarray, or torch.Tensor " |
| f"items, got {type(image)!r}" |
| ) |
|
|
| if tensor.ndim != 3: |
| raise ValueError(f"image tensor must be 3D, got shape {tuple(tensor.shape)}") |
| if tensor.shape[0] in {1, 3}: |
| tensor = tensor.float() |
| if tensor.shape[0] == 1: |
| tensor = tensor.expand(3, -1, -1) |
| elif tensor.shape[-1] in {1, 3}: |
| tensor = tensor.permute(2, 0, 1).float() |
| if tensor.shape[0] == 1: |
| tensor = tensor.expand(3, -1, -1) |
| else: |
| raise ValueError( |
| "image tensor must be channel-first or channel-last with 1 or 3 channels, " |
| f"got shape {tuple(tensor.shape)}" |
| ) |
|
|
| if tensor.max().item() > 1.0: |
| tensor = tensor / 255.0 |
| return tensor.clamp(0.0, 1.0) |
|
|
| def _resize_tensor(self, image: torch.Tensor, height: int, width: int) -> torch.Tensor: |
| if tuple(image.shape[-2:]) == (height, width): |
| return image |
| image = image.unsqueeze(0) |
| image = F.interpolate( |
| image, |
| size=(height, width), |
| mode="bicubic", |
| align_corners=False, |
| ) |
| return image.squeeze(0).clamp(0.0, 1.0) |
|
|
| def _preprocess_image( |
| self, |
| image: Any, |
| *, |
| min_pixels: int, |
| max_pixels: int, |
| factor: int, |
| ) -> tuple[torch.Tensor, tuple[int, int]]: |
| if Image is not None and isinstance(image, Image.Image): |
| width, height = image.size |
| resized_h, resized_w = smart_resize( |
| height=height, |
| width=width, |
| factor=factor, |
| min_pixels=min_pixels, |
| max_pixels=max_pixels, |
| ) |
| resampling = getattr(Image, "Resampling", Image).BICUBIC |
| image = image.convert("RGB").resize((resized_w, resized_h), resampling) |
| tensor = self._to_tensor(image) |
| else: |
| tensor = self._to_tensor(image) |
| height, width = tensor.shape[-2:] |
| resized_h, resized_w = smart_resize( |
| height=height, |
| width=width, |
| factor=factor, |
| min_pixels=min_pixels, |
| max_pixels=max_pixels, |
| ) |
| tensor = self._resize_tensor(tensor, resized_h, resized_w) |
|
|
| if resized_h % self.patch_size != 0 or resized_w % self.patch_size != 0: |
| raise ValueError( |
| f"resized image {(resized_h, resized_w)} must be divisible by " |
| f"patch_size={self.patch_size}; use a factor divisible by patch_size" |
| ) |
|
|
| return tensor, (resized_h // self.patch_size, resized_w // self.patch_size) |
|
|
| def _patchify(self, image: torch.Tensor) -> torch.Tensor: |
| patch_size = self.patch_size |
| patches = image.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size) |
| patches = patches.permute(1, 2, 0, 3, 4).reshape(-1, image.shape[0], patch_size, patch_size) |
| return patches.contiguous() |
|
|
| def __call__( |
| self, |
| images: Any, |
| *, |
| min_pixels: Optional[int] = None, |
| mix_pixels: Optional[int] = None, |
| max_pixels: Optional[int] = None, |
| max_length: Optional[int] = None, |
| padding: bool = True, |
| factor: Optional[int] = None, |
| return_tensors: str = "pt", |
| **kwargs: Any, |
| ) -> BatchFeature: |
| if kwargs: |
| unknown = ", ".join(sorted(kwargs)) |
| raise TypeError(f"Unexpected keyword argument(s): {unknown}") |
| if return_tensors != "pt": |
| raise ValueError("TIPSv2ImageProcessor currently supports return_tensors='pt' only.") |
|
|
| if min_pixels is not None and mix_pixels is not None: |
| raise ValueError("Specify only one of min_pixels or mix_pixels.") |
| if mix_pixels is not None: |
| min_pixels = mix_pixels |
| min_pixels = self.min_pixels if min_pixels is None else min_pixels |
| max_pixels = self.max_pixels if max_pixels is None else max_pixels |
| factor = self.factor if factor is None else factor |
| factor = 2 * self.patch_size if factor is None else factor |
|
|
| if factor % self.patch_size != 0: |
| raise ValueError( |
| f"factor={factor} must be divisible by patch_size={self.patch_size}" |
| ) |
|
|
| image_list = list(images) if self._is_batched(images) else [images] |
|
|
| pixel_chunks: list[torch.Tensor] = [] |
| input_id_chunks: list[torch.Tensor] = [] |
| position_id_chunks: list[torch.Tensor] = [] |
| grid_size_chunks: list[torch.Tensor] = [] |
| document_id_chunks: list[torch.Tensor] = [] |
| image_token_spans: list[tuple[int, int]] = [] |
| image_grid_sizes: list[tuple[int, int]] = [] |
| truncated_images: list[int] = [] |
|
|
| total_length = 0 |
| processed_docs = 0 |
| special_tokens = 1 + self.num_register_tokens |
|
|
| for image_idx, image in enumerate(image_list): |
| image_tensor, (grid_h, grid_w) = self._preprocess_image( |
| image, |
| min_pixels=min_pixels, |
| max_pixels=max_pixels, |
| factor=factor, |
| ) |
| patches = self._patchify(image_tensor) |
| num_patches = patches.shape[0] |
| image_length = special_tokens + num_patches |
|
|
| if max_length is not None and image_length > max_length: |
| raise ValueError( |
| f"image at index {image_idx} needs {image_length} tokens, " |
| f"which exceeds max_length={max_length}" |
| ) |
| if max_length is not None and total_length + image_length > max_length: |
| truncated_images.extend(range(image_idx, len(image_list))) |
| break |
|
|
| zero_special = patches.new_zeros( |
| (special_tokens, image_tensor.shape[0], self.patch_size, self.patch_size) |
| ) |
| pixel_chunks.append(torch.cat([zero_special, patches], dim=0)) |
|
|
| input_ids = torch.empty(image_length, dtype=torch.int32) |
| input_ids[0] = CLS_TOKEN_ID |
| if self.num_register_tokens: |
| input_ids[1:special_tokens] = REGISTER_TOKEN_ID |
| input_ids[special_tokens:] = PATCH_TOKEN_ID |
| input_id_chunks.append(input_ids) |
|
|
| position_ids = torch.zeros((image_length, 2), dtype=torch.int32) |
| rows = torch.arange(grid_h, dtype=torch.int32).repeat_interleave(grid_w) |
| cols = torch.arange(grid_w, dtype=torch.int32).repeat(grid_h) |
| position_ids[special_tokens:, 0] = rows |
| position_ids[special_tokens:, 1] = cols |
| position_id_chunks.append(position_ids) |
|
|
| grid_sizes = torch.empty((image_length, 2), dtype=torch.int32) |
| grid_sizes[:, 0] = grid_h |
| grid_sizes[:, 1] = grid_w |
| grid_size_chunks.append(grid_sizes) |
|
|
| document_id_chunks.append( |
| torch.full((image_length,), processed_docs, dtype=torch.int32) |
| ) |
| image_token_spans.append((total_length, total_length + image_length)) |
| image_grid_sizes.append((grid_h, grid_w)) |
| total_length += image_length |
| processed_docs += 1 |
|
|
| if pixel_chunks: |
| pixel_values = torch.cat(pixel_chunks, dim=0) |
| input_ids = torch.cat(input_id_chunks, dim=0) |
| position_ids = torch.cat(position_id_chunks, dim=0) |
| grid_sizes = torch.cat(grid_size_chunks, dim=0) |
| document_ids = torch.cat(document_id_chunks, dim=0) |
| else: |
| pixel_values = torch.empty((0, 3, self.patch_size, self.patch_size), dtype=torch.float32) |
| input_ids = torch.empty((0,), dtype=torch.int32) |
| position_ids = torch.empty((0, 2), dtype=torch.int32) |
| grid_sizes = torch.empty((0, 2), dtype=torch.int32) |
| document_ids = torch.empty((0,), dtype=torch.int32) |
|
|
| if padding and max_length is not None and pixel_values.shape[0] < max_length: |
| pad_len = max_length - pixel_values.shape[0] |
| pad_pixels = pixel_values.new_zeros( |
| (pad_len, pixel_values.shape[1], self.patch_size, self.patch_size) |
| ) |
| pixel_values = torch.cat([pixel_values, pad_pixels], dim=0) |
| input_ids = torch.cat( |
| [input_ids, torch.full((pad_len,), PATCH_TOKEN_ID, dtype=torch.int32)], |
| dim=0, |
| ) |
| position_ids = torch.cat( |
| [position_ids, torch.zeros((pad_len, 2), dtype=torch.int32)], |
| dim=0, |
| ) |
| grid_sizes = torch.cat( |
| [grid_sizes, torch.zeros((pad_len, 2), dtype=torch.int32)], |
| dim=0, |
| ) |
| document_ids = torch.cat( |
| [document_ids, torch.full((pad_len,), -1, dtype=torch.int32)], |
| dim=0, |
| ) |
|
|
| spans = torch.tensor(image_token_spans, dtype=torch.int32) |
| grids = torch.tensor(image_grid_sizes, dtype=torch.int32) |
| if spans.numel() == 0: |
| spans = spans.reshape(0, 2) |
| if grids.numel() == 0: |
| grids = grids.reshape(0, 2) |
|
|
| return BatchFeature( |
| data={ |
| "pixel_values": pixel_values, |
| "input_ids": input_ids, |
| "position_ids": position_ids, |
| "grid_sizes": grid_sizes, |
| "document_ids": document_ids, |
| "image_token_spans": spans, |
| "image_grid_sizes": grids, |
| "truncated_images": truncated_images, |
| } |
| ) |
|
|