| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Any, List, Mapping, Optional, Tuple |
| |
|
| | import torch |
| |
|
| | import torchvision |
| | from torchvision.transforms.v2 import functional as F |
| |
|
| | from utils import ( |
| | find_supported_resolutions, |
| | get_canvas_best_fit, |
| | resize_with_pad, |
| | tile_crop, |
| | ) |
| |
|
| | from torchtitan.tools.logging import logger |
| |
|
| |
|
| | class CLIPTransform: |
| | """ |
| | This class accepts images of any size and dynamically resizes, pads, normalizes and tiles it |
| | based on the image aspect ratio and the number of image tiles we allow. |
| | |
| | The algorithm will NOT distort the image to fit a certain aspect ratio, because |
| | that leads to a significant degradation in image quality. |
| | |
| | The user can choose if they want to allow upscaling by using the flag ``resize_to_max_canvas``. |
| | |
| | For example, if an input image is of size 300x800, and we want to allow |
| | a maximum of 16 image tiles, with side 224px, then: |
| | |
| | If ``resize_to_max_canvas=False``, then: |
| | best_resolution = (448, 896) -> smallest canvas, up to 16 tiles, that doesn't require downscaling |
| | image is NOT resized |
| | image is padded (300, 800) -> 448,896 |
| | Image is tiled 2x4, for a final output shape of (8, 3, 224, 224) |
| | |
| | If ``resize_to_max_canvas=True``, then: |
| | best_resolution = (448, 1344) # canvas that allows maximum upscaling, with minimum padding, up to 16 tiles |
| | image is resized without distortion (300,800) -> (448, 1194) #448 is the limiting side for the resize |
| | image is padded (448, 1194) -> (448, 1344) |
| | Image is tiled 2x6, for a final output shape of (10, 3, 224, 224) |
| | |
| | Args: |
| | image_mean (Optional[List[float]]): Mean values of each channel, used for normalization. |
| | Should be the same used for the pre-trained model. If None, no normalization is performed. Default None. |
| | image_std (Optional[List[float]]): Standard deviation values of each channel, used for normalization. |
| | Should be the same used for the pre-trained model. If None, no normalization is performed. Default None. |
| | possible_resolutions (Optional[List[Tuple[int, int]]]): List of possible resolutions as tuples (height, width). |
| | where each tuple represents a possible canvas to fit the image into when calling ``get_canvas_best_fit``. |
| | If None, this will be calculated using max_num_tiles and tile_size. Default None. |
| | tile_size (int): Size of the tiles to divide the image into. Default 224. |
| | max_num_tiles (Optional[int]): Only used if possible_resolutions is NOT given. |
| | Maximum number of tiles to break an image into. |
| | This will be used to generate possible_resolutions, |
| | e.g. [(224, 224), (224, 448), (448, 224)] if max_num_tiles = 2 and tile_size = 224. |
| | Default 4. |
| | dtype (torch.dtype): Data type of the output image. Default torch.bfloat16. |
| | resample (str): Resampling method used when resizing images. Supports any enum of |
| | ``torchvision.transforms.InterpolationMode``, e.g. "nearest", "nearest_exact", "bilinear", "bicubic". |
| | Default 'bilinear'. |
| | resize_to_max_canvas (bool): "If True, the image will be upscaled without distortion to fit the largest possible |
| | resolution from possible_resolutions. |
| | If False, it will pick the resolution that minimizes downscaling, including no downscaling at all. |
| | In this case, the image will only be upscaled if it's size < tile_size. Default False. |
| | |
| | Examples: |
| | >>> image_transform = CLIPImageTransform( |
| | ... image_mean=None, |
| | ... image_std=None, |
| | ... tile_size=224, |
| | ... possible_resolutions=None, |
| | ... max_num_tiles=4, |
| | ... resample="bilinear", |
| | ... resize_to_max_canvas=True, |
| | ...) |
| | >>> # create random image |
| | >>> image = (np.random.rand(100,200,3) * 255).astype(np.uint8) |
| | >>> image = PIL.Image.fromarray(image) |
| | >>> output = image_transform(image) |
| | >>> output['image'].shape # [num_tiles, num_channels, tile_size, tile_size] |
| | torch.Size([2, 3, 224, 224]) |
| | >>> output['ar'] # image best fits the canvas 224x448 |
| | torch.tensor([1,2]) |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | *, |
| | image_mean: Optional[List[float]] = None, |
| | image_std: Optional[List[float]] = None, |
| | possible_resolutions: Optional[List[Tuple[int, int]]] = None, |
| | tile_size: int = 224, |
| | max_num_tiles: Optional[int] = 4, |
| | dtype: torch.dtype = torch.bfloat16, |
| | resample: str = "bilinear", |
| | resize_to_max_canvas: bool = False, |
| | ) -> None: |
| |
|
| | |
| | assert ( |
| | possible_resolutions is not None or max_num_tiles is not None |
| | ), f"Either possible_resolutions or max_num_tiles must be given. Got {possible_resolutions} and {max_num_tiles}" |
| |
|
| | |
| | if not possible_resolutions and max_num_tiles: |
| | possible_resolutions = find_supported_resolutions( |
| | max_num_tiles=max_num_tiles, tile_size=tile_size |
| | ) |
| | else: |
| | possible_resolutions = possible_resolutions |
| |
|
| | self.possible_resolutions = torch.tensor(possible_resolutions).reshape(-1, 2) |
| | logger.debug( |
| | f"Found possible_resolutions: {self.possible_resolutions}. Will fit the images into the canvas with best fit." |
| | ) |
| |
|
| | self.resize_to_max_canvas = resize_to_max_canvas |
| |
|
| | |
| | assert (image_mean is None) == ( |
| | image_std is None |
| | ), f"Need to provide both or none of image_mean and image_std. Got {image_mean=} and {image_std=}" |
| | self.mean = image_mean |
| | self.std = image_std |
| |
|
| | |
| | self.max_size = None if resize_to_max_canvas else tile_size |
| | self.dtype = dtype |
| | self.resample = torchvision.transforms.InterpolationMode[resample.upper()] |
| |
|
| | |
| | self.tile_size = tile_size |
| |
|
| | def __call__(self, image: torch.Tensor) -> Mapping[str, Any]: |
| | """ |
| | Apply image decoding and transformations to the "image" field in the sample. |
| | |
| | Args: |
| | sample (Mapping[str, Any]): A sample with an "image" field containing |
| | a List[Message] to tokenize |
| | |
| | Returns: |
| | Mapping[str, Any]: The sample with an updated "image" filed and added |
| | "aspect_ratio" field. |
| | """ |
| | assert isinstance(image, torch.Tensor), "Input image must be a torch.Tensor." |
| |
|
| | image = F.to_image(image) |
| | image = F.grayscale_to_rgb_image(image) |
| | image = F.to_dtype(image, dtype=self.dtype, scale=True) |
| |
|
| | |
| | best_resolution = get_canvas_best_fit( |
| | image=image, |
| | possible_resolutions=self.possible_resolutions, |
| | resize_to_max_canvas=self.resize_to_max_canvas, |
| | ) |
| |
|
| | |
| | image = resize_with_pad( |
| | image=image, |
| | target_size=best_resolution, |
| | resample=self.resample, |
| | max_size=self.max_size, |
| | ) |
| |
|
| | |
| | if self.mean: |
| | image = F.normalize(image, mean=self.mean, std=self.std) |
| |
|
| | |
| | image = tile_crop(image=image, tile_size=self.tile_size) |
| |
|
| | aspect_ratio = torch.tensor(best_resolution).reshape(-1) // self.tile_size |
| |
|
| | return { |
| | "image": image, |
| | "aspect_ratio": aspect_ratio, |
| | } |
| |
|