| import bisect |
| import numpy as np |
| import albumentations |
| from PIL import Image |
| from torch.utils.data import Dataset, ConcatDataset |
|
|
|
|
| class ConcatDatasetWithIndex(ConcatDataset): |
| """Modified from original pytorch code to return dataset idx""" |
| def __getitem__(self, idx): |
| if idx < 0: |
| if -idx > len(self): |
| raise ValueError("absolute value of index should not exceed dataset length") |
| idx = len(self) + idx |
| dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) |
| if dataset_idx == 0: |
| sample_idx = idx |
| else: |
| sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] |
| return self.datasets[dataset_idx][sample_idx], dataset_idx |
|
|
|
|
| class ImagePaths(Dataset): |
| def __init__(self, paths, size=None, random_crop=False, labels=None): |
| self.size = size |
| self.random_crop = random_crop |
|
|
| self.labels = dict() if labels is None else labels |
| self.labels["file_path_"] = paths |
| self._length = len(paths) |
|
|
| if self.size is not None and self.size > 0: |
| self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) |
| if not self.random_crop: |
| self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) |
| else: |
| self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) |
| self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) |
| else: |
| self.preprocessor = lambda **kwargs: kwargs |
|
|
| def __len__(self): |
| return self._length |
|
|
| def preprocess_image(self, image_path): |
| image = Image.open(image_path) |
| if not image.mode == "RGB": |
| image = image.convert("RGB") |
| image = np.array(image).astype(np.uint8) |
| image = self.preprocessor(image=image)["image"] |
| image = (image/127.5 - 1.0).astype(np.float32) |
| return image |
|
|
| def __getitem__(self, i): |
| example = dict() |
| example["image"] = self.preprocess_image(self.labels["file_path_"][i]) |
| for k in self.labels: |
| example[k] = self.labels[k][i] |
| return example |
|
|
|
|
| class NumpyPaths(ImagePaths): |
| def preprocess_image(self, image_path): |
| image = np.load(image_path).squeeze(0) |
| image = np.transpose(image, (1,2,0)) |
| image = Image.fromarray(image, mode="RGB") |
| image = np.array(image).astype(np.uint8) |
| image = self.preprocessor(image=image)["image"] |
| image = (image/127.5 - 1.0).astype(np.float32) |
| return image |
|
|