| import glob | |
| from pathlib import Path | |
| from typing import Any, Callable, Optional | |
| from torchvision.datasets import VisionDataset | |
| class BaseDataset(VisionDataset): | |
| def __init__( | |
| self, | |
| root: str, | |
| loader: Callable[[str], Any], | |
| transforms: Optional[Callable] = None, | |
| transform: Optional[Callable] = None, | |
| target_transform: Optional[Callable] = None, | |
| train: bool = True | |
| ) -> None: | |
| super().__init__(root, transforms, transform, target_transform) | |
| self.root_path = Path(root) | |
| self.loader = loader | |
| mode = 'train' if train else 'test' | |
| self.data = sorted(glob.glob(f'{mode}/images/*.jpg', root_dir=root)) | |
| self.masks = sorted(glob.glob(f'{mode}/masks/*.png', root_dir=root)) | |
| def __getitem__(self, index: int) -> Any: | |
| img_path, mask_path = self.data[index], self.masks[index] | |
| img_path, mask_path = self.root_path / img_path, self.root_path / mask_path | |
| img, mask = self.loader(img_path), self.loader(mask_path) | |
| img, mask = self.transforms(img, mask) | |
| return img, mask.squeeze(dim=0).bool().float() | |
| def __len__(self) -> int: | |
| return len(self.data) | |