Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import Tensor | |
| from torchvision.transforms import ColorJitter as _ColorJitter | |
| import torchvision.transforms.functional as TF | |
| import numpy as np | |
| from typing import Tuple, Union, Optional, Callable | |
| def _crop( | |
| image: Tensor, | |
| label: Tensor, | |
| top: int, | |
| left: int, | |
| height: int, | |
| width: int, | |
| ) -> Tuple[Tensor, Tensor]: | |
| image = TF.crop(image, top, left, height, width) | |
| if len(label) > 0: | |
| label[:, 0] -= left | |
| label[:, 1] -= top | |
| label_mask = (label[:, 0] >= 0) & (label[:, 0] < width) & (label[:, 1] >= 0) & (label[:, 1] < height) | |
| label = label[label_mask] | |
| return image, label | |
| def _resize( | |
| image: Tensor, | |
| label: Tensor, | |
| height: int, | |
| width: int, | |
| ) -> Tuple[Tensor, Tensor]: | |
| image_height, image_width = image.shape[-2:] | |
| image = TF.resize(image, (height, width), interpolation=TF.InterpolationMode.BICUBIC, antialias=True) if (image_height != height or image_width != width) else image | |
| if len(label) > 0 and (image_height != height or image_width != width): | |
| label[:, 0] = label[:, 0] * width / image_width | |
| label[:, 1] = label[:, 1] * height / image_height | |
| label[:, 0] = label[:, 0].clamp(min=0, max=width - 1) | |
| label[:, 1] = label[:, 1].clamp(min=0, max=height - 1) | |
| return image, label | |
| class RandomCrop(object): | |
| def __init__(self, size: Tuple[int, int]) -> None: | |
| self.size = size | |
| assert len(self.size) == 2, f"size should be a tuple (h, w), got {self.size}." | |
| def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]: | |
| crop_height, crop_width = self.size | |
| image_height, image_width = image.shape[-2:] | |
| assert crop_height <= image_height and crop_width <= image_width, \ | |
| f"crop size should be no larger than image size, got crop size {self.size} and image size {image.shape}." | |
| top = torch.randint(0, image_height - crop_height + 1, (1,)).item() | |
| left = torch.randint(0, image_width - crop_width + 1, (1,)).item() | |
| return _crop(image, label, top, left, crop_height, crop_width) | |
| class Resize(object): | |
| def __init__(self, size: Tuple[int, int]) -> None: | |
| self.size = size | |
| assert len(self.size) == 2, f"size should be a tuple (h, w), got {self.size}." | |
| def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]: | |
| return _resize(image, label, self.size[0], self.size[1]) | |
| class Resize2Multiple(object): | |
| """ | |
| Resize the image so that it satisfies: | |
| img_h = window_h + stride_h * n_h | |
| img_w = window_w + stride_w * n_w | |
| """ | |
| def __init__( | |
| self, | |
| window_size: Tuple[int, int], | |
| stride: Tuple[int, int], | |
| ) -> None: | |
| window_size = (int(window_size), int(window_size)) if isinstance(window_size, (int, float)) else window_size | |
| window_size = tuple(window_size) | |
| stride = (int(stride), int(stride)) if isinstance(stride, (int, float)) else stride | |
| stride = tuple(stride) | |
| assert len(window_size) == 2, f"window_size should be a tuple (h, w), got {window_size}." | |
| assert len(stride) == 2, f"stride should be a tuple (h, w), got {stride}." | |
| assert all(s > 0 for s in window_size), f"window_size should be positive, got {window_size}." | |
| assert all(s > 0 for s in stride), f"stride should be positive, got {stride}." | |
| assert stride[0] <= window_size[0] and stride[1] <= window_size[1], f"stride should be no larger than window_size, got {stride} and {window_size}." | |
| self.window_size = window_size | |
| self.stride = stride | |
| def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]: | |
| image_height, image_width = image.shape[-2:] | |
| window_height, window_width = self.window_size | |
| stride_height, stride_width = self.stride | |
| new_height = int(max(round((image_height - window_height) / stride_height), 0) * stride_height + window_height) | |
| new_width = int(max(round((image_width - window_width) / stride_width), 0) * stride_width + window_width) | |
| if new_height == image_height and new_width == image_width: | |
| return image, label | |
| else: | |
| return _resize(image, label, new_height, new_width) | |
| class ZeroPad2Multiple(object): | |
| def __init__( | |
| self, | |
| window_size: Tuple[int, int], | |
| stride: Tuple[int, int], | |
| ) -> None: | |
| window_size = (int(window_size), int(window_size)) if isinstance(window_size, (int, float)) else window_size | |
| window_size = tuple(window_size) | |
| stride = (int(stride), int(stride)) if isinstance(stride, (int, float)) else stride | |
| stride = tuple(stride) | |
| assert len(window_size) == 2, f"window_size should be a tuple (h, w), got {window_size}." | |
| assert len(stride) == 2, f"stride should be a tuple (h, w), got {stride}." | |
| assert all(s > 0 for s in window_size), f"window_size should be positive, got {window_size}." | |
| assert all(s > 0 for s in stride), f"stride should be positive, got {stride}." | |
| assert stride[0] <= window_size[0] and stride[1] <= window_size[1], f"stride should be no larger than window_size, got {stride} and {window_size}." | |
| self.window_size = window_size | |
| self.stride = stride | |
| def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]: | |
| image_height, image_width = image.shape[-2:] | |
| window_height, window_width = self.window_size | |
| stride_height, stride_width = self.stride | |
| new_height = int(max(np.ceil((image_height - window_height) / stride_height), 0) * stride_height + window_height) | |
| new_width = int(max(np.ceil((image_width - window_width) / stride_width), 0) * stride_width + window_width) | |
| if new_height == image_height and new_width == image_width: | |
| return image, label | |
| else: | |
| assert new_height >= image_height and new_width >= image_width, f"new size should be no less than the original size, got {new_height} and {new_width}." | |
| pad_height, pad_width = new_height - image_height, new_width - image_width | |
| return TF.pad(image, (0, 0, pad_width, pad_height), fill=0), label # only pad the right and bottom sides so that the label coordinates are not affected | |
| class RandomResizedCrop(object): | |
| def __init__( | |
| self, | |
| size: Tuple[int, int], | |
| scale: Tuple[float, float] = (0.75, 1.25), | |
| ) -> None: | |
| """ | |
| Randomly crop an image and resize it to a given size. The aspect ratio is preserved during this process. | |
| """ | |
| self.size = size | |
| self.scale = scale | |
| assert len(self.size) == 2, f"size should be a tuple (h, w), got {self.size}." | |
| assert 0 < self.scale[0] <= self.scale[1], f"scale should satisfy 0 < scale[0] <= scale[1], got {self.scale}." | |
| def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]: | |
| out_height, out_width = self.size | |
| # out_ratio = out_width / out_height | |
| scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() # if scale < 1, then the image will be zoomed in, otherwise zoomed out | |
| in_height, in_width = image.shape[-2:] | |
| # if in_width / in_height < out_ratio: # Image is too tall | |
| # crop_width = int(in_width * scale) | |
| # crop_height = int(crop_width / out_ratio) | |
| # else: # Image is too wide | |
| # crop_height = int(in_height * scale) | |
| # crop_width = int(crop_height * out_ratio) | |
| crop_height, crop_width = int(out_height * scale), int(out_width * scale) | |
| if crop_height <= in_height and crop_width <= in_width: # directly crop and resize the image | |
| top = torch.randint(0, in_height - crop_height + 1, (1,)).item() | |
| left = torch.randint(0, in_width - crop_width + 1, (1,)).item() | |
| else: # resize the image and then crop | |
| ratio = max(crop_height / in_height, crop_width / in_width) # keep the aspect ratio | |
| resize_height, resize_width = int(in_height * ratio) + 1, int(in_width * ratio) + 1 # add 1 to make sure the resized image is no less than the crop size | |
| image, label = _resize(image, label, resize_height, resize_width) | |
| top = torch.randint(0, resize_height - crop_height + 1, (1,)).item() | |
| left = torch.randint(0, resize_width - crop_width + 1, (1,)).item() | |
| image, label = _crop(image, label, top, left, crop_height, crop_width) | |
| return _resize(image, label, out_height, out_width) | |
| class RandomHorizontalFlip(object): | |
| def __init__(self, p: float = 0.5) -> None: | |
| self.p = p | |
| assert 0 <= self.p <= 1, f"p should be in range [0, 1], got {self.p}." | |
| def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]: | |
| if torch.rand(1) < self.p: | |
| image = TF.hflip(image) | |
| if len(label) > 0: | |
| label[:, 0] = image.shape[-1] - 1 - label[:, 0] # if width is 256, then 0 -> 255, 1 -> 254, 2 -> 253, etc. | |
| label[:, 0] = label[:, 0].clamp(min=0, max=image.shape[-1] - 1) | |
| return image, label | |
| class ColorJitter(object): | |
| def __init__( | |
| self, | |
| brightness: Union[float, Tuple[float, float]] = 0.4, | |
| contrast: Union[float, Tuple[float, float]] = 0.4, | |
| saturation: Union[float, Tuple[float, float]] = 0.4, | |
| hue: Union[float, Tuple[float, float]] = 0.2, | |
| ) -> None: | |
| self.color_jitter = _ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) | |
| def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]: | |
| return self.color_jitter(image), label | |
| class RandomGrayscale(object): | |
| def __init__(self, p: float = 0.1) -> None: | |
| self.p = p | |
| assert 0 <= self.p <= 1, f"p should be in range [0, 1], got {self.p}." | |
| def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]: | |
| if torch.rand(1) < self.p: | |
| image = TF.rgb_to_grayscale(image, num_output_channels=3) | |
| return image, label | |
| class GaussianBlur(object): | |
| def __init__(self, kernel_size: int, sigma: Tuple[float, float] = (0.1, 2.0)) -> None: | |
| self.kernel_size = kernel_size | |
| self.sigma = sigma | |
| def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]: | |
| return TF.gaussian_blur(image, self.kernel_size, self.sigma), label | |
| class RandomApply(object): | |
| def __init__(self, transforms: Tuple[Callable, ...], p: Union[float, Tuple[float, ...]] = 0.5) -> None: | |
| self.transforms = transforms | |
| p = [p] * len(transforms) if isinstance(p, float) else p | |
| assert all(0 <= p_ <= 1 for p_ in p), f"p should be in range [0, 1], got {p}." | |
| assert len(p) == len(transforms), f"p should be a float or a tuple of floats with the same length as transforms, got {p}." | |
| self.p = p | |
| def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]: | |
| for transform, p in zip(self.transforms, self.p): | |
| if torch.rand(1) < p: | |
| image, label = transform(image, label) | |
| return image, label | |
| class PepperSaltNoise(object): | |
| def __init__(self, saltiness: float = 0.001, spiciness: float = 0.001) -> None: | |
| self.saltiness = saltiness | |
| self.spiciness = spiciness | |
| assert 0 <= self.saltiness <= 1, f"saltiness should be in range [0, 1], got {self.saltiness}." | |
| assert 0 <= self.spiciness <= 1, f"spiciness should be in range [0, 1], got {self.spiciness}." | |
| def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]: | |
| noise = torch.rand_like(image) | |
| image = torch.where(noise < self.saltiness, 1., image) # Salt | |
| image = torch.where(noise > 1 - self.spiciness, 0., image) # Pepper | |
| return image, label | |