Image-Text-to-Text
Transformers
Safetensors
youtu_vl
text-generation
conversational
custom_code
Youtu-Parsing / image_processing_siglip2_fast.py
Yinsongliu's picture
Upload model with LFS assets
c13c3aa
from typing import List, Optional, Tuple, Union
import os
import torch
import math
from torchvision.transforms import functional as F
from transformers.image_processing_utils import BatchFeature
from transformers.image_processing_utils_fast import (
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
SizeDict,
)
from transformers.image_utils import (
ImageInput,
PILImageResampling,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
logging,
)
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING = r"""
Args:
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
`do_resize` parameter in the `preprocess` method.
size (`dict`, *optional*, defaults to `self.size`):
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
method.
default_to_square (`bool`, *optional*, defaults to `self.default_to_square`):
Whether to default to a square image when resizing, if size is an int.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
overridden by the `resample` parameter in the `preprocess` method.
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
`preprocess` method.
crop_size (`Dict[str, int]` *optional*, defaults to `self.crop_size`):
Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
method.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
`do_rescale` parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`):
Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
overridden by the `rescale_factor` parameter in the `preprocess` method.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
return_tensors (`str` or `TensorType`, *optional*, defaults to `self.return_tensors`):
Returns stacked tensors if set to `pt, otherwise returns a list of tensors.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.data_format`):
Only `ChannelDimension.FIRST` is supported. Added for compatibility with slow processors.
input_data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.input_data_format`):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
device (`torch.device`, *optional*, defaults to `self.device`):
The device to process the images on. If unset, the device is inferred from the input images."""
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS = r"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Describes the maximum input dimensions to the model.
resample (`PILImageResampling` or `InterpolationMode`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
Whether to center crop the image.
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
Size of the output image after applying `center_crop`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
return_tensors (`str` or `TensorType`, *optional*, defaults to `self.return_tensors`):
Returns stacked tensors if set to `pt, otherwise returns a list of tensors.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.data_format`):
Only `ChannelDimension.FIRST` is supported. Added for compatibility with slow processors.
input_data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.input_data_format`):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
device (`torch.device`, *optional*, defaults to `self.device`):
The device to process the images on. If unset, the device is inferred from the input images."""
if is_torch_available():
import torch
if is_torchvision_available():
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F
logger = logging.get_logger(__name__)
def get_image_size_for_patches(
image_height: int, image_width: int, patch_size: int, max_num_patches: int
) -> Tuple[int, int]:
"""
Args:
image_height (`int`):
Original image height.
image_width (`int`):
Original image width.
patch_size (`int`):
Patch size for processing.
Returns:
Tuple: (target_height, target_width)
"""
def get_scaled_image_size(scale: float, size: int, patch_size: int) -> int:
patch_size = patch_size * 2
scaled_size = size * scale
scaled_size = math.ceil(scaled_size / patch_size) * patch_size
scaled_size = max(patch_size, scaled_size)
return int(scaled_size)
scale = 1.0
while True:
target_height = get_scaled_image_size(scale, image_height, patch_size)
target_width = get_scaled_image_size(scale, image_width, patch_size)
num_patches = (target_height / patch_size) * (target_width / patch_size)
if num_patches > max_num_patches:
scale -= 0.02
else:
break
return target_height, target_width
def convert_image_to_patches(image: "torch.Tensor", patch_size: int, merge_size: int) -> "torch.Tensor":
"""
Converts an input image into flattened patches.
Args:
image: Input image tensor of shape (channels, height, width)
patch_size: Size of each square patch (in pixels)
merge_size: Number of adjacent patches to merge
"""
num_channels, image_height, image_width = image.shape
num_patches_height = image_height // patch_size
num_patches_width = image_width // patch_size
patched_image = image.reshape(num_channels,
num_patches_height//merge_size,
merge_size, patch_size,
num_patches_width//merge_size,
merge_size, patch_size)
patched_image = patched_image.permute(1, 4, 2, 5, 3, 6, 0)
patched_image = patched_image.reshape(num_patches_height * num_patches_width, -1)
return patched_image
def pad_along_first_dim(
tensor: "torch.Tensor", target_length: int, pad_value: int = 0
) -> Tuple["torch.Tensor", "torch.Tensor"]:
"""
Pad the input tensor along its first dimension to a target length.
Args:
tensor (torch.Tensor): The input tensor to be padded.
target_length (int): The desired length of the first dimension after padding.
pad_value (int, optional): The value to use for padding. Defaults to 0.
"""
current_length = tensor.shape[0]
padding_length = target_length - current_length
mask = torch.ones((target_length,), dtype=torch.int32)
if padding_length > 0:
padding = [0, 0] * (tensor.ndim - 1) + [0, padding_length]
tensor = torch.nn.functional.pad(tensor, padding, mode="constant", value=pad_value)
mask[-padding_length:] = 0
return tensor, mask
class Siglip2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
patch_size: Optional[int]
max_num_patches: Optional[int]
@add_start_docstrings(
r"Constructs a fast Siglip2 image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
"""
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch the image will be split to.
max_num_patches (`int`, *optional*, defaults to 256):
The image will be resized to have at most this number of patches,
and then padded in "patch" dimension to match this number exactly.
""",
)
class Siglip2ImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BILINEAR
image_mean = [0.5, 0.5, 0.5]
image_std = [0.5, 0.5, 0.5]
do_resize = True
do_rescale = True
do_normalize = True
patch_size = 16
max_num_patches = 256
valid_kwargs = Siglip2FastImageProcessorKwargs
unused_kwargs = ["size", "do_center_crop", "crop_size"]
print_max_patched = True
def __init__(self, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]):
super().__init__(**kwargs)
def _validate_preprocess_kwargs(self, **kwargs) -> tuple:
kwargs.pop("do_resize", None)
return super()._validate_preprocess_kwargs(**kwargs)
@add_start_docstrings(
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
"""
patch_size (`int`, *optional*, defaults to `self.patch_size`):
The size (resolution) of each patch the image will be split to.
max_num_patches (`int`, *optional*, defaults to `self.max_num_patches`):
The image will be resized to have at most this number of patches,
and then padded in "patch" dimension to match this number exactly.
""",
)
def preprocess(self, images: ImageInput, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]) -> BatchFeature:
return super().preprocess(images, **kwargs)
def get_max_image_patches(self, images):
return 4096 * 6 * 6
def _preprocess(
self,
images: List["torch.Tensor"],
do_resize: bool,
patch_size: int,
max_num_patches: int,
interpolation: Optional["F.InterpolationMode"],
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
pixel_masks = []
pixel_values = []
spatial_shapes = []
if Siglip2ImageProcessorFast.print_max_patched:
Siglip2ImageProcessorFast.print_max_patched = False
for i, image in enumerate(images):
height, width, = get_image_size_for_patches(
image_height=image.shape[1],
image_width=image.shape[2],
patch_size=patch_size,
max_num_patches=max_num_patches,
)
side_dict = SizeDict(height=height, width=width)
image = self.resize(image=image, size=side_dict, interpolation=interpolation)
image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
patches = convert_image_to_patches(image, patch_size, 2)
patches, mask = pad_along_first_dim(patches, len(patches))
num_patches_height = image.shape[1] // patch_size
num_patches_width = image.shape[2] // patch_size
spatial_shapes.append((num_patches_height, num_patches_width))
pixel_values.append(patches)
pixel_masks.append(mask)
pixel_values = torch.stack(pixel_values, dim=0)
pixel_masks = torch.stack(pixel_masks, dim=0)
spatial_shapes = torch.tensor(spatial_shapes)
batch_feature = BatchFeature(
data={
"pixel_values": pixel_values,
"pixel_attention_mask": pixel_masks,
"spatial_shapes": spatial_shapes,
},
tensor_type=return_tensors,
)
return batch_feature
__all__ = ["Siglip2ImageProcessorFast"]