| import numpy as np |
| import cv2 |
| from torchvision.transforms import functional as F |
| from PIL import Image |
| from torchvision import transforms |
|
|
|
|
| class BasicAugmenter(): |
|
|
| def __init__(self, crop_augmentation_prob, photometric_augmentation_prob, low_res_augmentation_prob): |
| self.crop_augmentation_prob = crop_augmentation_prob |
| self.photometric_augmentation_prob = photometric_augmentation_prob |
| self.low_res_augmentation_prob = low_res_augmentation_prob |
| self.random_resized_crop = transforms.RandomResizedCrop(size=(112, 112), |
| scale=(0.2, 1.0), |
| ratio=(0.75, 1.3333333333333333)) |
| self.photometric = transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0) |
|
|
| def augment(self, sample): |
|
|
| |
| if np.random.random() < self.crop_augmentation_prob: |
| |
| sample, crop_ratio = self.crop_augment(sample) |
|
|
| |
| if np.random.random() < self.low_res_augmentation_prob: |
| |
| img_np, resize_ratio = self.low_res_augmentation(np.array(sample)) |
| sample = Image.fromarray(img_np.astype(np.uint8)) |
|
|
| |
| if np.random.random() < self.photometric_augmentation_prob: |
| sample = self.photometric_augmentation(sample) |
|
|
| |
| if np.random.random() < 0.5: |
| sample = F.hflip(sample) |
|
|
| return sample |
|
|
| def crop_augment(self, sample): |
| new = np.zeros_like(np.array(sample)) |
| if hasattr(F, '_get_image_size'): |
| orig_W, orig_H = F._get_image_size(sample) |
| else: |
| |
| orig_W, orig_H = F.get_image_size(sample) |
| i, j, h, w = self.random_resized_crop.get_params(sample, |
| self.random_resized_crop.scale, |
| self.random_resized_crop.ratio) |
| cropped = F.crop(sample, i, j, h, w) |
| new[i:i+h,j:j+w, :] = np.array(cropped) |
| sample = Image.fromarray(new.astype(np.uint8)) |
| crop_ratio = min(h, w) / max(orig_H, orig_W) |
| return sample, crop_ratio |
|
|
| def low_res_augmentation(self, img): |
| |
| img_shape = img.shape |
| side_ratio = np.random.uniform(0.2, 1.0) |
| small_side = int(side_ratio * img_shape[0]) |
| interpolation = np.random.choice( |
| [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]) |
| small_img = cv2.resize(img, (small_side, small_side), interpolation=interpolation) |
| interpolation = np.random.choice( |
| [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]) |
| aug_img = cv2.resize(small_img, (img_shape[1], img_shape[0]), interpolation=interpolation) |
|
|
| return aug_img, side_ratio |
|
|
| def photometric_augmentation(self, sample): |
| fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \ |
| self.photometric.get_params(self.photometric.brightness, self.photometric.contrast, |
| self.photometric.saturation, self.photometric.hue) |
| for fn_id in fn_idx: |
| if fn_id == 0 and brightness_factor is not None: |
| sample = F.adjust_brightness(sample, brightness_factor) |
| elif fn_id == 1 and contrast_factor is not None: |
| sample = F.adjust_contrast(sample, contrast_factor) |
| elif fn_id == 2 and saturation_factor is not None: |
| sample = F.adjust_saturation(sample, saturation_factor) |
|
|
| return sample |
|
|
|
|
|
|
| def main(): |
| from PIL import Image, ImageDraw |
| import torch |
|
|
| image = Image.open('/data/data/faces/ms1mv2_subset_images/84946/5770863.jpg') |
| |
| image_draw = ImageDraw.Draw(image) |
| image_draw.rectangle((10, 10, 110, 110), outline='red') |
| image_draw.rectangle((0, 0, 120, 120), outline='blue') |
|
|
| augmenter = BasicAugmenter(0.2, 0.2, 0.2) |
| |
| grids = [] |
| for i in range(10): |
| grid = [] |
| for j in range(10): |
| align_input_sample = augmenter.augment(image) |
| grid.append(align_input_sample) |
| grids.append(grid) |
| |
| grid_image = Image.new('RGB', (1120, 1120)) |
| for i in range(10): |
| for j in range(10): |
| grid_image.paste(grids[i][j], (112 * j, 112 * i)) |
| grid_image.save(f'/mckim/temp/BasicAugmenter.jpg') |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|