tqv06's picture
Upload folder using huggingface_hub
866ee56 verified
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import os
import numpy as np
import torch
import torch.distributed as dist
from timm.data import Mixup, create_transform
from torchvision import transforms
from torchvision.datasets import ImageFolder
from .cached_image_folder import ImageCephDataset
from .samplers import NodeDistributedSampler, SubsetRandomSampler
try:
from torchvision.transforms import InterpolationMode
def _pil_interp(method):
if method == 'bicubic':
return InterpolationMode.BICUBIC
elif method == 'lanczos':
return InterpolationMode.LANCZOS
elif method == 'hamming':
return InterpolationMode.HAMMING
else:
return InterpolationMode.BILINEAR
except:
from timm.data.transforms import _pil_interp
class TTA(torch.nn.Module):
def __init__(self, size, scales=[1.0, 1.05, 1.1]):
super().__init__()
self.size = size
self.scales = scales
def forward(self, img):
out = []
cc = transforms.CenterCrop(self.size)
for scale in self.scales:
size_ = int(scale * self.size)
rs = transforms.Resize(size_, interpolation=_pil_interp('bicubic'))
img_ = rs(img)
img_ = cc(img_)
out.append(img_)
return out
def __repr__(self) -> str:
return f'{self.__class__.__name__}(size={self.size}, scale={self.scales})'
def build_loader(config):
config.defrost()
dataset_train, config.MODEL.NUM_CLASSES = build_dataset('train', config=config)
config.freeze()
print(f'local rank {config.LOCAL_RANK} / global rank {dist.get_rank()}'
'successfully build train dataset')
dataset_val, _ = build_dataset('val', config=config)
print(f'local rank {config.LOCAL_RANK} / global rank {dist.get_rank()}'
'successfully build val dataset')
dataset_test, _ = build_dataset('test', config=config)
print(f'local rank {config.LOCAL_RANK} / global rank {dist.get_rank()}'
'successfully build test dataset')
num_tasks = dist.get_world_size()
global_rank = dist.get_rank()
if dataset_train is not None:
if config.DATA.IMG_ON_MEMORY:
sampler_train = NodeDistributedSampler(dataset_train)
else:
if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':
indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
sampler_train = SubsetRandomSampler(indices)
else:
sampler_train = torch.utils.data.DistributedSampler(
dataset_train,
num_replicas=num_tasks,
rank=global_rank,
shuffle=True)
if dataset_val is not None:
if config.TEST.SEQUENTIAL:
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
else:
sampler_val = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)
if dataset_test is not None:
if config.TEST.SEQUENTIAL:
sampler_test = torch.utils.data.SequentialSampler(dataset_test)
else:
sampler_test = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
data_loader_train = torch.utils.data.DataLoader(
dataset_train,
sampler=sampler_train,
batch_size=config.DATA.BATCH_SIZE,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
drop_last=True,
persistent_workers=True) if dataset_train is not None else None
data_loader_val = torch.utils.data.DataLoader(
dataset_val,
sampler=sampler_val,
batch_size=config.DATA.BATCH_SIZE,
shuffle=False,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
drop_last=False,
persistent_workers=True) if dataset_val is not None else None
data_loader_test = torch.utils.data.DataLoader(
dataset_test,
sampler=sampler_test,
batch_size=config.DATA.BATCH_SIZE,
shuffle=False,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
drop_last=False,
persistent_workers=True) if dataset_test is not None else None
# setup mixup / cutmix
mixup_fn = None
mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
if mixup_active:
mixup_fn = Mixup(mixup_alpha=config.AUG.MIXUP,
cutmix_alpha=config.AUG.CUTMIX,
cutmix_minmax=config.AUG.CUTMIX_MINMAX,
prob=config.AUG.MIXUP_PROB,
switch_prob=config.AUG.MIXUP_SWITCH_PROB,
mode=config.AUG.MIXUP_MODE,
label_smoothing=config.MODEL.LABEL_SMOOTHING,
num_classes=config.MODEL.NUM_CLASSES)
return dataset_train, dataset_val, dataset_test, data_loader_train, \
data_loader_val, data_loader_test, mixup_fn
def build_loader2(config):
config.defrost()
dataset_train, config.MODEL.NUM_CLASSES = build_dataset('train', config=config)
config.freeze()
dataset_val, _ = build_dataset('val', config=config)
dataset_test, _ = build_dataset('test', config=config)
data_loader_train = torch.utils.data.DataLoader(
dataset_train,
shuffle=True,
batch_size=config.DATA.BATCH_SIZE,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
drop_last=True,
persistent_workers=True) if dataset_train is not None else None
data_loader_val = torch.utils.data.DataLoader(
dataset_val,
batch_size=config.DATA.BATCH_SIZE,
shuffle=False,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
drop_last=False,
persistent_workers=True) if dataset_val is not None else None
data_loader_test = torch.utils.data.DataLoader(
dataset_test,
batch_size=config.DATA.BATCH_SIZE,
shuffle=False,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
drop_last=False,
persistent_workers=True) if dataset_test is not None else None
# setup mixup / cutmix
mixup_fn = None
mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
if mixup_active:
mixup_fn = Mixup(mixup_alpha=config.AUG.MIXUP,
cutmix_alpha=config.AUG.CUTMIX,
cutmix_minmax=config.AUG.CUTMIX_MINMAX,
prob=config.AUG.MIXUP_PROB,
switch_prob=config.AUG.MIXUP_SWITCH_PROB,
mode=config.AUG.MIXUP_MODE,
label_smoothing=config.MODEL.LABEL_SMOOTHING,
num_classes=config.MODEL.NUM_CLASSES)
return dataset_train, dataset_val, dataset_test, data_loader_train, \
data_loader_val, data_loader_test, mixup_fn
def build_dataset(split, config):
if config.DATA.TRANSFORM == 'build_transform':
transform = build_transform(split == 'train', config)
elif config.DATA.TRANSFORM == 'build_transform_for_linear_probe':
transform = build_transform_for_linear_probe(split == 'train', config)
else:
raise NotImplementedError
print(split, transform)
dataset = None
nb_classes = None
prefix = split
if config.DATA.DATASET == 'imagenet' or config.DATA.DATASET == 'imagenet-real':
if prefix == 'train' and not config.EVAL_MODE:
root = os.path.join(config.DATA.DATA_PATH, 'train')
dataset = ImageCephDataset(root, 'train',
transform=transform,
on_memory=config.DATA.IMG_ON_MEMORY)
elif prefix == 'val':
root = os.path.join(config.DATA.DATA_PATH, 'val')
dataset = ImageCephDataset(root, 'val', transform=transform)
nb_classes = 1000
elif config.DATA.DATASET == 'imagenet22K':
if prefix == 'train':
if not config.EVAL_MODE:
root = config.DATA.DATA_PATH
dataset = ImageCephDataset(root, 'train',
transform=transform,
on_memory=config.DATA.IMG_ON_MEMORY)
nb_classes = 21841
elif prefix == 'val':
root = os.path.join(config.DATA.DATA_PATH, 'val')
dataset = ImageCephDataset(root, 'val', transform=transform)
nb_classes = 1000
elif config.DATA.DATASET == 'imagenetv2':
from .imagenetv2 import ImageNetV2Dataset
if prefix == 'train' and not config.EVAL_MODE:
print(f'Only test split available for {config.DATA.DATASET}')
else:
dataset = ImageNetV2Dataset(variant='matched-frequency',
transform=transform,
location=config.DATA.DATA_PATH)
nb_classes = 1000
elif config.DATA.DATASET == 'imagenet_sketch':
if prefix == 'train' and not config.EVAL_MODE:
print(f'Only test split available for {config.DATA.DATASET}')
else:
dataset = ImageFolder(root=config.DATA.DATA_PATH, transform=transform)
nb_classes = 1000
elif config.DATA.DATASET == 'imagenet_a':
if prefix == 'train' and not config.EVAL_MODE:
print(f'Only test split available for {config.DATA.DATASET}')
else:
dataset = ImageFolder(root=config.DATA.DATA_PATH, transform=transform)
nb_classes = 1000 # actual number of classes is 200
elif config.DATA.DATASET == 'imagenet_r':
if prefix == 'train' and not config.EVAL_MODE:
print(f'Only test split available for {config.DATA.DATASET}')
else:
dataset = ImageFolder(root=config.DATA.DATA_PATH, transform=transform)
nb_classes = 1000 # actual number of classes is 200
else:
raise NotImplementedError(
f'build_dataset does support {config.DATA.DATASET}')
return dataset, nb_classes
def build_transform_for_linear_probe(is_train, config):
# linear probe: weak augmentation
if is_train:
transform = transforms.Compose([
transforms.RandomResizedCrop(
config.DATA.IMG_SIZE, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=config.AUG.MEAN, std=config.AUG.STD)
])
else:
transform = transforms.Compose([
transforms.Resize(
config.DATA.IMG_SIZE, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(config.DATA.IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=config.AUG.MEAN, std=config.AUG.STD)
])
return transform
def build_transform(is_train, config):
resize_im = config.DATA.IMG_SIZE > 32
if is_train:
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=config.DATA.IMG_SIZE,
is_training=True,
color_jitter=config.AUG.COLOR_JITTER
if config.AUG.COLOR_JITTER > 0 else None,
auto_augment=config.AUG.AUTO_AUGMENT
if config.AUG.AUTO_AUGMENT != 'none' else None,
re_prob=config.AUG.REPROB,
re_mode=config.AUG.REMODE,
re_count=config.AUG.RECOUNT,
interpolation=config.DATA.INTERPOLATION,
)
if not resize_im:
# replace RandomResizedCropAndInterpolation with
# RandomCrop
transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
return transform
t = []
if resize_im:
if config.TEST.CROP:
size = int(1.0 * config.DATA.IMG_SIZE)
t.append(
transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
# to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
elif config.AUG.RANDOM_RESIZED_CROP:
t.append(
transforms.RandomResizedCrop(
(config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
interpolation=_pil_interp(config.DATA.INTERPOLATION)))
else:
t.append(
transforms.Resize(
(config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
interpolation=_pil_interp(config.DATA.INTERPOLATION)))
t.append(transforms.ToTensor())
t.append(transforms.Normalize(config.AUG.MEAN, config.AUG.STD))
return transforms.Compose(t)