IQA-Interpretation / training /utils /utils_data.py
dvarfe's picture
sync with github version
0705c62
Raw
History Blame Contribute Delete
3.36 kB
from random import randrange
import torchvision.transforms.functional as TF
from typing import List, Callable, Union
from PIL.Image import Image as PILImage
import torch
# def resize_crop(img: PILImage, crop_size: int = 224, downscale_factor: int = 1) -> PILImage:
# """
# Resize the image with the desired downscale factor and optionally crop it to the desired size. The crop is randomly
# sampled from the image. If crop_size is None, no crop is applied. If the crop is out of bounds, the image is
# automatically padded with zeros.
# Args:
# img (PIL Image): image to resize and crop
# crop_size (int): size of the crop. If None, no crop is applied
# downscale_factor (int): downscale factor to apply to the image
# Returns:
# img (PIL Image): resized and/or cropped image
# """
# w, h = img.size
# if downscale_factor > 1:
# img = img.resize((w // downscale_factor, h // downscale_factor))
# w, h = img.size
# if crop_size is not None:
# top = randrange(0, max(1, h - crop_size))
# left = randrange(0, max(1, w - crop_size))
# img = TF.crop(img, top, left, crop_size, crop_size) # Automatically pad with zeros if the crop is out of bounds
# return img
def resize_crop(img: torch.Tensor, crop_size: int = 224, downscale_factor: int = 1) -> PILImage:
"""
Resize the image with the desired downscale factor and optionally crop it to the desired size. The crop is randomly
sampled from the image. If crop_size is None, no crop is applied. If the crop is out of bounds, the image is
automatically padded with zeros.
Args:
img (torch.Tensor): image to resize and crop
crop_size (int): size of the crop. If None, no crop is applied
downscale_factor (int): downscale factor to apply to the image
Returns:
img (torch.Tensor): resized and/or cropped image
"""
_, w, h = img.shape
if downscale_factor > 1:
new_h, new_w = h // downscale_factor, w // downscale_factor
img = TF.resize(img, (new_h, new_w))
_, h, w = img.shape
if crop_size is not None:
top = randrange(0, max(1, h - crop_size))
left = randrange(0, max(1, w - crop_size))
img = TF.crop(img, top, left, crop_size, crop_size) # Automatically pad with zeros if the crop is out of bounds
return img
def center_corners_crop(img: PILImage, crop_size: int = 224) -> List[PILImage]:
"""
Return the center crop and the four corners of the image.
Args:
img (PIL.Image): image to crop
crop_size (int): size of each crop
Returns:
crops (List[PIL.Image]): list of the five crops
"""
width, height = img.size
# Calculate the coordinates for the center crop and the four corners
cx = width // 2
cy = height // 2
crops = [
TF.crop(img, cy - crop_size // 2, cx - crop_size // 2, crop_size, crop_size), # Center
TF.crop(img, 0, 0, crop_size, crop_size), # Top-left corner
TF.crop(img, height - crop_size, 0, crop_size, crop_size), # Bottom-left corner
TF.crop(img, 0, width - crop_size, crop_size, crop_size), # Top-right corner
TF.crop(img, height - crop_size, width - crop_size, crop_size, crop_size) # Bottom-right corner
]
return crops