Spaces:
Runtime error
Runtime error
| """ | |
| Adapted from https://github.com/nv-nguyen/template-pose/blob/main/src/utils/augmentation.py | |
| """ | |
| from torchvision import transforms | |
| from PIL import ImageEnhance, ImageFilter, Image | |
| import numpy as np | |
| import random | |
| import logging | |
| from torchvision.transforms import RandomResizedCrop, ToTensor | |
| class PillowRGBAugmentation: | |
| def __init__(self, pillow_fn, p, factor_interval): | |
| self._pillow_fn = pillow_fn | |
| self.p = p | |
| self.factor_interval = factor_interval | |
| def __call__(self, PIL_image): | |
| if random.random() <= self.p: | |
| factor = random.uniform(*self.factor_interval) | |
| if PIL_image.mode != "RGB": | |
| logging.warning( | |
| f"Error when apply data aug, image mode: {PIL_image.mode}" | |
| ) | |
| imgs = imgs.convert("RGB") | |
| logging.warning(f"Success to change to {PIL_image.mode}") | |
| PIL_image = (self._pillow_fn(PIL_image).enhance(factor=factor)).convert( | |
| "RGB" | |
| ) | |
| return PIL_image | |
| class PillowSharpness(PillowRGBAugmentation): | |
| def __init__( | |
| self, | |
| p=0.3, | |
| factor_interval=(0, 40.0), | |
| ): | |
| super().__init__( | |
| pillow_fn=ImageEnhance.Sharpness, | |
| p=p, | |
| factor_interval=factor_interval, | |
| ) | |
| class PillowContrast(PillowRGBAugmentation): | |
| def __init__( | |
| self, | |
| p=0.3, | |
| factor_interval=(0.5, 1.6), | |
| ): | |
| super().__init__( | |
| pillow_fn=ImageEnhance.Contrast, | |
| p=p, | |
| factor_interval=factor_interval, | |
| ) | |
| class PillowBrightness(PillowRGBAugmentation): | |
| def __init__( | |
| self, | |
| p=0.5, | |
| factor_interval=(0.5, 2.0), | |
| ): | |
| super().__init__( | |
| pillow_fn=ImageEnhance.Brightness, | |
| p=p, | |
| factor_interval=factor_interval, | |
| ) | |
| class PillowColor(PillowRGBAugmentation): | |
| def __init__( | |
| self, | |
| p=1, | |
| factor_interval=(0.0, 20.0), | |
| ): | |
| super().__init__( | |
| pillow_fn=ImageEnhance.Color, | |
| p=p, | |
| factor_interval=factor_interval, | |
| ) | |
| class PillowBlur: | |
| def __init__(self, p=0.4, factor_interval=(1, 3)): | |
| self.p = p | |
| self.k = random.randint(*factor_interval) | |
| def __call__(self, PIL_image): | |
| if random.random() <= self.p: | |
| PIL_image = PIL_image.filter(ImageFilter.GaussianBlur(self.k)) | |
| return PIL_image | |
| class NumpyGaussianNoise: | |
| def __init__(self, p, factor_interval=(0.01, 0.3)): | |
| self.noise_ratio = random.uniform(*factor_interval) | |
| self.p = p | |
| def __call__(self, img): | |
| if random.random() <= self.p: | |
| img = np.copy(img) | |
| noisesigma = random.uniform(0, self.noise_ratio) | |
| gauss = np.random.normal(0, noisesigma, img.shape) * 255 | |
| img = img + gauss | |
| img[img > 255] = 255 | |
| img[img < 0] = 0 | |
| return Image.fromarray(np.uint8(img)) | |
| class StandardAugmentation: | |
| def __init__( | |
| self, names, brightness, contrast, sharpness, color, blur, gaussian_noise | |
| ): | |
| self.brightness = brightness | |
| self.contrast = contrast | |
| self.sharpness = sharpness | |
| self.color = color | |
| self.blur = blur | |
| self.gaussian_noise = gaussian_noise | |
| # define a dictionary of augmentation functions to be applied | |
| self.names = names.split(",") | |
| self.augmentations = { | |
| "brightness": self.brightness, | |
| "contrast": self.contrast, | |
| "sharpness": self.sharpness, | |
| "color": self.color, | |
| "blur": self.blur, | |
| "gaussian_noise": self.gaussian_noise, | |
| } | |
| def __call__(self, img): | |
| for name in self.names: | |
| img = self.augmentations[name](img) | |
| return img | |
| class GeometricAugmentation: | |
| def __init__( | |
| self, | |
| names, | |
| random_resized_crop, | |
| random_horizontal_flip, | |
| random_vertical_flip, | |
| random_rotation, | |
| ): | |
| self.random_resized_crop = random_resized_crop | |
| self.random_horizontal_flip = random_horizontal_flip | |
| self.random_vertical_flip = random_vertical_flip | |
| self.random_rotation = random_rotation | |
| self.names = names.split(",") | |
| self.augmentations = { | |
| "random_resized_crop": self.random_resized_crop, | |
| "random_horizontal_flip": self.random_horizontal_flip, | |
| "random_vertical_flip": self.random_vertical_flip, | |
| "random_rotation": self.random_rotation, | |
| } | |
| def __call__(self, img): | |
| for name in self.names: | |
| img = self.augmentations[name](img) | |
| return img | |
| class ImageAugmentation: | |
| def __init__( | |
| self, names, clip_transform, standard_augmentation, geometric_augmentation | |
| ): | |
| self.clip_transform = clip_transform | |
| self.standard_augmentation = standard_augmentation | |
| self.geometric_augmentation = geometric_augmentation | |
| self.names = names.split(",") | |
| self.transforms = { | |
| "clip_transform": self.clip_transform, | |
| "standard_augmentation": self.standard_augmentation, | |
| "geometric_augmentation": self.geometric_augmentation, | |
| } | |
| print(f"Image augmentation: {self.names}") | |
| def __call__(self, img): | |
| for name in self.names: | |
| img = self.transforms[name](img) | |
| return img | |
| if __name__ == "__main__": | |
| # sanity check | |
| import glob | |
| import torchvision.transforms as transforms | |
| from torchvision.utils import save_image | |
| from omegaconf import DictConfig, OmegaConf | |
| from hydra.utils import instantiate | |
| import torch | |
| from PIL import Image | |
| augmentation_config = OmegaConf.load( | |
| "./configs/dataset/train_transform/augmentation.yaml" | |
| ) | |
| augmentation_config.names = "standard_augmentation,geometric_augmentation" | |
| augmentation_transform = instantiate(augmentation_config) | |
| img_paths = glob.glob("./datasets/osv5m/test/images/*.jpg") | |
| num_try = 20 | |
| num_try_per_image = 8 | |
| num_imgs = 8 | |
| for idx in range(num_try): | |
| imgs = [] | |
| for idx_img in range(num_imgs): | |
| img = Image.open(img_paths[idx_img]) | |
| for idx_try in range(num_try_per_image): | |
| if idx_try == 0: | |
| imgs.append(ToTensor()(img.resize((224, 224)))) | |
| img_aug = augmentation_transform(img.copy()) | |
| img_aug = ToTensor()(img_aug) | |
| imgs.append(img_aug) | |
| imgs = torch.stack(imgs) | |
| save_image(imgs, f"augmentation_{idx:03d}.png", nrow=9) | |