| | import math |
| | from enum import Enum |
| | from typing import List, Tuple, Optional, Dict |
| |
|
| | import torch |
| | from torch import Tensor |
| |
|
| | from torchvision.transforms import functional as F |
| | from torchvision.transforms.functional import InterpolationMode |
| |
|
| | __all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"] |
| |
|
| |
|
| | def _apply_op( |
| | img: Tensor, op_name: str, magnitude: float, interpolation: InterpolationMode, fill: Optional[List[float]] |
| | ): |
| | if op_name == "ShearX": |
| | img = F.affine( |
| | img, |
| | angle=0.0, |
| | translate=[0, 0], |
| | scale=1.0, |
| | shear=[math.degrees(magnitude), 0.0], |
| | interpolation=interpolation, |
| | fill=fill, |
| | ) |
| | elif op_name == "ShearY": |
| | img = F.affine( |
| | img, |
| | angle=0.0, |
| | translate=[0, 0], |
| | scale=1.0, |
| | shear=[0.0, math.degrees(magnitude)], |
| | interpolation=interpolation, |
| | fill=fill, |
| | ) |
| | elif op_name == "TranslateX": |
| | img = F.affine( |
| | img, |
| | angle=0.0, |
| | translate=[int(magnitude), 0], |
| | scale=1.0, |
| | interpolation=interpolation, |
| | shear=[0.0, 0.0], |
| | fill=fill, |
| | ) |
| | elif op_name == "TranslateY": |
| | img = F.affine( |
| | img, |
| | angle=0.0, |
| | translate=[0, int(magnitude)], |
| | scale=1.0, |
| | interpolation=interpolation, |
| | shear=[0.0, 0.0], |
| | fill=fill, |
| | ) |
| | elif op_name == "Rotate": |
| | img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill) |
| | elif op_name == "Brightness": |
| | img = F.adjust_brightness(img, 1.0 + magnitude) |
| | elif op_name == "Color": |
| | img = F.adjust_saturation(img, 1.0 + magnitude) |
| | elif op_name == "Contrast": |
| | img = F.adjust_contrast(img, 1.0 + magnitude) |
| | elif op_name == "Sharpness": |
| | img = F.adjust_sharpness(img, 1.0 + magnitude) |
| | elif op_name == "Posterize": |
| | img = F.posterize(img, int(magnitude)) |
| | elif op_name == "Solarize": |
| | img = F.solarize(img, magnitude) |
| | elif op_name == "AutoContrast": |
| | img = F.autocontrast(img) |
| | elif op_name == "Equalize": |
| | img = F.equalize(img) |
| | elif op_name == "Invert": |
| | img = F.invert(img) |
| | elif op_name == "Identity": |
| | pass |
| | else: |
| | raise ValueError(f"The provided operator {op_name} is not recognized.") |
| | return img |
| |
|
| |
|
| | class AutoAugmentPolicy(Enum): |
| | """AutoAugment policies learned on different datasets. |
| | Available policies are IMAGENET, CIFAR10 and SVHN. |
| | """ |
| |
|
| | IMAGENET = "imagenet" |
| | CIFAR10 = "cifar10" |
| | SVHN = "svhn" |
| |
|
| |
|
| | |
| | class AutoAugment(torch.nn.Module): |
| | r"""AutoAugment data augmentation method based on |
| | `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_. |
| | If the image is torch Tensor, it should be of type torch.uint8, and it is expected |
| | to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. |
| | If img is PIL Image, it is expected to be in mode "L" or "RGB". |
| | |
| | Args: |
| | policy (AutoAugmentPolicy): Desired policy enum defined by |
| | :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``. |
| | interpolation (InterpolationMode): Desired interpolation enum defined by |
| | :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. |
| | If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. |
| | fill (sequence or number, optional): Pixel fill value for the area outside the transformed |
| | image. If given a number, the value is used for all bands respectively. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, |
| | interpolation: InterpolationMode = InterpolationMode.NEAREST, |
| | fill: Optional[List[float]] = None, |
| | ) -> None: |
| | super().__init__() |
| | self.policy = policy |
| | self.interpolation = interpolation |
| | self.fill = fill |
| | self.policies = self._get_policies(policy) |
| |
|
| | def _get_policies( |
| | self, policy: AutoAugmentPolicy |
| | ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]: |
| | if policy == AutoAugmentPolicy.IMAGENET: |
| | return [ |
| | (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), |
| | (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), |
| | (("Equalize", 0.8, None), ("Equalize", 0.6, None)), |
| | (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), |
| | (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), |
| | (("Equalize", 0.4, None), ("Rotate", 0.8, 8)), |
| | (("Solarize", 0.6, 3), ("Equalize", 0.6, None)), |
| | (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), |
| | (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), |
| | (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), |
| | (("Rotate", 0.8, 8), ("Color", 0.4, 0)), |
| | (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), |
| | (("Equalize", 0.0, None), ("Equalize", 0.8, None)), |
| | (("Invert", 0.6, None), ("Equalize", 1.0, None)), |
| | (("Color", 0.6, 4), ("Contrast", 1.0, 8)), |
| | (("Rotate", 0.8, 8), ("Color", 1.0, 2)), |
| | (("Color", 0.8, 8), ("Solarize", 0.8, 7)), |
| | (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), |
| | (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), |
| | (("Color", 0.4, 0), ("Equalize", 0.6, None)), |
| | (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), |
| | (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), |
| | (("Invert", 0.6, None), ("Equalize", 1.0, None)), |
| | (("Color", 0.6, 4), ("Contrast", 1.0, 8)), |
| | (("Equalize", 0.8, None), ("Equalize", 0.6, None)), |
| | ] |
| | elif policy == AutoAugmentPolicy.CIFAR10: |
| | return [ |
| | (("Invert", 0.1, None), ("Contrast", 0.2, 6)), |
| | (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), |
| | (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), |
| | (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), |
| | (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), |
| | (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), |
| | (("Color", 0.4, 3), ("Brightness", 0.6, 7)), |
| | (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), |
| | (("Equalize", 0.6, None), ("Equalize", 0.5, None)), |
| | (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), |
| | (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), |
| | (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), |
| | (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), |
| | (("Brightness", 0.9, 6), ("Color", 0.2, 8)), |
| | (("Solarize", 0.5, 2), ("Invert", 0.0, None)), |
| | (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)), |
| | (("Equalize", 0.2, None), ("Equalize", 0.6, None)), |
| | (("Color", 0.9, 9), ("Equalize", 0.6, None)), |
| | (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), |
| | (("Brightness", 0.1, 3), ("Color", 0.7, 0)), |
| | (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), |
| | (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), |
| | (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), |
| | (("Equalize", 0.8, None), ("Invert", 0.1, None)), |
| | (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), |
| | ] |
| | elif policy == AutoAugmentPolicy.SVHN: |
| | return [ |
| | (("ShearX", 0.9, 4), ("Invert", 0.2, None)), |
| | (("ShearY", 0.9, 8), ("Invert", 0.7, None)), |
| | (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), |
| | (("Invert", 0.9, None), ("Equalize", 0.6, None)), |
| | (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), |
| | (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), |
| | (("ShearY", 0.9, 8), ("Invert", 0.4, None)), |
| | (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), |
| | (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), |
| | (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), |
| | (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), |
| | (("ShearY", 0.8, 8), ("Invert", 0.7, None)), |
| | (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)), |
| | (("Invert", 0.9, None), ("Equalize", 0.6, None)), |
| | (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), |
| | (("Invert", 0.8, None), ("TranslateY", 0.0, 2)), |
| | (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), |
| | (("Invert", 0.6, None), ("Rotate", 0.8, 4)), |
| | (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), |
| | (("ShearX", 0.1, 6), ("Invert", 0.6, None)), |
| | (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), |
| | (("ShearY", 0.8, 4), ("Invert", 0.8, None)), |
| | (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), |
| | (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), |
| | (("ShearX", 0.7, 2), ("Invert", 0.1, None)), |
| | ] |
| | else: |
| | raise ValueError(f"The provided policy {policy} is not recognized.") |
| |
|
| | def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: |
| | return { |
| | |
| | "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), |
| | "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), |
| | "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), |
| | "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), |
| | "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), |
| | "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), |
| | "Color": (torch.linspace(0.0, 0.9, num_bins), True), |
| | "Contrast": (torch.linspace(0.0, 0.9, num_bins), True), |
| | "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), |
| | "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), |
| | "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), |
| | "AutoContrast": (torch.tensor(0.0), False), |
| | "Equalize": (torch.tensor(0.0), False), |
| | "Invert": (torch.tensor(0.0), False), |
| | } |
| |
|
| | @staticmethod |
| | def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]: |
| | """Get parameters for autoaugment transformation |
| | |
| | Returns: |
| | params required by the autoaugment transformation |
| | """ |
| | policy_id = int(torch.randint(transform_num, (1,)).item()) |
| | probs = torch.rand((2,)) |
| | signs = torch.randint(2, (2,)) |
| |
|
| | return policy_id, probs, signs |
| |
|
| | def forward(self, img: Tensor) -> Tensor: |
| | """ |
| | img (PIL Image or Tensor): Image to be transformed. |
| | |
| | Returns: |
| | PIL Image or Tensor: AutoAugmented image. |
| | """ |
| | fill = self.fill |
| | if isinstance(img, Tensor): |
| | if isinstance(fill, (int, float)): |
| | fill = [float(fill)] * F.get_image_num_channels(img) |
| | elif fill is not None: |
| | fill = [float(f) for f in fill] |
| |
|
| | transform_id, probs, signs = self.get_params(len(self.policies)) |
| |
|
| | for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]): |
| | if probs[i] <= p: |
| | op_meta = self._augmentation_space(10, F.get_image_size(img)) |
| | magnitudes, signed = op_meta[op_name] |
| | magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0 |
| | if signed and signs[i] == 0: |
| | magnitude *= -1.0 |
| | img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) |
| |
|
| | return img |
| |
|
| | def __repr__(self) -> str: |
| | return self.__class__.__name__ + f"(policy={self.policy}, fill={self.fill})" |
| |
|
| |
|
| | class RandAugment(torch.nn.Module): |
| | r"""RandAugment data augmentation method based on |
| | `"RandAugment: Practical automated data augmentation with a reduced search space" |
| | <https://arxiv.org/abs/1909.13719>`_. |
| | If the image is torch Tensor, it should be of type torch.uint8, and it is expected |
| | to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. |
| | If img is PIL Image, it is expected to be in mode "L" or "RGB". |
| | |
| | Args: |
| | num_ops (int): Number of augmentation transformations to apply sequentially. |
| | magnitude (int): Magnitude for all the transformations. |
| | num_magnitude_bins (int): The number of different magnitude values. |
| | interpolation (InterpolationMode): Desired interpolation enum defined by |
| | :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. |
| | If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. |
| | fill (sequence or number, optional): Pixel fill value for the area outside the transformed |
| | image. If given a number, the value is used for all bands respectively. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | num_ops: int = 2, |
| | magnitude: int = 9, |
| | num_magnitude_bins: int = 31, |
| | interpolation: InterpolationMode = InterpolationMode.NEAREST, |
| | fill: Optional[List[float]] = None, |
| | ) -> None: |
| | super().__init__() |
| | self.num_ops = num_ops |
| | self.magnitude = magnitude |
| | self.num_magnitude_bins = num_magnitude_bins |
| | self.interpolation = interpolation |
| | self.fill = fill |
| |
|
| | def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: |
| | return { |
| | |
| | "Identity": (torch.tensor(0.0), False), |
| | "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), |
| | "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), |
| | "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), |
| | "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), |
| | "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), |
| | "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), |
| | "Color": (torch.linspace(0.0, 0.9, num_bins), True), |
| | "Contrast": (torch.linspace(0.0, 0.9, num_bins), True), |
| | "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), |
| | "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), |
| | "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), |
| | "AutoContrast": (torch.tensor(0.0), False), |
| | "Equalize": (torch.tensor(0.0), False), |
| | } |
| |
|
| | def forward(self, img: Tensor) -> Tensor: |
| | """ |
| | img (PIL Image or Tensor): Image to be transformed. |
| | |
| | Returns: |
| | PIL Image or Tensor: Transformed image. |
| | """ |
| | fill = self.fill |
| | if isinstance(img, Tensor): |
| | if isinstance(fill, (int, float)): |
| | fill = [float(fill)] * F.get_image_num_channels(img) |
| | elif fill is not None: |
| | fill = [float(f) for f in fill] |
| |
|
| | for _ in range(self.num_ops): |
| | op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img)) |
| | op_index = int(torch.randint(len(op_meta), (1,)).item()) |
| | op_name = list(op_meta.keys())[op_index] |
| | magnitudes, signed = op_meta[op_name] |
| | magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0 |
| | if signed and torch.randint(2, (1,)): |
| | magnitude *= -1.0 |
| | img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) |
| |
|
| | return img |
| |
|
| | def __repr__(self) -> str: |
| | s = self.__class__.__name__ + "(" |
| | s += "num_ops={num_ops}" |
| | s += ", magnitude={magnitude}" |
| | s += ", num_magnitude_bins={num_magnitude_bins}" |
| | s += ", interpolation={interpolation}" |
| | s += ", fill={fill}" |
| | s += ")" |
| | return s.format(**self.__dict__) |
| |
|
| |
|
| | class TrivialAugmentWide(torch.nn.Module): |
| | r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in |
| | `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`_. |
| | If the image is torch Tensor, it should be of type torch.uint8, and it is expected |
| | to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. |
| | If img is PIL Image, it is expected to be in mode "L" or "RGB". |
| | |
| | Args: |
| | num_magnitude_bins (int): The number of different magnitude values. |
| | interpolation (InterpolationMode): Desired interpolation enum defined by |
| | :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. |
| | If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. |
| | fill (sequence or number, optional): Pixel fill value for the area outside the transformed |
| | image. If given a number, the value is used for all bands respectively. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | num_magnitude_bins: int = 31, |
| | interpolation: InterpolationMode = InterpolationMode.NEAREST, |
| | fill: Optional[List[float]] = None, |
| | ) -> None: |
| | super().__init__() |
| | self.num_magnitude_bins = num_magnitude_bins |
| | self.interpolation = interpolation |
| | self.fill = fill |
| |
|
| | def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]: |
| | return { |
| | |
| | "Identity": (torch.tensor(0.0), False), |
| | "ShearX": (torch.linspace(0.0, 0.99, num_bins), True), |
| | "ShearY": (torch.linspace(0.0, 0.99, num_bins), True), |
| | "TranslateX": (torch.linspace(0.0, 32.0, num_bins), True), |
| | "TranslateY": (torch.linspace(0.0, 32.0, num_bins), True), |
| | "Rotate": (torch.linspace(0.0, 135.0, num_bins), True), |
| | "Brightness": (torch.linspace(0.0, 0.99, num_bins), True), |
| | "Color": (torch.linspace(0.0, 0.99, num_bins), True), |
| | "Contrast": (torch.linspace(0.0, 0.99, num_bins), True), |
| | "Sharpness": (torch.linspace(0.0, 0.99, num_bins), True), |
| | "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False), |
| | "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), |
| | "AutoContrast": (torch.tensor(0.0), False), |
| | "Equalize": (torch.tensor(0.0), False), |
| | } |
| |
|
| | def forward(self, img: Tensor) -> Tensor: |
| | """ |
| | img (PIL Image or Tensor): Image to be transformed. |
| | |
| | Returns: |
| | PIL Image or Tensor: Transformed image. |
| | """ |
| | fill = self.fill |
| | if isinstance(img, Tensor): |
| | if isinstance(fill, (int, float)): |
| | fill = [float(fill)] * F.get_image_num_channels(img) |
| | elif fill is not None: |
| | fill = [float(f) for f in fill] |
| |
|
| | op_meta = self._augmentation_space(self.num_magnitude_bins) |
| | op_index = int(torch.randint(len(op_meta), (1,)).item()) |
| | op_name = list(op_meta.keys())[op_index] |
| | magnitudes, signed = op_meta[op_name] |
| | magnitude = ( |
| | float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) |
| | if magnitudes.ndim > 0 |
| | else 0.0 |
| | ) |
| | if signed and torch.randint(2, (1,)): |
| | magnitude *= -1.0 |
| |
|
| | return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) |
| |
|
| | def __repr__(self) -> str: |
| | s = self.__class__.__name__ + "(" |
| | s += "num_magnitude_bins={num_magnitude_bins}" |
| | s += ", interpolation={interpolation}" |
| | s += ", fill={fill}" |
| | s += ")" |
| | return s.format(**self.__dict__) |
| |
|