Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| import kornia.augmentation as K | |
| class ImageAugmentations(nn.Module): | |
| def __init__(self, output_size, augmentations_number, p=0.7, resize=True): | |
| super().__init__() | |
| self.output_size = output_size | |
| self.augmentations_number = augmentations_number | |
| self.augmentations = [ | |
| K.RandomAffine(degrees=15, translate=0.1, p=p, padding_mode="border"), | |
| # K.RandomPerspective(0.7, p=p), | |
| ] | |
| self.resize = ( | |
| nn.AdaptiveAvgPool2d((self.output_size, self.output_size)) | |
| if resize | |
| else (lambda x: x) | |
| ) | |
| def forward(self, image, mask, with_orig=True): | |
| """Extends the image and mask with identical augmentations | |
| If the input consists of image I, and mask M, the extended augmented output will be: | |
| [I_aug1, I_aug2, I_aug3, ...], [M_aug1, M_aug2, M_aug3, ...] | |
| If with_orig=True, the extended augmented output will be: | |
| [I, I_aug1, I_aug2, ...], [M, M_aug1, M_aug2, ...] | |
| Args: | |
| image: tensor of shape [1, C, H, W] | |
| mask: tensor of shape [1, 1, H, W] | |
| with_orig: if True, first returned image and mask will be un-augmented inputs | |
| Returns: | |
| tuple of (extended images of shape [augmentations_number, C, H, W], | |
| extended masks of shape [augmentations_number, 1, H, W]) | |
| """ | |
| # Duplicate the inputs, in contrast to regular augmentations that do not change the number of samples | |
| resized_images = self.resize(image) | |
| resized_images = resized_images.repeat(self.augmentations_number, 1, 1, 1) | |
| resized_masks = self.resize(mask) | |
| resized_masks = resized_masks.repeat(self.augmentations_number, 1, 1, 1) | |
| batch_size = image.shape[0] | |
| if with_orig: | |
| # At least one non-augmented image | |
| non_aug_inputs = resized_images[:batch_size] | |
| aug_inputs = resized_images[batch_size:] | |
| non_aug_masks = resized_masks[:batch_size] | |
| aug_masks = resized_masks[batch_size:] | |
| for trans in self.augmentations: | |
| trans_params = trans.forward_parameters(aug_inputs.shape) | |
| aug_inputs = trans(aug_inputs, trans_params) | |
| aug_masks = trans(aug_masks, trans_params) | |
| updated_input_batch = torch.cat([non_aug_inputs, aug_inputs], dim=0) | |
| updated_mask_batch = torch.cat([non_aug_masks, aug_masks], dim=0) | |
| else: | |
| aug_inputs = resized_images | |
| aug_masks = resized_masks | |
| for trans in self.augmentations: | |
| trans_params = trans.forward_parameters(aug_inputs.shape) | |
| aug_inputs = trans(aug_inputs, trans_params) | |
| aug_masks = trans(aug_masks, trans_params) | |
| updated_input_batch = aug_inputs | |
| updated_mask_batch = aug_masks | |
| return updated_input_batch, updated_mask_batch | |