"""Image processor for packed TIPSv2 vision inputs.""" import math from typing import Any, Optional import numpy as np import torch import torch.nn.functional as F from transformers import BatchFeature from transformers.image_processing_utils import BaseImageProcessor try: from PIL import Image except ImportError: # pragma: no cover - depends on optional runtime dependency. Image = None PATCH_TOKEN_ID = 0 CLS_TOKEN_ID = 1 REGISTER_TOKEN_ID = 2 def smart_resize( height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280, ) -> tuple[int, int]: """Resize while preserving aspect ratio and divisibility by ``factor``.""" if height <= 0 or width <= 0: raise ValueError(f"height and width must be positive, got {(height, width)}") if max(height, width) / min(height, width) > 200: raise ValueError( "absolute aspect ratio must be smaller than 200, got " f"{max(height, width) / min(height, width)}" ) h_bar = round(height / factor) * factor w_bar = round(width / factor) * factor if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) h_bar = max(factor, math.floor(height / beta / factor) * factor) w_bar = max(factor, math.floor(width / beta / factor) * factor) elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) h_bar = math.ceil(height * beta / factor) * factor w_bar = math.ceil(width * beta / factor) * factor return h_bar, w_bar class TIPSv2ImageProcessor(BaseImageProcessor): """Build packed patch sequences for TIPSv2 image encoder inputs.""" model_input_names = [ "pixel_values", "input_ids", "position_ids", "grid_sizes", "document_ids", ] def __init__( self, patch_size: int = 14, num_register_tokens: int = 1, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280, factor: Optional[int] = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.patch_size = patch_size self.num_register_tokens = num_register_tokens self.min_pixels = min_pixels self.max_pixels = max_pixels self.factor = factor @staticmethod def _is_batched(images: Any) -> bool: return isinstance(images, (list, tuple)) def _to_tensor(self, image: Any) -> torch.Tensor: if Image is not None and isinstance(image, Image.Image): image = image.convert("RGB") array = np.asarray(image, dtype=np.float32).copy() return torch.from_numpy(array).permute(2, 0, 1).div_(255.0) if isinstance(image, np.ndarray): tensor = torch.from_numpy(image) elif isinstance(image, torch.Tensor): tensor = image.detach().clone() else: raise TypeError( "images must contain PIL.Image.Image, numpy.ndarray, or torch.Tensor " f"items, got {type(image)!r}" ) if tensor.ndim != 3: raise ValueError(f"image tensor must be 3D, got shape {tuple(tensor.shape)}") if tensor.shape[0] in {1, 3}: tensor = tensor.float() if tensor.shape[0] == 1: tensor = tensor.expand(3, -1, -1) elif tensor.shape[-1] in {1, 3}: tensor = tensor.permute(2, 0, 1).float() if tensor.shape[0] == 1: tensor = tensor.expand(3, -1, -1) else: raise ValueError( "image tensor must be channel-first or channel-last with 1 or 3 channels, " f"got shape {tuple(tensor.shape)}" ) if tensor.max().item() > 1.0: tensor = tensor / 255.0 return tensor.clamp(0.0, 1.0) def _resize_tensor(self, image: torch.Tensor, height: int, width: int) -> torch.Tensor: if tuple(image.shape[-2:]) == (height, width): return image image = image.unsqueeze(0) image = F.interpolate( image, size=(height, width), mode="bicubic", align_corners=False, ) return image.squeeze(0).clamp(0.0, 1.0) def _preprocess_image( self, image: Any, *, min_pixels: int, max_pixels: int, factor: int, ) -> tuple[torch.Tensor, tuple[int, int]]: if Image is not None and isinstance(image, Image.Image): width, height = image.size resized_h, resized_w = smart_resize( height=height, width=width, factor=factor, min_pixels=min_pixels, max_pixels=max_pixels, ) resampling = getattr(Image, "Resampling", Image).BICUBIC image = image.convert("RGB").resize((resized_w, resized_h), resampling) tensor = self._to_tensor(image) else: tensor = self._to_tensor(image) height, width = tensor.shape[-2:] resized_h, resized_w = smart_resize( height=height, width=width, factor=factor, min_pixels=min_pixels, max_pixels=max_pixels, ) tensor = self._resize_tensor(tensor, resized_h, resized_w) if resized_h % self.patch_size != 0 or resized_w % self.patch_size != 0: raise ValueError( f"resized image {(resized_h, resized_w)} must be divisible by " f"patch_size={self.patch_size}; use a factor divisible by patch_size" ) return tensor, (resized_h // self.patch_size, resized_w // self.patch_size) def _patchify(self, image: torch.Tensor) -> torch.Tensor: patch_size = self.patch_size patches = image.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size) patches = patches.permute(1, 2, 0, 3, 4).reshape(-1, image.shape[0], patch_size, patch_size) return patches.contiguous() def __call__( self, images: Any, *, min_pixels: Optional[int] = None, mix_pixels: Optional[int] = None, max_pixels: Optional[int] = None, max_length: Optional[int] = None, padding: bool = True, factor: Optional[int] = None, return_tensors: str = "pt", **kwargs: Any, ) -> BatchFeature: if kwargs: unknown = ", ".join(sorted(kwargs)) raise TypeError(f"Unexpected keyword argument(s): {unknown}") if return_tensors != "pt": raise ValueError("TIPSv2ImageProcessor currently supports return_tensors='pt' only.") if min_pixels is not None and mix_pixels is not None: raise ValueError("Specify only one of min_pixels or mix_pixels.") if mix_pixels is not None: min_pixels = mix_pixels min_pixels = self.min_pixels if min_pixels is None else min_pixels max_pixels = self.max_pixels if max_pixels is None else max_pixels factor = self.factor if factor is None else factor factor = 2 * self.patch_size if factor is None else factor if factor % self.patch_size != 0: raise ValueError( f"factor={factor} must be divisible by patch_size={self.patch_size}" ) image_list = list(images) if self._is_batched(images) else [images] pixel_chunks: list[torch.Tensor] = [] input_id_chunks: list[torch.Tensor] = [] position_id_chunks: list[torch.Tensor] = [] grid_size_chunks: list[torch.Tensor] = [] document_id_chunks: list[torch.Tensor] = [] image_token_spans: list[tuple[int, int]] = [] image_grid_sizes: list[tuple[int, int]] = [] truncated_images: list[int] = [] total_length = 0 processed_docs = 0 special_tokens = 1 + self.num_register_tokens for image_idx, image in enumerate(image_list): image_tensor, (grid_h, grid_w) = self._preprocess_image( image, min_pixels=min_pixels, max_pixels=max_pixels, factor=factor, ) patches = self._patchify(image_tensor) num_patches = patches.shape[0] image_length = special_tokens + num_patches if max_length is not None and image_length > max_length: raise ValueError( f"image at index {image_idx} needs {image_length} tokens, " f"which exceeds max_length={max_length}" ) if max_length is not None and total_length + image_length > max_length: truncated_images.extend(range(image_idx, len(image_list))) break zero_special = patches.new_zeros( (special_tokens, image_tensor.shape[0], self.patch_size, self.patch_size) ) pixel_chunks.append(torch.cat([zero_special, patches], dim=0)) input_ids = torch.empty(image_length, dtype=torch.int32) input_ids[0] = CLS_TOKEN_ID if self.num_register_tokens: input_ids[1:special_tokens] = REGISTER_TOKEN_ID input_ids[special_tokens:] = PATCH_TOKEN_ID input_id_chunks.append(input_ids) position_ids = torch.zeros((image_length, 2), dtype=torch.int32) rows = torch.arange(grid_h, dtype=torch.int32).repeat_interleave(grid_w) cols = torch.arange(grid_w, dtype=torch.int32).repeat(grid_h) position_ids[special_tokens:, 0] = rows position_ids[special_tokens:, 1] = cols position_id_chunks.append(position_ids) grid_sizes = torch.empty((image_length, 2), dtype=torch.int32) grid_sizes[:, 0] = grid_h grid_sizes[:, 1] = grid_w grid_size_chunks.append(grid_sizes) document_id_chunks.append( torch.full((image_length,), processed_docs, dtype=torch.int32) ) image_token_spans.append((total_length, total_length + image_length)) image_grid_sizes.append((grid_h, grid_w)) total_length += image_length processed_docs += 1 if pixel_chunks: pixel_values = torch.cat(pixel_chunks, dim=0) input_ids = torch.cat(input_id_chunks, dim=0) position_ids = torch.cat(position_id_chunks, dim=0) grid_sizes = torch.cat(grid_size_chunks, dim=0) document_ids = torch.cat(document_id_chunks, dim=0) else: pixel_values = torch.empty((0, 3, self.patch_size, self.patch_size), dtype=torch.float32) input_ids = torch.empty((0,), dtype=torch.int32) position_ids = torch.empty((0, 2), dtype=torch.int32) grid_sizes = torch.empty((0, 2), dtype=torch.int32) document_ids = torch.empty((0,), dtype=torch.int32) if padding and max_length is not None and pixel_values.shape[0] < max_length: pad_len = max_length - pixel_values.shape[0] pad_pixels = pixel_values.new_zeros( (pad_len, pixel_values.shape[1], self.patch_size, self.patch_size) ) pixel_values = torch.cat([pixel_values, pad_pixels], dim=0) input_ids = torch.cat( [input_ids, torch.full((pad_len,), PATCH_TOKEN_ID, dtype=torch.int32)], dim=0, ) position_ids = torch.cat( [position_ids, torch.zeros((pad_len, 2), dtype=torch.int32)], dim=0, ) grid_sizes = torch.cat( [grid_sizes, torch.zeros((pad_len, 2), dtype=torch.int32)], dim=0, ) document_ids = torch.cat( [document_ids, torch.full((pad_len,), -1, dtype=torch.int32)], dim=0, ) spans = torch.tensor(image_token_spans, dtype=torch.int32) grids = torch.tensor(image_grid_sizes, dtype=torch.int32) if spans.numel() == 0: spans = spans.reshape(0, 2) if grids.numel() == 0: grids = grids.reshape(0, 2) return BatchFeature( data={ "pixel_values": pixel_values, "input_ids": input_ids, "position_ids": position_ids, "grid_sizes": grid_sizes, "document_ids": document_ids, "image_token_spans": spans, "image_grid_sizes": grids, "truncated_images": truncated_images, } )