|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """
|
| Created in September 2022
|
| @author: fabrizio.guillaro
|
| """
|
|
|
| from abc import ABC, abstractmethod
|
| from PIL import Image
|
| import numpy as np
|
| import torch
|
| import random
|
| import cv2
|
|
|
|
|
| class AbstractDataset(ABC):
|
|
|
| def __init__(self, crop_size, grid_crop: bool, max_dim=None, aug=None):
|
| """
|
| :param crop_size: (H, W) or None. H and W must be the multiple of 8 if grid_crop==True.
|
| :param grid_crop: T: crop within 8x8 grid. F: crop anywhere.
|
| :param max_dim: if image is bigger than this size, it is cropped
|
| :param aug: augmentation
|
| """
|
| self._crop_size = crop_size
|
| self._max_dim = max_dim
|
| self._grid_crop = grid_crop
|
|
|
| if grid_crop and crop_size is not None:
|
| assert crop_size[0] % 8 == 0 and crop_size[1] % 8 == 0
|
|
|
| self.img_list = None
|
| self.aug = aug
|
|
|
|
|
|
|
|
|
| def _create_tensor(self, mask=None, rgb_path=None):
|
| ignore_index = -1
|
|
|
| try:
|
| img_RGB = np.array(Image.open(rgb_path).convert("RGB"))
|
| except:
|
| raise ValueError(f'error path: {rgb_path}')
|
|
|
| h, w = img_RGB.shape[0], img_RGB.shape[1]
|
|
|
| if mask is None:
|
| mask = np.zeros((h, w))
|
| elif mask.shape[0]!=h or mask.shape[1]!=w:
|
|
|
| print(f'MASK MISMATCH: {rgb_path} \n h:{h}, w:{w}, mask: {mask.shape}', flush=True)
|
| try:
|
| mask = np.ascontiguousarray(np.rot90(mask))
|
| assert mask.shape[0]==h and mask.shape[1]==w
|
| except:
|
| mask = cv2.resize(np.uint8(mask), (h, w), interpolation=cv2.INTER_NEAREST)>0
|
|
|
|
|
| if self.aug is not None:
|
| mask = np.uint8(mask)
|
| dat = self.aug(image=img_RGB, mask=mask)
|
| assert dat['image'].dtype==img_RGB.dtype
|
| assert dat['mask'].dtype==mask.dtype
|
| img_RGB = dat['image']
|
| mask = dat['mask']>0
|
| h, w = img_RGB.shape[0], img_RGB.shape[1]
|
| del dat
|
|
|
|
|
| if self._crop_size is None and self._grid_crop:
|
| crop_size = (-(-h//8) * 8, -(-w//8) * 8)
|
| elif self._crop_size is None and not self._grid_crop:
|
| crop_size = None
|
| else:
|
| crop_size = self._crop_size
|
|
|
| if crop_size is not None:
|
|
|
| if h < crop_size[0] or w < crop_size[1]:
|
|
|
|
|
| if img_RGB is not None:
|
| temp = np.full((max(h, crop_size[0]), max(w, crop_size[1]), 3), 127.5)
|
| temp[:img_RGB.shape[0], :img_RGB.shape[1], :] = img_RGB
|
| img_RGB = temp
|
|
|
|
|
| temp = np.full((max(h, crop_size[0]), max(w, crop_size[1])), ignore_index)
|
| try:
|
| temp[:mask.shape[0], :mask.shape[1]] = mask
|
| mask = temp
|
| except:
|
| raise ValueError(f'{rgb_path}\nh:{h}, w:{w}, temp:{temp.shape}, mask: {mask.shape}')
|
|
|
|
|
| if self._grid_crop:
|
| s_r = (random.randint(0, max(h - crop_size[0], 0)) // 8) * 8
|
| s_c = (random.randint(0, max(w - crop_size[1], 0)) // 8) * 8
|
| else:
|
| s_r = random.randint(0, max(h - crop_size[0], 0))
|
| s_c = random.randint(0, max(w - crop_size[1], 0))
|
|
|
|
|
| mask = mask[s_r:s_r+crop_size[0], s_c:s_c+crop_size[1]]
|
| img_RGB = img_RGB[s_r:s_r+crop_size[0], s_c:s_c+crop_size[1], :]
|
|
|
|
|
| if self._max_dim is not None:
|
| max_dim = self._max_dim
|
|
|
| s_r = (max((h - max_dim)//2, 0) // 8) * 8
|
| s_c = (max((w - max_dim)//2, 0) // 8) * 8
|
|
|
|
|
| mask = mask[s_r:s_r+max_dim, s_c:s_c+max_dim]
|
| img_RGB = img_RGB[s_r:s_r+max_dim, s_c:s_c+max_dim, :]
|
|
|
| t_mask = torch.tensor(mask, dtype=torch.long)
|
| t_RGB = torch.tensor(img_RGB.transpose(2,0,1), dtype=torch.float)/256.0
|
| return t_RGB, t_mask
|
|
|
|
|
| @abstractmethod
|
| def get_img(self, index):
|
| pass
|
|
|
| def get_img_name(self, index):
|
| item = self.img_list[index]
|
| if isinstance(item, list):
|
| return item[0]
|
| else:
|
| return item
|
|
|
| def __len__(self):
|
| return len(self.img_list)
|
|
|
|
|