|
|
from typing import List, Optional, Tuple, Union |
|
|
import os |
|
|
import torch |
|
|
import math |
|
|
from torchvision.transforms import functional as F |
|
|
from transformers.image_processing_utils import BatchFeature |
|
|
from transformers.image_processing_utils_fast import ( |
|
|
BaseImageProcessorFast, |
|
|
DefaultFastImageProcessorKwargs, |
|
|
SizeDict, |
|
|
) |
|
|
from transformers.image_utils import ( |
|
|
ImageInput, |
|
|
PILImageResampling, |
|
|
) |
|
|
from transformers.processing_utils import Unpack |
|
|
from transformers.utils import ( |
|
|
TensorType, |
|
|
add_start_docstrings, |
|
|
is_torch_available, |
|
|
is_torchvision_available, |
|
|
is_torchvision_v2_available, |
|
|
logging, |
|
|
) |
|
|
|
|
|
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING = r""" |
|
|
|
|
|
Args: |
|
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`): |
|
|
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the |
|
|
`do_resize` parameter in the `preprocess` method. |
|
|
size (`dict`, *optional*, defaults to `self.size`): |
|
|
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` |
|
|
method. |
|
|
default_to_square (`bool`, *optional*, defaults to `self.default_to_square`): |
|
|
Whether to default to a square image when resizing, if size is an int. |
|
|
resample (`PILImageResampling`, *optional*, defaults to `self.resample`): |
|
|
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be |
|
|
overridden by the `resample` parameter in the `preprocess` method. |
|
|
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): |
|
|
Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the |
|
|
`preprocess` method. |
|
|
crop_size (`Dict[str, int]` *optional*, defaults to `self.crop_size`): |
|
|
Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` |
|
|
method. |
|
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): |
|
|
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the |
|
|
`do_rescale` parameter in the `preprocess` method. |
|
|
rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`): |
|
|
Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be |
|
|
overridden by the `rescale_factor` parameter in the `preprocess` method. |
|
|
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): |
|
|
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` |
|
|
method. Can be overridden by the `do_normalize` parameter in the `preprocess` method. |
|
|
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): |
|
|
Mean to use if normalizing the image. This is a float or list of floats the length of the number of |
|
|
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be |
|
|
overridden by the `image_mean` parameter in the `preprocess` method. |
|
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): |
|
|
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the |
|
|
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. |
|
|
Can be overridden by the `image_std` parameter in the `preprocess` method. |
|
|
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): |
|
|
Whether to convert the image to RGB. |
|
|
return_tensors (`str` or `TensorType`, *optional*, defaults to `self.return_tensors`): |
|
|
Returns stacked tensors if set to `pt, otherwise returns a list of tensors. |
|
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.data_format`): |
|
|
Only `ChannelDimension.FIRST` is supported. Added for compatibility with slow processors. |
|
|
input_data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.input_data_format`): |
|
|
The channel dimension format for the input image. If unset, the channel dimension format is inferred |
|
|
from the input image. Can be one of: |
|
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. |
|
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. |
|
|
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. |
|
|
device (`torch.device`, *optional*, defaults to `self.device`): |
|
|
The device to process the images on. If unset, the device is inferred from the input images.""" |
|
|
|
|
|
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS = r""" |
|
|
Preprocess an image or batch of images. |
|
|
|
|
|
Args: |
|
|
images (`ImageInput`): |
|
|
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If |
|
|
passing in images with pixel values between 0 and 1, set `do_rescale=False`. |
|
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`): |
|
|
Whether to resize the image. |
|
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`): |
|
|
Describes the maximum input dimensions to the model. |
|
|
resample (`PILImageResampling` or `InterpolationMode`, *optional*, defaults to `self.resample`): |
|
|
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only |
|
|
has an effect if `do_resize` is set to `True`. |
|
|
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): |
|
|
Whether to center crop the image. |
|
|
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): |
|
|
Size of the output image after applying `center_crop`. |
|
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): |
|
|
Whether to rescale the image. |
|
|
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): |
|
|
Rescale factor to rescale the image by if `do_rescale` is set to `True`. |
|
|
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): |
|
|
Whether to normalize the image. |
|
|
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): |
|
|
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. |
|
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): |
|
|
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to |
|
|
`True`. |
|
|
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): |
|
|
Whether to convert the image to RGB. |
|
|
return_tensors (`str` or `TensorType`, *optional*, defaults to `self.return_tensors`): |
|
|
Returns stacked tensors if set to `pt, otherwise returns a list of tensors. |
|
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.data_format`): |
|
|
Only `ChannelDimension.FIRST` is supported. Added for compatibility with slow processors. |
|
|
input_data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.input_data_format`): |
|
|
The channel dimension format for the input image. If unset, the channel dimension format is inferred |
|
|
from the input image. Can be one of: |
|
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. |
|
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. |
|
|
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. |
|
|
device (`torch.device`, *optional*, defaults to `self.device`): |
|
|
The device to process the images on. If unset, the device is inferred from the input images.""" |
|
|
|
|
|
|
|
|
if is_torch_available(): |
|
|
import torch |
|
|
|
|
|
if is_torchvision_available(): |
|
|
if is_torchvision_v2_available(): |
|
|
from torchvision.transforms.v2 import functional as F |
|
|
else: |
|
|
from torchvision.transforms import functional as F |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
def get_image_size_for_patches( |
|
|
image_height: int, image_width: int, patch_size: int, max_num_patches: int |
|
|
) -> Tuple[int, int]: |
|
|
""" |
|
|
Args: |
|
|
image_height (`int`): |
|
|
Original image height. |
|
|
image_width (`int`): |
|
|
Original image width. |
|
|
patch_size (`int`): |
|
|
Patch size for processing. |
|
|
|
|
|
Returns: |
|
|
Tuple: (target_height, target_width) |
|
|
""" |
|
|
|
|
|
def get_scaled_image_size(scale: float, size: int, patch_size: int) -> int: |
|
|
patch_size = patch_size * 2 |
|
|
scaled_size = size * scale |
|
|
scaled_size = math.ceil(scaled_size / patch_size) * patch_size |
|
|
scaled_size = max(patch_size, scaled_size) |
|
|
return int(scaled_size) |
|
|
|
|
|
scale = 1.0 |
|
|
while True: |
|
|
target_height = get_scaled_image_size(scale, image_height, patch_size) |
|
|
target_width = get_scaled_image_size(scale, image_width, patch_size) |
|
|
num_patches = (target_height / patch_size) * (target_width / patch_size) |
|
|
|
|
|
if num_patches > max_num_patches: |
|
|
scale -= 0.02 |
|
|
else: |
|
|
break |
|
|
|
|
|
return target_height, target_width |
|
|
|
|
|
|
|
|
def convert_image_to_patches(image: "torch.Tensor", patch_size: int, merge_size: int) -> "torch.Tensor": |
|
|
""" |
|
|
Converts an input image into flattened patches. |
|
|
|
|
|
Args: |
|
|
image: Input image tensor of shape (channels, height, width) |
|
|
patch_size: Size of each square patch (in pixels) |
|
|
merge_size: Number of adjacent patches to merge |
|
|
|
|
|
""" |
|
|
|
|
|
num_channels, image_height, image_width = image.shape |
|
|
num_patches_height = image_height // patch_size |
|
|
num_patches_width = image_width // patch_size |
|
|
patched_image = image.reshape(num_channels, |
|
|
num_patches_height//merge_size, |
|
|
merge_size, patch_size, |
|
|
num_patches_width//merge_size, |
|
|
merge_size, patch_size) |
|
|
patched_image = patched_image.permute(1, 4, 2, 5, 3, 6, 0) |
|
|
patched_image = patched_image.reshape(num_patches_height * num_patches_width, -1) |
|
|
return patched_image |
|
|
|
|
|
def pad_along_first_dim( |
|
|
tensor: "torch.Tensor", target_length: int, pad_value: int = 0 |
|
|
) -> Tuple["torch.Tensor", "torch.Tensor"]: |
|
|
""" |
|
|
Pad the input tensor along its first dimension to a target length. |
|
|
|
|
|
Args: |
|
|
tensor (torch.Tensor): The input tensor to be padded. |
|
|
target_length (int): The desired length of the first dimension after padding. |
|
|
pad_value (int, optional): The value to use for padding. Defaults to 0. |
|
|
""" |
|
|
current_length = tensor.shape[0] |
|
|
padding_length = target_length - current_length |
|
|
mask = torch.ones((target_length,), dtype=torch.int32) |
|
|
if padding_length > 0: |
|
|
padding = [0, 0] * (tensor.ndim - 1) + [0, padding_length] |
|
|
tensor = torch.nn.functional.pad(tensor, padding, mode="constant", value=pad_value) |
|
|
mask[-padding_length:] = 0 |
|
|
return tensor, mask |
|
|
|
|
|
|
|
|
class Siglip2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs): |
|
|
patch_size: Optional[int] |
|
|
max_num_patches: Optional[int] |
|
|
|
|
|
|
|
|
@add_start_docstrings( |
|
|
r"Constructs a fast Siglip2 image processor.", |
|
|
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, |
|
|
""" |
|
|
patch_size (`int`, *optional*, defaults to 16): |
|
|
The size (resolution) of each patch the image will be split to. |
|
|
max_num_patches (`int`, *optional*, defaults to 256): |
|
|
The image will be resized to have at most this number of patches, |
|
|
and then padded in "patch" dimension to match this number exactly. |
|
|
""", |
|
|
) |
|
|
class Siglip2ImageProcessorFast(BaseImageProcessorFast): |
|
|
resample = PILImageResampling.BILINEAR |
|
|
image_mean = [0.5, 0.5, 0.5] |
|
|
image_std = [0.5, 0.5, 0.5] |
|
|
do_resize = True |
|
|
do_rescale = True |
|
|
do_normalize = True |
|
|
patch_size = 16 |
|
|
max_num_patches = 256 |
|
|
valid_kwargs = Siglip2FastImageProcessorKwargs |
|
|
unused_kwargs = ["size", "do_center_crop", "crop_size"] |
|
|
print_max_patched = True |
|
|
|
|
|
def __init__(self, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
def _validate_preprocess_kwargs(self, **kwargs) -> tuple: |
|
|
kwargs.pop("do_resize", None) |
|
|
return super()._validate_preprocess_kwargs(**kwargs) |
|
|
|
|
|
@add_start_docstrings( |
|
|
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, |
|
|
""" |
|
|
patch_size (`int`, *optional*, defaults to `self.patch_size`): |
|
|
The size (resolution) of each patch the image will be split to. |
|
|
max_num_patches (`int`, *optional*, defaults to `self.max_num_patches`): |
|
|
The image will be resized to have at most this number of patches, |
|
|
and then padded in "patch" dimension to match this number exactly. |
|
|
""", |
|
|
) |
|
|
def preprocess(self, images: ImageInput, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]) -> BatchFeature: |
|
|
return super().preprocess(images, **kwargs) |
|
|
|
|
|
def get_max_image_patches(self, images): |
|
|
return 4096 * 6 * 6 |
|
|
|
|
|
def _preprocess( |
|
|
self, |
|
|
images: List["torch.Tensor"], |
|
|
do_resize: bool, |
|
|
patch_size: int, |
|
|
max_num_patches: int, |
|
|
interpolation: Optional["F.InterpolationMode"], |
|
|
do_rescale: bool, |
|
|
rescale_factor: float, |
|
|
do_normalize: bool, |
|
|
image_mean: Optional[Union[float, List[float]]], |
|
|
image_std: Optional[Union[float, List[float]]], |
|
|
return_tensors: Optional[Union[str, TensorType]], |
|
|
**kwargs, |
|
|
) -> BatchFeature: |
|
|
pixel_masks = [] |
|
|
pixel_values = [] |
|
|
spatial_shapes = [] |
|
|
|
|
|
if Siglip2ImageProcessorFast.print_max_patched: |
|
|
Siglip2ImageProcessorFast.print_max_patched = False |
|
|
|
|
|
for i, image in enumerate(images): |
|
|
height, width, = get_image_size_for_patches( |
|
|
image_height=image.shape[1], |
|
|
image_width=image.shape[2], |
|
|
patch_size=patch_size, |
|
|
max_num_patches=max_num_patches, |
|
|
) |
|
|
|
|
|
side_dict = SizeDict(height=height, width=width) |
|
|
image = self.resize(image=image, size=side_dict, interpolation=interpolation) |
|
|
image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std) |
|
|
|
|
|
patches = convert_image_to_patches(image, patch_size, 2) |
|
|
patches, mask = pad_along_first_dim(patches, len(patches)) |
|
|
|
|
|
num_patches_height = image.shape[1] // patch_size |
|
|
num_patches_width = image.shape[2] // patch_size |
|
|
|
|
|
spatial_shapes.append((num_patches_height, num_patches_width)) |
|
|
pixel_values.append(patches) |
|
|
pixel_masks.append(mask) |
|
|
|
|
|
pixel_values = torch.stack(pixel_values, dim=0) |
|
|
pixel_masks = torch.stack(pixel_masks, dim=0) |
|
|
spatial_shapes = torch.tensor(spatial_shapes) |
|
|
|
|
|
batch_feature = BatchFeature( |
|
|
data={ |
|
|
"pixel_values": pixel_values, |
|
|
"pixel_attention_mask": pixel_masks, |
|
|
"spatial_shapes": spatial_shapes, |
|
|
}, |
|
|
tensor_type=return_tensors, |
|
|
) |
|
|
return batch_feature |
|
|
|
|
|
|
|
|
__all__ = ["Siglip2ImageProcessorFast"] |
|
|
|