| | |
| | |
| | import numpy as np |
| | from model import BiSeNet |
| |
|
| | import torch |
| |
|
| | import os |
| | import os.path as osp |
| |
|
| | from PIL import Image |
| | import torchvision.transforms as transforms |
| | import cv2 |
| | from pathlib import Path |
| | import configargparse |
| | import tqdm |
| |
|
| | |
| |
|
| | def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg', |
| | img_size=(512, 512)): |
| | im = np.array(im) |
| | vis_im = im.copy().astype(np.uint8) |
| | vis_parsing_anno = parsing_anno.copy().astype(np.uint8) |
| | vis_parsing_anno = cv2.resize( |
| | vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) |
| | vis_parsing_anno_color = np.zeros( |
| | (vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + np.array([255, 255, 255]) |
| | vis_parsing_anno_color_face = np.zeros( |
| | (vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + np.array([255, 255, 255]) |
| |
|
| | num_of_class = np.max(vis_parsing_anno) |
| | |
| | for pi in range(1, 14): |
| | index = np.where(vis_parsing_anno == pi) |
| | vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0]) |
| | for pi in range(14, 16): |
| | index = np.where(vis_parsing_anno == pi) |
| | vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 255, 0]) |
| | for pi in range(16, 17): |
| | index = np.where(vis_parsing_anno == pi) |
| | vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 0, 255]) |
| | for pi in range(17, num_of_class+1): |
| | index = np.where(vis_parsing_anno == pi) |
| | vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0]) |
| |
|
| | vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) |
| | index = np.where(vis_parsing_anno == num_of_class-1) |
| | vis_im = cv2.resize(vis_parsing_anno_color, img_size, |
| | interpolation=cv2.INTER_NEAREST) |
| | if save_im: |
| | cv2.imwrite(save_path, vis_im) |
| |
|
| | for pi in range(1, 7): |
| | index = np.where(vis_parsing_anno == pi) |
| | vis_parsing_anno_color_face[index[0], index[1], :] = np.array([255, 0, 0]) |
| | for pi in range(10, 14): |
| | index = np.where(vis_parsing_anno == pi) |
| | vis_parsing_anno_color_face[index[0], index[1], :] = np.array([255, 0, 0]) |
| | pad = 5 |
| | vis_parsing_anno_color_face = vis_parsing_anno_color_face.astype(np.uint8) |
| | face_part = (vis_parsing_anno_color_face[..., 0] == 255) & (vis_parsing_anno_color_face[..., 1] == 0) & (vis_parsing_anno_color_face[..., 2] == 0) |
| | face_coords = np.stack(np.nonzero(face_part), axis=-1) |
| | sorted_inds = np.lexsort((-face_coords[:, 0], face_coords[:, 1])) |
| | sorted_face_coords = face_coords[sorted_inds] |
| | u, uid, ucnt = np.unique(sorted_face_coords[:, 1], return_index=True, return_counts=True) |
| | bottom_face_coords = sorted_face_coords[uid] + np.array([pad, 0]) |
| | rows, cols, _ = vis_parsing_anno_color_face.shape |
| |
|
| | |
| | bottom_face_coords[:, 0] = np.clip(bottom_face_coords[:, 0], 0, rows - 1) |
| |
|
| | y_min = np.min(bottom_face_coords[:, 1]) |
| | y_max = np.max(bottom_face_coords[:, 1]) |
| |
|
| | |
| | y_range = y_max - y_min |
| | height_per_part = y_range // 4 |
| |
|
| | start_y_part1 = y_min + height_per_part |
| | end_y_part1 = start_y_part1 + height_per_part |
| |
|
| | start_y_part2 = end_y_part1 |
| | end_y_part2 = start_y_part2 + height_per_part |
| |
|
| | for coord in bottom_face_coords: |
| | x, y = coord |
| | start_x = max(x - pad, 0) |
| | end_x = min(x + pad, rows) |
| | if start_y_part1 <= y <= end_y_part1 or start_y_part2 <= y <= end_y_part2: |
| | vis_parsing_anno_color_face[start_x:end_x, y] = [255, 0, 0] |
| | |
| | |
| | |
| | |
| |
|
| | vis_im = cv2.GaussianBlur(vis_parsing_anno_color_face, (9, 9), cv2.BORDER_DEFAULT) |
| |
|
| | vis_im = cv2.resize(vis_im, img_size, |
| | interpolation=cv2.INTER_NEAREST) |
| |
|
| | cv2.imwrite(save_path.replace('.png', '_face.png'), vis_im) |
| |
|
| |
|
| | def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'): |
| |
|
| | Path(respth).mkdir(parents=True, exist_ok=True) |
| |
|
| | print(f'[INFO] loading model...') |
| | n_classes = 19 |
| | net = BiSeNet(n_classes=n_classes) |
| | net.cuda() |
| | net.load_state_dict(torch.load(cp)) |
| | net.eval() |
| |
|
| | to_tensor = transforms.Compose([ |
| | transforms.ToTensor(), |
| | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
| | ]) |
| |
|
| | image_paths = os.listdir(dspth) |
| |
|
| | with torch.no_grad(): |
| | for image_path in tqdm.tqdm(image_paths): |
| | if image_path.endswith('.jpg') or image_path.endswith('.png'): |
| | img = Image.open(osp.join(dspth, image_path)) |
| | ori_size = img.size |
| | image = img.resize((512, 512), Image.BILINEAR) |
| | image = image.convert("RGB") |
| | img = to_tensor(image) |
| |
|
| | |
| | inputs = torch.unsqueeze(img, 0) |
| | outputs = net(inputs.cuda()) |
| | parsing = outputs.mean(0).cpu().numpy().argmax(0) |
| | image_path = int(image_path[:-4]) |
| | image_path = str(image_path) + '.png' |
| |
|
| | vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path), img_size=ori_size) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = configargparse.ArgumentParser() |
| | parser.add_argument('--respath', type=str, default='./result/', help='result path for label') |
| | parser.add_argument('--imgpath', type=str, default='./imgs/', help='path for input images') |
| | parser.add_argument('--modelpath', type=str, default='data_utils/face_parsing/79999_iter.pth') |
| | args = parser.parse_args() |
| | evaluate(respth=args.respath, dspth=args.imgpath, cp=args.modelpath) |
| |
|