jounery-d commited on
Commit
5604fdf
·
verified ·
1 Parent(s): 9df1736

Upload 4 files

Browse files
python/run_whole_image.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import argparse
4
+ import glob
5
+
6
+ import numpy as np
7
+ from utils.general import imwrite
8
+ from utils.restoration_helper import RestoreHelper
9
+
10
+ if __name__ == '__main__':
11
+ parser = argparse.ArgumentParser()
12
+
13
+ parser.add_argument('-i', '--input_path', type=str, default='./pic',
14
+ help='Input image, video or folder. Default: inputs/whole_imgs')
15
+ parser.add_argument('-o', '--output_path', type=str, default=None,
16
+ help='Output folder. Default: results/<input_name>_<w>')
17
+ parser.add_argument('-s', '--upscale', type=int, default=1,
18
+ help='The final upsampling scale of the image. Default: 1')
19
+ parser.add_argument('--detect_model', type=str, default='yolov5l-face.axmodel', help='face detection model path')
20
+ parser.add_argument('--restore_model', type=str, default='codeformer.axmodel', help='face restore model path')
21
+ parser.add_argument('--bg_model', type=str, default='realesrgan-x2.axmodel', help='background upsampler model path')
22
+ parser.add_argument('--has_aligned', action='store_true', help='Input are cropped and aligned faces. Default: False')
23
+ parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face. Default: False')
24
+ parser.add_argument('--draw_box', action='store_true', help='Draw the bounding box for the detected faces. Default: False')
25
+ parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces. Default: None')
26
+
27
+ args = parser.parse_args()
28
+
29
+ # ------------------------ input & output ------------------------
30
+ if args.input_path.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path
31
+ input_img_list = [args.input_path]
32
+ result_root = f'results/test_img_{args.upscale}'
33
+ else: # input img folder
34
+ if args.input_path.endswith('/'): # solve when path ends with /
35
+ args.input_path = args.input_path[:-1]
36
+ # scan all the jpg and png images
37
+ input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[jpJP][pnPN]*[gG]')))
38
+ result_root = 'results'
39
+
40
+ if not args.output_path is None: # set output path
41
+ result_root = args.output_path
42
+
43
+ test_img_num = len(input_img_list)
44
+ if test_img_num == 0:
45
+ raise FileNotFoundError('No input image/video is found...\n'
46
+ '\tNote that --input_path for video should end with .mp4|.mov|.avi')
47
+
48
+ # ------------------ set up FaceRestoreHelper -------------------
49
+ restore_helper = RestoreHelper(
50
+ args.upscale,
51
+ face_size=512,
52
+ crop_ratio=(1, 1),
53
+ det_model=args.detect_model,
54
+ res_model=args.restore_model,
55
+ bg_model=args.bg_model,
56
+ save_ext='png',
57
+ use_parse=True
58
+ )
59
+
60
+ # -------------------- start to processing ---------------------
61
+ for i, img_path in enumerate(input_img_list):
62
+ # clean all the intermediate results to process the next image
63
+ restore_helper.clean_all()
64
+
65
+ if isinstance(img_path, str):
66
+ img_name = os.path.basename(img_path)
67
+ basename, ext = os.path.splitext(img_name)
68
+ print(f'[{i+1}/{test_img_num}] Processing: {img_name}')
69
+ img = cv2.imread(img_path, cv2.IMREAD_COLOR)
70
+
71
+ restore_helper.read_image(img)
72
+ # get face landmarks for each face
73
+ num_det_faces = restore_helper.get_face_landmarks_5(
74
+ only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5)
75
+ print(f'\tdetect {num_det_faces} faces')
76
+ # align and warp each face
77
+ restore_helper.align_warp_face()
78
+ # face restoration for each cropped face
79
+ for idx, cropped_face in enumerate(restore_helper.cropped_faces):
80
+ # prepare data
81
+ cropped_face_t = (cropped_face.astype(np.float32) / 255.0) * 2.0 - 1.0
82
+ cropped_face_t = np.transpose(
83
+ np.expand_dims(np.ascontiguousarray(cropped_face_t[...,::-1]), axis=0),
84
+ (0,3,1,2)
85
+ )
86
+ #print('cropped_face_t', cropped_face_t.shape)
87
+
88
+ try:
89
+ ort_outs = restore_helper.rs_sessison.run(
90
+ restore_helper.rs_output,
91
+ {restore_helper.rs_input: cropped_face_t}
92
+ )
93
+ restored_face = ort_outs[0]
94
+ restored_face = (restored_face.squeeze().transpose(1, 2, 0) * 0.5 + 0.5) * 255
95
+ restored_face = np.clip(restored_face[...,::-1], 0, 255).astype(np.uint8)
96
+ except Exception as error:
97
+ print(f'\tFailed inference for CodeFormer: {error}')
98
+ restored_face = (cropped_face_t.squeeze().transpose(1, 2, 0) * 0.5 + 0.5) * 255
99
+ restored_face = np.clip(restored_face, 0, 255).astype(np.uint8)
100
+
101
+ restored_face = restored_face.astype('uint8')
102
+ restore_helper.add_restored_face(restored_face, cropped_face)
103
+
104
+
105
+ # paste_back
106
+ if not args.has_aligned:
107
+ # upsample the background
108
+ # Now only support RealESRGAN for upsampling background
109
+ bg_img = restore_helper.background_upsampling(img)
110
+ restore_helper.get_inverse_affine(None)
111
+ # paste each restored face to the input image
112
+ restored_img = restore_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box)
113
+
114
+ # save faces
115
+ # for idx, (cropped_face, restored_face) in enumerate(zip(face_helper.cropped_faces, face_helper.restored_faces)):
116
+ # # save cropped face
117
+ # if not args.has_aligned:
118
+ # save_crop_path = os.path.join(result_root, 'cropped_faces', f'{basename}_{idx:02d}.png')
119
+ # imwrite(cropped_face, save_crop_path)
120
+ # # save restored face
121
+ # if args.has_aligned:
122
+ # save_face_name = f'{basename}.png'
123
+ # else:
124
+ # save_face_name = f'{basename}_{idx:02d}.png'
125
+ # if args.suffix is not None:
126
+ # save_face_name = f'{save_face_name[:-4]}_{args.suffix}.png'
127
+ # save_restore_path = os.path.join(result_root, 'restored_faces', save_face_name)
128
+ # imwrite(restored_face, save_restore_path)
129
+
130
+ # save restored img
131
+ if not args.has_aligned and restored_img is not None:
132
+ if args.suffix is not None:
133
+ basename = f'{basename}_{args.suffix}'
134
+ save_restore_path = os.path.join(result_root, 'final_results', f'{basename}.png')
135
+ imwrite(restored_img, save_restore_path)
136
+
python/utils/face_detector.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import copy
3
+ import re
4
+ import torch
5
+ import numpy as np
6
+ import axengine as ort
7
+
8
+ from pathlib import Path
9
+ #from lib.datasets import letterbox
10
+ from utils.general import (
11
+ non_max_suppression_face,
12
+ scale_coords,
13
+ scale_coords_landmarks,
14
+ letterbox,
15
+ )
16
+
17
+
18
+ def isListempty(inList):
19
+ if isinstance(inList, list): # Is a list
20
+ return all(map(isListempty, inList))
21
+ return False # Not a list
22
+
23
+ class YoloDetector:
24
+ def __init__(
25
+ self,
26
+ model_path='yolov5l-face.onnx',
27
+ min_face=10,
28
+ target_size=None,
29
+ ):
30
+ """
31
+ model_path: path to the .onnx model file.
32
+ min_face : minimal face size in pixels.
33
+ target_size : target size of smaller image axis (choose lower for faster work). e.g. 480, 720, 1080.
34
+ None for original resolution.
35
+ """
36
+ self._class_path = Path(__file__).parent.absolute()
37
+ self.target_size = target_size
38
+ self.min_face = min_face
39
+ self.session = ort.InferenceSession(model_path)
40
+ self.input_name = self.session.get_inputs()[0].name
41
+ self.output_names = [x.name for x in self.session.get_outputs()]
42
+
43
+
44
+ def _preprocess(self, imgs):
45
+ """
46
+ Preprocessing image before passing through the network. Resize and conversion to torch tensor.
47
+ """
48
+ pp_imgs = []
49
+ for img in imgs:
50
+ h0, w0 = img.shape[:2] # orig hw
51
+ if self.target_size:
52
+ r = self.target_size / min(h0, w0) # resize image to img_size
53
+ if r < 1:
54
+ img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_LINEAR)
55
+
56
+ #imgsz = check_img_size(max(img.shape[:2]), s=self.detector.stride.max()) # check img_size
57
+ imgsz = (640, 640)
58
+ img = letterbox(img, new_shape=imgsz)[0]
59
+ pp_imgs.append(img)
60
+ pp_imgs = np.array(pp_imgs)
61
+ #pp_imgs = pp_imgs.transpose(0, 3, 1, 2)
62
+
63
+ pp_imgs = pp_imgs.astype(np.float32) # uint8 to fp16/32
64
+ return pp_imgs
65
+
66
+ def _postprocess(self, imgs, origimgs, pred, conf_thres, iou_thres):
67
+ """
68
+ Postprocessing of raw pytorch model output.
69
+ Returns:
70
+ bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2.
71
+ points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners).
72
+ """
73
+ bboxes = [[] for _ in range(len(origimgs))]
74
+ landmarks = [[] for _ in range(len(origimgs))]
75
+
76
+ pred = non_max_suppression_face(pred, conf_thres, iou_thres)
77
+
78
+ for image_id, origimg in enumerate(origimgs):
79
+ img_shape = origimg.shape
80
+ image_height, image_width = img_shape[:2]
81
+ gn = torch.tensor(img_shape)[[1, 0, 1, 0]] # normalization gain whwh
82
+ gn_lks = torch.tensor(img_shape)[[1, 0, 1, 0, 1, 0, 1, 0, 1, 0]] # normalization gain landmarks
83
+ det = pred[image_id].cpu()
84
+ scale_coords(imgs[image_id].shape[0:], det[:, :4], img_shape).round()
85
+ scale_coords_landmarks(imgs[image_id].shape[0:], det[:, 5:15], img_shape).round()
86
+
87
+ for j in range(det.size()[0]):
88
+ box = (det[j, :4].view(1, 4) / gn).view(-1).tolist()
89
+ box = list(
90
+ map(int, [box[0] * image_width, box[1] * image_height, box[2] * image_width, box[3] * image_height])
91
+ )
92
+ if box[3] - box[1] < self.min_face:
93
+ continue
94
+ lm = (det[j, 5:15].view(1, 10) / gn_lks).view(-1).tolist()
95
+ lm = list(map(int, [i * image_width if j % 2 == 0 else i * image_height for j, i in enumerate(lm)]))
96
+ lm = [lm[i : i + 2] for i in range(0, len(lm), 2)]
97
+ bboxes[image_id].append(box)
98
+ landmarks[image_id].append(lm)
99
+ return bboxes, landmarks
100
+
101
+ def detect_faces(self, imgs, conf_thres=0.7, iou_thres=0.5):
102
+ """
103
+ Get bbox coordinates and keypoints of faces on original image.
104
+ Params:
105
+ imgs: image or list of images to detect faces on with BGR order (convert to RGB order for inference)
106
+ conf_thres: confidence threshold for each prediction
107
+ iou_thres: threshold for NMS (filter of intersecting bboxes)
108
+ Returns:
109
+ bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2.
110
+ points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners).
111
+ """
112
+ # Pass input images through face detector
113
+ images = imgs if isinstance(imgs, list) else [imgs]
114
+ images = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images]
115
+ origimgs = copy.deepcopy(images)
116
+
117
+ images = self._preprocess(images)
118
+
119
+ # process ONNX model
120
+ pred = self.session.run(self.output_names, {self.input_name: images})[0]
121
+ pred = torch.from_numpy(pred)
122
+
123
+ # postprocess the output
124
+ bboxes, points = self._postprocess(images, origimgs, pred, conf_thres, iou_thres)
125
+
126
+ # return bboxes, points
127
+ if not isListempty(points):
128
+ bboxes = np.array(bboxes).reshape(-1,4)
129
+ points = np.array(points).reshape(-1,10)
130
+ padding = bboxes[:,0].reshape(-1,1)
131
+ return np.concatenate((bboxes, padding, points), axis=1)
132
+ else:
133
+ return None
134
+
135
+ def __call__(self, *args):
136
+ return self.predict(*args)
python/utils/general.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # General utils
2
+
3
+ import glob
4
+ import os
5
+ import random
6
+ import time
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import torch
11
+ import torchvision
12
+ from PIL import Image
13
+
14
+ # Settings
15
+ torch.set_printoptions(linewidth=320, precision=5, profile='long')
16
+ np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
17
+ cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
18
+ os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
19
+
20
+ def init_seeds(seed=0):
21
+ # Initialize random number generator (RNG) seeds
22
+ random.seed(seed)
23
+ np.random.seed(seed)
24
+ init_torch_seeds(seed)
25
+
26
+
27
+ def get_latest_run(search_dir='.'):
28
+ # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
29
+ last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
30
+ return max(last_list, key=os.path.getctime) if last_list else ''
31
+
32
+
33
+ def imwrite(img, file_path, params=None, auto_mkdir=True):
34
+ """Write image to file.
35
+
36
+ Args:
37
+ img (ndarray): Image array to be written.
38
+ file_path (str): Image file path.
39
+ params (None or list): Same as opencv's :func:`imwrite` interface.
40
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
41
+ whether to create it automatically.
42
+
43
+ Returns:
44
+ bool: Successful or not.
45
+ """
46
+ if auto_mkdir:
47
+ dir_name = os.path.abspath(os.path.dirname(file_path))
48
+ os.makedirs(dir_name, exist_ok=True)
49
+ return cv2.imwrite(file_path, img, params)
50
+
51
+
52
+ def img2tensor(imgs, bgr2rgb=True, float32=True):
53
+ """Numpy array to tensor.
54
+
55
+ Args:
56
+ imgs (list[ndarray] | ndarray): Input images.
57
+ bgr2rgb (bool): Whether to change bgr to rgb.
58
+ float32 (bool): Whether to change to float32.
59
+
60
+ Returns:
61
+ list[tensor] | tensor: Tensor images. If returned results only have
62
+ one element, just return tensor.
63
+ """
64
+
65
+ def _totensor(img, bgr2rgb, float32):
66
+ if img.shape[2] == 3 and bgr2rgb:
67
+ if img.dtype == 'float64':
68
+ img = img.astype('float32')
69
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
70
+ img = torch.from_numpy(img.transpose(2, 0, 1))
71
+ if float32:
72
+ img = img.float()
73
+ return img
74
+
75
+ if isinstance(imgs, list):
76
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
77
+ else:
78
+ return _totensor(imgs, bgr2rgb, float32)
79
+
80
+ def is_gray(img, threshold=10):
81
+ img = Image.fromarray(img)
82
+ if len(img.getbands()) == 1:
83
+ return True
84
+ img1 = np.asarray(img.getchannel(channel=0), dtype=np.int16)
85
+ img2 = np.asarray(img.getchannel(channel=1), dtype=np.int16)
86
+ img3 = np.asarray(img.getchannel(channel=2), dtype=np.int16)
87
+ diff1 = (img1 - img2).var()
88
+ diff2 = (img2 - img3).var()
89
+ diff3 = (img3 - img1).var()
90
+ diff_sum = (diff1 + diff2 + diff3) / 3.0
91
+ if diff_sum <= threshold:
92
+ return True
93
+ else:
94
+ return False
95
+
96
+ def rgb2gray(img, out_channel=3):
97
+ r, g, b = img[:,:,0], img[:,:,1], img[:,:,2]
98
+ gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
99
+ if out_channel == 3:
100
+ gray = gray[:,:,np.newaxis].repeat(3, axis=2)
101
+ return gray
102
+
103
+ def bgr2gray(img, out_channel=3):
104
+ b, g, r = img[:,:,0], img[:,:,1], img[:,:,2]
105
+ gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
106
+ if out_channel == 3:
107
+ gray = gray[:,:,np.newaxis].repeat(3, axis=2)
108
+ return gray
109
+
110
+ def calc_mean_std(feat, eps=1e-5):
111
+ """
112
+ Args:
113
+ feat (numpy): 3D [w h c]s
114
+ """
115
+ size = feat.shape
116
+ assert len(size) == 3, 'The input feature should be 3D tensor.'
117
+ c = size[2]
118
+ feat_var = feat.reshape(-1, c).var(axis=0) + eps
119
+ feat_std = np.sqrt(feat_var).reshape(1, 1, c)
120
+ feat_mean = feat.reshape(-1, c).mean(axis=0).reshape(1, 1, c)
121
+ return feat_mean, feat_std
122
+
123
+
124
+ def adain_npy(content_feat, style_feat):
125
+ """Adaptive instance normalization for numpy.
126
+
127
+ Args:
128
+ content_feat (numpy): The input feature.
129
+ style_feat (numpy): The reference feature.
130
+ """
131
+ size = content_feat.shape
132
+ style_mean, style_std = calc_mean_std(style_feat)
133
+ content_mean, content_std = calc_mean_std(content_feat)
134
+ normalized_feat = (content_feat - np.broadcast_to(content_mean, size)) / np.broadcast_to(content_std, size)
135
+ return normalized_feat * np.broadcast_to(style_std, size) + np.broadcast_to(style_mean, size)
136
+
137
+ def xyxy2xywh(x):
138
+ # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
139
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
140
+ y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
141
+ y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
142
+ y[:, 2] = x[:, 2] - x[:, 0] # width
143
+ y[:, 3] = x[:, 3] - x[:, 1] # height
144
+ return y
145
+
146
+
147
+ def xywh2xyxy(x):
148
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
149
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
150
+ y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
151
+ y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
152
+ y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
153
+ y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
154
+ return y
155
+
156
+
157
+ def xywhn2xyxy(x, w=640, h=640, padw=32, padh=32):
158
+ # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
159
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
160
+ y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
161
+ y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
162
+ y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
163
+ y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
164
+ return y
165
+
166
+
167
+ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
168
+ # Rescale coords (xyxy) from img1_shape to img0_shape
169
+ if ratio_pad is None: # calculate from img0_shape
170
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
171
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
172
+ else:
173
+ gain = ratio_pad[0][0]
174
+ pad = ratio_pad[1]
175
+
176
+ coords[:, [0, 2]] -= pad[0] # x padding
177
+ coords[:, [1, 3]] -= pad[1] # y padding
178
+ coords[:, :4] /= gain
179
+ clip_coords(coords, img0_shape)
180
+ return coords
181
+
182
+
183
+ def clip_coords(boxes, img_shape):
184
+ # Clip bounding xyxy bounding boxes to image shape (height, width)
185
+ boxes[:, 0].clamp_(0, img_shape[1]) # x1
186
+ boxes[:, 1].clamp_(0, img_shape[0]) # y1
187
+ boxes[:, 2].clamp_(0, img_shape[1]) # x2
188
+ boxes[:, 3].clamp_(0, img_shape[0]) # y2
189
+
190
+ def box_iou(box1, box2):
191
+ # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
192
+ """
193
+ Return intersection-over-union (Jaccard index) of boxes.
194
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
195
+ Arguments:
196
+ box1 (Tensor[N, 4])
197
+ box2 (Tensor[M, 4])
198
+ Returns:
199
+ iou (Tensor[N, M]): the NxM matrix containing the pairwise
200
+ IoU values for every element in boxes1 and boxes2
201
+ """
202
+
203
+ def box_area(box):
204
+ # box = 4xn
205
+ return (box[2] - box[0]) * (box[3] - box[1])
206
+
207
+ area1 = box_area(box1.T)
208
+ area2 = box_area(box2.T)
209
+
210
+ # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
211
+ inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) -
212
+ torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
213
+ # iou = inter / (area1 + area2 - inter)
214
+ return inter / (area1[:, None] + area2 - inter)
215
+
216
+
217
+ def wh_iou(wh1, wh2):
218
+ # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
219
+ wh1 = wh1[:, None] # [N,1,2]
220
+ wh2 = wh2[None] # [1,M,2]
221
+ inter = torch.min(wh1, wh2).prod(2) # [N,M]
222
+ # iou = inter / (area1 + area2 - inter)
223
+ return inter / (wh1.prod(2) + wh2.prod(2) - inter)
224
+
225
+ def non_max_suppression_face(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
226
+ """Performs Non-Maximum Suppression (NMS) on inference results
227
+ Returns:
228
+ detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
229
+ """
230
+
231
+ nc = prediction.shape[2] - 15 # number of classes
232
+ xc = prediction[..., 4] > conf_thres # candidates
233
+
234
+ # Settings
235
+ min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
236
+ time_limit = 10.0 # seconds to quit after
237
+ redundant = True # require redundant detections
238
+ multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
239
+ merge = False # use merge-NMS
240
+
241
+ t = time.time()
242
+ output = [torch.zeros((0, 16), device=prediction.device)] * prediction.shape[0]
243
+ for xi, x in enumerate(prediction): # image index, image inference
244
+ # Apply constraints
245
+ # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
246
+ x = x[xc[xi]] # confidence
247
+
248
+ # Cat apriori labels if autolabelling
249
+ if labels and len(labels[xi]):
250
+ l = labels[xi]
251
+ v = torch.zeros((len(l), nc + 15), device=x.device)
252
+ v[:, :4] = l[:, 1:5] # box
253
+ v[:, 4] = 1.0 # conf
254
+ v[range(len(l)), l[:, 0].long() + 15] = 1.0 # cls
255
+ x = torch.cat((x, v), 0)
256
+
257
+ # If none remain process next image
258
+ if not x.shape[0]:
259
+ continue
260
+
261
+ # Compute conf
262
+ x[:, 15:] *= x[:, 4:5] # conf = obj_conf * cls_conf
263
+
264
+ # Box (center x, center y, width, height) to (x1, y1, x2, y2)
265
+ box = xywh2xyxy(x[:, :4])
266
+
267
+ # Detections matrix nx6 (xyxy, conf, landmarks, cls)
268
+ if multi_label:
269
+ i, j = (x[:, 15:] > conf_thres).nonzero(as_tuple=False).T
270
+ x = torch.cat((box[i], x[i, j + 15, None], x[i, 5:15] ,j[:, None].float()), 1)
271
+ else: # best class only
272
+ conf, j = x[:, 15:].max(1, keepdim=True)
273
+ x = torch.cat((box, conf, x[:, 5:15], j.float()), 1)[conf.view(-1) > conf_thres]
274
+
275
+ # Filter by class
276
+ if classes is not None:
277
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
278
+
279
+ # If none remain process next image
280
+ n = x.shape[0] # number of boxes
281
+ if not n:
282
+ continue
283
+
284
+ # Batched NMS
285
+ c = x[:, 15:16] * (0 if agnostic else max_wh) # classes
286
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
287
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
288
+ #if i.shape[0] > max_det: # limit detections
289
+ # i = i[:max_det]
290
+ if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
291
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
292
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
293
+ weights = iou * scores[None] # box weights
294
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
295
+ if redundant:
296
+ i = i[iou.sum(1) > 1] # require redundancy
297
+
298
+ output[xi] = x[i]
299
+ if (time.time() - t) > time_limit:
300
+ break # time limit exceeded
301
+
302
+ return output
303
+
304
+ def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None):
305
+ # Rescale coords (xyxy) from img1_shape to img0_shape
306
+ if ratio_pad is None: # calculate from img0_shape
307
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
308
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
309
+ else:
310
+ gain = ratio_pad[0][0]
311
+ pad = ratio_pad[1]
312
+
313
+ coords[:, [0, 2, 4, 6, 8]] -= pad[0] # x padding
314
+ coords[:, [1, 3, 5, 7, 9]] -= pad[1] # y padding
315
+ coords[:, :10] /= gain
316
+ coords[:, 0].clamp_(0, img0_shape[1]) # x1
317
+ coords[:, 1].clamp_(0, img0_shape[0]) # y1
318
+ coords[:, 2].clamp_(0, img0_shape[1]) # x2
319
+ coords[:, 3].clamp_(0, img0_shape[0]) # y2
320
+ coords[:, 4].clamp_(0, img0_shape[1]) # x3
321
+ coords[:, 5].clamp_(0, img0_shape[0]) # y3
322
+ coords[:, 6].clamp_(0, img0_shape[1]) # x4
323
+ coords[:, 7].clamp_(0, img0_shape[0]) # y4
324
+ coords[:, 8].clamp_(0, img0_shape[1]) # x5
325
+ coords[:, 9].clamp_(0, img0_shape[0]) # y5
326
+ return coords
327
+
328
+ def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=False, scaleFill=False, scaleup=True):
329
+ # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
330
+ shape = img.shape[:2] # current shape [height, width]
331
+ if isinstance(new_shape, int):
332
+ new_shape = (new_shape, new_shape)
333
+
334
+ # Scale ratio (new / old)
335
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
336
+ if not scaleup: # only scale down, do not scale up (for better test mAP)
337
+ r = min(r, 1.0)
338
+
339
+
340
+ # Compute padding
341
+ ratio = r, r # width, height ratios
342
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
343
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
344
+ #(dw, dh)
345
+ if auto: # minimum rectangle
346
+ dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding
347
+ #print(dw, dh)
348
+ elif scaleFill: # stretch
349
+ dw, dh = 0.0, 0.0
350
+ new_unpad = (new_shape[1], new_shape[0])
351
+ ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
352
+
353
+ dw /= 2 # divide padding into 2 sides
354
+ dh /= 2
355
+
356
+ if shape[::-1] != new_unpad: # resize
357
+ img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
358
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
359
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
360
+ img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
361
+ return img, ratio, (dw, dh)
python/utils/restoration_helper.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import math
4
+ import torch
5
+ import numpy as np
6
+ import axengine as ort
7
+
8
+ from utils.general import img2tensor, imwrite, is_gray, bgr2gray, adain_npy
9
+ from utils.face_detector import YoloDetector
10
+
11
+ def get_largest_face(det_faces, h, w):
12
+ def get_location(val, length):
13
+ if val < 0:
14
+ return 0
15
+ elif val > length:
16
+ return length
17
+ else:
18
+ return val
19
+
20
+ face_areas = []
21
+ for det_face in det_faces:
22
+ left = get_location(det_face[0], w)
23
+ right = get_location(det_face[2], w)
24
+ top = get_location(det_face[1], h)
25
+ bottom = get_location(det_face[3], h)
26
+ face_area = (right - left) * (bottom - top)
27
+ face_areas.append(face_area)
28
+ largest_idx = face_areas.index(max(face_areas))
29
+ return det_faces[largest_idx], largest_idx
30
+
31
+
32
+ def get_center_face(det_faces, h=0, w=0, center=None):
33
+ if center is not None:
34
+ center = np.array(center)
35
+ else:
36
+ center = np.array([w / 2, h / 2])
37
+ center_dist = []
38
+ for det_face in det_faces:
39
+ face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
40
+ dist = np.linalg.norm(face_center - center)
41
+ center_dist.append(dist)
42
+ center_idx = center_dist.index(min(center_dist))
43
+ return det_faces[center_idx], center_idx
44
+
45
+
46
+ class RestoreHelper(object):
47
+ """Helper for the restoration pipeline (base class)."""
48
+ def __init__(self,
49
+ upscale_factor,
50
+ face_size=512,
51
+ crop_ratio=(1, 1),
52
+ det_model='yolov5l-face.onnx',
53
+ res_model='codeformer.onnx',
54
+ bg_model='realesrgan_x2.onnx' ,
55
+ save_ext='png',
56
+ template_3points=False,
57
+ pad_blur=False,
58
+ use_parse=False,
59
+ ):
60
+
61
+ # face alignment params
62
+ self.template_3points = template_3points # improve robustness
63
+ self.upscale_factor = int(upscale_factor)
64
+ # the cropped face ratio based on the square face
65
+ self.crop_ratio = crop_ratio # (h, w)
66
+ assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
67
+ self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
68
+
69
+ # standard 5 landmarks for FFHQ faces with 512 x 512
70
+ # facexlib
71
+ self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
72
+ [201.26117, 371.41043], [313.08905, 371.15118]])
73
+
74
+ # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
75
+ # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
76
+ # [198.22603, 372.82502], [313.91018, 372.75659]])
77
+
78
+ self.face_template = self.face_template * (face_size / 512.0)
79
+ if self.crop_ratio[0] > 1:
80
+ self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
81
+ if self.crop_ratio[1] > 1:
82
+ self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
83
+ self.save_ext = save_ext
84
+ self.pad_blur = pad_blur
85
+ if self.pad_blur is True:
86
+ self.template_3points = False
87
+
88
+ self.all_landmarks_5 = []
89
+ self.det_faces = []
90
+ self.affine_matrices = []
91
+ self.inverse_affine_matrices = []
92
+ self.cropped_faces = []
93
+ self.restored_faces = []
94
+ self.pad_input_imgs = []
95
+
96
+ # init face detection model
97
+ self.face_detector = YoloDetector(model_path=det_model)
98
+
99
+ # init face parsing model
100
+ self.use_parse = use_parse
101
+ #self.face_parse = init_parsing_model(model_name='parsenet')
102
+
103
+ #init face restore model
104
+ self.res_model = res_model
105
+ self.rs_sessison, self.rs_input, self.rs_output = self.init_face_restoration()
106
+
107
+ # init background upsampling model
108
+ self.tile = 108
109
+ self.tile_pad = 10
110
+ self.scale = 2
111
+ self.bg_model = bg_model
112
+ self.bg_sessison, self.bg_input, self.bg_output = self.init_background_upsampling()
113
+
114
+ def set_upscale_factor(self, upscale_factor):
115
+ self.upscale_factor = upscale_factor
116
+
117
+ def read_image(self, img):
118
+ """img can be image path or cv2 loaded image."""
119
+ # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
120
+ if isinstance(img, str):
121
+ img = cv2.imread(img)
122
+
123
+ if np.max(img) > 256: # 16-bit image
124
+ img = img / 65535 * 255
125
+ if len(img.shape) == 2: # gray image
126
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
127
+ elif img.shape[2] == 4: # BGRA image with alpha channel
128
+ img = img[:, :, 0:3]
129
+
130
+ self.input_img = img
131
+ self.is_gray = is_gray(img, threshold=10)
132
+ if self.is_gray:
133
+ print('Grayscale input: True')
134
+
135
+ if min(self.input_img.shape[:2])<512:
136
+ f = 512.0/min(self.input_img.shape[:2])
137
+ self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
138
+
139
+ def get_face_landmarks_5(self,
140
+ only_keep_largest=False,
141
+ only_center_face=False,
142
+ resize=None,
143
+ blur_ratio=0.01,
144
+ eye_dist_threshold=None):
145
+
146
+ if resize is None:
147
+ scale = 1
148
+ input_img = self.input_img
149
+ else:
150
+ h, w = self.input_img.shape[0:2]
151
+ scale = resize / min(h, w)
152
+ # scale = max(1, scale) # always scale up; comment this out for HD images, e.g., AIGC faces.
153
+ h, w = int(h * scale), int(w * scale)
154
+ interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
155
+ input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
156
+
157
+ #with torch.no_grad():
158
+ bboxes = self.face_detector.detect_faces(input_img)
159
+
160
+ if bboxes is None or bboxes.shape[0] == 0:
161
+ return 0
162
+ else:
163
+ bboxes = bboxes / scale
164
+
165
+ for bbox in bboxes:
166
+ # remove faces with too small eye distance: side faces or too small faces
167
+ eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
168
+ if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
169
+ continue
170
+
171
+ if self.template_3points:
172
+ landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
173
+ else:
174
+ landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
175
+ self.all_landmarks_5.append(landmark)
176
+ self.det_faces.append(bbox[0:5])
177
+
178
+ if len(self.det_faces) == 0:
179
+ return 0
180
+ if only_keep_largest:
181
+ h, w, _ = self.input_img.shape
182
+ self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
183
+ self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
184
+ elif only_center_face:
185
+ h, w, _ = self.input_img.shape
186
+ self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
187
+ self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
188
+
189
+ # pad blurry images
190
+ if self.pad_blur:
191
+ self.pad_input_imgs = []
192
+ for landmarks in self.all_landmarks_5:
193
+ # get landmarks
194
+ eye_left = landmarks[0, :]
195
+ eye_right = landmarks[1, :]
196
+ eye_avg = (eye_left + eye_right) * 0.5
197
+ mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
198
+ eye_to_eye = eye_right - eye_left
199
+ eye_to_mouth = mouth_avg - eye_avg
200
+
201
+ # Get the oriented crop rectangle
202
+ # x: half width of the oriented crop rectangle
203
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
204
+ # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
205
+ # norm with the hypotenuse: get the direction
206
+ x /= np.hypot(*x) # get the hypotenuse of a right triangle
207
+ rect_scale = 1.5
208
+ x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
209
+ # y: half height of the oriented crop rectangle
210
+ y = np.flipud(x) * [-1, 1]
211
+
212
+ # c: center
213
+ c = eye_avg + eye_to_mouth * 0.1
214
+ # quad: (left_top, left_bottom, right_bottom, right_top)
215
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
216
+ # qsize: side length of the square
217
+ qsize = np.hypot(*x) * 2
218
+ border = max(int(np.rint(qsize * 0.1)), 3)
219
+
220
+ # get pad
221
+ # pad: (width_left, height_top, width_right, height_bottom)
222
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
223
+ int(np.ceil(max(quad[:, 1]))))
224
+ pad = [
225
+ max(-pad[0] + border, 1),
226
+ max(-pad[1] + border, 1),
227
+ max(pad[2] - self.input_img.shape[0] + border, 1),
228
+ max(pad[3] - self.input_img.shape[1] + border, 1)
229
+ ]
230
+
231
+ if max(pad) > 1:
232
+ # pad image
233
+ pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
234
+ # modify landmark coords
235
+ landmarks[:, 0] += pad[0]
236
+ landmarks[:, 1] += pad[1]
237
+ # blur pad images
238
+ h, w, _ = pad_img.shape
239
+ y, x, _ = np.ogrid[:h, :w, :1]
240
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
241
+ np.float32(w - 1 - x) / pad[2]),
242
+ 1.0 - np.minimum(np.float32(y) / pad[1],
243
+ np.float32(h - 1 - y) / pad[3]))
244
+ blur = int(qsize * blur_ratio)
245
+ if blur % 2 == 0:
246
+ blur += 1
247
+ blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
248
+ # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
249
+
250
+ pad_img = pad_img.astype('float32')
251
+ pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
252
+ pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
253
+ pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
254
+ self.pad_input_imgs.append(pad_img)
255
+ else:
256
+ self.pad_input_imgs.append(np.copy(self.input_img))
257
+
258
+ return len(self.all_landmarks_5)
259
+
260
+ def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
261
+ """Align and warp faces with face template.
262
+ """
263
+ if self.pad_blur:
264
+ assert len(self.pad_input_imgs) == len(
265
+ self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
266
+ for idx, landmark in enumerate(self.all_landmarks_5):
267
+ # use 5 landmarks to get affine matrix
268
+ # use cv2.LMEDS method for the equivalence to skimage transform
269
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
270
+ affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
271
+ self.affine_matrices.append(affine_matrix)
272
+ # warp and crop faces
273
+ if border_mode == 'constant':
274
+ border_mode = cv2.BORDER_CONSTANT
275
+ elif border_mode == 'reflect101':
276
+ border_mode = cv2.BORDER_REFLECT101
277
+ elif border_mode == 'reflect':
278
+ border_mode = cv2.BORDER_REFLECT
279
+ if self.pad_blur:
280
+ input_img = self.pad_input_imgs[idx]
281
+ else:
282
+ input_img = self.input_img
283
+ cropped_face = cv2.warpAffine(
284
+ input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray
285
+ self.cropped_faces.append(cropped_face)
286
+ # save the cropped face
287
+ if save_cropped_path is not None:
288
+ path = os.path.splitext(save_cropped_path)[0]
289
+ save_path = f'{path}_{idx:02d}.{self.save_ext}'
290
+ imwrite(cropped_face, save_path)
291
+
292
+ def get_inverse_affine(self, save_inverse_affine_path=None):
293
+ """Get inverse affine matrix."""
294
+ for idx, affine_matrix in enumerate(self.affine_matrices):
295
+ inverse_affine = cv2.invertAffineTransform(affine_matrix)
296
+ inverse_affine *= self.upscale_factor
297
+ self.inverse_affine_matrices.append(inverse_affine)
298
+ # save inverse affine matrices
299
+ if save_inverse_affine_path is not None:
300
+ path, _ = os.path.splitext(save_inverse_affine_path)
301
+ save_path = f'{path}_{idx:02d}.pth'
302
+ torch.save(inverse_affine, save_path)
303
+
304
+
305
+ def add_restored_face(self, restored_face, input_face=None):
306
+ if self.is_gray:
307
+ restored_face = bgr2gray(restored_face) # convert img into grayscale
308
+ if input_face is not None:
309
+ restored_face = adain_npy(restored_face, input_face) # transfer the color
310
+ self.restored_faces.append(restored_face)
311
+
312
+
313
+ def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
314
+ h, w, _ = self.input_img.shape
315
+ h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
316
+
317
+ if upsample_img is None:
318
+ # simply resize the background
319
+ # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
320
+ upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR)
321
+ else:
322
+ upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
323
+
324
+ assert len(self.restored_faces) == len(
325
+ self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
326
+
327
+ inv_mask_borders = []
328
+ for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
329
+ if face_upsampler is not None:
330
+ restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0]
331
+ inverse_affine /= self.upscale_factor
332
+ inverse_affine[:, 2] *= self.upscale_factor
333
+ face_size = (self.face_size[0]*self.upscale_factor, self.face_size[1]*self.upscale_factor)
334
+ else:
335
+ # Add an offset to inverse affine matrix, for more precise back alignment
336
+ if self.upscale_factor > 1:
337
+ extra_offset = 0.5 * self.upscale_factor
338
+ else:
339
+ extra_offset = 0
340
+ inverse_affine[:, 2] += extra_offset
341
+ face_size = self.face_size
342
+ inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
343
+
344
+ # always use square mask
345
+ mask = np.ones(face_size, dtype=np.float32)
346
+ inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
347
+ # remove the black borders
348
+ inv_mask_erosion = cv2.erode(
349
+ inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
350
+ pasted_face = inv_mask_erosion[:, :, None] * inv_restored
351
+ total_face_area = np.sum(inv_mask_erosion) # // 3
352
+ # add border
353
+ if draw_box:
354
+ h, w = face_size
355
+ mask_border = np.ones((h, w, 3), dtype=np.float32)
356
+ border = int(1400/np.sqrt(total_face_area))
357
+ mask_border[border:h-border, border:w-border,:] = 0
358
+ inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
359
+ inv_mask_borders.append(inv_mask_border)
360
+ # compute the fusion edge based on the area of face
361
+ w_edge = int(total_face_area**0.5) // 20
362
+ erosion_radius = w_edge * 2
363
+ inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
364
+ blur_size = w_edge * 2
365
+ inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
366
+ if len(upsample_img.shape) == 2: # upsample_img is gray image
367
+ upsample_img = upsample_img[:, :, None]
368
+ inv_soft_mask = inv_soft_mask[:, :, None]
369
+
370
+ # parse mask
371
+ #if self.use_parse:
372
+ # if 0:
373
+ # # inference
374
+ # face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
375
+ # face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
376
+ # normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
377
+ # face_input = torch.unsqueeze(face_input, 0)
378
+ # with torch.no_grad():
379
+ # out = self.face_parse(face_input)[0]
380
+ # out = out.argmax(dim=1).squeeze().cpu().numpy()
381
+
382
+ # parse_mask = np.zeros(out.shape)
383
+ # MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
384
+ # for idx, color in enumerate(MASK_COLORMAP):
385
+ # parse_mask[out == idx] = color
386
+ # # blur the mask
387
+ # parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
388
+ # parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
389
+ # # remove the black borders
390
+ # thres = 10
391
+ # parse_mask[:thres, :] = 0
392
+ # parse_mask[-thres:, :] = 0
393
+ # parse_mask[:, :thres] = 0
394
+ # parse_mask[:, -thres:] = 0
395
+ # parse_mask = parse_mask / 255.
396
+
397
+ # parse_mask = cv2.resize(parse_mask, face_size)
398
+ # parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3)
399
+ # inv_soft_parse_mask = parse_mask[:, :, None]
400
+ # # pasted_face = inv_restored
401
+ # fuse_mask = (inv_soft_parse_mask<inv_soft_mask).astype('int')
402
+ # inv_soft_mask = inv_soft_parse_mask*fuse_mask + inv_soft_mask*(1-fuse_mask)
403
+
404
+ if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
405
+ alpha = upsample_img[:, :, 3:]
406
+ upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
407
+ upsample_img = np.concatenate((upsample_img, alpha), axis=2)
408
+ else:
409
+ upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
410
+
411
+ if np.max(upsample_img) > 256: # 16-bit image
412
+ upsample_img = upsample_img.astype(np.uint16)
413
+ else:
414
+ upsample_img = upsample_img.astype(np.uint8)
415
+
416
+ # draw bounding box
417
+ if draw_box:
418
+ # upsample_input_img = cv2.resize(input_img, (w_up, h_up))
419
+ img_color = np.ones([*upsample_img.shape], dtype=np.float32)
420
+ img_color[:,:,0] = 0
421
+ img_color[:,:,1] = 255
422
+ img_color[:,:,2] = 0
423
+ for inv_mask_border in inv_mask_borders:
424
+ upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img
425
+ # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
426
+
427
+ if save_path is not None:
428
+ path = os.path.splitext(save_path)[0]
429
+ save_path = f'{path}.{self.save_ext}'
430
+ imwrite(upsample_img, save_path)
431
+ return upsample_img
432
+
433
+ def init_face_restoration(self):
434
+ session = ort.InferenceSession(self.res_model)
435
+ input_name = session.get_inputs()[0].name
436
+ output_names = [x.name for x in session.get_outputs()]
437
+ return session, input_name, output_names
438
+
439
+ def pre_process(self, img):
440
+ # mod tile_pad for divisible borders
441
+ tile_pad_h, tile_pad_w = 0, 0
442
+ h, w = img.shape[0:2]
443
+
444
+ if h % self.tile != 0:
445
+ tile_pad_h = (self.tile - h % self.tile)
446
+ if w % self.tile != 0:
447
+ tile_pad_w = (self.tile - w % self.tile)
448
+ img = np.pad(img, ((0, tile_pad_h), (0, tile_pad_w), (0, 0)), 'constant') #mode='reflect')
449
+
450
+ # boundary tile_pad
451
+ img = np.pad(img, ((self.tile_pad, self.tile_pad), (self.tile_pad, self.tile_pad), (0, 0)), 'constant')
452
+
453
+ # to CHW-Batch format
454
+ img = (img[..., ::-1] / 255).astype(np.float32)
455
+ img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
456
+
457
+ return img
458
+
459
+ def init_background_upsampling(self):
460
+ session = ort.InferenceSession(self.bg_model)
461
+ input_name = session.get_inputs()[0].name
462
+ output_names = [x.name for x in session.get_outputs()]
463
+ return session, input_name, output_names
464
+
465
+ def tile_process(self, img, origin_shape):
466
+ """It will first crop input images to tiles, and then process each tile.
467
+ Finally, all the processed tiles are merged into one images.
468
+ """
469
+ # tile
470
+ batch, channel, height, width = img.shape
471
+ output_height = int(round(height * self.scale))
472
+ output_width = int(round(width * self.scale))
473
+ output_shape = (batch, channel, output_height, output_width)
474
+ origin_h, origin_w = origin_shape[0:2]
475
+
476
+ # start with black image
477
+ output = np.zeros(output_shape)
478
+ tiles_x = math.floor(width / self.tile)
479
+ tiles_y = math.floor(height / self.tile)
480
+ #print(f'Tile {tiles_x} x {tiles_y} for image {imgname}')
481
+
482
+ start_tile = int(round(self.tile_pad * self.scale))
483
+ end_tile = int(round(self.tile * self.scale)) + start_tile
484
+
485
+ # loop over all tiles
486
+ for y in range(tiles_y):
487
+ for x in range(tiles_x):
488
+ # extract tile from input image
489
+ ofs_x = x * self.tile
490
+ ofs_y = y * self.tile
491
+ # input tile area on total image
492
+ input_start_x = ofs_x
493
+ input_end_x = min(ofs_x + self.tile, width)
494
+ input_start_y = ofs_y
495
+ input_end_y = min(ofs_y + self.tile, height)
496
+
497
+ # input tile dimensions
498
+ input_tile = img[:, :, input_start_y:(input_end_y+2*self.tile_pad),
499
+ input_start_x:(input_end_x+2*self.tile_pad)]
500
+
501
+ # upscale tile
502
+ try:
503
+ output_tile = self.bg_sessison.run(self.bg_output, {self.bg_input: input_tile})
504
+ except RuntimeError as error:
505
+ print('Error', error)
506
+
507
+ # output tile area on total image
508
+ output_start_x = int(round(input_start_x * self.scale))
509
+ output_end_x = int(round(input_end_x * self.scale))
510
+ output_start_y = int(round(input_start_y * self.scale))
511
+ output_end_y = int(round(input_end_y * self.scale))
512
+
513
+ output[:, :, output_start_y:output_end_y,
514
+ output_start_x:output_end_x] = output_tile[0][:, :, start_tile:end_tile, start_tile:end_tile]
515
+
516
+ # remove extra tile_padding parts
517
+ output = output[:, :, :int(round(origin_h * self.scale)), :int(round(origin_w * self.scale))].squeeze(0)
518
+ output = np.transpose(output[::-1, ...], (1, 2, 0)).astype(np.float32)
519
+ output = np.clip(output*255.0, 0, 255).astype(np.uint8)
520
+
521
+ #resize origin shape
522
+ output = cv2.resize(output, (origin_w, origin_h), interpolation=cv2.INTER_LINEAR)
523
+
524
+ return output
525
+
526
+ def background_upsampling(self, img):
527
+ """Background upsampling with Real-ESRGAN.
528
+ """
529
+ # pre-process
530
+ img_input = self.pre_process(img)
531
+
532
+ # tile process
533
+ output = self.tile_process(img_input, img.shape)
534
+
535
+ return output
536
+
537
+ def clean_all(self):
538
+ self.all_landmarks_5 = []
539
+ self.restored_faces = []
540
+ self.affine_matrices = []
541
+ self.cropped_faces = []
542
+ self.inverse_affine_matrices = []
543
+ self.det_faces = []
544
+ self.pad_input_imgs = []