| import logging |
| from os import listdir |
| from os.path import splitext |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
| from torch.utils.data import Dataset |
|
|
|
|
| class BasicDataset(Dataset): |
| def __init__(self, images_dir: str, masks_dir: str, scale: float = 1.0, mask_suffix: str = ''): |
| self.images_dir = Path(images_dir) |
| self.masks_dir = Path(masks_dir) |
| assert 0 < scale <= 1, 'Scale must be between 0 and 1' |
| self.scale = scale |
| self.mask_suffix = mask_suffix |
|
|
| self.ids = [splitext(file)[0] for file in listdir(images_dir) if not file.startswith('.')] |
| if not self.ids: |
| raise RuntimeError(f'No input file found in {images_dir}, make sure you put your images there') |
| logging.info(f'Creating dataset with {len(self.ids)} examples') |
|
|
| def __len__(self): |
| return len(self.ids) |
|
|
| @classmethod |
| def preprocess(cls, pil_img, scale, is_mask): |
| w= h = pil_img.size |
| newW, newH = int(scale * w), int(scale * h) |
| assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel' |
| |
| |
| img_ndarray = np.asarray(pil_img) |
|
|
| if img_ndarray.ndim == 2 and not is_mask: |
| img_ndarray = img_ndarray[np.newaxis, ...] |
| elif not is_mask: |
| img_ndarray = img_ndarray.transpose((2, 0, 1)) |
|
|
| if not is_mask: |
| img_ndarray = img_ndarray / 255 |
|
|
| return img_ndarray |
|
|
| @classmethod |
| def load(cls, filename): |
| ext = splitext(filename)[1] |
| if ext in ['.npz', '.npy']: |
| return Image.fromarray(np.load(filename)) |
| elif ext in ['.pt', '.pth']: |
| return Image.fromarray(torch.load(filename).numpy()) |
| else: |
| return Image.open(filename) |
|
|
| def __getitem__(self, idx): |
| name = self.ids[idx] |
| mask_file = list(self.masks_dir.glob(name + self.mask_suffix + '.*')) |
| img_file = list(self.images_dir.glob(name + '.*')) |
|
|
| assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}' |
| assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}' |
| mask = self.load(mask_file[0]) |
| img = self.load(img_file[0]) |
|
|
| assert img.size == mask.size, \ |
| 'Image and mask {name} should be the same size, but are {img.size} and {mask.size}' |
|
|
| img = self.preprocess(img, self.scale, is_mask=False) |
| mask = self.preprocess(mask, self.scale, is_mask=True) |
|
|
| return { |
| 'image': torch.as_tensor(img.copy()).float().contiguous(), |
| 'mask': torch.as_tensor(mask.copy()).long().contiguous() |
| } |
|
|
|
|
| class GODataset(BasicDataset): |
| def __init__(self, images_dir, masks_dir, scale=1): |
| super().__init__(images_dir, masks_dir, scale, mask_suffix='_gt') |
|
|