Spaces:
Paused
Paused
File size: 2,335 Bytes
3de0e37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import sys
sys.path.append('.')
import tqdm
import os
import os.path as osp
import numpy as np
import cv2
from global_value_utils import GLOBAL_DATA_ROOT, PARSING_COLOR_LIST, DATASET_NAME
from util.imutil import read_rgb, write_rgb
from external_code.face_parsing.my_parsing_util import FaceParsing
data_name = [d for d in DATASET_NAME if d != 'CelebaMask_HQ']
def makedir(pat):
if not os.path.exists(pat):
os.makedirs(pat)
def vis_parsing_maps(im, parsing_anno, stride, save_im, save_path, img_path):
# Colors for all 20 parts
label_path = os.path.join(save_path, 'label')
vis_path = os.path.join(save_path, 'vis')
makedir(pat=label_path)
makedir(pat=vis_path)
im = np.array(im)
vis_im = im.copy().astype(np.uint8)
vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
num_of_class = np.max(vis_parsing_anno)
img_path = img_path[:-4] + '.png'
for pi in range(0, num_of_class + 1):
index = np.where(vis_parsing_anno == pi)
if len(index[0]) > 0:
vis_parsing_anno_color[index[0], index[1], :] = PARSING_COLOR_LIST[pi]
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
vis_im = cv2.addWeighted(vis_im, 0.4, vis_parsing_anno_color, 0.6, 0)
cv2.imwrite(os.path.join(label_path, img_path), vis_parsing_anno)
write_rgb(os.path.join(vis_path, img_path), vis_im)
def evaluate(respth, dspth):
if not os.path.exists(respth):
os.makedirs(respth)
files = os.listdir(dspth)
files.sort()
for image_path in tqdm.tqdm(files):
parsing, image = FaceParsing.parsing_img(read_rgb(osp.join(dspth, image_path)))
parsing = FaceParsing.swap_parsing_label_to_celeba_mask(parsing)
vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=respth, img_path=image_path)
if __name__ == "__main__":
for dn in data_name:
input_dir = os.path.join(GLOBAL_DATA_ROOT, dn, 'images_256')
target_root_dir = os.path.join(GLOBAL_DATA_ROOT, dn)
evaluate(respth=target_root_dir, dspth=input_dir)
|