tipsv2-b14-vision / image_processing_tips.py
toilaluan's picture
update
d1941eb
"""Image processor for packed TIPSv2 vision inputs."""
import math
from typing import Any, Optional
import numpy as np
import torch
import torch.nn.functional as F
from transformers import BatchFeature
from transformers.image_processing_utils import BaseImageProcessor
try:
from PIL import Image
except ImportError: # pragma: no cover - depends on optional runtime dependency.
Image = None
PATCH_TOKEN_ID = 0
CLS_TOKEN_ID = 1
REGISTER_TOKEN_ID = 2
def smart_resize(
height: int,
width: int,
factor: int = 28,
min_pixels: int = 56 * 56,
max_pixels: int = 14 * 14 * 4 * 1280,
) -> tuple[int, int]:
"""Resize while preserving aspect ratio and divisibility by ``factor``."""
if height <= 0 or width <= 0:
raise ValueError(f"height and width must be positive, got {(height, width)}")
if max(height, width) / min(height, width) > 200:
raise ValueError(
"absolute aspect ratio must be smaller than 200, got "
f"{max(height, width) / min(height, width)}"
)
h_bar = round(height / factor) * factor
w_bar = round(width / factor) * factor
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = max(factor, math.floor(height / beta / factor) * factor)
w_bar = max(factor, math.floor(width / beta / factor) * factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
return h_bar, w_bar
class TIPSv2ImageProcessor(BaseImageProcessor):
"""Build packed patch sequences for TIPSv2 image encoder inputs."""
model_input_names = [
"pixel_values",
"input_ids",
"position_ids",
"grid_sizes",
"document_ids",
]
def __init__(
self,
patch_size: int = 14,
num_register_tokens: int = 1,
min_pixels: int = 56 * 56,
max_pixels: int = 14 * 14 * 4 * 1280,
factor: Optional[int] = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.patch_size = patch_size
self.num_register_tokens = num_register_tokens
self.min_pixels = min_pixels
self.max_pixels = max_pixels
self.factor = factor
@staticmethod
def _is_batched(images: Any) -> bool:
return isinstance(images, (list, tuple))
def _to_tensor(self, image: Any) -> torch.Tensor:
if Image is not None and isinstance(image, Image.Image):
image = image.convert("RGB")
array = np.asarray(image, dtype=np.float32).copy()
return torch.from_numpy(array).permute(2, 0, 1).div_(255.0)
if isinstance(image, np.ndarray):
tensor = torch.from_numpy(image)
elif isinstance(image, torch.Tensor):
tensor = image.detach().clone()
else:
raise TypeError(
"images must contain PIL.Image.Image, numpy.ndarray, or torch.Tensor "
f"items, got {type(image)!r}"
)
if tensor.ndim != 3:
raise ValueError(f"image tensor must be 3D, got shape {tuple(tensor.shape)}")
if tensor.shape[0] in {1, 3}:
tensor = tensor.float()
if tensor.shape[0] == 1:
tensor = tensor.expand(3, -1, -1)
elif tensor.shape[-1] in {1, 3}:
tensor = tensor.permute(2, 0, 1).float()
if tensor.shape[0] == 1:
tensor = tensor.expand(3, -1, -1)
else:
raise ValueError(
"image tensor must be channel-first or channel-last with 1 or 3 channels, "
f"got shape {tuple(tensor.shape)}"
)
if tensor.max().item() > 1.0:
tensor = tensor / 255.0
return tensor.clamp(0.0, 1.0)
def _resize_tensor(self, image: torch.Tensor, height: int, width: int) -> torch.Tensor:
if tuple(image.shape[-2:]) == (height, width):
return image
image = image.unsqueeze(0)
image = F.interpolate(
image,
size=(height, width),
mode="bicubic",
align_corners=False,
)
return image.squeeze(0).clamp(0.0, 1.0)
def _preprocess_image(
self,
image: Any,
*,
min_pixels: int,
max_pixels: int,
factor: int,
) -> tuple[torch.Tensor, tuple[int, int]]:
if Image is not None and isinstance(image, Image.Image):
width, height = image.size
resized_h, resized_w = smart_resize(
height=height,
width=width,
factor=factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
resampling = getattr(Image, "Resampling", Image).BICUBIC
image = image.convert("RGB").resize((resized_w, resized_h), resampling)
tensor = self._to_tensor(image)
else:
tensor = self._to_tensor(image)
height, width = tensor.shape[-2:]
resized_h, resized_w = smart_resize(
height=height,
width=width,
factor=factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
tensor = self._resize_tensor(tensor, resized_h, resized_w)
if resized_h % self.patch_size != 0 or resized_w % self.patch_size != 0:
raise ValueError(
f"resized image {(resized_h, resized_w)} must be divisible by "
f"patch_size={self.patch_size}; use a factor divisible by patch_size"
)
return tensor, (resized_h // self.patch_size, resized_w // self.patch_size)
def _patchify(self, image: torch.Tensor) -> torch.Tensor:
patch_size = self.patch_size
patches = image.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)
patches = patches.permute(1, 2, 0, 3, 4).reshape(-1, image.shape[0], patch_size, patch_size)
return patches.contiguous()
def __call__(
self,
images: Any,
*,
min_pixels: Optional[int] = None,
mix_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
max_length: Optional[int] = None,
padding: bool = True,
factor: Optional[int] = None,
return_tensors: str = "pt",
**kwargs: Any,
) -> BatchFeature:
if kwargs:
unknown = ", ".join(sorted(kwargs))
raise TypeError(f"Unexpected keyword argument(s): {unknown}")
if return_tensors != "pt":
raise ValueError("TIPSv2ImageProcessor currently supports return_tensors='pt' only.")
if min_pixels is not None and mix_pixels is not None:
raise ValueError("Specify only one of min_pixels or mix_pixels.")
if mix_pixels is not None:
min_pixels = mix_pixels
min_pixels = self.min_pixels if min_pixels is None else min_pixels
max_pixels = self.max_pixels if max_pixels is None else max_pixels
factor = self.factor if factor is None else factor
factor = 2 * self.patch_size if factor is None else factor
if factor % self.patch_size != 0:
raise ValueError(
f"factor={factor} must be divisible by patch_size={self.patch_size}"
)
image_list = list(images) if self._is_batched(images) else [images]
pixel_chunks: list[torch.Tensor] = []
input_id_chunks: list[torch.Tensor] = []
position_id_chunks: list[torch.Tensor] = []
grid_size_chunks: list[torch.Tensor] = []
document_id_chunks: list[torch.Tensor] = []
image_token_spans: list[tuple[int, int]] = []
image_grid_sizes: list[tuple[int, int]] = []
truncated_images: list[int] = []
total_length = 0
processed_docs = 0
special_tokens = 1 + self.num_register_tokens
for image_idx, image in enumerate(image_list):
image_tensor, (grid_h, grid_w) = self._preprocess_image(
image,
min_pixels=min_pixels,
max_pixels=max_pixels,
factor=factor,
)
patches = self._patchify(image_tensor)
num_patches = patches.shape[0]
image_length = special_tokens + num_patches
if max_length is not None and image_length > max_length:
raise ValueError(
f"image at index {image_idx} needs {image_length} tokens, "
f"which exceeds max_length={max_length}"
)
if max_length is not None and total_length + image_length > max_length:
truncated_images.extend(range(image_idx, len(image_list)))
break
zero_special = patches.new_zeros(
(special_tokens, image_tensor.shape[0], self.patch_size, self.patch_size)
)
pixel_chunks.append(torch.cat([zero_special, patches], dim=0))
input_ids = torch.empty(image_length, dtype=torch.int32)
input_ids[0] = CLS_TOKEN_ID
if self.num_register_tokens:
input_ids[1:special_tokens] = REGISTER_TOKEN_ID
input_ids[special_tokens:] = PATCH_TOKEN_ID
input_id_chunks.append(input_ids)
position_ids = torch.zeros((image_length, 2), dtype=torch.int32)
rows = torch.arange(grid_h, dtype=torch.int32).repeat_interleave(grid_w)
cols = torch.arange(grid_w, dtype=torch.int32).repeat(grid_h)
position_ids[special_tokens:, 0] = rows
position_ids[special_tokens:, 1] = cols
position_id_chunks.append(position_ids)
grid_sizes = torch.empty((image_length, 2), dtype=torch.int32)
grid_sizes[:, 0] = grid_h
grid_sizes[:, 1] = grid_w
grid_size_chunks.append(grid_sizes)
document_id_chunks.append(
torch.full((image_length,), processed_docs, dtype=torch.int32)
)
image_token_spans.append((total_length, total_length + image_length))
image_grid_sizes.append((grid_h, grid_w))
total_length += image_length
processed_docs += 1
if pixel_chunks:
pixel_values = torch.cat(pixel_chunks, dim=0)
input_ids = torch.cat(input_id_chunks, dim=0)
position_ids = torch.cat(position_id_chunks, dim=0)
grid_sizes = torch.cat(grid_size_chunks, dim=0)
document_ids = torch.cat(document_id_chunks, dim=0)
else:
pixel_values = torch.empty((0, 3, self.patch_size, self.patch_size), dtype=torch.float32)
input_ids = torch.empty((0,), dtype=torch.int32)
position_ids = torch.empty((0, 2), dtype=torch.int32)
grid_sizes = torch.empty((0, 2), dtype=torch.int32)
document_ids = torch.empty((0,), dtype=torch.int32)
if padding and max_length is not None and pixel_values.shape[0] < max_length:
pad_len = max_length - pixel_values.shape[0]
pad_pixels = pixel_values.new_zeros(
(pad_len, pixel_values.shape[1], self.patch_size, self.patch_size)
)
pixel_values = torch.cat([pixel_values, pad_pixels], dim=0)
input_ids = torch.cat(
[input_ids, torch.full((pad_len,), PATCH_TOKEN_ID, dtype=torch.int32)],
dim=0,
)
position_ids = torch.cat(
[position_ids, torch.zeros((pad_len, 2), dtype=torch.int32)],
dim=0,
)
grid_sizes = torch.cat(
[grid_sizes, torch.zeros((pad_len, 2), dtype=torch.int32)],
dim=0,
)
document_ids = torch.cat(
[document_ids, torch.full((pad_len,), -1, dtype=torch.int32)],
dim=0,
)
spans = torch.tensor(image_token_spans, dtype=torch.int32)
grids = torch.tensor(image_grid_sizes, dtype=torch.int32)
if spans.numel() == 0:
spans = spans.reshape(0, 2)
if grids.numel() == 0:
grids = grids.reshape(0, 2)
return BatchFeature(
data={
"pixel_values": pixel_values,
"input_ids": input_ids,
"position_ids": position_ids,
"grid_sizes": grid_sizes,
"document_ids": document_ids,
"image_token_spans": spans,
"image_grid_sizes": grids,
"truncated_images": truncated_images,
}
)