| | import os.path |
| | import torch |
| | import torch.utils.data as data |
| | from PIL import Image |
| | import random |
| | import utils |
| | import numpy as np |
| | import torchvision.transforms as transforms |
| | from utils_core import flow_viz |
| | import cv2 |
| |
|
| | class DDDataset(data.Dataset): |
| | def __init__(self): |
| | super(DDDataset, self).__init__() |
| | def initialize(self, opt): |
| | self.opt = opt |
| | self.dir_txt = opt.datapath |
| | self.paths = [] |
| | in_file = open(self.dir_txt, "r") |
| | k = 0 |
| | list_paths = in_file.readlines() |
| | for line in list_paths: |
| | |
| | flag = False |
| | line = line.strip() |
| | line = line.split() |
| | |
| | |
| | if (not os.path.exists(line[0])): |
| | print(line[0]+" not exists") |
| | continue |
| | if (not os.path.exists(line[1])): |
| | print(line[1]+" not exists") |
| | continue |
| | if (not os.path.exists(line[2])): |
| | print(line[2]+" not exists") |
| | continue |
| | if (not os.path.exists(line[3])): |
| | print(line[3]+" not exists") |
| | continue |
| | |
| | |
| | |
| |
|
| | |
| | path_list = [line[0], line[1], line[2], line[3]] |
| | self.paths.append(path_list) |
| | k += 1 |
| | in_file.close() |
| | self.data_size = len(self.paths) |
| | print("num data: ", len(self.paths)) |
| |
|
| | def process_data(self, color, mask): |
| | non_zero = mask.nonzero() |
| | bound = 10 |
| | min_x = max(0, non_zero[1].min()-bound) |
| | max_x = min(self.opt.width-1, non_zero[1].max()+bound) |
| | min_y = max(0, non_zero[0].min()-bound) |
| | max_y = min(self.opt.height-1, non_zero[0].max()+bound) |
| | color = color * (mask!=0).astype(float)[:, :, None] |
| | crop_color = color[min_y:max_y, min_x:max_x, :] |
| | crop_color = cv2.resize(np.ascontiguousarray(crop_color), (self.opt.crop_width, self.opt.crop_height), interpolation=cv2.INTER_LINEAR) |
| | crop_params = [[min_x], [max_x], [min_y], [max_y]] |
| |
|
| | return crop_color, crop_params |
| |
|
| | def __getitem__(self, index): |
| | paths = self.paths[index % self.data_size] |
| | src_color = np.array(Image.open(paths[0])) |
| | src_color = src_color.astype(np.uint8) |
| | raw_src_color = src_color.copy() |
| | src_mask = np.array(Image.open(paths[1]))[:, :, 0] |
| | cv2.imwrite("test_mask.png", src_mask) |
| | src_mask_copy = src_mask.copy() |
| | src_crop_color, src_crop_params = self.process_data(src_color, src_mask) |
| | |
| | |
| | raw_src_color = torch.from_numpy(raw_src_color).permute(2, 0, 1).float() / 255.0 |
| | src_crop_color = torch.from_numpy(src_crop_color).permute(2, 0, 1).float() / 255.0 |
| |
|
| | src_mask_copy = (src_mask_copy!=0) |
| | src_mask_copy = torch.tensor(src_mask_copy[np.newaxis, :, :]) |
| |
|
| | tar_color = np.array(Image.open(paths[2])) |
| | tar_color = tar_color.astype(np.uint8) |
| | raw_tar_color = tar_color.copy() |
| | tar_mask = np.array(Image.open(paths[3]))[:, :, 0] |
| | tar_mask_copy = tar_mask.copy() |
| | tar_crop_color, tar_crop_params = self.process_data(tar_color, tar_mask) |
| |
|
| | raw_tar_color = torch.from_numpy(raw_tar_color).permute(2, 0, 1).float() / 255.0 |
| | tar_crop_color = torch.from_numpy(tar_crop_color).permute(2, 0, 1).float() / 255.0 |
| |
|
| | tar_mask_copy = (tar_mask_copy!=0) |
| | tar_mask_copy = torch.tensor(tar_mask_copy[np.newaxis, :, :]) |
| |
|
| | Crop_param = torch.tensor(src_crop_params+tar_crop_params) |
| |
|
| | split_ = paths[0].split("/") |
| | path1 = split_[-1][:-4] + "_" + paths[2].split("/")[-1][:-4] +".oflow" |
| |
|
| | return {"path_flow":path1, "src_crop_color":src_crop_color, "tar_crop_color":tar_crop_color, "src_color":raw_src_color, "tar_color":raw_tar_color, "src_mask":src_mask_copy, "tar_mask":tar_mask_copy, "Crop_param":Crop_param} |
| |
|
| | def __len__(self): |
| | return self.data_size |
| |
|
| | def name(self): |
| | return 'DDDataset' |
| |
|