File size: 2,335 Bytes
3de0e37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/python
# -*- encoding: utf-8 -*-

import sys
sys.path.append('.')

import tqdm

import os
import os.path as osp
import numpy as np
import cv2
from global_value_utils import GLOBAL_DATA_ROOT, PARSING_COLOR_LIST, DATASET_NAME
from util.imutil import read_rgb, write_rgb
from external_code.face_parsing.my_parsing_util import FaceParsing


data_name = [d for d in DATASET_NAME if d != 'CelebaMask_HQ']

def makedir(pat):
    if not os.path.exists(pat):
        os.makedirs(pat)


def vis_parsing_maps(im, parsing_anno, stride, save_im, save_path, img_path):
    # Colors for all 20 parts

    label_path = os.path.join(save_path, 'label')
    vis_path = os.path.join(save_path, 'vis')
    makedir(pat=label_path)
    makedir(pat=vis_path)

    im = np.array(im)
    vis_im = im.copy().astype(np.uint8)
    vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
    vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
    vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255

    num_of_class = np.max(vis_parsing_anno)

    img_path = img_path[:-4] + '.png'

    for pi in range(0, num_of_class + 1):
        index = np.where(vis_parsing_anno == pi)
        if len(index[0]) > 0:
            vis_parsing_anno_color[index[0], index[1], :] = PARSING_COLOR_LIST[pi]

    vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
    vis_im = cv2.addWeighted(vis_im, 0.4, vis_parsing_anno_color, 0.6, 0)

    cv2.imwrite(os.path.join(label_path, img_path), vis_parsing_anno)
    write_rgb(os.path.join(vis_path, img_path), vis_im)


def evaluate(respth, dspth):
    if not os.path.exists(respth):
        os.makedirs(respth)

    files = os.listdir(dspth)
    files.sort()
    for image_path in tqdm.tqdm(files):
        parsing, image = FaceParsing.parsing_img(read_rgb(osp.join(dspth, image_path)))
        parsing = FaceParsing.swap_parsing_label_to_celeba_mask(parsing)
        vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=respth, img_path=image_path)


if __name__ == "__main__":
    for dn in data_name:
        input_dir = os.path.join(GLOBAL_DATA_ROOT, dn, 'images_256')
        target_root_dir = os.path.join(GLOBAL_DATA_ROOT, dn)
        evaluate(respth=target_root_dir, dspth=input_dir)