| | |
| | |
| | |
| | |
| |
|
| | import logging |
| | from enum import Enum |
| | from typing import Any, Callable, List, Optional, TypeVar |
| |
|
| | import torch |
| | from torch.utils.data import Sampler |
| |
|
| | from .datasets import ImageNet, ImageNet22k |
| | from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler |
| |
|
| |
|
| | logger = logging.getLogger("dinov2") |
| |
|
| |
|
| | class SamplerType(Enum): |
| | DISTRIBUTED = 0 |
| | EPOCH = 1 |
| | INFINITE = 2 |
| | SHARDED_INFINITE = 3 |
| | SHARDED_INFINITE_NEW = 4 |
| |
|
| |
|
| | def _make_bool_str(b: bool) -> str: |
| | return "yes" if b else "no" |
| |
|
| |
|
| | def _make_sample_transform(image_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None): |
| | def transform(sample): |
| | image, target = sample |
| | if image_transform is not None: |
| | image = image_transform(image) |
| | if target_transform is not None: |
| | target = target_transform(target) |
| | return image, target |
| |
|
| | return transform |
| |
|
| |
|
| | def _parse_dataset_str(dataset_str: str): |
| | tokens = dataset_str.split(":") |
| |
|
| | name = tokens[0] |
| | kwargs = {} |
| |
|
| | for token in tokens[1:]: |
| | key, value = token.split("=") |
| | assert key in ("root", "extra", "split") |
| | kwargs[key] = value |
| |
|
| | if name == "ImageNet": |
| | class_ = ImageNet |
| | if "split" in kwargs: |
| | kwargs["split"] = ImageNet.Split[kwargs["split"]] |
| | elif name == "ImageNet22k": |
| | class_ = ImageNet22k |
| | else: |
| | raise ValueError(f'Unsupported dataset "{name}"') |
| |
|
| | return class_, kwargs |
| |
|
| |
|
| | def make_dataset( |
| | *, |
| | dataset_str: str, |
| | transform: Optional[Callable] = None, |
| | target_transform: Optional[Callable] = None, |
| | ): |
| | """ |
| | Creates a dataset with the specified parameters. |
| | |
| | Args: |
| | dataset_str: A dataset string description (e.g. ImageNet:split=TRAIN). |
| | transform: A transform to apply to images. |
| | target_transform: A transform to apply to targets. |
| | |
| | Returns: |
| | The created dataset. |
| | """ |
| | logger.info(f'using dataset: "{dataset_str}"') |
| |
|
| | class_, kwargs = _parse_dataset_str(dataset_str) |
| | dataset = class_(transform=transform, target_transform=target_transform, **kwargs) |
| |
|
| | logger.info(f"# of dataset samples: {len(dataset):,d}") |
| |
|
| | |
| | if not hasattr(dataset, "transform"): |
| | setattr(dataset, "transform", transform) |
| | if not hasattr(dataset, "target_transform"): |
| | setattr(dataset, "target_transform", target_transform) |
| |
|
| | return dataset |
| |
|
| |
|
| | def _make_sampler( |
| | *, |
| | dataset, |
| | type: Optional[SamplerType] = None, |
| | shuffle: bool = False, |
| | seed: int = 0, |
| | size: int = -1, |
| | advance: int = 0, |
| | ) -> Optional[Sampler]: |
| | sample_count = len(dataset) |
| |
|
| | if type == SamplerType.INFINITE: |
| | logger.info("sampler: infinite") |
| | if size > 0: |
| | raise ValueError("sampler size > 0 is invalid") |
| | return InfiniteSampler( |
| | sample_count=sample_count, |
| | shuffle=shuffle, |
| | seed=seed, |
| | advance=advance, |
| | ) |
| | elif type in (SamplerType.SHARDED_INFINITE, SamplerType.SHARDED_INFINITE_NEW): |
| | logger.info("sampler: sharded infinite") |
| | if size > 0: |
| | raise ValueError("sampler size > 0 is invalid") |
| | |
| | use_new_shuffle_tensor_slice = type == SamplerType.SHARDED_INFINITE_NEW |
| | return ShardedInfiniteSampler( |
| | sample_count=sample_count, |
| | shuffle=shuffle, |
| | seed=seed, |
| | advance=advance, |
| | use_new_shuffle_tensor_slice=use_new_shuffle_tensor_slice, |
| | ) |
| | elif type == SamplerType.EPOCH: |
| | logger.info("sampler: epoch") |
| | if advance > 0: |
| | raise NotImplementedError("sampler advance > 0 is not supported") |
| | size = size if size > 0 else sample_count |
| | logger.info(f"# of samples / epoch: {size:,d}") |
| | return EpochSampler( |
| | size=size, |
| | sample_count=sample_count, |
| | shuffle=shuffle, |
| | seed=seed, |
| | ) |
| | elif type == SamplerType.DISTRIBUTED: |
| | logger.info("sampler: distributed") |
| | if size > 0: |
| | raise ValueError("sampler size > 0 is invalid") |
| | if advance > 0: |
| | raise ValueError("sampler advance > 0 is invalid") |
| | return torch.utils.data.DistributedSampler( |
| | dataset=dataset, |
| | shuffle=shuffle, |
| | seed=seed, |
| | drop_last=False, |
| | ) |
| |
|
| | logger.info("sampler: none") |
| | return None |
| |
|
| |
|
| | T = TypeVar("T") |
| |
|
| |
|
| | def make_data_loader( |
| | *, |
| | dataset, |
| | batch_size: int, |
| | num_workers: int, |
| | shuffle: bool = True, |
| | seed: int = 0, |
| | sampler_type: Optional[SamplerType] = SamplerType.INFINITE, |
| | sampler_size: int = -1, |
| | sampler_advance: int = 0, |
| | drop_last: bool = True, |
| | persistent_workers: bool = False, |
| | collate_fn: Optional[Callable[[List[T]], Any]] = None, |
| | ): |
| | """ |
| | Creates a data loader with the specified parameters. |
| | |
| | Args: |
| | dataset: A dataset (third party, LaViDa or WebDataset). |
| | batch_size: The size of batches to generate. |
| | num_workers: The number of workers to use. |
| | shuffle: Whether to shuffle samples. |
| | seed: The random seed to use. |
| | sampler_type: Which sampler to use: EPOCH, INFINITE, SHARDED_INFINITE, SHARDED_INFINITE_NEW, DISTRIBUTED or None. |
| | sampler_size: The number of images per epoch (when applicable) or -1 for the entire dataset. |
| | sampler_advance: How many samples to skip (when applicable). |
| | drop_last: Whether the last non-full batch of data should be dropped. |
| | persistent_workers: maintain the workers Dataset instances alive after a dataset has been consumed once. |
| | collate_fn: Function that performs batch collation |
| | """ |
| |
|
| | sampler = _make_sampler( |
| | dataset=dataset, |
| | type=sampler_type, |
| | shuffle=shuffle, |
| | seed=seed, |
| | size=sampler_size, |
| | advance=sampler_advance, |
| | ) |
| |
|
| | logger.info("using PyTorch data loader") |
| | data_loader = torch.utils.data.DataLoader( |
| | dataset, |
| | sampler=sampler, |
| | batch_size=batch_size, |
| | num_workers=num_workers, |
| | pin_memory=True, |
| | drop_last=drop_last, |
| | persistent_workers=persistent_workers, |
| | collate_fn=collate_fn, |
| | ) |
| |
|
| | try: |
| | logger.info(f"# of batches: {len(data_loader):,d}") |
| | except TypeError: |
| | logger.info("infinite data loader") |
| | return data_loader |
| |
|