# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # -------------------------------------------------------- # References: # DeiT: https://github.com/facebookresearch/deit # -------------------------------------------------------- import os import PIL from torchvision import datasets, transforms import torch from timm.data import create_transform from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD def build_dataset(is_train, args): transform = build_transform(is_train, args) root = os.path.join(args.data_path, 'train' if is_train else 'val') dataset = datasets.ImageFolder(root, transform=transform) return dataset def build_dataset_full(is_train,args): transform = build_transform(is_train, args) train_set = datasets.ImageFolder(root=os.path.join(args.data_path, 'train'), transform=transform) valid_set = datasets.ImageFolder(root=os.path.join(args.data_path, 'val'), transform=transform) full_set = torch.utils.data.ConcatDataset([train_set, valid_set]) print(full_set) return full_set def build_transform(is_train, args): mean = IMAGENET_DEFAULT_MEAN std = IMAGENET_DEFAULT_STD # train transform if is_train: # this should always dispatch to transforms_imagenet_train transform = create_transform( input_size=args.input_size, is_training=True, color_jitter=args.color_jitter, auto_augment=args.aa, interpolation='bicubic', re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, mean=mean, std=std, ) # if args.three_aug: # secondary_tfl = transforms.RandomChoice([gray_scale(p=1.0), # Solarization(p=1.0), # GaussianBlur(p=1.0)]) # transform = transforms.Compose([transform,secondary_tfl]) return transform # eval transform t = [] if args.input_size <= 224: crop_pct = 224 / 232 else: crop_pct = 1.0 size = int(args.input_size / crop_pct) t.append( transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images ) t.append(transforms.CenterCrop(args.input_size)) t.append(transforms.ToTensor()) t.append(transforms.Normalize(mean, std)) return transforms.Compose(t) import torch from torchvision import transforms import numpy as np from torchvision import datasets, transforms import random from PIL import ImageFilter, ImageOps import torchvision.transforms.functional as TF class GaussianBlur(object): """ Apply Gaussian Blur to the PIL image. """ def __init__(self, p=0.1, radius_min=0.1, radius_max=2.): self.prob = p self.radius_min = radius_min self.radius_max = radius_max def __call__(self, img): do_it = random.random() <= self.prob if not do_it: return img img = img.filter( ImageFilter.GaussianBlur( radius=random.uniform(self.radius_min, self.radius_max) ) ) return img class Solarization(object): """ Apply Solarization to the PIL image. """ def __init__(self, p=0.2): self.p = p def __call__(self, img): if random.random() < self.p: return ImageOps.solarize(img) else: return img class gray_scale(object): """ Apply Solarization to the PIL image. """ def __init__(self, p=0.2): self.p = p self.transf = transforms.Grayscale(3) def __call__(self, img): if random.random() < self.p: return self.transf(img) else: return img