|
|
|
|
|
|
|
|
|
|
| from typing import Any, Tuple
|
|
|
| from torchvision.datasets import VisionDataset
|
|
|
| from .decoders import Decoder, ImageDataDecoder, TargetDecoder
|
|
|
|
|
| class ExtendedVisionDataset(VisionDataset):
|
| def __init__(
|
| self,
|
| image_decoder: Decoder = ImageDataDecoder,
|
| target_decoder: Decoder = TargetDecoder,
|
| *args,
|
| **kwargs,
|
| ) -> None:
|
| super().__init__(*args, **kwargs)
|
| self.image_decoder = image_decoder
|
| self.target_decoder = target_decoder
|
|
|
| def get_image_data(self, index: int) -> bytes:
|
| raise NotImplementedError
|
|
|
| def get_target(self, index: int) -> Any:
|
| raise NotImplementedError
|
|
|
| def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
| try:
|
| image_data = self.get_image_data(index)
|
| image = self.image_decoder(image_data).decode()
|
| except Exception as e:
|
| raise RuntimeError(f"can not read image for sample {index}") from e
|
| target = self.get_target(index)
|
| target = self.target_decoder(target).decode()
|
|
|
| if self.transforms is not None:
|
| image, target = self.transforms(image, target)
|
|
|
| return image, target
|
|
|
| def __len__(self) -> int:
|
| raise NotImplementedError
|
|
|