| """ |
| Code adapted from timm https://github.com/huggingface/pytorch-image-models |
| |
| Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich |
| """ |
|
|
| import logging |
| from contextlib import suppress |
| from functools import partial |
| from itertools import repeat |
|
|
| import numpy as np |
| import torch |
| import torch.utils.data |
| from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
| from timm.data.dataset import IterableImageDataset |
| from timm.data.loader import PrefetchLoader, _worker_init |
| from timm.data.transforms_factory import create_transform |
|
|
| _logger = logging.getLogger(__name__) |
|
|
|
|
| def fast_collate(batch, target_dtype=torch.uint8): |
| """A fast collation function optimized for uint8 images (np array or torch) and target_dtype targets (labels)""" |
| assert isinstance(batch[0], tuple) |
| batch_size = len(batch) |
| if isinstance(batch[0][0], np.ndarray): |
| targets = torch.tensor([b[1] for b in batch], dtype=target_dtype) |
| assert len(targets) == batch_size |
| tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) |
| for i in range(batch_size): |
| tensor[i] += torch.from_numpy(batch[i][0]) |
| return tensor, targets |
| else: |
| raise ValueError(f"Incorrect batch type: {type(batch[0][0])}") |
|
|
|
|
| def adapt_to_chs(x, n): |
| if not isinstance(x, (tuple, list)): |
| x = tuple(repeat(x, n)) |
| elif len(x) != n: |
| |
| if len(x) * 2 == n: |
| x = np.concatenate((x, x)) |
| _logger.warning(f"Pretrained mean/std different shape than model (doubled channes), using concat: {x}.") |
| else: |
| x_mean = np.mean(x).item() |
| x = (x_mean,) * n |
| _logger.warning(f"Pretrained mean/std different shape than model, using avg value {x}.") |
| else: |
| assert len(x) == n, "normalization stats must match image channels" |
| return x |
|
|
|
|
| class PrefetchLoaderForMultiInput(PrefetchLoader): |
| def __init__( |
| self, |
| loader, |
| mean=IMAGENET_DEFAULT_MEAN, |
| std=IMAGENET_DEFAULT_STD, |
| channels=3, |
| device=torch.device("cpu"), |
| img_dtype=torch.float32, |
| ): |
|
|
| mean = adapt_to_chs(mean, channels) |
| std = adapt_to_chs(std, channels) |
| normalization_shape = (1, channels, 1, 1) |
|
|
| self.loader = loader |
| self.device = device |
| self.img_dtype = img_dtype |
| self.mean = torch.tensor([x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape) |
| self.std = torch.tensor([x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape) |
|
|
| self.is_cuda = torch.cuda.is_available() and device.type == "cpu" |
|
|
| def __iter__(self): |
| first = True |
| if self.is_cuda: |
| stream = torch.cuda.Stream() |
| stream_context = partial(torch.cuda.stream, stream=stream) |
| else: |
| stream = None |
| stream_context = suppress |
|
|
| for next_input, next_target in self.loader: |
|
|
| with stream_context(): |
| next_input = next_input.to(device=self.device, non_blocking=True) |
| next_target = next_target.to(device=self.device, non_blocking=True) |
| next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std) |
|
|
| if not first: |
| yield input, target |
| else: |
| first = False |
|
|
| if stream is not None: |
| torch.cuda.current_stream().wait_stream(stream) |
|
|
| input = next_input |
| target = next_target |
|
|
| yield input, target |
|
|
|
|
| def create_loader( |
| dataset, |
| input_size, |
| batch_size, |
| mean=IMAGENET_DEFAULT_MEAN, |
| std=IMAGENET_DEFAULT_STD, |
| num_workers=1, |
| crop_pct=None, |
| crop_mode=None, |
| pin_memory=False, |
| img_dtype=torch.float32, |
| device=torch.device("cpu"), |
| persistent_workers=True, |
| worker_seeding="all", |
| target_type=torch.int64, |
| ): |
|
|
| transform = create_transform( |
| input_size, |
| is_training=False, |
| use_prefetcher=True, |
| mean=mean, |
| std=std, |
| crop_pct=crop_pct, |
| crop_mode=crop_mode, |
| ) |
| dataset.transform = transform |
|
|
| if isinstance(dataset, IterableImageDataset): |
| |
| |
| dataset.set_loader_cfg(num_workers=num_workers) |
| raise ValueError("Incorrect dataset type: IterableImageDataset") |
|
|
| loader_class = torch.utils.data.DataLoader |
| loader_args = dict( |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=num_workers, |
| sampler=None, |
| collate_fn=lambda batch: fast_collate(batch, target_dtype=target_type), |
| pin_memory=pin_memory, |
| drop_last=False, |
| worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding), |
| persistent_workers=persistent_workers, |
| ) |
| try: |
| loader = loader_class(dataset, **loader_args) |
| except TypeError: |
| loader_args.pop("persistent_workers") |
| loader = loader_class(dataset, **loader_args) |
|
|
| loader = PrefetchLoaderForMultiInput( |
| loader, |
| mean=mean, |
| std=std, |
| channels=input_size[0], |
| device=device, |
| img_dtype=img_dtype, |
| ) |
|
|
| return loader |
|
|