| import glob | |
| import os | |
| from typing import List, Optional, Tuple | |
| import logging | |
| import numpy as np | |
| import torchvision.transforms.functional as TF | |
| import PIL | |
| from PIL import Image | |
| from torchvision.datasets import VisionDataset | |
| logger = logging.getLogger(__name__) | |
| class PathDataset(VisionDataset): | |
| def __init__( | |
| self, | |
| root: List[str], | |
| loader: None = None, | |
| transform: Optional[str] = None, | |
| extra_transform: Optional[str] = None, | |
| mean: Optional[List[float]] = None, | |
| std: Optional[List[float]] = None, | |
| ): | |
| super().__init__(root=root) | |
| PIL.Image.MAX_IMAGE_PIXELS = 256000001 | |
| self.files = [] | |
| for folder in self.root: | |
| self.files.extend( | |
| sorted(glob.glob(os.path.join(folder, "**", "*.jpg"), recursive=True)) | |
| ) | |
| self.files.extend( | |
| sorted(glob.glob(os.path.join(folder, "**", "*.png"), recursive=True)) | |
| ) | |
| self.transform = transform | |
| self.extra_transform = extra_transform | |
| self.mean = mean | |
| self.std = std | |
| self.loader = loader | |
| logger.info(f"loaded {len(self.files)} samples from {root}") | |
| assert (mean is None) == (std is None) | |
| def __len__(self) -> int: | |
| return len(self.files) | |
| def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]: | |
| path = self.files[idx] | |
| if self.loader is not None: | |
| return self.loader(path), None | |
| img = Image.open(path).convert("RGB") | |
| if self.transform is not None: | |
| img = self.transform(img) | |
| img = TF.to_tensor(img) | |
| if self.mean is not None and self.std is not None: | |
| img = TF.normalize(img, self.mean, self.std) | |
| return img, None | |