Spaces:
Paused
Paused
| #!/usr/bin/python | |
| # -*- encoding: utf-8 -*- | |
| import sys | |
| sys.path.append('.') | |
| import tqdm | |
| import os | |
| import os.path as osp | |
| import numpy as np | |
| import cv2 | |
| from global_value_utils import GLOBAL_DATA_ROOT, PARSING_COLOR_LIST, DATASET_NAME | |
| from util.imutil import read_rgb, write_rgb | |
| from external_code.face_parsing.my_parsing_util import FaceParsing | |
| data_name = [d for d in DATASET_NAME if d != 'CelebaMask_HQ'] | |
| def makedir(pat): | |
| if not os.path.exists(pat): | |
| os.makedirs(pat) | |
| def vis_parsing_maps(im, parsing_anno, stride, save_im, save_path, img_path): | |
| # Colors for all 20 parts | |
| label_path = os.path.join(save_path, 'label') | |
| vis_path = os.path.join(save_path, 'vis') | |
| makedir(pat=label_path) | |
| makedir(pat=vis_path) | |
| im = np.array(im) | |
| vis_im = im.copy().astype(np.uint8) | |
| vis_parsing_anno = parsing_anno.copy().astype(np.uint8) | |
| vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) | |
| vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 | |
| num_of_class = np.max(vis_parsing_anno) | |
| img_path = img_path[:-4] + '.png' | |
| for pi in range(0, num_of_class + 1): | |
| index = np.where(vis_parsing_anno == pi) | |
| if len(index[0]) > 0: | |
| vis_parsing_anno_color[index[0], index[1], :] = PARSING_COLOR_LIST[pi] | |
| vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) | |
| vis_im = cv2.addWeighted(vis_im, 0.4, vis_parsing_anno_color, 0.6, 0) | |
| cv2.imwrite(os.path.join(label_path, img_path), vis_parsing_anno) | |
| write_rgb(os.path.join(vis_path, img_path), vis_im) | |
| def evaluate(respth, dspth): | |
| if not os.path.exists(respth): | |
| os.makedirs(respth) | |
| files = os.listdir(dspth) | |
| files.sort() | |
| for image_path in tqdm.tqdm(files): | |
| parsing, image = FaceParsing.parsing_img(read_rgb(osp.join(dspth, image_path))) | |
| parsing = FaceParsing.swap_parsing_label_to_celeba_mask(parsing) | |
| vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=respth, img_path=image_path) | |
| if __name__ == "__main__": | |
| for dn in data_name: | |
| input_dir = os.path.join(GLOBAL_DATA_ROOT, dn, 'images_256') | |
| target_root_dir = os.path.join(GLOBAL_DATA_ROOT, dn) | |
| evaluate(respth=target_root_dir, dspth=input_dir) | |