Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| import kornia.augmentation as K | |
| class ImageAugmentations(nn.Module): | |
| def __init__(self, output_size, augmentations_number, p=0.7): | |
| super().__init__() | |
| self.output_size = output_size | |
| self.augmentations_number = augmentations_number | |
| self.augmentations = nn.Sequential( | |
| K.RandomAffine(degrees=15, translate=0.1, p=p, padding_mode="border"), # type: ignore | |
| K.RandomPerspective(0.7, p=p), | |
| ) | |
| self.avg_pool = nn.AdaptiveAvgPool2d((self.output_size, self.output_size)) | |
| def forward(self, input): | |
| """Extents the input batch with augmentations | |
| If the input is consists of images [I1, I2] the extended augmented output | |
| will be [I1_resized, I2_resized, I1_aug1, I2_aug1, I1_aug2, I2_aug2 ...] | |
| Args: | |
| input ([type]): input batch of shape [batch, C, H, W] | |
| Returns: | |
| updated batch: of shape [batch * augmentations_number, C, H, W] | |
| """ | |
| # We want to multiply the number of images in the batch in contrast to regular augmantations | |
| # that do not change the number of samples in the batch) | |
| resized_images = self.avg_pool(input) | |
| resized_images = torch.tile(resized_images, dims=(self.augmentations_number, 1, 1, 1)) | |
| batch_size = input.shape[0] | |
| # We want at least one non augmented image | |
| non_augmented_batch = resized_images[:batch_size] | |
| augmented_batch = self.augmentations(resized_images[batch_size:]) | |
| updated_batch = torch.cat([non_augmented_batch, augmented_batch], dim=0) | |
| return updated_batch | |