| | import os |
| | import numpy as np |
| | from glob import glob |
| | import torchvision.transforms.functional as TF |
| | from PIL import Image |
| | from torch.utils import data |
| |
|
| | from . import transforms as my_tf |
| | from myutils import load_image_in_PIL as load_img |
| |
|
| |
|
| | def load_image_in_PIL(path, mode='RGB'): |
| | img = Image.open(path) |
| | img.load() |
| | return img.convert(mode) |
| |
|
| |
|
| | class WaterDataset(data.Dataset): |
| |
|
| | def __init__(self, mode, dataset_path, input_size=None, test_case=None, eval_size=None): |
| |
|
| | super(WaterDataset, self).__init__() |
| |
|
| | self.mode = mode |
| | self.input_size = input_size |
| | self.test_case = test_case |
| | self.img_list = [] |
| | self.label_list = [] |
| | self.verbose_flag = False |
| | self.online_augmentation_per_epoch = 640 |
| | self.eval_size = eval_size |
| |
|
| | if mode == 'train_offline': |
| | with open(os.path.join(dataset_path, 'train_imgs.txt')) as f: |
| | water_subdirs = f.readlines() |
| | water_subdirs = [x.strip() for x in water_subdirs] |
| |
|
| | print('Initialize offline training dataset:') |
| |
|
| | for sub_folder in water_subdirs: |
| | label_list = glob(os.path.join(dataset_path, 'Annotations/', sub_folder, '*.png')) |
| | label_list.sort(key=lambda x: (len(x), x)) |
| | self.label_list += label_list |
| |
|
| | name_list = [os.path.basename(x)[:-4] for x in label_list] |
| |
|
| | img_list = glob(os.path.join(dataset_path, 'JPEGImages/', sub_folder, '*.jpg')) |
| | img_list.sort(key=lambda x: (len(x), x)) |
| | img_list_valid = [] |
| | for img_path in img_list: |
| | if os.path.basename(img_path)[:-4] in name_list: |
| | img_list_valid.append(img_path) |
| |
|
| | self.img_list += img_list_valid |
| |
|
| | print('Add', sub_folder, len(img_list_valid), 'files.') |
| |
|
| |
|
| |
|
| | elif mode == 'eval': |
| | if test_case is None: |
| | raise ('test_case can not be None.') |
| |
|
| | img_path = os.path.join(dataset_path, 'JPEGImages/', test_case) |
| | img_list = os.listdir(img_path) |
| | img_list.sort(key=lambda x: (len(x), x)) |
| | self.img_list = [os.path.join(img_path, name) for name in img_list] |
| |
|
| | first_frame_label_path = os.path.join(dataset_path, 'Annotations/', test_case, img_list[0]) |
| |
|
| | |
| | first_frame_label_path = first_frame_label_path[:-3] |
| | if os.path.exists(first_frame_label_path + 'png'): |
| | first_frame_label_path += 'png' |
| | else: |
| | first_frame_label_path += 'jpg' |
| |
|
| | if not os.path.exists(first_frame_label_path): |
| | label_list = glob(os.path.join(dataset_path, 'Annotations/', test_case, '*.png')) |
| | label_list.sort(key=lambda x: (x, len(x))) |
| | first_frame_label_path = label_list[0] |
| |
|
| | self.first_frame = load_image_in_PIL(self.img_list[0], 'RGB') |
| | self.img_list.pop(0) |
| |
|
| | self.first_frame_label = load_image_in_PIL(first_frame_label_path, 'P') |
| |
|
| | if self.eval_size: |
| | self.origin_size = self.first_frame.size |
| | self.first_frame = self.first_frame.resize(self.eval_size, Image.ANTIALIAS) |
| | self.first_frame_label = self.first_frame_label.resize(self.eval_size, Image.ANTIALIAS) |
| |
|
| | else: |
| | raise ('Mode %s does not support in [train_offline, train_online, eval].' % mode) |
| |
|
| | def __len__(self): |
| | if self.mode == 'train_online': |
| | return self.online_augmentation_per_epoch |
| | else: |
| | return len(self.img_list) |
| |
|
| | def get_first_frame(self): |
| | img_tf = TF.to_tensor(self.first_frame) |
| | img_tf = my_tf.imagenet_normalization(img_tf) |
| | return img_tf |
| |
|
| | def get_first_frame_label(self): |
| | return TF.to_tensor(self.first_frame_label) |
| |
|
| | def __getitem__(self, index): |
| | raise NotImplementedError |
| |
|
| |
|
| | class WaterDataset_RGB(WaterDataset): |
| | def __init__(self, mode, dataset_path, input_size=None, test_case=None, eval_size=None): |
| | super(WaterDataset_RGB, self).__init__(mode, dataset_path, input_size, test_case, eval_size) |
| |
|
| | def __getitem__(self, index): |
| | if self.mode == 'train_offline' or self.mode == 'val_offline' or self.mode == 'test_offline': |
| | img = load_img(self.img_list[index], 'RGB') |
| | label = load_img(self.label_list[index], 'P') |
| | return self.apply_transforms(img, label) |
| | elif self.mode == 'train_online': |
| | return self.apply_transforms(self.first_frame, self.first_frame_label) |
| | elif self.mode == 'eval': |
| | img = load_img(self.img_list[index], 'RGB') |
| | if self.eval_size: |
| | img = img.resize(self.eval_size, Image.ANTIALIAS) |
| | return self.apply_transforms(img) |
| | else: |
| | raise Exception("Error: Invalid dataset mode!") |
| |
|
| | def resize_to_origin(self, img): |
| | return img.resize(self.origin_size) |
| |
|
| | def apply_transforms(self, img, label=None): |
| | if self.mode == 'train_offline' or self.mode == 'train_online': |
| | img = my_tf.random_adjust_color(img, self.verbose_flag) |
| | img, label = my_tf.random_affine_transformation(img, None, label, self.verbose_flag) |
| | img, label = my_tf.random_resized_crop(img, None, label, self.input_size, self.verbose_flag) |
| | elif self.mode == 'test_offline' or self.mode == 'val_offline': |
| | img = TF.resize(img, self.input_size) |
| | label = TF.resize(label, self.input_size) |
| | elif self.mode == 'eval': |
| | pass |
| |
|
| | img_orig = TF.to_tensor(img) |
| | img_norm = my_tf.imagenet_normalization(img_orig) |
| |
|
| | if self.mode == 'train_offline' or self.mode == 'train_online': |
| | |
| | label = np.expand_dims(np.array(label, np.float32), axis=0) |
| | return img_norm, label |
| | elif self.mode == 'val_offline': |
| | label = np.expand_dims(np.array(label, np.float32), axis=0) |
| | return img_norm, label |
| | elif self.mode == 'test_offline': |
| | label = np.expand_dims(np.array(label, np.float32), axis=0) |
| | return img_norm, label, img_orig |
| | else: |
| | return None |