| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import PIL |
|
|
| from torchvision import datasets, transforms |
| import torch |
| from timm.data import create_transform |
| from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
|
|
| def build_dataset(is_train, args): |
| transform = build_transform(is_train, args) |
| root = os.path.join(args.data_path, 'train' if is_train else 'val') |
| dataset = datasets.ImageFolder(root, transform=transform) |
| return dataset |
|
|
| def build_dataset_full(is_train,args): |
| transform = build_transform(is_train, args) |
| |
| train_set = datasets.ImageFolder(root=os.path.join(args.data_path, 'train'), |
| transform=transform) |
| valid_set = datasets.ImageFolder(root=os.path.join(args.data_path, 'val'), |
| transform=transform) |
| |
| full_set = torch.utils.data.ConcatDataset([train_set, valid_set]) |
| print(full_set) |
| return full_set |
|
|
| def build_transform(is_train, args): |
| mean = IMAGENET_DEFAULT_MEAN |
| std = IMAGENET_DEFAULT_STD |
| |
| if is_train: |
| |
| transform = create_transform( |
| input_size=args.input_size, |
| is_training=True, |
| color_jitter=args.color_jitter, |
| auto_augment=args.aa, |
| interpolation='bicubic', |
| re_prob=args.reprob, |
| re_mode=args.remode, |
| re_count=args.recount, |
| mean=mean, |
| std=std, |
| ) |
| |
|
|
| |
| |
| |
| |
| return transform |
|
|
| |
| t = [] |
| if args.input_size <= 224: |
| crop_pct = 224 / 232 |
| else: |
| crop_pct = 1.0 |
| size = int(args.input_size / crop_pct) |
| t.append( |
| transforms.Resize(size, interpolation=PIL.Image.BICUBIC), |
| ) |
| t.append(transforms.CenterCrop(args.input_size)) |
|
|
| t.append(transforms.ToTensor()) |
| t.append(transforms.Normalize(mean, std)) |
| return transforms.Compose(t) |
| import torch |
| from torchvision import transforms |
|
|
| import numpy as np |
| from torchvision import datasets, transforms |
| import random |
|
|
|
|
|
|
| from PIL import ImageFilter, ImageOps |
| import torchvision.transforms.functional as TF |
|
|
|
|
| class GaussianBlur(object): |
| """ |
| Apply Gaussian Blur to the PIL image. |
| """ |
| def __init__(self, p=0.1, radius_min=0.1, radius_max=2.): |
| self.prob = p |
| self.radius_min = radius_min |
| self.radius_max = radius_max |
|
|
| def __call__(self, img): |
| do_it = random.random() <= self.prob |
| if not do_it: |
| return img |
|
|
| img = img.filter( |
| ImageFilter.GaussianBlur( |
| radius=random.uniform(self.radius_min, self.radius_max) |
| ) |
| ) |
| return img |
|
|
| class Solarization(object): |
| """ |
| Apply Solarization to the PIL image. |
| """ |
| def __init__(self, p=0.2): |
| self.p = p |
|
|
| def __call__(self, img): |
| if random.random() < self.p: |
| return ImageOps.solarize(img) |
| else: |
| return img |
|
|
| class gray_scale(object): |
| """ |
| Apply Solarization to the PIL image. |
| """ |
| def __init__(self, p=0.2): |
| self.p = p |
| self.transf = transforms.Grayscale(3) |
| |
| def __call__(self, img): |
| if random.random() < self.p: |
| return self.transf(img) |
| else: |
| return img |
|
|