|
|
|
|
|
|
|
|
|
|
| from typing import Any, Optional, Tuple
|
|
|
| from torch.utils.data import Dataset
|
|
|
|
|
| def extend_samples_with_index(dataset_class):
|
| class DatasetWithIndex(dataset_class):
|
| def __init__(self, **kwargs) -> None:
|
| root = dataset_class.get_root()
|
| super().__init__(root=root, **kwargs)
|
|
|
| def __getitem__(self, index: int):
|
| image, target = super().__getitem__(index)
|
| return image, target, index
|
|
|
| return DatasetWithIndex
|
|
|
|
|
| class DatasetWithEnumeratedTargets(Dataset):
|
| """
|
| If pad_dataset is set, pads based on torch's DistributedSampler implementation, which
|
| with drop_last=False pads the last batch to be a multiple of the world size.
|
| https://github.com/pytorch/pytorch/blob/main/torch/utils/data/distributed.py#L91
|
| """
|
|
|
| def __init__(self, dataset: Dataset, pad_dataset: bool = False, num_replicas: Optional[int] = None):
|
| self._dataset = dataset
|
| self._size = len(self._dataset)
|
| self._padded_size = self._size
|
| self._pad_dataset = pad_dataset
|
| if self._pad_dataset:
|
| assert num_replicas is not None, "num_replicas should be set if pad_dataset is True"
|
| self._padded_size = num_replicas * ((len(dataset) + num_replicas - 1) // num_replicas)
|
|
|
| def get_image_relpath(self, index: int) -> str:
|
| assert self._pad_dataset or index < self._size
|
| return self._dataset.get_image_relpath(index % self._size)
|
|
|
| def get_image_data(self, index: int) -> bytes:
|
| assert self._pad_dataset or index < self._size
|
| return self._dataset.get_image_data(index % self._size)
|
|
|
| def get_target(self, index: int) -> Tuple[Any, int]:
|
| target = self._dataset.get_target(index % self._size)
|
| if index >= self._size:
|
| assert self._pad_dataset
|
| return (-1, target)
|
| return (index, target)
|
|
|
| def get_sample_decoder(self, index: int) -> Any:
|
| assert self._pad_dataset or index < self._size
|
| return self._dataset.get_sample_decoder(index % self._size)
|
|
|
| def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]:
|
| image, target = self._dataset[index % self._size]
|
| if index >= self._size:
|
| assert self._pad_dataset
|
| return image, (-1, target)
|
| target = index if target is None else target
|
| return image, (index, target)
|
|
|
| def __len__(self) -> int:
|
| return self._padded_size
|
|
|