from typing import List, Optional, Tuple, Union import os import torch import math from torchvision.transforms import functional as F from transformers.image_processing_utils import BatchFeature from transformers.image_processing_utils_fast import ( BaseImageProcessorFast, DefaultFastImageProcessorKwargs, SizeDict, ) from transformers.image_utils import ( ImageInput, PILImageResampling, ) from transformers.processing_utils import Unpack from transformers.utils import ( TensorType, add_start_docstrings, is_torch_available, is_torchvision_available, is_torchvision_v2_available, logging, ) BASE_IMAGE_PROCESSOR_FAST_DOCSTRING = r""" Args: do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the `do_resize` parameter in the `preprocess` method. size (`dict`, *optional*, defaults to `self.size`): Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` method. default_to_square (`bool`, *optional*, defaults to `self.default_to_square`): Whether to default to a square image when resizing, if size is an int. resample (`PILImageResampling`, *optional*, defaults to `self.resample`): Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be overridden by the `resample` parameter in the `preprocess` method. do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the `preprocess` method. crop_size (`Dict[str, int]` *optional*, defaults to `self.crop_size`): Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` method. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` parameter in the `preprocess` method. rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`): Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be overridden by the `rescale_factor` parameter in the `preprocess` method. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` method. Can be overridden by the `do_normalize` parameter in the `preprocess` method. image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): Mean to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be overridden by the `image_mean` parameter in the `preprocess` method. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): Standard deviation to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. Can be overridden by the `image_std` parameter in the `preprocess` method. do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): Whether to convert the image to RGB. return_tensors (`str` or `TensorType`, *optional*, defaults to `self.return_tensors`): Returns stacked tensors if set to `pt, otherwise returns a list of tensors. data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.data_format`): Only `ChannelDimension.FIRST` is supported. Added for compatibility with slow processors. input_data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.input_data_format`): The channel dimension format for the input image. If unset, the channel dimension format is inferred from the input image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. device (`torch.device`, *optional*, defaults to `self.device`): The device to process the images on. If unset, the device is inferred from the input images.""" BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS = r""" Preprocess an image or batch of images. Args: images (`ImageInput`): Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image. size (`Dict[str, int]`, *optional*, defaults to `self.size`): Describes the maximum input dimensions to the model. resample (`PILImageResampling` or `InterpolationMode`, *optional*, defaults to `self.resample`): Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only has an effect if `do_resize` is set to `True`. do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): Whether to center crop the image. crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): Size of the output image after applying `center_crop`. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): Whether to rescale the image. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): Rescale factor to rescale the image by if `do_rescale` is set to `True`. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): Whether to normalize the image. image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`. do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): Whether to convert the image to RGB. return_tensors (`str` or `TensorType`, *optional*, defaults to `self.return_tensors`): Returns stacked tensors if set to `pt, otherwise returns a list of tensors. data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.data_format`): Only `ChannelDimension.FIRST` is supported. Added for compatibility with slow processors. input_data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.input_data_format`): The channel dimension format for the input image. If unset, the channel dimension format is inferred from the input image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. device (`torch.device`, *optional*, defaults to `self.device`): The device to process the images on. If unset, the device is inferred from the input images.""" if is_torch_available(): import torch if is_torchvision_available(): if is_torchvision_v2_available(): from torchvision.transforms.v2 import functional as F else: from torchvision.transforms import functional as F logger = logging.get_logger(__name__) def get_image_size_for_patches( image_height: int, image_width: int, patch_size: int, max_num_patches: int ) -> Tuple[int, int]: """ Args: image_height (`int`): Original image height. image_width (`int`): Original image width. patch_size (`int`): Patch size for processing. Returns: Tuple: (target_height, target_width) """ def get_scaled_image_size(scale: float, size: int, patch_size: int) -> int: patch_size = patch_size * 2 scaled_size = size * scale scaled_size = math.ceil(scaled_size / patch_size) * patch_size scaled_size = max(patch_size, scaled_size) return int(scaled_size) scale = 1.0 while True: target_height = get_scaled_image_size(scale, image_height, patch_size) target_width = get_scaled_image_size(scale, image_width, patch_size) num_patches = (target_height / patch_size) * (target_width / patch_size) if num_patches > max_num_patches: scale -= 0.02 else: break return target_height, target_width def convert_image_to_patches(image: "torch.Tensor", patch_size: int, merge_size: int) -> "torch.Tensor": """ Converts an input image into flattened patches. Args: image: Input image tensor of shape (channels, height, width) patch_size: Size of each square patch (in pixels) merge_size: Number of adjacent patches to merge """ num_channels, image_height, image_width = image.shape num_patches_height = image_height // patch_size num_patches_width = image_width // patch_size patched_image = image.reshape(num_channels, num_patches_height//merge_size, merge_size, patch_size, num_patches_width//merge_size, merge_size, patch_size) patched_image = patched_image.permute(1, 4, 2, 5, 3, 6, 0) patched_image = patched_image.reshape(num_patches_height * num_patches_width, -1) return patched_image def pad_along_first_dim( tensor: "torch.Tensor", target_length: int, pad_value: int = 0 ) -> Tuple["torch.Tensor", "torch.Tensor"]: """ Pad the input tensor along its first dimension to a target length. Args: tensor (torch.Tensor): The input tensor to be padded. target_length (int): The desired length of the first dimension after padding. pad_value (int, optional): The value to use for padding. Defaults to 0. """ current_length = tensor.shape[0] padding_length = target_length - current_length mask = torch.ones((target_length,), dtype=torch.int32) if padding_length > 0: padding = [0, 0] * (tensor.ndim - 1) + [0, padding_length] tensor = torch.nn.functional.pad(tensor, padding, mode="constant", value=pad_value) mask[-padding_length:] = 0 return tensor, mask class Siglip2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs): patch_size: Optional[int] max_num_patches: Optional[int] @add_start_docstrings( r"Constructs a fast Siglip2 image processor.", BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, """ patch_size (`int`, *optional*, defaults to 16): The size (resolution) of each patch the image will be split to. max_num_patches (`int`, *optional*, defaults to 256): The image will be resized to have at most this number of patches, and then padded in "patch" dimension to match this number exactly. """, ) class Siglip2ImageProcessorFast(BaseImageProcessorFast): resample = PILImageResampling.BILINEAR image_mean = [0.5, 0.5, 0.5] image_std = [0.5, 0.5, 0.5] do_resize = True do_rescale = True do_normalize = True patch_size = 16 max_num_patches = 256 valid_kwargs = Siglip2FastImageProcessorKwargs unused_kwargs = ["size", "do_center_crop", "crop_size"] print_max_patched = True def __init__(self, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]): super().__init__(**kwargs) def _validate_preprocess_kwargs(self, **kwargs) -> tuple: kwargs.pop("do_resize", None) return super()._validate_preprocess_kwargs(**kwargs) @add_start_docstrings( BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, """ patch_size (`int`, *optional*, defaults to `self.patch_size`): The size (resolution) of each patch the image will be split to. max_num_patches (`int`, *optional*, defaults to `self.max_num_patches`): The image will be resized to have at most this number of patches, and then padded in "patch" dimension to match this number exactly. """, ) def preprocess(self, images: ImageInput, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]) -> BatchFeature: return super().preprocess(images, **kwargs) def get_max_image_patches(self, images): return 4096 * 6 * 6 def _preprocess( self, images: List["torch.Tensor"], do_resize: bool, patch_size: int, max_num_patches: int, interpolation: Optional["F.InterpolationMode"], do_rescale: bool, rescale_factor: float, do_normalize: bool, image_mean: Optional[Union[float, List[float]]], image_std: Optional[Union[float, List[float]]], return_tensors: Optional[Union[str, TensorType]], **kwargs, ) -> BatchFeature: pixel_masks = [] pixel_values = [] spatial_shapes = [] if Siglip2ImageProcessorFast.print_max_patched: Siglip2ImageProcessorFast.print_max_patched = False for i, image in enumerate(images): height, width, = get_image_size_for_patches( image_height=image.shape[1], image_width=image.shape[2], patch_size=patch_size, max_num_patches=max_num_patches, ) side_dict = SizeDict(height=height, width=width) image = self.resize(image=image, size=side_dict, interpolation=interpolation) image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std) patches = convert_image_to_patches(image, patch_size, 2) patches, mask = pad_along_first_dim(patches, len(patches)) num_patches_height = image.shape[1] // patch_size num_patches_width = image.shape[2] // patch_size spatial_shapes.append((num_patches_height, num_patches_width)) pixel_values.append(patches) pixel_masks.append(mask) pixel_values = torch.stack(pixel_values, dim=0) pixel_masks = torch.stack(pixel_masks, dim=0) spatial_shapes = torch.tensor(spatial_shapes) batch_feature = BatchFeature( data={ "pixel_values": pixel_values, "pixel_attention_mask": pixel_masks, "spatial_shapes": spatial_shapes, }, tensor_type=return_tensors, ) return batch_feature __all__ = ["Siglip2ImageProcessorFast"]