| from typing import List, Optional, Tuple, Union |
| import os |
| import torch |
| import math |
| from torchvision.transforms import functional as F |
| from transformers.image_processing_utils import BatchFeature |
| from transformers.image_processing_utils_fast import ( |
| BaseImageProcessorFast, |
| DefaultFastImageProcessorKwargs, |
| SizeDict, |
| ) |
| from transformers.image_utils import ( |
| ImageInput, |
| PILImageResampling, |
| ) |
| from transformers.processing_utils import Unpack |
| from transformers.utils import ( |
| TensorType, |
| add_start_docstrings, |
| is_torch_available, |
| is_torchvision_available, |
| is_torchvision_v2_available, |
| logging, |
| ) |
|
|
| BASE_IMAGE_PROCESSOR_FAST_DOCSTRING = r""" |
| |
| Args: |
| do_resize (`bool`, *optional*, defaults to `self.do_resize`): |
| Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the |
| `do_resize` parameter in the `preprocess` method. |
| size (`dict`, *optional*, defaults to `self.size`): |
| Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` |
| method. |
| default_to_square (`bool`, *optional*, defaults to `self.default_to_square`): |
| Whether to default to a square image when resizing, if size is an int. |
| resample (`PILImageResampling`, *optional*, defaults to `self.resample`): |
| Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be |
| overridden by the `resample` parameter in the `preprocess` method. |
| do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): |
| Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the |
| `preprocess` method. |
| crop_size (`Dict[str, int]` *optional*, defaults to `self.crop_size`): |
| Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` |
| method. |
| do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): |
| Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the |
| `do_rescale` parameter in the `preprocess` method. |
| rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`): |
| Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be |
| overridden by the `rescale_factor` parameter in the `preprocess` method. |
| do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): |
| Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` |
| method. Can be overridden by the `do_normalize` parameter in the `preprocess` method. |
| image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): |
| Mean to use if normalizing the image. This is a float or list of floats the length of the number of |
| channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be |
| overridden by the `image_mean` parameter in the `preprocess` method. |
| image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): |
| Standard deviation to use if normalizing the image. This is a float or list of floats the length of the |
| number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. |
| Can be overridden by the `image_std` parameter in the `preprocess` method. |
| do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): |
| Whether to convert the image to RGB. |
| return_tensors (`str` or `TensorType`, *optional*, defaults to `self.return_tensors`): |
| Returns stacked tensors if set to `pt, otherwise returns a list of tensors. |
| data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.data_format`): |
| Only `ChannelDimension.FIRST` is supported. Added for compatibility with slow processors. |
| input_data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.input_data_format`): |
| The channel dimension format for the input image. If unset, the channel dimension format is inferred |
| from the input image. Can be one of: |
| - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. |
| - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. |
| - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. |
| device (`torch.device`, *optional*, defaults to `self.device`): |
| The device to process the images on. If unset, the device is inferred from the input images.""" |
|
|
| BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS = r""" |
| Preprocess an image or batch of images. |
| |
| Args: |
| images (`ImageInput`): |
| Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If |
| passing in images with pixel values between 0 and 1, set `do_rescale=False`. |
| do_resize (`bool`, *optional*, defaults to `self.do_resize`): |
| Whether to resize the image. |
| size (`Dict[str, int]`, *optional*, defaults to `self.size`): |
| Describes the maximum input dimensions to the model. |
| resample (`PILImageResampling` or `InterpolationMode`, *optional*, defaults to `self.resample`): |
| Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only |
| has an effect if `do_resize` is set to `True`. |
| do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): |
| Whether to center crop the image. |
| crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): |
| Size of the output image after applying `center_crop`. |
| do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): |
| Whether to rescale the image. |
| rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): |
| Rescale factor to rescale the image by if `do_rescale` is set to `True`. |
| do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): |
| Whether to normalize the image. |
| image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): |
| Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. |
| image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): |
| Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to |
| `True`. |
| do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): |
| Whether to convert the image to RGB. |
| return_tensors (`str` or `TensorType`, *optional*, defaults to `self.return_tensors`): |
| Returns stacked tensors if set to `pt, otherwise returns a list of tensors. |
| data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.data_format`): |
| Only `ChannelDimension.FIRST` is supported. Added for compatibility with slow processors. |
| input_data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.input_data_format`): |
| The channel dimension format for the input image. If unset, the channel dimension format is inferred |
| from the input image. Can be one of: |
| - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. |
| - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. |
| - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. |
| device (`torch.device`, *optional*, defaults to `self.device`): |
| The device to process the images on. If unset, the device is inferred from the input images.""" |
|
|
|
|
| if is_torch_available(): |
| import torch |
|
|
| if is_torchvision_available(): |
| if is_torchvision_v2_available(): |
| from torchvision.transforms.v2 import functional as F |
| else: |
| from torchvision.transforms import functional as F |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| def get_image_size_for_patches( |
| image_height: int, image_width: int, patch_size: int, max_num_patches: int |
| ) -> Tuple[int, int]: |
| """ |
| Args: |
| image_height (`int`): |
| Original image height. |
| image_width (`int`): |
| Original image width. |
| patch_size (`int`): |
| Patch size for processing. |
| |
| Returns: |
| Tuple: (target_height, target_width) |
| """ |
|
|
| def get_scaled_image_size(scale: float, size: int, patch_size: int) -> int: |
| patch_size = patch_size * 2 |
| scaled_size = size * scale |
| scaled_size = math.ceil(scaled_size / patch_size) * patch_size |
| scaled_size = max(patch_size, scaled_size) |
| return int(scaled_size) |
| |
| scale = 1.0 |
| while True: |
| target_height = get_scaled_image_size(scale, image_height, patch_size) |
| target_width = get_scaled_image_size(scale, image_width, patch_size) |
| num_patches = (target_height / patch_size) * (target_width / patch_size) |
|
|
| if num_patches > max_num_patches: |
| scale -= 0.02 |
| else: |
| break |
|
|
| return target_height, target_width |
|
|
|
|
| def convert_image_to_patches(image: "torch.Tensor", patch_size: int, merge_size: int) -> "torch.Tensor": |
| """ |
| Converts an input image into flattened patches. |
| |
| Args: |
| image: Input image tensor of shape (channels, height, width) |
| patch_size: Size of each square patch (in pixels) |
| merge_size: Number of adjacent patches to merge |
| |
| """ |
|
|
| num_channels, image_height, image_width = image.shape |
| num_patches_height = image_height // patch_size |
| num_patches_width = image_width // patch_size |
| patched_image = image.reshape(num_channels, |
| num_patches_height//merge_size, |
| merge_size, patch_size, |
| num_patches_width//merge_size, |
| merge_size, patch_size) |
| patched_image = patched_image.permute(1, 4, 2, 5, 3, 6, 0) |
| patched_image = patched_image.reshape(num_patches_height * num_patches_width, -1) |
| return patched_image |
|
|
| def pad_along_first_dim( |
| tensor: "torch.Tensor", target_length: int, pad_value: int = 0 |
| ) -> Tuple["torch.Tensor", "torch.Tensor"]: |
| """ |
| Pad the input tensor along its first dimension to a target length. |
| |
| Args: |
| tensor (torch.Tensor): The input tensor to be padded. |
| target_length (int): The desired length of the first dimension after padding. |
| pad_value (int, optional): The value to use for padding. Defaults to 0. |
| """ |
| current_length = tensor.shape[0] |
| padding_length = target_length - current_length |
| mask = torch.ones((target_length,), dtype=torch.int32) |
| if padding_length > 0: |
| padding = [0, 0] * (tensor.ndim - 1) + [0, padding_length] |
| tensor = torch.nn.functional.pad(tensor, padding, mode="constant", value=pad_value) |
| mask[-padding_length:] = 0 |
| return tensor, mask |
|
|
|
|
| class Siglip2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs): |
| patch_size: Optional[int] |
| max_num_patches: Optional[int] |
|
|
|
|
| @add_start_docstrings( |
| r"Constructs a fast Siglip2 image processor.", |
| BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, |
| """ |
| patch_size (`int`, *optional*, defaults to 16): |
| The size (resolution) of each patch the image will be split to. |
| max_num_patches (`int`, *optional*, defaults to 256): |
| The image will be resized to have at most this number of patches, |
| and then padded in "patch" dimension to match this number exactly. |
| """, |
| ) |
| class Siglip2ImageProcessorFast(BaseImageProcessorFast): |
| resample = PILImageResampling.BILINEAR |
| image_mean = [0.5, 0.5, 0.5] |
| image_std = [0.5, 0.5, 0.5] |
| do_resize = True |
| do_rescale = True |
| do_normalize = True |
| patch_size = 16 |
| max_num_patches = 256 |
| valid_kwargs = Siglip2FastImageProcessorKwargs |
| unused_kwargs = ["size", "do_center_crop", "crop_size"] |
| print_max_patched = True |
|
|
| def __init__(self, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]): |
| super().__init__(**kwargs) |
|
|
| def _validate_preprocess_kwargs(self, **kwargs) -> tuple: |
| kwargs.pop("do_resize", None) |
| return super()._validate_preprocess_kwargs(**kwargs) |
|
|
| @add_start_docstrings( |
| BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, |
| """ |
| patch_size (`int`, *optional*, defaults to `self.patch_size`): |
| The size (resolution) of each patch the image will be split to. |
| max_num_patches (`int`, *optional*, defaults to `self.max_num_patches`): |
| The image will be resized to have at most this number of patches, |
| and then padded in "patch" dimension to match this number exactly. |
| """, |
| ) |
| def preprocess(self, images: ImageInput, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]) -> BatchFeature: |
| return super().preprocess(images, **kwargs) |
| |
| def get_max_image_patches(self, images): |
| return 4096 * 6 * 6 |
|
|
| def _preprocess( |
| self, |
| images: List["torch.Tensor"], |
| do_resize: bool, |
| patch_size: int, |
| max_num_patches: int, |
| interpolation: Optional["F.InterpolationMode"], |
| do_rescale: bool, |
| rescale_factor: float, |
| do_normalize: bool, |
| image_mean: Optional[Union[float, List[float]]], |
| image_std: Optional[Union[float, List[float]]], |
| return_tensors: Optional[Union[str, TensorType]], |
| **kwargs, |
| ) -> BatchFeature: |
| pixel_masks = [] |
| pixel_values = [] |
| spatial_shapes = [] |
|
|
| if Siglip2ImageProcessorFast.print_max_patched: |
| Siglip2ImageProcessorFast.print_max_patched = False |
|
|
| for i, image in enumerate(images): |
| height, width, = get_image_size_for_patches( |
| image_height=image.shape[1], |
| image_width=image.shape[2], |
| patch_size=patch_size, |
| max_num_patches=max_num_patches, |
| ) |
|
|
| side_dict = SizeDict(height=height, width=width) |
| image = self.resize(image=image, size=side_dict, interpolation=interpolation) |
| image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std) |
|
|
| patches = convert_image_to_patches(image, patch_size, 2) |
| patches, mask = pad_along_first_dim(patches, len(patches)) |
|
|
| num_patches_height = image.shape[1] // patch_size |
| num_patches_width = image.shape[2] // patch_size |
|
|
| spatial_shapes.append((num_patches_height, num_patches_width)) |
| pixel_values.append(patches) |
| pixel_masks.append(mask) |
|
|
| pixel_values = torch.stack(pixel_values, dim=0) |
| pixel_masks = torch.stack(pixel_masks, dim=0) |
| spatial_shapes = torch.tensor(spatial_shapes) |
|
|
| batch_feature = BatchFeature( |
| data={ |
| "pixel_values": pixel_values, |
| "pixel_attention_mask": pixel_masks, |
| "spatial_shapes": spatial_shapes, |
| }, |
| tensor_type=return_tensors, |
| ) |
| return batch_feature |
|
|
|
|
| __all__ = ["Siglip2ImageProcessorFast"] |
|
|