from random import randrange import torchvision.transforms.functional as TF from typing import List, Callable, Union from PIL.Image import Image as PILImage import torch # def resize_crop(img: PILImage, crop_size: int = 224, downscale_factor: int = 1) -> PILImage: # """ # Resize the image with the desired downscale factor and optionally crop it to the desired size. The crop is randomly # sampled from the image. If crop_size is None, no crop is applied. If the crop is out of bounds, the image is # automatically padded with zeros. # Args: # img (PIL Image): image to resize and crop # crop_size (int): size of the crop. If None, no crop is applied # downscale_factor (int): downscale factor to apply to the image # Returns: # img (PIL Image): resized and/or cropped image # """ # w, h = img.size # if downscale_factor > 1: # img = img.resize((w // downscale_factor, h // downscale_factor)) # w, h = img.size # if crop_size is not None: # top = randrange(0, max(1, h - crop_size)) # left = randrange(0, max(1, w - crop_size)) # img = TF.crop(img, top, left, crop_size, crop_size) # Automatically pad with zeros if the crop is out of bounds # return img def resize_crop(img: torch.Tensor, crop_size: int = 224, downscale_factor: int = 1) -> PILImage: """ Resize the image with the desired downscale factor and optionally crop it to the desired size. The crop is randomly sampled from the image. If crop_size is None, no crop is applied. If the crop is out of bounds, the image is automatically padded with zeros. Args: img (torch.Tensor): image to resize and crop crop_size (int): size of the crop. If None, no crop is applied downscale_factor (int): downscale factor to apply to the image Returns: img (torch.Tensor): resized and/or cropped image """ _, w, h = img.shape if downscale_factor > 1: new_h, new_w = h // downscale_factor, w // downscale_factor img = TF.resize(img, (new_h, new_w)) _, h, w = img.shape if crop_size is not None: top = randrange(0, max(1, h - crop_size)) left = randrange(0, max(1, w - crop_size)) img = TF.crop(img, top, left, crop_size, crop_size) # Automatically pad with zeros if the crop is out of bounds return img def center_corners_crop(img: PILImage, crop_size: int = 224) -> List[PILImage]: """ Return the center crop and the four corners of the image. Args: img (PIL.Image): image to crop crop_size (int): size of each crop Returns: crops (List[PIL.Image]): list of the five crops """ width, height = img.size # Calculate the coordinates for the center crop and the four corners cx = width // 2 cy = height // 2 crops = [ TF.crop(img, cy - crop_size // 2, cx - crop_size // 2, crop_size, crop_size), # Center TF.crop(img, 0, 0, crop_size, crop_size), # Top-left corner TF.crop(img, height - crop_size, 0, crop_size, crop_size), # Bottom-left corner TF.crop(img, 0, width - crop_size, crop_size, crop_size), # Top-right corner TF.crop(img, height - crop_size, width - crop_size, crop_size, crop_size) # Bottom-right corner ] return crops