| import os.path |
| import torch |
| import torch.utils.data as data |
| from PIL import Image |
| import random |
| import utils |
| import numpy as np |
| import torchvision.transforms as transforms |
| from utils_core import flow_viz |
| import cv2 |
|
|
| class DDDataset(data.Dataset): |
| def __init__(self): |
| super(DDDataset, self).__init__() |
| def initialize(self, opt): |
| self.opt = opt |
| self.dir_txt = opt.datapath |
| self.paths = [] |
| in_file = open(self.dir_txt, "r") |
| k = 0 |
| list_paths = in_file.readlines() |
| for line in list_paths: |
| |
| flag = False |
| line = line.strip() |
| line = line.split() |
| |
| |
| if (not os.path.exists(line[0])): |
| print(line[0]+" not exists") |
| continue |
| if (not os.path.exists(line[1])): |
| print(line[1]+" not exists") |
| continue |
| if (not os.path.exists(line[2])): |
| print(line[2]+" not exists") |
| continue |
| if (not os.path.exists(line[3])): |
| print(line[3]+" not exists") |
| continue |
| |
| |
| |
|
|
| |
| path_list = [line[0], line[1], line[2], line[3]] |
| self.paths.append(path_list) |
| k += 1 |
| in_file.close() |
| self.data_size = len(self.paths) |
| print("num data: ", len(self.paths)) |
|
|
| def process_data(self, color, mask): |
| non_zero = mask.nonzero() |
| bound = 10 |
| min_x = max(0, non_zero[1].min()-bound) |
| max_x = min(self.opt.width-1, non_zero[1].max()+bound) |
| min_y = max(0, non_zero[0].min()-bound) |
| max_y = min(self.opt.height-1, non_zero[0].max()+bound) |
| color = color * (mask!=0).astype(float)[:, :, None] |
| crop_color = color[min_y:max_y, min_x:max_x, :] |
| crop_color = cv2.resize(np.ascontiguousarray(crop_color), (self.opt.crop_width, self.opt.crop_height), interpolation=cv2.INTER_LINEAR) |
| crop_params = [[min_x], [max_x], [min_y], [max_y]] |
|
|
| return crop_color, crop_params |
|
|
| def __getitem__(self, index): |
| paths = self.paths[index % self.data_size] |
| src_color = np.array(Image.open(paths[0])) |
| src_color = src_color.astype(np.uint8) |
| raw_src_color = src_color.copy() |
| src_mask = np.array(Image.open(paths[1])) |
| src_mask_copy = src_mask.copy() |
| src_crop_color, src_crop_params = self.process_data(src_color, src_mask) |
| |
| |
| raw_src_color = torch.from_numpy(raw_src_color).permute(2, 0, 1).float() / 255.0 |
| src_crop_color = torch.from_numpy(src_crop_color).permute(2, 0, 1).float() / 255.0 |
|
|
| src_mask_copy = (src_mask_copy!=0) |
| src_mask_copy = torch.tensor(src_mask_copy[np.newaxis, :, :]) |
|
|
| tar_color = np.array(Image.open(paths[2])) |
| tar_color = tar_color.astype(np.uint8) |
| raw_tar_color = tar_color.copy() |
| tar_mask = np.array(Image.open(paths[3])) |
| tar_mask_copy = tar_mask.copy() |
| tar_crop_color, tar_crop_params = self.process_data(tar_color, tar_mask) |
|
|
| raw_tar_color = torch.from_numpy(raw_tar_color).permute(2, 0, 1).float() / 255.0 |
| tar_crop_color = torch.from_numpy(tar_crop_color).permute(2, 0, 1).float() / 255.0 |
|
|
| tar_mask_copy = (tar_mask_copy!=0) |
| tar_mask_copy = torch.tensor(tar_mask_copy[np.newaxis, :, :]) |
|
|
| Crop_param = torch.tensor(src_crop_params+tar_crop_params) |
|
|
| split_ = paths[0].split("/") |
| path1 = split_[-1][:-4] + "_" + paths[2].split("/")[-1][:-4] +".oflow" |
|
|
| return {"path_flow":path1, "src_crop_color":src_crop_color, "tar_crop_color":tar_crop_color, "src_color":raw_src_color, "tar_color":raw_tar_color, "src_mask":src_mask_copy, "tar_mask":tar_mask_copy, "Crop_param":Crop_param} |
|
|
| def __len__(self): |
| return self.data_size |
|
|
| def name(self): |
| return 'DDDataset' |