| | import cv2
|
| | import numpy as np
|
| | import torch
|
| | from torch.utils.data import Dataset
|
| |
|
| | from config import *
|
| |
|
| |
|
| | class InpaintDataset(Dataset):
|
| | def __init__(self, in_image, mask_image, resize_to):
|
| | if resize_to is None:
|
| | resize_to = RESIZE_TO
|
| | self.imglist = [in_image]
|
| | self.masklist = [mask_image]
|
| | self.setsize = resize_to
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def __len__(self):
|
| | return len(self.imglist)
|
| |
|
| | def __getitem__(self, index):
|
| |
|
| | img = cv2.imread(self.imglist[index])
|
| | mask = cv2.imread(self.masklist[index])[:, :, 0]
|
| |
|
| |
|
| |
|
| | img = cv2.resize(img, self.setsize)
|
| | mask = cv2.resize(mask, self.setsize)
|
| |
|
| |
|
| | """
|
| | contours, hier = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| | for cidx, cnt in enumerate(contours):
|
| | (x, y, w, h) = cv2.boundingRect(cnt)
|
| | mask[y:y+h, x:x+w] = 255
|
| | """
|
| | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| |
|
| | img = (
|
| | torch.from_numpy(img.astype(np.float32) / 255.0)
|
| | .permute(2, 0, 1)
|
| | .contiguous()
|
| | )
|
| | mask = (
|
| | torch.from_numpy(mask.astype(np.float32) / 255.0)
|
| | .unsqueeze(0)
|
| | .contiguous()
|
| | )
|
| | return img, mask
|
| |
|