| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import os |
| import random |
| import numpy as np |
| import torch |
| import torchvision.transforms as T |
| from torchvision import datasets |
| from megatron import get_args |
| from megatron.data.image_folder import ImageFolder |
| from megatron.data.autoaugment import ImageNetPolicy |
| from megatron.data.data_samplers import RandomSeedDataset |
| from PIL import Image, ImageFilter, ImageOps |
|
|
|
|
| class GaussianBlur(object): |
| """ |
| Apply Gaussian Blur to the PIL image. |
| """ |
| def __init__(self, p=0.5, radius_min=0.1, radius_max=2.): |
| self.prob = p |
| self.radius_min = radius_min |
| self.radius_max = radius_max |
|
|
| def __call__(self, img): |
| do_it = random.random() <= self.prob |
| if not do_it: |
| return img |
|
|
| return img.filter( |
| ImageFilter.GaussianBlur( |
| radius=random.uniform(self.radius_min, self.radius_max) |
| ) |
| ) |
|
|
|
|
| class Solarization(object): |
| """ |
| Apply Solarization to the PIL image. |
| """ |
| def __init__(self, p): |
| self.p = p |
|
|
| def __call__(self, img): |
| if random.random() < self.p: |
| return ImageOps.solarize(img) |
| else: |
| return img |
|
|
|
|
| class ClassificationTransform(): |
| def __init__(self, image_size, train=True): |
| args = get_args() |
| assert args.fp16 or args.bf16 |
| self.data_type = torch.half if args.fp16 else torch.bfloat16 |
| if train: |
| self.transform = T.Compose([ |
| T.RandomResizedCrop(image_size), |
| T.RandomHorizontalFlip(), |
| T.ColorJitter(0.4, 0.4, 0.4, 0.1), |
| ImageNetPolicy(), |
| T.ToTensor(), |
| T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
| T.ConvertImageDtype(self.data_type) |
| ]) |
| else: |
| self.transform = T.Compose([ |
| T.Resize(image_size), |
| T.CenterCrop(image_size), |
| T.ToTensor(), |
| T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
| T.ConvertImageDtype(self.data_type) |
| ]) |
|
|
| def __call__(self, input): |
| output = self.transform(input) |
| return output |
|
|
|
|
| class InpaintingTransform(): |
| def __init__(self, image_size, train=True): |
|
|
| args = get_args() |
| self.mask_factor = args.mask_factor |
| self.mask_type = args.mask_type |
| self.image_size = image_size |
| self.patch_size = args.patch_dim |
| self.mask_size = int(self.mask_factor*(image_size[0]/self.patch_size)*(image_size[1]/self.patch_size)) |
| self.train = train |
| assert args.fp16 or args.bf16 |
| self.data_type = torch.half if args.fp16 else torch.bfloat16 |
| |
| if self.train: |
| self.transform = T.Compose([ |
| T.RandomResizedCrop(self.image_size), |
| T.RandomHorizontalFlip(), |
| T.ColorJitter(0.4, 0.4, 0.4, 0.1), |
| ImageNetPolicy(), |
| T.ToTensor(), |
| T.ConvertImageDtype(self.data_type) |
| ]) |
| else: |
| self.transform = T.Compose([ |
| T.Resize(self.image_size, interpolation=2), |
| T.CenterCrop(self.image_size), |
| T.ToTensor(), |
| T.ConvertImageDtype(self.data_type) |
| ]) |
|
|
| def gen_mask(self, image_size, mask_size, mask_type, patch_size): |
| |
| action_list = [[0, 1], [0, -1], [1, 0], [-1, 0]] |
| assert image_size[0] == image_size[1] |
| img_size_patch = image_size[0] // patch_size |
|
|
| |
| mask = torch.zeros((image_size[0], image_size[1]), dtype=torch.float) |
|
|
| if mask_type == 'random': |
| x = torch.randint(0, img_size_patch, ()) |
| y = torch.randint(0, img_size_patch, ()) |
| for i in range(mask_size): |
| r = torch.randint(0, len(action_list), ()) |
| x = torch.clamp(x + action_list[r][0], min=0, max=img_size_patch - 1) |
| y = torch.clamp(y + action_list[r][1], min=0, max=img_size_patch - 1) |
| x_offset = x * patch_size |
| y_offset = y * patch_size |
| mask[x_offset:x_offset+patch_size, y_offset:y_offset+patch_size] = 1 |
| else: |
| assert mask_type == 'row' |
| count = 0 |
| for x in reversed(range(img_size_patch)): |
| for y in reversed(range(img_size_patch)): |
| if (count < mask_size): |
| count += 1 |
| x_offset = x * patch_size |
| y_offset = y * patch_size |
| mask[x_offset:x_offset+patch_size, y_offset:y_offset+patch_size] = 1 |
| return mask |
|
|
| def __call__(self, input): |
| trans_input = self.transform(input) |
| mask = self.gen_mask(self.image_size, self.mask_size, |
| self.mask_type, self.patch_size) |
| mask = mask.unsqueeze(dim=0) |
| return trans_input, mask |
|
|
|
|
| class DinoTransform(object): |
| def __init__(self, image_size, train=True): |
| args = get_args() |
| self.data_type = torch.half if args.fp16 else torch.bfloat16 |
|
|
| flip_and_color_jitter = T.Compose([ |
| T.RandomHorizontalFlip(p=0.5), |
| T.RandomApply( |
| [T.ColorJitter(brightness=0.4, contrast=0.4, |
| saturation=0.2, hue=0.1)], |
| p=0.8 |
| ), |
| T.RandomGrayscale(p=0.2), |
| ]) |
|
|
| if args.fp16 or args.bf16: |
| normalize = T.Compose([ |
| T.ToTensor(), |
| T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
| T.ConvertImageDtype(self.data_type) |
| ]) |
| else: |
| normalize = T.Compose([ |
| T.ToTensor(), |
| T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
| ]) |
|
|
| |
| scale_const = 0.4 |
| self.global_transform1 = T.Compose([ |
| T.RandomResizedCrop(image_size, |
| scale=(scale_const, 1), |
| interpolation=Image.BICUBIC), |
| flip_and_color_jitter, |
| GaussianBlur(1.0), |
| normalize |
| ]) |
| |
| self.global_transform2 = T.Compose([ |
| T.RandomResizedCrop(image_size, |
| scale=(scale_const, 1), |
| interpolation=Image.BICUBIC), |
| flip_and_color_jitter, |
| GaussianBlur(0.1), |
| Solarization(0.2), |
| normalize |
| ]) |
| |
| self.local_crops_number = args.dino_local_crops_number |
| self.local_transform = T.Compose([ |
| T.RandomResizedCrop(args.dino_local_img_size, |
| scale=(0.05, scale_const), |
| interpolation=Image.BICUBIC), |
| flip_and_color_jitter, |
| GaussianBlur(p=0.5), |
| normalize |
| ]) |
|
|
| def __call__(self, image): |
| crops = [] |
| crops.append(self.global_transform1(image)) |
| crops.append(self.global_transform2(image)) |
| for _ in range(self.local_crops_number): |
| crops.append(self.local_transform(image)) |
| return crops |
|
|
|
|
| def build_train_valid_datasets(data_path, image_size=224): |
| args = get_args() |
|
|
| if args.vision_pretraining_type == 'classify': |
| train_transform = ClassificationTransform(image_size) |
| val_transform = ClassificationTransform(image_size, train=False) |
| elif args.vision_pretraining_type == 'inpaint': |
| train_transform = InpaintingTransform(image_size, train=False) |
| val_transform = InpaintingTransform(image_size, train=False) |
| elif args.vision_pretraining_type == 'dino': |
| train_transform = DinoTransform(image_size, train=True) |
| val_transform = ClassificationTransform(image_size, train=False) |
| else: |
| raise Exception('{} vit pretraining type is not supported.'.format( |
| args.vit_pretraining_type)) |
|
|
| |
| train_data_path = data_path[0] if len(data_path) <= 2 else data_path[2] |
| train_data = ImageFolder( |
| root=train_data_path, |
| transform=train_transform, |
| classes_fraction=args.classes_fraction, |
| data_per_class_fraction=args.data_per_class_fraction |
| ) |
| train_data = RandomSeedDataset(train_data) |
|
|
| |
| val_data_path = data_path[1] |
| val_data = ImageFolder( |
| root=val_data_path, |
| transform=val_transform |
| ) |
| val_data = RandomSeedDataset(val_data) |
|
|
| return train_data, val_data |
|
|