| | |
| | |
| | |
| | |
| |
|
| | from typing import Any, Tuple |
| |
|
| | from torch.utils.data import Dataset |
| |
|
| |
|
| | class DatasetWithEnumeratedTargets(Dataset): |
| | def __init__(self, dataset): |
| | self._dataset = dataset |
| |
|
| | def get_image_data(self, index: int) -> bytes: |
| | return self._dataset.get_image_data(index) |
| |
|
| | def get_target(self, index: int) -> Tuple[Any, int]: |
| | target = self._dataset.get_target(index) |
| | return (index, target) |
| |
|
| | def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]: |
| | image, target = self._dataset[index] |
| | target = index if target is None else target |
| | return image, (index, target) |
| |
|
| | def __len__(self) -> int: |
| | return len(self._dataset) |
| |
|