Spaces:
Runtime error
Runtime error
| # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import math | |
| import time | |
| import cv2 | |
| import numpy as np | |
| import paddle | |
| import paddle.nn.functional as F | |
| from paddleseg import utils | |
| from paddleseg.core import infer | |
| from paddleseg.utils import logger, progbar, TimeAverager | |
| from matting.utils import mkdir | |
| def partition_list(arr, m): | |
| """split the list 'arr' into m pieces""" | |
| n = int(math.ceil(len(arr) / float(m))) | |
| return [arr[i:i + n] for i in range(0, len(arr), n)] | |
| def save_alpha_pred(alpha, path, trimap=None): | |
| """ | |
| The value of alpha is range [0, 1], shape should be [h,w] | |
| """ | |
| dirname = os.path.dirname(path) | |
| if not os.path.exists(dirname): | |
| os.makedirs(dirname) | |
| trimap = cv2.imread(trimap, 0) | |
| alpha[trimap == 0] = 0 | |
| alpha[trimap == 255] = 255 | |
| alpha = (alpha).astype('uint8') | |
| cv2.imwrite(path, alpha) | |
| def reverse_transform(alpha, trans_info): | |
| """recover pred to origin shape""" | |
| for item in trans_info[::-1]: | |
| if item[0] == 'resize': | |
| h, w = item[1][0], item[1][1] | |
| alpha = F.interpolate(alpha, [h, w], mode='bilinear') | |
| elif item[0] == 'padding': | |
| h, w = item[1][0], item[1][1] | |
| alpha = alpha[:, :, 0:h, 0:w] | |
| else: | |
| raise Exception("Unexpected info '{}' in im_info".format(item[0])) | |
| return alpha | |
| def preprocess(img, transforms, trimap=None): | |
| data = {} | |
| data['img'] = img | |
| if trimap is not None: | |
| data['trimap'] = trimap | |
| data['gt_fields'] = ['trimap'] | |
| data['trans_info'] = [] | |
| data = transforms(data) | |
| data['img'] = paddle.to_tensor(data['img']) | |
| data['img'] = data['img'].unsqueeze(0) | |
| if trimap is not None: | |
| data['trimap'] = paddle.to_tensor(data['trimap']) | |
| data['trimap'] = data['trimap'].unsqueeze((0, 1)) | |
| return data | |
| def predict(model, | |
| model_path, | |
| transforms, | |
| image_list, | |
| image_dir=None, | |
| trimap_list=None, | |
| save_dir='output'): | |
| """ | |
| predict and visualize the image_list. | |
| Args: | |
| model (nn.Layer): Used to predict for input image. | |
| model_path (str): The path of pretrained model. | |
| transforms (transforms.Compose): Preprocess for input image. | |
| image_list (list): A list of image path to be predicted. | |
| image_dir (str, optional): The root directory of the images predicted. Default: None. | |
| trimap_list (list, optional): A list of trimap of image_list. Default: None. | |
| save_dir (str, optional): The directory to save the visualized results. Default: 'output'. | |
| """ | |
| utils.utils.load_entire_model(model, model_path) | |
| model.eval() | |
| nranks = paddle.distributed.get_world_size() | |
| local_rank = paddle.distributed.get_rank() | |
| if nranks > 1: | |
| img_lists = partition_list(image_list, nranks) | |
| trimap_lists = partition_list( | |
| trimap_list, nranks) if trimap_list is not None else None | |
| else: | |
| img_lists = [image_list] | |
| trimap_lists = [trimap_list] if trimap_list is not None else None | |
| logger.info("Start to predict...") | |
| progbar_pred = progbar.Progbar(target=len(img_lists[0]), verbose=1) | |
| preprocess_cost_averager = TimeAverager() | |
| infer_cost_averager = TimeAverager() | |
| postprocess_cost_averager = TimeAverager() | |
| batch_start = time.time() | |
| with paddle.no_grad(): | |
| for i, im_path in enumerate(img_lists[local_rank]): | |
| preprocess_start = time.time() | |
| trimap = trimap_lists[local_rank][ | |
| i] if trimap_list is not None else None | |
| data = preprocess(img=im_path, transforms=transforms, trimap=trimap) | |
| preprocess_cost_averager.record(time.time() - preprocess_start) | |
| infer_start = time.time() | |
| alpha_pred = model(data) | |
| infer_cost_averager.record(time.time() - infer_start) | |
| postprocess_start = time.time() | |
| alpha_pred = reverse_transform(alpha_pred, data['trans_info']) | |
| alpha_pred = (alpha_pred.numpy()).squeeze() | |
| alpha_pred = (alpha_pred * 255).astype('uint8') | |
| # get the saved name | |
| # if image_dir is not None: | |
| # im_file = im_path.replace(image_dir, '') | |
| # else: | |
| # im_file = os.path.basename(im_path) | |
| # if im_file[0] == '/' or im_file[0] == '\\': | |
| # im_file = im_file[1:] | |
| # save_path = os.path.join(save_dir, im_file) | |
| # mkdir(save_path) | |
| # save_alpha_pred(alpha_pred, save_path, trimap=trimap) | |
| postprocess_cost_averager.record(time.time() - postprocess_start) | |
| preprocess_cost = preprocess_cost_averager.get_average() | |
| infer_cost = infer_cost_averager.get_average() | |
| postprocess_cost = postprocess_cost_averager.get_average() | |
| if local_rank == 0: | |
| progbar_pred.update(i + 1, | |
| [('preprocess_cost', preprocess_cost), | |
| ('infer_cost cost', infer_cost), | |
| ('postprocess_cost', postprocess_cost)]) | |
| preprocess_cost_averager.reset() | |
| infer_cost_averager.reset() | |
| postprocess_cost_averager.reset() | |
| return alpha_pred | |