File size: 2,678 Bytes
bc90483 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | # Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.
from typing import Any, Optional, Tuple
from torch.utils.data import Dataset
def extend_samples_with_index(dataset_class):
class DatasetWithIndex(dataset_class):
def __init__(self, **kwargs) -> None:
root = dataset_class.get_root()
super().__init__(root=root, **kwargs)
def __getitem__(self, index: int):
image, target = super().__getitem__(index)
return image, target, index
return DatasetWithIndex
class DatasetWithEnumeratedTargets(Dataset):
"""
If pad_dataset is set, pads based on torch's DistributedSampler implementation, which
with drop_last=False pads the last batch to be a multiple of the world size.
https://github.com/pytorch/pytorch/blob/main/torch/utils/data/distributed.py#L91
"""
def __init__(self, dataset: Dataset, pad_dataset: bool = False, num_replicas: Optional[int] = None):
self._dataset = dataset
self._size = len(self._dataset)
self._padded_size = self._size
self._pad_dataset = pad_dataset
if self._pad_dataset:
assert num_replicas is not None, "num_replicas should be set if pad_dataset is True"
self._padded_size = num_replicas * ((len(dataset) + num_replicas - 1) // num_replicas)
def get_image_relpath(self, index: int) -> str:
assert self._pad_dataset or index < self._size
return self._dataset.get_image_relpath(index % self._size)
def get_image_data(self, index: int) -> bytes:
assert self._pad_dataset or index < self._size
return self._dataset.get_image_data(index % self._size)
def get_target(self, index: int) -> Tuple[Any, int]:
target = self._dataset.get_target(index % self._size)
if index >= self._size:
assert self._pad_dataset
return (-1, target)
return (index, target)
def get_sample_decoder(self, index: int) -> Any:
assert self._pad_dataset or index < self._size
return self._dataset.get_sample_decoder(index % self._size)
def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]:
image, target = self._dataset[index % self._size]
if index >= self._size:
assert self._pad_dataset
return image, (-1, target)
target = index if target is None else target
return image, (index, target)
def __len__(self) -> int:
return self._padded_size
|