ericzhang0328's picture
Upload folder using huggingface_hub
0c1e054 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
import os
import PIL
from torchvision import datasets, transforms
import torch
from timm.data import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def build_dataset(is_train, args):
transform = build_transform(is_train, args)
root = os.path.join(args.data_path, 'train' if is_train else 'val')
dataset = datasets.ImageFolder(root, transform=transform)
return dataset
def build_dataset_full(is_train,args):
transform = build_transform(is_train, args)
train_set = datasets.ImageFolder(root=os.path.join(args.data_path, 'train'),
transform=transform)
valid_set = datasets.ImageFolder(root=os.path.join(args.data_path, 'val'),
transform=transform)
full_set = torch.utils.data.ConcatDataset([train_set, valid_set])
print(full_set)
return full_set
def build_transform(is_train, args):
mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD
# train transform
if is_train:
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=args.input_size,
is_training=True,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation='bicubic',
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
mean=mean,
std=std,
)
# if args.three_aug:
# secondary_tfl = transforms.RandomChoice([gray_scale(p=1.0),
# Solarization(p=1.0),
# GaussianBlur(p=1.0)])
# transform = transforms.Compose([transform,secondary_tfl])
return transform
# eval transform
t = []
if args.input_size <= 224:
crop_pct = 224 / 232
else:
crop_pct = 1.0
size = int(args.input_size / crop_pct)
t.append(
transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(args.input_size))
t.append(transforms.ToTensor())
t.append(transforms.Normalize(mean, std))
return transforms.Compose(t)
import torch
from torchvision import transforms
import numpy as np
from torchvision import datasets, transforms
import random
from PIL import ImageFilter, ImageOps
import torchvision.transforms.functional as TF
class GaussianBlur(object):
"""
Apply Gaussian Blur to the PIL image.
"""
def __init__(self, p=0.1, radius_min=0.1, radius_max=2.):
self.prob = p
self.radius_min = radius_min
self.radius_max = radius_max
def __call__(self, img):
do_it = random.random() <= self.prob
if not do_it:
return img
img = img.filter(
ImageFilter.GaussianBlur(
radius=random.uniform(self.radius_min, self.radius_max)
)
)
return img
class Solarization(object):
"""
Apply Solarization to the PIL image.
"""
def __init__(self, p=0.2):
self.p = p
def __call__(self, img):
if random.random() < self.p:
return ImageOps.solarize(img)
else:
return img
class gray_scale(object):
"""
Apply Solarization to the PIL image.
"""
def __init__(self, p=0.2):
self.p = p
self.transf = transforms.Grayscale(3)
def __call__(self, img):
if random.random() < self.p:
return self.transf(img)
else:
return img