File size: 6,726 Bytes
5604fdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import cv2
import argparse
import glob

import numpy as np
from utils.general import imwrite
from utils.restoration_helper import RestoreHelper

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('-i', '--input_path', type=str, default='./pic', 
            help='Input image, video or folder. Default: inputs/whole_imgs')
    parser.add_argument('-o', '--output_path', type=str, default=None, 
            help='Output folder. Default: results/<input_name>_<w>')
    parser.add_argument('-s', '--upscale', type=int, default=1, 
            help='The final upsampling scale of the image. Default: 1')
    parser.add_argument('--detect_model', type=str, default='yolov5l-face.axmodel', help='face detection model path')
    parser.add_argument('--restore_model', type=str, default='codeformer.axmodel', help='face restore model path')
    parser.add_argument('--bg_model', type=str, default='realesrgan-x2.axmodel', help='background upsampler model path')
    parser.add_argument('--has_aligned', action='store_true', help='Input are cropped and aligned faces. Default: False')
    parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face. Default: False')
    parser.add_argument('--draw_box', action='store_true', help='Draw the bounding box for the detected faces. Default: False')
    parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces. Default: None')

    args = parser.parse_args()

    # ------------------------ input & output ------------------------
    if args.input_path.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path
        input_img_list = [args.input_path]
        result_root = f'results/test_img_{args.upscale}'
    else: # input img folder
        if args.input_path.endswith('/'):  # solve when path ends with /
            args.input_path = args.input_path[:-1]
        # scan all the jpg and png images
        input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[jpJP][pnPN]*[gG]')))
        result_root = 'results'

    if not args.output_path is None: # set output path
        result_root = args.output_path

    test_img_num = len(input_img_list)
    if test_img_num == 0:
        raise FileNotFoundError('No input image/video is found...\n' 
            '\tNote that --input_path for video should end with .mp4|.mov|.avi')
    
    # ------------------ set up FaceRestoreHelper -------------------
    restore_helper = RestoreHelper(
        args.upscale,
        face_size=512,
        crop_ratio=(1, 1),
        det_model=args.detect_model,
        res_model=args.restore_model,
        bg_model=args.bg_model,
        save_ext='png',
        use_parse=True
        )

    # -------------------- start to processing ---------------------
    for i, img_path in enumerate(input_img_list):
        # clean all the intermediate results to process the next image
        restore_helper.clean_all()
        
        if isinstance(img_path, str):
            img_name = os.path.basename(img_path)
            basename, ext = os.path.splitext(img_name)
            print(f'[{i+1}/{test_img_num}] Processing: {img_name}')
            img = cv2.imread(img_path, cv2.IMREAD_COLOR)

        restore_helper.read_image(img)
        # get face landmarks for each face
        num_det_faces = restore_helper.get_face_landmarks_5(
            only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5)
        print(f'\tdetect {num_det_faces} faces')
        # align and warp each face
        restore_helper.align_warp_face()
        # face restoration for each cropped face
        for idx, cropped_face in enumerate(restore_helper.cropped_faces):
            # prepare data
            cropped_face_t = (cropped_face.astype(np.float32) / 255.0) * 2.0 - 1.0
            cropped_face_t = np.transpose(
              np.expand_dims(np.ascontiguousarray(cropped_face_t[...,::-1]), axis=0), 
              (0,3,1,2)
            )
            #print('cropped_face_t', cropped_face_t.shape)

            try:
                ort_outs = restore_helper.rs_sessison.run(
                    restore_helper.rs_output, 
                    {restore_helper.rs_input: cropped_face_t}
                    )
                restored_face = ort_outs[0]
                restored_face = (restored_face.squeeze().transpose(1, 2, 0) * 0.5 + 0.5) * 255
                restored_face = np.clip(restored_face[...,::-1], 0, 255).astype(np.uint8)
            except Exception as error:
                print(f'\tFailed inference for CodeFormer: {error}')
                restored_face = (cropped_face_t.squeeze().transpose(1, 2, 0) * 0.5 + 0.5) * 255
                restored_face = np.clip(restored_face, 0, 255).astype(np.uint8)

            restored_face = restored_face.astype('uint8')
            restore_helper.add_restored_face(restored_face, cropped_face)


        # paste_back
        if not args.has_aligned:
            # upsample the background
            # Now only support RealESRGAN for upsampling background
            bg_img = restore_helper.background_upsampling(img)
            restore_helper.get_inverse_affine(None)
            # paste each restored face to the input image
            restored_img = restore_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box)

        # save faces
        # for idx, (cropped_face, restored_face) in enumerate(zip(face_helper.cropped_faces, face_helper.restored_faces)):
            # # save cropped face
            # if not args.has_aligned: 
                # save_crop_path = os.path.join(result_root, 'cropped_faces', f'{basename}_{idx:02d}.png')
                # imwrite(cropped_face, save_crop_path)
            # # save restored face
            # if args.has_aligned:
                # save_face_name = f'{basename}.png'
            # else:
                # save_face_name = f'{basename}_{idx:02d}.png'
            # if args.suffix is not None:
                # save_face_name = f'{save_face_name[:-4]}_{args.suffix}.png'
            # save_restore_path = os.path.join(result_root, 'restored_faces', save_face_name)
            # imwrite(restored_face, save_restore_path)

        # save restored img
        if not args.has_aligned and restored_img is not None:
            if args.suffix is not None:
                basename = f'{basename}_{args.suffix}'
            save_restore_path = os.path.join(result_root, 'final_results', f'{basename}.png')
            imwrite(restored_img, save_restore_path)