Spaces:
Running
Running
| from typing import Tuple, List | |
| import random | |
| import torch | |
| import torch.nn as nn | |
| class Intensity(nn.Module): | |
| """ | |
| Overview: | |
| Intensity transformation for data augmentation. Scale the image intensity by a random factor. | |
| """ | |
| def __init__(self, scale: float) -> None: | |
| """ | |
| Arguments: | |
| - scale (:obj:`float`): The scale factor for intensity transformation. | |
| """ | |
| super().__init__() | |
| self.scale = scale | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Shapes: | |
| - x (:obj:`torch.Tensor`): The input image tensor with shape (B, C, H, W). | |
| - output (:obj:`torch.Tensor`): The output image tensor with shape (B, C, H, W). | |
| """ | |
| r = torch.randn((x.size(0), 1, 1, 1), device=x.device) | |
| noise = 1.0 + (self.scale * r.clamp(-2.0, 2.0)) | |
| return x * noise | |
| class RandomCrop(nn.Module): | |
| """ | |
| Overview: | |
| Random crop the image to the given size. | |
| """ | |
| def __init__(self, image_shape: Tuple[int]) -> None: | |
| """ | |
| Arguments: | |
| - image_shape (:obj:`Tuple[int]`): The target shape of the image to be cropped. | |
| """ | |
| super().__init__() | |
| self.image_shape = image_shape | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Shapes: | |
| - x (:obj:`torch.Tensor`): The input image tensor with shape (B, C, H, W), where H and W are \ | |
| the original image shape. | |
| - output (:obj:`torch.Tensor`): The output image tensor with shape (B, C, H_, W_), where H_ and W_ are \ | |
| the target image shape indicated by `image_shape`. | |
| """ | |
| H, W = x.shape[2:] | |
| H_, W_ = self.image_shape | |
| dh, dw = H - H_, W - W_ | |
| h, w = random.randint(0, dh), random.randint(0, dw) | |
| return x[..., h:h + H_, w:w + W_] | |
| class ImageTransforms(object): | |
| """ | |
| Overview: | |
| Image transformation for data augmentation. Including image normalization (divide 255), random crop and | |
| intensity transformation. | |
| """ | |
| def __init__(self, augmentation: List[str], shift_delta: int = 4, image_shape: Tuple[int] = (96, 96)) -> None: | |
| """ | |
| Arguments: | |
| - augmentation (:obj:`List[str]`): The list of augmentation types. Now support "shift" and "intensity". | |
| - shift_delta (:obj:`int`): The delta value for random shift padding before crop. Use ReplicationPad2d \ | |
| to pad the image without the loss of information. | |
| - image_shape (:obj:`Tuple[int]`): The target shape of the image to be cropped. | |
| """ | |
| self.augmentation = augmentation | |
| self.image_transforms = [] | |
| for aug in self.augmentation: | |
| if aug == "shift": | |
| # TODO validate the effectiveness of ReflectionPad2d | |
| transformation = nn.Sequential(nn.ReplicationPad2d(shift_delta), RandomCrop(image_shape)) | |
| elif aug == "intensity": | |
| transformation = Intensity(scale=0.05) | |
| else: | |
| raise NotImplementedError("not support augmentation type: {}".format(aug)) | |
| self.image_transforms.append(transformation) | |
| def transform(self, images: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Shapes: | |
| - x (:obj:`torch.Tensor`): The input image tensor with shape (B, C, H, W), where H and W are \ | |
| the original image shape. | |
| - output (:obj:`torch.Tensor`): The output image tensor with shape (B, C, H_, W_), where H_ and W_ are \ | |
| the target image shape indicated by `image_shape`. | |
| .. note:: | |
| Use torch.no_grad() to save cuda memory. Transformations are not trainable. | |
| """ | |
| images = images.float() / 255. if images.dtype == torch.uint8 else images | |
| processed_images = images.reshape(-1, *images.shape[-3:]) | |
| for transform in self.image_transforms: | |
| processed_images = transform(processed_images) | |
| processed_images = processed_images.view(*images.shape[:-3], *processed_images.shape[1:]) | |
| return processed_images | |