Munazz's picture
Move files to Clothes-Category-Classifier
e94e577
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
import os
import PIL
from torchvision import datasets, transforms
from timm.data import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def build_dataset(is_train, args) :
transform = build_transform(is_train, args)
if args.data_set == 'CIFAR10' :
root = os.path.join(args.data_path, 'train' if is_train else 'val')
nb_cls = 10
elif args.data_set == 'CIFAR100' :
root = os.path.join(args.data_path, 'train' if is_train else 'val')
nb_cls = 100
elif args.data_set == 'Animal10N' :
root = os.path.join(args.data_path, 'train' if is_train else 'test')
nb_cls = 10
elif args.data_set == 'Clothing1M' :
# we use a randomly selected balanced training subset
root = os.path.join(args.data_path, 'noisy_rand_subtrain' if is_train else 'clean_val')
nb_cls = 14
elif args.data_set == 'Food101N' :
root = os.path.join(args.data_path, 'train' if is_train else 'test')
nb_cls = 101
dataset = datasets.ImageFolder(root, transform=transform)
print(dataset)
return dataset, nb_cls
def build_transform(is_train, args) :
if args.data_set == 'CIFAR10' or args.data_set == 'CIFAR100' :
mean = (0.4914, 0.4822, 0.4465)
std = (0.2023, 0.1994, 0.2010)
else :
mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD
resize_im = args.input_size > 32
if is_train :
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=args.input_size,
is_training=True,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation='bicubic',
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
mean=mean,
std=std,
)
if not resize_im :
# replace RandomResizedCropAndInterpolation with
# RandomCrop
transform.transforms[0] = transforms.RandomCrop(
args.input_size, padding=4)
return transform
# eval transform
t = []
if args.input_size <= 224 :
crop_pct = 224 / 256
else :
crop_pct = 1.0
size = int(args.input_size / crop_pct)
t.append(
transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(args.input_size))
t.append(transforms.ToTensor())
t.append(transforms.Normalize(mean, std))
return transforms.Compose(t)