| |
| |
| |
| |
|
|
| import logging |
| from enum import Enum |
| from typing import Any, Callable, List, Optional, TypeVar |
|
|
| import torch |
| from torch.utils.data import Sampler |
|
|
| from .datasets import ImageNet, ImageNet22k |
| from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler |
|
|
|
|
| logger = logging.getLogger("dinov2") |
|
|
|
|
| class SamplerType(Enum): |
| DISTRIBUTED = 0 |
| EPOCH = 1 |
| INFINITE = 2 |
| SHARDED_INFINITE = 3 |
| SHARDED_INFINITE_NEW = 4 |
|
|
|
|
| def _make_bool_str(b: bool) -> str: |
| return "yes" if b else "no" |
|
|
|
|
| def _make_sample_transform(image_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None): |
| def transform(sample): |
| image, target = sample |
| if image_transform is not None: |
| image = image_transform(image) |
| if target_transform is not None: |
| target = target_transform(target) |
| return image, target |
|
|
| return transform |
|
|
|
|
| def _parse_dataset_str(dataset_str: str): |
| tokens = dataset_str.split(":") |
|
|
| name = tokens[0] |
| kwargs = {} |
|
|
| for token in tokens[1:]: |
| key, value = token.split("=") |
| assert key in ("root", "extra", "split") |
| kwargs[key] = value |
|
|
| if name == "ImageNet": |
| class_ = ImageNet |
| if "split" in kwargs: |
| kwargs["split"] = ImageNet.Split[kwargs["split"]] |
| elif name == "ImageNet22k": |
| class_ = ImageNet22k |
| else: |
| raise ValueError(f'Unsupported dataset "{name}"') |
|
|
| return class_, kwargs |
|
|
|
|
| def make_dataset( |
| *, |
| dataset_str: str, |
| transform: Optional[Callable] = None, |
| target_transform: Optional[Callable] = None, |
| ): |
| """ |
| Creates a dataset with the specified parameters. |
| |
| Args: |
| dataset_str: A dataset string description (e.g. ImageNet:split=TRAIN). |
| transform: A transform to apply to images. |
| target_transform: A transform to apply to targets. |
| |
| Returns: |
| The created dataset. |
| """ |
| logger.info(f'using dataset: "{dataset_str}"') |
|
|
| class_, kwargs = _parse_dataset_str(dataset_str) |
| dataset = class_(transform=transform, target_transform=target_transform, **kwargs) |
|
|
| logger.info(f"# of dataset samples: {len(dataset):,d}") |
|
|
| |
| if not hasattr(dataset, "transform"): |
| setattr(dataset, "transform", transform) |
| if not hasattr(dataset, "target_transform"): |
| setattr(dataset, "target_transform", target_transform) |
|
|
| return dataset |
|
|
|
|
| def _make_sampler( |
| *, |
| dataset, |
| type: Optional[SamplerType] = None, |
| shuffle: bool = False, |
| seed: int = 0, |
| size: int = -1, |
| advance: int = 0, |
| ) -> Optional[Sampler]: |
| sample_count = len(dataset) |
|
|
| if type == SamplerType.INFINITE: |
| logger.info("sampler: infinite") |
| if size > 0: |
| raise ValueError("sampler size > 0 is invalid") |
| return InfiniteSampler( |
| sample_count=sample_count, |
| shuffle=shuffle, |
| seed=seed, |
| advance=advance, |
| ) |
| elif type in (SamplerType.SHARDED_INFINITE, SamplerType.SHARDED_INFINITE_NEW): |
| logger.info("sampler: sharded infinite") |
| if size > 0: |
| raise ValueError("sampler size > 0 is invalid") |
| |
| use_new_shuffle_tensor_slice = type == SamplerType.SHARDED_INFINITE_NEW |
| return ShardedInfiniteSampler( |
| sample_count=sample_count, |
| shuffle=shuffle, |
| seed=seed, |
| advance=advance, |
| use_new_shuffle_tensor_slice=use_new_shuffle_tensor_slice, |
| ) |
| elif type == SamplerType.EPOCH: |
| logger.info("sampler: epoch") |
| if advance > 0: |
| raise NotImplementedError("sampler advance > 0 is not supported") |
| size = size if size > 0 else sample_count |
| logger.info(f"# of samples / epoch: {size:,d}") |
| return EpochSampler( |
| size=size, |
| sample_count=sample_count, |
| shuffle=shuffle, |
| seed=seed, |
| ) |
| elif type == SamplerType.DISTRIBUTED: |
| logger.info("sampler: distributed") |
| if size > 0: |
| raise ValueError("sampler size > 0 is invalid") |
| if advance > 0: |
| raise ValueError("sampler advance > 0 is invalid") |
| return torch.utils.data.DistributedSampler( |
| dataset=dataset, |
| shuffle=shuffle, |
| seed=seed, |
| drop_last=False, |
| ) |
|
|
| logger.info("sampler: none") |
| return None |
|
|
|
|
| T = TypeVar("T") |
|
|
|
|
| def make_data_loader( |
| *, |
| dataset, |
| batch_size: int, |
| num_workers: int, |
| shuffle: bool = True, |
| seed: int = 0, |
| sampler_type: Optional[SamplerType] = SamplerType.INFINITE, |
| sampler_size: int = -1, |
| sampler_advance: int = 0, |
| drop_last: bool = True, |
| persistent_workers: bool = False, |
| collate_fn: Optional[Callable[[List[T]], Any]] = None, |
| ): |
| """ |
| Creates a data loader with the specified parameters. |
| |
| Args: |
| dataset: A dataset (third party, LaViDa or WebDataset). |
| batch_size: The size of batches to generate. |
| num_workers: The number of workers to use. |
| shuffle: Whether to shuffle samples. |
| seed: The random seed to use. |
| sampler_type: Which sampler to use: EPOCH, INFINITE, SHARDED_INFINITE, SHARDED_INFINITE_NEW, DISTRIBUTED or None. |
| sampler_size: The number of images per epoch (when applicable) or -1 for the entire dataset. |
| sampler_advance: How many samples to skip (when applicable). |
| drop_last: Whether the last non-full batch of data should be dropped. |
| persistent_workers: maintain the workers Dataset instances alive after a dataset has been consumed once. |
| collate_fn: Function that performs batch collation |
| """ |
|
|
| sampler = _make_sampler( |
| dataset=dataset, |
| type=sampler_type, |
| shuffle=shuffle, |
| seed=seed, |
| size=sampler_size, |
| advance=sampler_advance, |
| ) |
|
|
| logger.info("using PyTorch data loader") |
| data_loader = torch.utils.data.DataLoader( |
| dataset, |
| sampler=sampler, |
| batch_size=batch_size, |
| num_workers=num_workers, |
| pin_memory=True, |
| drop_last=drop_last, |
| persistent_workers=persistent_workers, |
| collate_fn=collate_fn, |
| ) |
|
|
| try: |
| logger.info(f"# of batches: {len(data_loader):,d}") |
| except TypeError: |
| logger.info("infinite data loader") |
| return data_loader |
|
|