yuxi-liu-wired's picture
init
0decf42
# --------------------------------------------------------
# RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf)
# Github source: https://github.com/DingXiaoH/RepVGG
# Licensed under The MIT License [see LICENSE for details]
# The training script is based on the code of Swin Transformer (https://github.com/microsoft/Swin-Transformer)
# --------------------------------------------------------
import torch
import numpy as np
import torch.distributed as dist
from torchvision import datasets, transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import Mixup
from timm.data import create_transform
try:
from timm.data.transforms import str_to_pil_interp as _pil_interp
except:
from timm.data.transforms import _pil_interp
from .cached_image_folder import CachedImageFolder
from .samplers import SubsetRandomSampler
def build_loader(config):
config.defrost()
dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
config.freeze()
print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset")
dataset_val, _ = build_dataset(is_train=False, config=config)
print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset")
num_tasks = dist.get_world_size()
global_rank = dist.get_rank()
if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':
indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
sampler_train = SubsetRandomSampler(indices)
else:
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
if dataset_val is None:
sampler_val = None
else:
indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size()) #TODO
sampler_val = SubsetRandomSampler(indices)
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=config.DATA.BATCH_SIZE,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
drop_last=True,
)
if dataset_val is None:
data_loader_val = None
else:
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=config.DATA.TEST_BATCH_SIZE,
shuffle=False,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
drop_last=False
)
# setup mixup / cutmix
mixup_fn = None
mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
if mixup_active:
mixup_fn = Mixup(
mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
def build_dataset(is_train, config):
if config.DATA.DATASET == 'imagenet':
transform = build_transform(is_train, config)
prefix = 'train' if is_train else 'val'
if config.DATA.ZIP_MODE:
ann_file = prefix + "_map.txt"
prefix = prefix + ".zip@/"
dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,
cache_mode=config.DATA.CACHE_MODE if is_train else 'part')
else:
import torchvision
print('use raw ImageNet data')
dataset = torchvision.datasets.ImageNet(root=config.DATA.DATA_PATH, split='train' if is_train else 'val', transform=transform)
nb_classes = 1000
elif config.DATA.DATASET == 'cf100':
mean = [0.5070751592371323, 0.48654887331495095, 0.4409178433670343]
std = [0.2673342858792401, 0.2564384629170883, 0.27615047132568404]
if is_train:
transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
dataset = datasets.CIFAR100(root=config.DATA.DATA_PATH, train=True, download=True, transform=transform)
else:
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize(mean, std)])
dataset = datasets.CIFAR100(root=config.DATA.DATA_PATH, train=False, download=True, transform=transform)
nb_classes = 100
else:
raise NotImplementedError("We only support ImageNet and CIFAR-100 now.")
return dataset, nb_classes
def build_transform(is_train, config):
resize_im = config.DATA.IMG_SIZE > 32
if is_train:
# this should always dispatch to transforms_imagenet_train
if config.AUG.PRESET is None:
transform = create_transform(
input_size=config.DATA.IMG_SIZE,
is_training=True,
color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
re_prob=config.AUG.REPROB,
re_mode=config.AUG.REMODE,
re_count=config.AUG.RECOUNT,
interpolation=config.DATA.INTERPOLATION,
)
print('=============================== original AUG! ', config.AUG.AUTO_AUGMENT)
if not resize_im:
# replace RandomResizedCropAndInterpolation with
# RandomCrop
transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
elif config.AUG.PRESET.strip() == 'raug15':
from train.randaug import RandAugPolicy
transform = transforms.Compose([
transforms.RandomResizedCrop(config.DATA.IMG_SIZE),
transforms.RandomHorizontalFlip(),
RandAugPolicy(magnitude=15),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])
print('---------------------- RAND AUG 15 distortion!')
elif config.AUG.PRESET.strip() == 'weak':
transform = transforms.Compose([
transforms.RandomResizedCrop(config.DATA.IMG_SIZE),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])
elif config.AUG.PRESET.strip() == 'none':
transform = transforms.Compose([
transforms.Resize(config.DATA.IMG_SIZE, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
transforms.CenterCrop(config.DATA.IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])
else:
raise ValueError('???' + config.AUG.PRESET)
print(transform)
return transform
t = []
if resize_im:
if config.TEST.CROP:
size = int((256 / 224) * config.DATA.TEST_SIZE)
t.append(transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
# to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(config.DATA.TEST_SIZE))
else:
# default for testing
t.append(transforms.Resize(config.DATA.TEST_SIZE, interpolation=_pil_interp(config.DATA.INTERPOLATION)))
t.append(transforms.CenterCrop(config.DATA.TEST_SIZE))
t.append(transforms.ToTensor())
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
trans = transforms.Compose(t)
return trans