| | import math |
| | from typing import Tuple |
| |
|
| | import torch |
| | from torch import Tensor |
| | from torchvision.transforms import functional as F |
| |
|
| |
|
| | class RandomMixup(torch.nn.Module): |
| | """Randomly apply Mixup to the provided batch and targets. |
| | The class implements the data augmentations as described in the paper |
| | `"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_. |
| | |
| | Args: |
| | num_classes (int): number of classes used for one-hot encoding. |
| | p (float): probability of the batch being transformed. Default value is 0.5. |
| | alpha (float): hyperparameter of the Beta distribution used for mixup. |
| | Default value is 1.0. |
| | inplace (bool): boolean to make this transform inplace. Default set to False. |
| | """ |
| |
|
| | def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None: |
| | super().__init__() |
| |
|
| | if num_classes < 1: |
| | raise ValueError( |
| | f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}" |
| | ) |
| |
|
| | if alpha <= 0: |
| | raise ValueError("Alpha param can't be zero.") |
| |
|
| | self.num_classes = num_classes |
| | self.p = p |
| | self.alpha = alpha |
| | self.inplace = inplace |
| |
|
| | def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: |
| | """ |
| | Args: |
| | batch (Tensor): Float tensor of size (B, C, H, W) |
| | target (Tensor): Integer tensor of size (B, ) |
| | |
| | Returns: |
| | Tensor: Randomly transformed batch. |
| | """ |
| | if batch.ndim != 4: |
| | raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") |
| | if target.ndim != 1: |
| | raise ValueError(f"Target ndim should be 1. Got {target.ndim}") |
| | if not batch.is_floating_point(): |
| | raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") |
| | if target.dtype != torch.int64: |
| | raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") |
| |
|
| | if not self.inplace: |
| | batch = batch.clone() |
| | target = target.clone() |
| |
|
| | if target.ndim == 1: |
| | target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) |
| |
|
| | if torch.rand(1).item() >= self.p: |
| | return batch, target |
| |
|
| | |
| | batch_rolled = batch.roll(1, 0) |
| | target_rolled = target.roll(1, 0) |
| |
|
| | |
| | lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) |
| | batch_rolled.mul_(1.0 - lambda_param) |
| | batch.mul_(lambda_param).add_(batch_rolled) |
| |
|
| | target_rolled.mul_(1.0 - lambda_param) |
| | target.mul_(lambda_param).add_(target_rolled) |
| |
|
| | return batch, target |
| |
|
| | def __repr__(self) -> str: |
| | s = ( |
| | f"{self.__class__.__name__}(" |
| | f"num_classes={self.num_classes}" |
| | f", p={self.p}" |
| | f", alpha={self.alpha}" |
| | f", inplace={self.inplace}" |
| | f")" |
| | ) |
| | return s |
| |
|
| |
|
| | class RandomCutmix(torch.nn.Module): |
| | """Randomly apply Cutmix to the provided batch and targets. |
| | The class implements the data augmentations as described in the paper |
| | `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" |
| | <https://arxiv.org/abs/1905.04899>`_. |
| | |
| | Args: |
| | num_classes (int): number of classes used for one-hot encoding. |
| | p (float): probability of the batch being transformed. Default value is 0.5. |
| | alpha (float): hyperparameter of the Beta distribution used for cutmix. |
| | Default value is 1.0. |
| | inplace (bool): boolean to make this transform inplace. Default set to False. |
| | """ |
| |
|
| | def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None: |
| | super().__init__() |
| | if num_classes < 1: |
| | raise ValueError("Please provide a valid positive value for the num_classes.") |
| | if alpha <= 0: |
| | raise ValueError("Alpha param can't be zero.") |
| |
|
| | self.num_classes = num_classes |
| | self.p = p |
| | self.alpha = alpha |
| | self.inplace = inplace |
| |
|
| | def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: |
| | """ |
| | Args: |
| | batch (Tensor): Float tensor of size (B, C, H, W) |
| | target (Tensor): Integer tensor of size (B, ) |
| | |
| | Returns: |
| | Tensor: Randomly transformed batch. |
| | """ |
| | if batch.ndim != 4: |
| | raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") |
| | if target.ndim != 1: |
| | raise ValueError(f"Target ndim should be 1. Got {target.ndim}") |
| | if not batch.is_floating_point(): |
| | raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") |
| | if target.dtype != torch.int64: |
| | raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") |
| |
|
| | if not self.inplace: |
| | batch = batch.clone() |
| | target = target.clone() |
| |
|
| | if target.ndim == 1: |
| | target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) |
| |
|
| | if torch.rand(1).item() >= self.p: |
| | return batch, target |
| |
|
| | |
| | batch_rolled = batch.roll(1, 0) |
| | target_rolled = target.roll(1, 0) |
| |
|
| | |
| | lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) |
| | _, H, W = F.get_dimensions(batch) |
| |
|
| | r_x = torch.randint(W, (1,)) |
| | r_y = torch.randint(H, (1,)) |
| |
|
| | r = 0.5 * math.sqrt(1.0 - lambda_param) |
| | r_w_half = int(r * W) |
| | r_h_half = int(r * H) |
| |
|
| | x1 = int(torch.clamp(r_x - r_w_half, min=0)) |
| | y1 = int(torch.clamp(r_y - r_h_half, min=0)) |
| | x2 = int(torch.clamp(r_x + r_w_half, max=W)) |
| | y2 = int(torch.clamp(r_y + r_h_half, max=H)) |
| |
|
| | batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] |
| | lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) |
| |
|
| | target_rolled.mul_(1.0 - lambda_param) |
| | target.mul_(lambda_param).add_(target_rolled) |
| |
|
| | return batch, target |
| |
|
| | def __repr__(self) -> str: |
| | s = ( |
| | f"{self.__class__.__name__}(" |
| | f"num_classes={self.num_classes}" |
| | f", p={self.p}" |
| | f", alpha={self.alpha}" |
| | f", inplace={self.inplace}" |
| | f")" |
| | ) |
| | return s |
| |
|