| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.distributed as dist |
| | from timm.data import Mixup, create_transform |
| | from torchvision import transforms |
| | from torchvision.datasets import ImageFolder |
| |
|
| | from .cached_image_folder import ImageCephDataset |
| | from .samplers import NodeDistributedSampler, SubsetRandomSampler |
| |
|
| | try: |
| | from torchvision.transforms import InterpolationMode |
| |
|
| | def _pil_interp(method): |
| | if method == 'bicubic': |
| | return InterpolationMode.BICUBIC |
| | elif method == 'lanczos': |
| | return InterpolationMode.LANCZOS |
| | elif method == 'hamming': |
| | return InterpolationMode.HAMMING |
| | else: |
| | return InterpolationMode.BILINEAR |
| | except: |
| | from timm.data.transforms import _pil_interp |
| |
|
| |
|
| | class TTA(torch.nn.Module): |
| |
|
| | def __init__(self, size, scales=[1.0, 1.05, 1.1]): |
| | super().__init__() |
| | self.size = size |
| | self.scales = scales |
| |
|
| | def forward(self, img): |
| | out = [] |
| | cc = transforms.CenterCrop(self.size) |
| | for scale in self.scales: |
| | size_ = int(scale * self.size) |
| | rs = transforms.Resize(size_, interpolation=_pil_interp('bicubic')) |
| | img_ = rs(img) |
| | img_ = cc(img_) |
| | out.append(img_) |
| |
|
| | return out |
| |
|
| | def __repr__(self) -> str: |
| | return f'{self.__class__.__name__}(size={self.size}, scale={self.scales})' |
| |
|
| |
|
| | def build_loader(config): |
| | config.defrost() |
| | dataset_train, config.MODEL.NUM_CLASSES = build_dataset('train', config=config) |
| | config.freeze() |
| | print(f'local rank {config.LOCAL_RANK} / global rank {dist.get_rank()}' |
| | 'successfully build train dataset') |
| |
|
| | dataset_val, _ = build_dataset('val', config=config) |
| | print(f'local rank {config.LOCAL_RANK} / global rank {dist.get_rank()}' |
| | 'successfully build val dataset') |
| |
|
| | dataset_test, _ = build_dataset('test', config=config) |
| | print(f'local rank {config.LOCAL_RANK} / global rank {dist.get_rank()}' |
| | 'successfully build test dataset') |
| |
|
| | num_tasks = dist.get_world_size() |
| | global_rank = dist.get_rank() |
| |
|
| | if dataset_train is not None: |
| | if config.DATA.IMG_ON_MEMORY: |
| | sampler_train = NodeDistributedSampler(dataset_train) |
| | else: |
| | if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part': |
| | indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size()) |
| | sampler_train = SubsetRandomSampler(indices) |
| | else: |
| | sampler_train = torch.utils.data.DistributedSampler( |
| | dataset_train, |
| | num_replicas=num_tasks, |
| | rank=global_rank, |
| | shuffle=True) |
| |
|
| | if dataset_val is not None: |
| | if config.TEST.SEQUENTIAL: |
| | sampler_val = torch.utils.data.SequentialSampler(dataset_val) |
| | else: |
| | sampler_val = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False) |
| |
|
| | if dataset_test is not None: |
| | if config.TEST.SEQUENTIAL: |
| | sampler_test = torch.utils.data.SequentialSampler(dataset_test) |
| | else: |
| | sampler_test = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False) |
| |
|
| | data_loader_train = torch.utils.data.DataLoader( |
| | dataset_train, |
| | sampler=sampler_train, |
| | batch_size=config.DATA.BATCH_SIZE, |
| | num_workers=config.DATA.NUM_WORKERS, |
| | pin_memory=config.DATA.PIN_MEMORY, |
| | drop_last=True, |
| | persistent_workers=True) if dataset_train is not None else None |
| |
|
| | data_loader_val = torch.utils.data.DataLoader( |
| | dataset_val, |
| | sampler=sampler_val, |
| | batch_size=config.DATA.BATCH_SIZE, |
| | shuffle=False, |
| | num_workers=config.DATA.NUM_WORKERS, |
| | pin_memory=config.DATA.PIN_MEMORY, |
| | drop_last=False, |
| | persistent_workers=True) if dataset_val is not None else None |
| |
|
| | data_loader_test = torch.utils.data.DataLoader( |
| | dataset_test, |
| | sampler=sampler_test, |
| | batch_size=config.DATA.BATCH_SIZE, |
| | shuffle=False, |
| | num_workers=config.DATA.NUM_WORKERS, |
| | pin_memory=config.DATA.PIN_MEMORY, |
| | drop_last=False, |
| | persistent_workers=True) if dataset_test is not None else None |
| |
|
| | |
| | mixup_fn = None |
| | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None |
| | if mixup_active: |
| | mixup_fn = Mixup(mixup_alpha=config.AUG.MIXUP, |
| | cutmix_alpha=config.AUG.CUTMIX, |
| | cutmix_minmax=config.AUG.CUTMIX_MINMAX, |
| | prob=config.AUG.MIXUP_PROB, |
| | switch_prob=config.AUG.MIXUP_SWITCH_PROB, |
| | mode=config.AUG.MIXUP_MODE, |
| | label_smoothing=config.MODEL.LABEL_SMOOTHING, |
| | num_classes=config.MODEL.NUM_CLASSES) |
| |
|
| | return dataset_train, dataset_val, dataset_test, data_loader_train, \ |
| | data_loader_val, data_loader_test, mixup_fn |
| |
|
| |
|
| | def build_loader2(config): |
| | config.defrost() |
| | dataset_train, config.MODEL.NUM_CLASSES = build_dataset('train', config=config) |
| | config.freeze() |
| | dataset_val, _ = build_dataset('val', config=config) |
| | dataset_test, _ = build_dataset('test', config=config) |
| |
|
| | data_loader_train = torch.utils.data.DataLoader( |
| | dataset_train, |
| | shuffle=True, |
| | batch_size=config.DATA.BATCH_SIZE, |
| | num_workers=config.DATA.NUM_WORKERS, |
| | pin_memory=config.DATA.PIN_MEMORY, |
| | drop_last=True, |
| | persistent_workers=True) if dataset_train is not None else None |
| |
|
| | data_loader_val = torch.utils.data.DataLoader( |
| | dataset_val, |
| | batch_size=config.DATA.BATCH_SIZE, |
| | shuffle=False, |
| | num_workers=config.DATA.NUM_WORKERS, |
| | pin_memory=config.DATA.PIN_MEMORY, |
| | drop_last=False, |
| | persistent_workers=True) if dataset_val is not None else None |
| |
|
| | data_loader_test = torch.utils.data.DataLoader( |
| | dataset_test, |
| | batch_size=config.DATA.BATCH_SIZE, |
| | shuffle=False, |
| | num_workers=config.DATA.NUM_WORKERS, |
| | pin_memory=config.DATA.PIN_MEMORY, |
| | drop_last=False, |
| | persistent_workers=True) if dataset_test is not None else None |
| |
|
| | |
| | mixup_fn = None |
| | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None |
| | if mixup_active: |
| | mixup_fn = Mixup(mixup_alpha=config.AUG.MIXUP, |
| | cutmix_alpha=config.AUG.CUTMIX, |
| | cutmix_minmax=config.AUG.CUTMIX_MINMAX, |
| | prob=config.AUG.MIXUP_PROB, |
| | switch_prob=config.AUG.MIXUP_SWITCH_PROB, |
| | mode=config.AUG.MIXUP_MODE, |
| | label_smoothing=config.MODEL.LABEL_SMOOTHING, |
| | num_classes=config.MODEL.NUM_CLASSES) |
| |
|
| | return dataset_train, dataset_val, dataset_test, data_loader_train, \ |
| | data_loader_val, data_loader_test, mixup_fn |
| |
|
| |
|
| | def build_dataset(split, config): |
| | if config.DATA.TRANSFORM == 'build_transform': |
| | transform = build_transform(split == 'train', config) |
| | elif config.DATA.TRANSFORM == 'build_transform_for_linear_probe': |
| | transform = build_transform_for_linear_probe(split == 'train', config) |
| | else: |
| | raise NotImplementedError |
| | print(split, transform) |
| | dataset = None |
| | nb_classes = None |
| | prefix = split |
| | if config.DATA.DATASET == 'imagenet' or config.DATA.DATASET == 'imagenet-real': |
| | if prefix == 'train' and not config.EVAL_MODE: |
| | root = os.path.join(config.DATA.DATA_PATH, 'train') |
| | dataset = ImageCephDataset(root, 'train', |
| | transform=transform, |
| | on_memory=config.DATA.IMG_ON_MEMORY) |
| | elif prefix == 'val': |
| | root = os.path.join(config.DATA.DATA_PATH, 'val') |
| | dataset = ImageCephDataset(root, 'val', transform=transform) |
| | nb_classes = 1000 |
| | elif config.DATA.DATASET == 'imagenet22K': |
| | if prefix == 'train': |
| | if not config.EVAL_MODE: |
| | root = config.DATA.DATA_PATH |
| | dataset = ImageCephDataset(root, 'train', |
| | transform=transform, |
| | on_memory=config.DATA.IMG_ON_MEMORY) |
| | nb_classes = 21841 |
| | elif prefix == 'val': |
| | root = os.path.join(config.DATA.DATA_PATH, 'val') |
| | dataset = ImageCephDataset(root, 'val', transform=transform) |
| | nb_classes = 1000 |
| | elif config.DATA.DATASET == 'imagenetv2': |
| | from .imagenetv2 import ImageNetV2Dataset |
| | if prefix == 'train' and not config.EVAL_MODE: |
| | print(f'Only test split available for {config.DATA.DATASET}') |
| | else: |
| | dataset = ImageNetV2Dataset(variant='matched-frequency', |
| | transform=transform, |
| | location=config.DATA.DATA_PATH) |
| | nb_classes = 1000 |
| | elif config.DATA.DATASET == 'imagenet_sketch': |
| | if prefix == 'train' and not config.EVAL_MODE: |
| | print(f'Only test split available for {config.DATA.DATASET}') |
| | else: |
| | dataset = ImageFolder(root=config.DATA.DATA_PATH, transform=transform) |
| | nb_classes = 1000 |
| | elif config.DATA.DATASET == 'imagenet_a': |
| | if prefix == 'train' and not config.EVAL_MODE: |
| | print(f'Only test split available for {config.DATA.DATASET}') |
| | else: |
| | dataset = ImageFolder(root=config.DATA.DATA_PATH, transform=transform) |
| | nb_classes = 1000 |
| | elif config.DATA.DATASET == 'imagenet_r': |
| | if prefix == 'train' and not config.EVAL_MODE: |
| | print(f'Only test split available for {config.DATA.DATASET}') |
| | else: |
| | dataset = ImageFolder(root=config.DATA.DATA_PATH, transform=transform) |
| | nb_classes = 1000 |
| | else: |
| | raise NotImplementedError( |
| | f'build_dataset does support {config.DATA.DATASET}') |
| |
|
| | return dataset, nb_classes |
| |
|
| |
|
| | def build_transform_for_linear_probe(is_train, config): |
| | |
| | if is_train: |
| | transform = transforms.Compose([ |
| | transforms.RandomResizedCrop( |
| | config.DATA.IMG_SIZE, interpolation=transforms.InterpolationMode.BICUBIC), |
| | transforms.RandomHorizontalFlip(), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=config.AUG.MEAN, std=config.AUG.STD) |
| | ]) |
| | else: |
| | transform = transforms.Compose([ |
| | transforms.Resize( |
| | config.DATA.IMG_SIZE, interpolation=transforms.InterpolationMode.BICUBIC), |
| | transforms.CenterCrop(config.DATA.IMG_SIZE), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=config.AUG.MEAN, std=config.AUG.STD) |
| | ]) |
| | return transform |
| |
|
| |
|
| | def build_transform(is_train, config): |
| | resize_im = config.DATA.IMG_SIZE > 32 |
| | if is_train: |
| | |
| | transform = create_transform( |
| | input_size=config.DATA.IMG_SIZE, |
| | is_training=True, |
| | color_jitter=config.AUG.COLOR_JITTER |
| | if config.AUG.COLOR_JITTER > 0 else None, |
| | auto_augment=config.AUG.AUTO_AUGMENT |
| | if config.AUG.AUTO_AUGMENT != 'none' else None, |
| | re_prob=config.AUG.REPROB, |
| | re_mode=config.AUG.REMODE, |
| | re_count=config.AUG.RECOUNT, |
| | interpolation=config.DATA.INTERPOLATION, |
| | ) |
| | if not resize_im: |
| | |
| | |
| | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4) |
| |
|
| | return transform |
| |
|
| | t = [] |
| | if resize_im: |
| | if config.TEST.CROP: |
| | size = int(1.0 * config.DATA.IMG_SIZE) |
| | t.append( |
| | transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)), |
| | |
| | ) |
| | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) |
| | elif config.AUG.RANDOM_RESIZED_CROP: |
| | t.append( |
| | transforms.RandomResizedCrop( |
| | (config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), |
| | interpolation=_pil_interp(config.DATA.INTERPOLATION))) |
| | else: |
| | t.append( |
| | transforms.Resize( |
| | (config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), |
| | interpolation=_pil_interp(config.DATA.INTERPOLATION))) |
| | t.append(transforms.ToTensor()) |
| | t.append(transforms.Normalize(config.AUG.MEAN, config.AUG.STD)) |
| |
|
| | return transforms.Compose(t) |
| |
|