| | """ Loader Factory, Fast Collate, CUDA Prefetcher |
| | |
| | Prefetcher and Fast Collate inspired by NVIDIA APEX example at |
| | https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf |
| | |
| | Hacked together by / Copyright 2019, Ross Wightman |
| | """ |
| | import logging |
| | import random |
| | from contextlib import suppress |
| | from functools import partial |
| | from itertools import repeat |
| | from typing import Callable, Optional, Tuple, Union |
| |
|
| | import torch |
| | import torch.utils.data |
| | import numpy as np |
| |
|
| | from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
| | from .dataset import IterableImageDataset, ImageDataset |
| | from .distributed_sampler import OrderedDistributedSampler, RepeatAugSampler |
| | from .random_erasing import RandomErasing |
| | from .mixup import FastCollateMixup |
| | from .transforms_factory import create_transform |
| |
|
| | _logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def fast_collate(batch): |
| | """ A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)""" |
| | assert isinstance(batch[0], tuple) |
| | batch_size = len(batch) |
| | if isinstance(batch[0][0], tuple): |
| | |
| | |
| | inner_tuple_size = len(batch[0][0]) |
| | flattened_batch_size = batch_size * inner_tuple_size |
| | targets = torch.zeros(flattened_batch_size, dtype=torch.int64) |
| | tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8) |
| | for i in range(batch_size): |
| | assert len(batch[i][0]) == inner_tuple_size |
| | for j in range(inner_tuple_size): |
| | targets[i + j * batch_size] = batch[i][1] |
| | tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j]) |
| | return tensor, targets |
| | elif isinstance(batch[0][0], np.ndarray): |
| | targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) |
| | assert len(targets) == batch_size |
| | tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) |
| | for i in range(batch_size): |
| | tensor[i] += torch.from_numpy(batch[i][0]) |
| | return tensor, targets |
| | elif isinstance(batch[0][0], torch.Tensor): |
| | targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) |
| | assert len(targets) == batch_size |
| | tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) |
| | for i in range(batch_size): |
| | tensor[i].copy_(batch[i][0]) |
| | return tensor, targets |
| | else: |
| | assert False |
| |
|
| |
|
| | def adapt_to_chs(x, n): |
| | if not isinstance(x, (tuple, list)): |
| | x = tuple(repeat(x, n)) |
| | elif len(x) != n: |
| | x_mean = np.mean(x).item() |
| | x = (x_mean,) * n |
| | _logger.warning(f'Pretrained mean/std different shape than model, using avg value {x}.') |
| | else: |
| | assert len(x) == n, 'normalization stats must match image channels' |
| | return x |
| |
|
| |
|
| | class PrefetchLoader: |
| |
|
| | def __init__( |
| | self, |
| | loader, |
| | mean=IMAGENET_DEFAULT_MEAN, |
| | std=IMAGENET_DEFAULT_STD, |
| | channels=3, |
| | device=torch.device('cuda'), |
| | img_dtype=torch.float32, |
| | fp16=False, |
| | re_prob=0., |
| | re_mode='const', |
| | re_count=1, |
| | re_num_splits=0): |
| |
|
| | mean = adapt_to_chs(mean, channels) |
| | std = adapt_to_chs(std, channels) |
| | normalization_shape = (1, channels, 1, 1) |
| |
|
| | self.loader = loader |
| | self.device = device |
| | if fp16: |
| | |
| | img_dtype = torch.float16 |
| | self.img_dtype = img_dtype |
| | self.mean = torch.tensor( |
| | [x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape) |
| | self.std = torch.tensor( |
| | [x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape) |
| | if re_prob > 0.: |
| | self.random_erasing = RandomErasing( |
| | probability=re_prob, |
| | mode=re_mode, |
| | max_count=re_count, |
| | num_splits=re_num_splits, |
| | device=device, |
| | ) |
| | else: |
| | self.random_erasing = None |
| | self.is_cuda = torch.cuda.is_available() and device.type == 'cuda' |
| |
|
| | def __iter__(self): |
| | first = True |
| | if self.is_cuda: |
| | stream = torch.cuda.Stream() |
| | stream_context = partial(torch.cuda.stream, stream=stream) |
| | else: |
| | stream = None |
| | stream_context = suppress |
| |
|
| | for next_input, next_target in self.loader: |
| |
|
| | with stream_context(): |
| | next_input = next_input.to(device=self.device, non_blocking=True) |
| | next_target = next_target.to(device=self.device, non_blocking=True) |
| | next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std) |
| | if self.random_erasing is not None: |
| | next_input = self.random_erasing(next_input) |
| |
|
| | if not first: |
| | yield input, target |
| | else: |
| | first = False |
| |
|
| | if stream is not None: |
| | torch.cuda.current_stream().wait_stream(stream) |
| |
|
| | input = next_input |
| | target = next_target |
| |
|
| | yield input, target |
| |
|
| | def __len__(self): |
| | return len(self.loader) |
| |
|
| | @property |
| | def sampler(self): |
| | return self.loader.sampler |
| |
|
| | @property |
| | def dataset(self): |
| | return self.loader.dataset |
| |
|
| | @property |
| | def mixup_enabled(self): |
| | if isinstance(self.loader.collate_fn, FastCollateMixup): |
| | return self.loader.collate_fn.mixup_enabled |
| | else: |
| | return False |
| |
|
| | @mixup_enabled.setter |
| | def mixup_enabled(self, x): |
| | if isinstance(self.loader.collate_fn, FastCollateMixup): |
| | self.loader.collate_fn.mixup_enabled = x |
| |
|
| |
|
| | def _worker_init(worker_id, worker_seeding='all'): |
| | worker_info = torch.utils.data.get_worker_info() |
| | assert worker_info.id == worker_id |
| | if isinstance(worker_seeding, Callable): |
| | seed = worker_seeding(worker_info) |
| | random.seed(seed) |
| | torch.manual_seed(seed) |
| | np.random.seed(seed % (2 ** 32 - 1)) |
| | else: |
| | assert worker_seeding in ('all', 'part') |
| | |
| | |
| | if worker_seeding == 'all': |
| | np.random.seed(worker_info.seed % (2 ** 32 - 1)) |
| |
|
| |
|
| | def create_loader( |
| | dataset: Union[ImageDataset, IterableImageDataset], |
| | input_size: Union[int, Tuple[int, int], Tuple[int, int, int]], |
| | batch_size: int, |
| | is_training: bool = False, |
| | no_aug: bool = False, |
| | re_prob: float = 0., |
| | re_mode: str = 'const', |
| | re_count: int = 1, |
| | re_split: bool = False, |
| | train_crop_mode: Optional[str] = None, |
| | scale: Optional[Tuple[float, float]] = None, |
| | ratio: Optional[Tuple[float, float]] = None, |
| | hflip: float = 0.5, |
| | vflip: float = 0., |
| | color_jitter: float = 0.4, |
| | color_jitter_prob: Optional[float] = None, |
| | grayscale_prob: float = 0., |
| | gaussian_blur_prob: float = 0., |
| | auto_augment: Optional[str] = None, |
| | num_aug_repeats: int = 0, |
| | num_aug_splits: int = 0, |
| | interpolation: str = 'bilinear', |
| | mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN, |
| | std: Tuple[float, ...] = IMAGENET_DEFAULT_STD, |
| | num_workers: int = 1, |
| | distributed: bool = False, |
| | crop_pct: Optional[float] = None, |
| | crop_mode: Optional[str] = None, |
| | crop_border_pixels: Optional[int] = None, |
| | collate_fn: Optional[Callable] = None, |
| | pin_memory: bool = False, |
| | fp16: bool = False, |
| | img_dtype: torch.dtype = torch.float32, |
| | device: torch.device = torch.device('cuda'), |
| | use_prefetcher: bool = True, |
| | use_multi_epochs_loader: bool = False, |
| | persistent_workers: bool = True, |
| | worker_seeding: str = 'all', |
| | tf_preprocessing: bool = False, |
| | ): |
| | """ |
| | |
| | Args: |
| | dataset: The image dataset to load. |
| | input_size: Target input size (channels, height, width) tuple or size scalar. |
| | batch_size: Number of samples in a batch. |
| | is_training: Return training (random) transforms. |
| | no_aug: Disable augmentation for training (useful for debug). |
| | re_prob: Random erasing probability. |
| | re_mode: Random erasing fill mode. |
| | re_count: Number of random erasing regions. |
| | re_split: Control split of random erasing across batch size. |
| | scale: Random resize scale range (crop area, < 1.0 => zoom in). |
| | ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR). |
| | hflip: Horizontal flip probability. |
| | vflip: Vertical flip probability. |
| | color_jitter: Random color jitter component factors (brightness, contrast, saturation, hue). |
| | Scalar is applied as (scalar,) * 3 (no hue). |
| | color_jitter_prob: Apply color jitter with this probability if not None (for SimlCLR-like aug |
| | grayscale_prob: Probability of converting image to grayscale (for SimCLR-like aug). |
| | gaussian_blur_prob: Probability of applying gaussian blur (for SimCLR-like aug). |
| | auto_augment: Auto augment configuration string (see auto_augment.py). |
| | num_aug_repeats: Enable special sampler to repeat same augmentation across distributed GPUs. |
| | num_aug_splits: Enable mode where augmentations can be split across the batch. |
| | interpolation: Image interpolation mode. |
| | mean: Image normalization mean. |
| | std: Image normalization standard deviation. |
| | num_workers: Num worker processes per DataLoader. |
| | distributed: Enable dataloading for distributed training. |
| | crop_pct: Inference crop percentage (output size / resize size). |
| | crop_mode: Inference crop mode. One of ['squash', 'border', 'center']. Defaults to 'center' when None. |
| | crop_border_pixels: Inference crop border of specified # pixels around edge of original image. |
| | collate_fn: Override default collate_fn. |
| | pin_memory: Pin memory for device transfer. |
| | fp16: Deprecated argument for half-precision input dtype. Use img_dtype. |
| | img_dtype: Data type for input image. |
| | device: Device to transfer inputs and targets to. |
| | use_prefetcher: Use efficient pre-fetcher to load samples onto device. |
| | use_multi_epochs_loader: |
| | persistent_workers: Enable persistent worker processes. |
| | worker_seeding: Control worker random seeding at init. |
| | tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports. |
| | |
| | Returns: |
| | DataLoader |
| | """ |
| | re_num_splits = 0 |
| | if re_split: |
| | |
| | re_num_splits = num_aug_splits or 2 |
| | dataset.transform = create_transform( |
| | input_size, |
| | is_training=is_training, |
| | no_aug=no_aug, |
| | train_crop_mode=train_crop_mode, |
| | scale=scale, |
| | ratio=ratio, |
| | hflip=hflip, |
| | vflip=vflip, |
| | color_jitter=color_jitter, |
| | color_jitter_prob=color_jitter_prob, |
| | grayscale_prob=grayscale_prob, |
| | gaussian_blur_prob=gaussian_blur_prob, |
| | auto_augment=auto_augment, |
| | interpolation=interpolation, |
| | mean=mean, |
| | std=std, |
| | crop_pct=crop_pct, |
| | crop_mode=crop_mode, |
| | crop_border_pixels=crop_border_pixels, |
| | re_prob=re_prob, |
| | re_mode=re_mode, |
| | re_count=re_count, |
| | re_num_splits=re_num_splits, |
| | tf_preprocessing=tf_preprocessing, |
| | use_prefetcher=use_prefetcher, |
| | separate=num_aug_splits > 0, |
| | ) |
| |
|
| | if isinstance(dataset, IterableImageDataset): |
| | |
| | |
| | dataset.set_loader_cfg(num_workers=num_workers) |
| |
|
| | sampler = None |
| | if distributed and not isinstance(dataset, torch.utils.data.IterableDataset): |
| | if is_training: |
| | if num_aug_repeats: |
| | sampler = RepeatAugSampler(dataset, num_repeats=num_aug_repeats) |
| | else: |
| | sampler = torch.utils.data.distributed.DistributedSampler(dataset) |
| | else: |
| | |
| | |
| | sampler = OrderedDistributedSampler(dataset) |
| | else: |
| | assert num_aug_repeats == 0, "RepeatAugment not currently supported in non-distributed or IterableDataset use" |
| |
|
| | if collate_fn is None: |
| | collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate |
| |
|
| | loader_class = torch.utils.data.DataLoader |
| | if use_multi_epochs_loader: |
| | loader_class = MultiEpochsDataLoader |
| |
|
| | loader_args = dict( |
| | batch_size=batch_size, |
| | shuffle=not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training, |
| | num_workers=num_workers, |
| | sampler=sampler, |
| | collate_fn=collate_fn, |
| | pin_memory=pin_memory, |
| | drop_last=is_training, |
| | worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding), |
| | persistent_workers=persistent_workers |
| | ) |
| | try: |
| | loader = loader_class(dataset, **loader_args) |
| | except TypeError as e: |
| | loader_args.pop('persistent_workers') |
| | loader = loader_class(dataset, **loader_args) |
| | if use_prefetcher: |
| | prefetch_re_prob = re_prob if is_training and not no_aug else 0. |
| | loader = PrefetchLoader( |
| | loader, |
| | mean=mean, |
| | std=std, |
| | channels=input_size[0], |
| | device=device, |
| | fp16=fp16, |
| | img_dtype=img_dtype, |
| | re_prob=prefetch_re_prob, |
| | re_mode=re_mode, |
| | re_count=re_count, |
| | re_num_splits=re_num_splits |
| | ) |
| |
|
| | return loader |
| |
|
| |
|
| | class MultiEpochsDataLoader(torch.utils.data.DataLoader): |
| |
|
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self._DataLoader__initialized = False |
| | if self.batch_sampler is None: |
| | self.sampler = _RepeatSampler(self.sampler) |
| | else: |
| | self.batch_sampler = _RepeatSampler(self.batch_sampler) |
| | self._DataLoader__initialized = True |
| | self.iterator = super().__iter__() |
| |
|
| | def __len__(self): |
| | return len(self.sampler) if self.batch_sampler is None else len(self.batch_sampler.sampler) |
| |
|
| | def __iter__(self): |
| | for i in range(len(self)): |
| | yield next(self.iterator) |
| |
|
| |
|
| | class _RepeatSampler(object): |
| | """ Sampler that repeats forever. |
| | |
| | Args: |
| | sampler (Sampler) |
| | """ |
| |
|
| | def __init__(self, sampler): |
| | self.sampler = sampler |
| |
|
| | def __iter__(self): |
| | while True: |
| | yield from iter(self.sampler) |
| |
|