ZIP-B / datasets /transforms.py
Yiming-M's picture
2025-07-31 15:53 🐣
0ecb9aa verified
import torch
from torch import Tensor
from torchvision.transforms import ColorJitter as _ColorJitter
import torchvision.transforms.functional as TF
import numpy as np
from typing import Tuple, Union, Optional, Callable
def _crop(
image: Tensor,
label: Tensor,
top: int,
left: int,
height: int,
width: int,
) -> Tuple[Tensor, Tensor]:
image = TF.crop(image, top, left, height, width)
if len(label) > 0:
label[:, 0] -= left
label[:, 1] -= top
label_mask = (label[:, 0] >= 0) & (label[:, 0] < width) & (label[:, 1] >= 0) & (label[:, 1] < height)
label = label[label_mask]
return image, label
def _resize(
image: Tensor,
label: Tensor,
height: int,
width: int,
) -> Tuple[Tensor, Tensor]:
image_height, image_width = image.shape[-2:]
image = TF.resize(image, (height, width), interpolation=TF.InterpolationMode.BICUBIC, antialias=True) if (image_height != height or image_width != width) else image
if len(label) > 0 and (image_height != height or image_width != width):
label[:, 0] = label[:, 0] * width / image_width
label[:, 1] = label[:, 1] * height / image_height
label[:, 0] = label[:, 0].clamp(min=0, max=width - 1)
label[:, 1] = label[:, 1].clamp(min=0, max=height - 1)
return image, label
class RandomCrop(object):
def __init__(self, size: Tuple[int, int]) -> None:
self.size = size
assert len(self.size) == 2, f"size should be a tuple (h, w), got {self.size}."
def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
crop_height, crop_width = self.size
image_height, image_width = image.shape[-2:]
assert crop_height <= image_height and crop_width <= image_width, \
f"crop size should be no larger than image size, got crop size {self.size} and image size {image.shape}."
top = torch.randint(0, image_height - crop_height + 1, (1,)).item()
left = torch.randint(0, image_width - crop_width + 1, (1,)).item()
return _crop(image, label, top, left, crop_height, crop_width)
class Resize(object):
def __init__(self, size: Tuple[int, int]) -> None:
self.size = size
assert len(self.size) == 2, f"size should be a tuple (h, w), got {self.size}."
def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
return _resize(image, label, self.size[0], self.size[1])
class Resize2Multiple(object):
"""
Resize the image so that it satisfies:
img_h = window_h + stride_h * n_h
img_w = window_w + stride_w * n_w
"""
def __init__(
self,
window_size: Tuple[int, int],
stride: Tuple[int, int],
) -> None:
window_size = (int(window_size), int(window_size)) if isinstance(window_size, (int, float)) else window_size
window_size = tuple(window_size)
stride = (int(stride), int(stride)) if isinstance(stride, (int, float)) else stride
stride = tuple(stride)
assert len(window_size) == 2, f"window_size should be a tuple (h, w), got {window_size}."
assert len(stride) == 2, f"stride should be a tuple (h, w), got {stride}."
assert all(s > 0 for s in window_size), f"window_size should be positive, got {window_size}."
assert all(s > 0 for s in stride), f"stride should be positive, got {stride}."
assert stride[0] <= window_size[0] and stride[1] <= window_size[1], f"stride should be no larger than window_size, got {stride} and {window_size}."
self.window_size = window_size
self.stride = stride
def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
image_height, image_width = image.shape[-2:]
window_height, window_width = self.window_size
stride_height, stride_width = self.stride
new_height = int(max(round((image_height - window_height) / stride_height), 0) * stride_height + window_height)
new_width = int(max(round((image_width - window_width) / stride_width), 0) * stride_width + window_width)
if new_height == image_height and new_width == image_width:
return image, label
else:
return _resize(image, label, new_height, new_width)
class ZeroPad2Multiple(object):
def __init__(
self,
window_size: Tuple[int, int],
stride: Tuple[int, int],
) -> None:
window_size = (int(window_size), int(window_size)) if isinstance(window_size, (int, float)) else window_size
window_size = tuple(window_size)
stride = (int(stride), int(stride)) if isinstance(stride, (int, float)) else stride
stride = tuple(stride)
assert len(window_size) == 2, f"window_size should be a tuple (h, w), got {window_size}."
assert len(stride) == 2, f"stride should be a tuple (h, w), got {stride}."
assert all(s > 0 for s in window_size), f"window_size should be positive, got {window_size}."
assert all(s > 0 for s in stride), f"stride should be positive, got {stride}."
assert stride[0] <= window_size[0] and stride[1] <= window_size[1], f"stride should be no larger than window_size, got {stride} and {window_size}."
self.window_size = window_size
self.stride = stride
def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
image_height, image_width = image.shape[-2:]
window_height, window_width = self.window_size
stride_height, stride_width = self.stride
new_height = int(max(np.ceil((image_height - window_height) / stride_height), 0) * stride_height + window_height)
new_width = int(max(np.ceil((image_width - window_width) / stride_width), 0) * stride_width + window_width)
if new_height == image_height and new_width == image_width:
return image, label
else:
assert new_height >= image_height and new_width >= image_width, f"new size should be no less than the original size, got {new_height} and {new_width}."
pad_height, pad_width = new_height - image_height, new_width - image_width
return TF.pad(image, (0, 0, pad_width, pad_height), fill=0), label # only pad the right and bottom sides so that the label coordinates are not affected
class RandomResizedCrop(object):
def __init__(
self,
size: Tuple[int, int],
scale: Tuple[float, float] = (0.75, 1.25),
) -> None:
"""
Randomly crop an image and resize it to a given size. The aspect ratio is preserved during this process.
"""
self.size = size
self.scale = scale
assert len(self.size) == 2, f"size should be a tuple (h, w), got {self.size}."
assert 0 < self.scale[0] <= self.scale[1], f"scale should satisfy 0 < scale[0] <= scale[1], got {self.scale}."
def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
out_height, out_width = self.size
# out_ratio = out_width / out_height
scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() # if scale < 1, then the image will be zoomed in, otherwise zoomed out
in_height, in_width = image.shape[-2:]
# if in_width / in_height < out_ratio: # Image is too tall
# crop_width = int(in_width * scale)
# crop_height = int(crop_width / out_ratio)
# else: # Image is too wide
# crop_height = int(in_height * scale)
# crop_width = int(crop_height * out_ratio)
crop_height, crop_width = int(out_height * scale), int(out_width * scale)
if crop_height <= in_height and crop_width <= in_width: # directly crop and resize the image
top = torch.randint(0, in_height - crop_height + 1, (1,)).item()
left = torch.randint(0, in_width - crop_width + 1, (1,)).item()
else: # resize the image and then crop
ratio = max(crop_height / in_height, crop_width / in_width) # keep the aspect ratio
resize_height, resize_width = int(in_height * ratio) + 1, int(in_width * ratio) + 1 # add 1 to make sure the resized image is no less than the crop size
image, label = _resize(image, label, resize_height, resize_width)
top = torch.randint(0, resize_height - crop_height + 1, (1,)).item()
left = torch.randint(0, resize_width - crop_width + 1, (1,)).item()
image, label = _crop(image, label, top, left, crop_height, crop_width)
return _resize(image, label, out_height, out_width)
class RandomHorizontalFlip(object):
def __init__(self, p: float = 0.5) -> None:
self.p = p
assert 0 <= self.p <= 1, f"p should be in range [0, 1], got {self.p}."
def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
if torch.rand(1) < self.p:
image = TF.hflip(image)
if len(label) > 0:
label[:, 0] = image.shape[-1] - 1 - label[:, 0] # if width is 256, then 0 -> 255, 1 -> 254, 2 -> 253, etc.
label[:, 0] = label[:, 0].clamp(min=0, max=image.shape[-1] - 1)
return image, label
class ColorJitter(object):
def __init__(
self,
brightness: Union[float, Tuple[float, float]] = 0.4,
contrast: Union[float, Tuple[float, float]] = 0.4,
saturation: Union[float, Tuple[float, float]] = 0.4,
hue: Union[float, Tuple[float, float]] = 0.2,
) -> None:
self.color_jitter = _ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
return self.color_jitter(image), label
class RandomGrayscale(object):
def __init__(self, p: float = 0.1) -> None:
self.p = p
assert 0 <= self.p <= 1, f"p should be in range [0, 1], got {self.p}."
def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
if torch.rand(1) < self.p:
image = TF.rgb_to_grayscale(image, num_output_channels=3)
return image, label
class GaussianBlur(object):
def __init__(self, kernel_size: int, sigma: Tuple[float, float] = (0.1, 2.0)) -> None:
self.kernel_size = kernel_size
self.sigma = sigma
def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
return TF.gaussian_blur(image, self.kernel_size, self.sigma), label
class RandomApply(object):
def __init__(self, transforms: Tuple[Callable, ...], p: Union[float, Tuple[float, ...]] = 0.5) -> None:
self.transforms = transforms
p = [p] * len(transforms) if isinstance(p, float) else p
assert all(0 <= p_ <= 1 for p_ in p), f"p should be in range [0, 1], got {p}."
assert len(p) == len(transforms), f"p should be a float or a tuple of floats with the same length as transforms, got {p}."
self.p = p
def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
for transform, p in zip(self.transforms, self.p):
if torch.rand(1) < p:
image, label = transform(image, label)
return image, label
class PepperSaltNoise(object):
def __init__(self, saltiness: float = 0.001, spiciness: float = 0.001) -> None:
self.saltiness = saltiness
self.spiciness = spiciness
assert 0 <= self.saltiness <= 1, f"saltiness should be in range [0, 1], got {self.saltiness}."
assert 0 <= self.spiciness <= 1, f"spiciness should be in range [0, 1], got {self.spiciness}."
def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
noise = torch.rand_like(image)
image = torch.where(noise < self.saltiness, 1., image) # Salt
image = torch.where(noise > 1 - self.spiciness, 0., image) # Pepper
return image, label