Spaces:
Sleeping
Sleeping
| # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import math | |
| import cv2 | |
| import numpy as np | |
| import random | |
| import paddle | |
| from paddleseg.cvlibs import manager | |
| import ppmatting.transforms as T | |
| class MattingDataset(paddle.io.Dataset): | |
| """ | |
| Pass in a dataset that conforms to the format. | |
| matting_dataset/ | |
| |--bg/ | |
| | | |
| |--train/ | |
| | |--fg/ | |
| | |--alpha/ | |
| | | |
| |--val/ | |
| | |--fg/ | |
| | |--alpha/ | |
| | |--trimap/ (if existing) | |
| | | |
| |--train.txt | |
| | | |
| |--val.txt | |
| See README.md for more information of dataset. | |
| Args: | |
| dataset_root(str): The root path of dataset. | |
| transforms(list): Transforms for image. | |
| mode (str, optional): which part of dataset to use. it is one of ('train', 'val', 'trainval'). Default: 'train'. | |
| train_file (str|list, optional): File list is used to train. It should be `foreground_image.png background_image.png` | |
| or `foreground_image.png`. It shold be provided if mode equal to 'train'. Default: None. | |
| val_file (str|list, optional): File list is used to evaluation. It should be `foreground_image.png background_image.png` | |
| or `foreground_image.png` or ``foreground_image.png background_image.png trimap_image.png`. | |
| It shold be provided if mode equal to 'val'. Default: None. | |
| get_trimap (bool, optional): Whether to get triamp. Default: True. | |
| separator (str, optional): The separator of train_file or val_file. If file name contains ' ', '|' may be perfect. Default: ' '. | |
| key_del (tuple|list, optional): The key which is not need will be delete to accellect data reader. Default: None. | |
| if_rssn (bool, optional): Whether to use RSSN while Compositing image. Including denoise and blur. Default: False. | |
| """ | |
| def __init__(self, | |
| dataset_root, | |
| transforms, | |
| mode='train', | |
| train_file=None, | |
| val_file=None, | |
| get_trimap=True, | |
| separator=' ', | |
| key_del=None, | |
| if_rssn=False): | |
| super().__init__() | |
| self.dataset_root = dataset_root | |
| self.transforms = T.Compose(transforms) | |
| self.mode = mode | |
| self.get_trimap = get_trimap | |
| self.separator = separator | |
| self.key_del = key_del | |
| self.if_rssn = if_rssn | |
| # check file | |
| if mode == 'train' or mode == 'trainval': | |
| if train_file is None: | |
| raise ValueError( | |
| "When `mode` is 'train' or 'trainval', `train_file must be provided!" | |
| ) | |
| if isinstance(train_file, str): | |
| train_file = [train_file] | |
| file_list = train_file | |
| if mode == 'val' or mode == 'trainval': | |
| if val_file is None: | |
| raise ValueError( | |
| "When `mode` is 'val' or 'trainval', `val_file must be provided!" | |
| ) | |
| if isinstance(val_file, str): | |
| val_file = [val_file] | |
| file_list = val_file | |
| if mode == 'trainval': | |
| file_list = train_file + val_file | |
| # read file | |
| self.fg_bg_list = [] | |
| for file in file_list: | |
| file = os.path.join(dataset_root, file) | |
| with open(file, 'r') as f: | |
| lines = f.readlines() | |
| for line in lines: | |
| line = line.strip() | |
| self.fg_bg_list.append(line) | |
| if mode != 'val': | |
| random.shuffle(self.fg_bg_list) | |
| def __getitem__(self, idx): | |
| data = {} | |
| fg_bg_file = self.fg_bg_list[idx] | |
| fg_bg_file = fg_bg_file.split(self.separator) | |
| data['img_name'] = fg_bg_file[0] # using in save prediction results | |
| fg_file = os.path.join(self.dataset_root, fg_bg_file[0]) | |
| alpha_file = fg_file.replace('/fg', '/alpha') | |
| fg = cv2.imread(fg_file) | |
| alpha = cv2.imread(alpha_file, 0) | |
| data['alpha'] = alpha | |
| data['gt_fields'] = [] | |
| # line is: fg [bg] [trimap] | |
| if len(fg_bg_file) >= 2: | |
| bg_file = os.path.join(self.dataset_root, fg_bg_file[1]) | |
| bg = cv2.imread(bg_file) | |
| data['img'], data['fg'], data['bg'] = self.composite(fg, alpha, bg) | |
| if self.mode in ['train', 'trainval']: | |
| data['gt_fields'].append('fg') | |
| data['gt_fields'].append('bg') | |
| data['gt_fields'].append('alpha') | |
| if len(fg_bg_file) == 3 and self.get_trimap: | |
| if self.mode == 'val': | |
| trimap_path = os.path.join(self.dataset_root, fg_bg_file[2]) | |
| if os.path.exists(trimap_path): | |
| data['trimap'] = trimap_path | |
| data['gt_fields'].append('trimap') | |
| data['ori_trimap'] = cv2.imread(trimap_path, 0) | |
| else: | |
| raise FileNotFoundError( | |
| 'trimap is not Found: {}'.format(fg_bg_file[2])) | |
| else: | |
| data['img'] = fg | |
| if self.mode in ['train', 'trainval']: | |
| data['fg'] = fg.copy() | |
| data['bg'] = fg.copy() | |
| data['gt_fields'].append('fg') | |
| data['gt_fields'].append('bg') | |
| data['gt_fields'].append('alpha') | |
| data['trans_info'] = [] # Record shape change information | |
| # Generate trimap from alpha if no trimap file provided | |
| if self.get_trimap: | |
| if 'trimap' not in data: | |
| data['trimap'] = self.gen_trimap( | |
| data['alpha'], mode=self.mode).astype('float32') | |
| data['gt_fields'].append('trimap') | |
| if self.mode == 'val': | |
| data['ori_trimap'] = data['trimap'].copy() | |
| # Delete key which is not need | |
| if self.key_del is not None: | |
| for key in self.key_del: | |
| if key in data.keys(): | |
| data.pop(key) | |
| if key in data['gt_fields']: | |
| data['gt_fields'].remove(key) | |
| data = self.transforms(data) | |
| # When evaluation, gt should not be transforms. | |
| if self.mode == 'val': | |
| data['gt_fields'].append('alpha') | |
| data['img'] = data['img'].astype('float32') | |
| for key in data.get('gt_fields', []): | |
| data[key] = data[key].astype('float32') | |
| if 'trimap' in data: | |
| data['trimap'] = data['trimap'][np.newaxis, :, :] | |
| if 'ori_trimap' in data: | |
| data['ori_trimap'] = data['ori_trimap'][np.newaxis, :, :] | |
| data['alpha'] = data['alpha'][np.newaxis, :, :] / 255. | |
| return data | |
| def __len__(self): | |
| return len(self.fg_bg_list) | |
| def composite(self, fg, alpha, ori_bg): | |
| if self.if_rssn: | |
| if np.random.rand() < 0.5: | |
| fg = cv2.fastNlMeansDenoisingColored(fg, None, 3, 3, 7, 21) | |
| ori_bg = cv2.fastNlMeansDenoisingColored(ori_bg, None, 3, 3, 7, | |
| 21) | |
| if np.random.rand() < 0.5: | |
| radius = np.random.choice([19, 29, 39, 49, 59]) | |
| ori_bg = cv2.GaussianBlur(ori_bg, (radius, radius), 0, 0) | |
| fg_h, fg_w = fg.shape[:2] | |
| ori_bg_h, ori_bg_w = ori_bg.shape[:2] | |
| wratio = fg_w / ori_bg_w | |
| hratio = fg_h / ori_bg_h | |
| ratio = wratio if wratio > hratio else hratio | |
| # Resize ori_bg if it is smaller than fg. | |
| if ratio > 1: | |
| resize_h = math.ceil(ori_bg_h * ratio) | |
| resize_w = math.ceil(ori_bg_w * ratio) | |
| bg = cv2.resize( | |
| ori_bg, (resize_w, resize_h), interpolation=cv2.INTER_LINEAR) | |
| else: | |
| bg = ori_bg | |
| bg = bg[0:fg_h, 0:fg_w, :] | |
| alpha = alpha / 255 | |
| alpha = np.expand_dims(alpha, axis=2) | |
| image = alpha * fg + (1 - alpha) * bg | |
| image = image.astype(np.uint8) | |
| return image, fg, bg | |
| def gen_trimap(alpha, mode='train', eval_kernel=7): | |
| if mode == 'train': | |
| k_size = random.choice(range(2, 5)) | |
| iterations = np.random.randint(5, 15) | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, | |
| (k_size, k_size)) | |
| dilated = cv2.dilate(alpha, kernel, iterations=iterations) | |
| eroded = cv2.erode(alpha, kernel, iterations=iterations) | |
| trimap = np.zeros(alpha.shape) | |
| trimap.fill(128) | |
| trimap[eroded > 254.5] = 255 | |
| trimap[dilated < 0.5] = 0 | |
| else: | |
| k_size = eval_kernel | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, | |
| (k_size, k_size)) | |
| dilated = cv2.dilate(alpha, kernel) | |
| trimap = np.zeros(alpha.shape) | |
| trimap.fill(128) | |
| trimap[alpha >= 250] = 255 | |
| trimap[dilated <= 5] = 0 | |
| return trimap | |