Upload 4 files
Browse files- python/run_whole_image.py +136 -0
- python/utils/face_detector.py +136 -0
- python/utils/general.py +361 -0
- python/utils/restoration_helper.py +544 -0
python/run_whole_image.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import argparse
|
| 4 |
+
import glob
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from utils.general import imwrite
|
| 8 |
+
from utils.restoration_helper import RestoreHelper
|
| 9 |
+
|
| 10 |
+
if __name__ == '__main__':
|
| 11 |
+
parser = argparse.ArgumentParser()
|
| 12 |
+
|
| 13 |
+
parser.add_argument('-i', '--input_path', type=str, default='./pic',
|
| 14 |
+
help='Input image, video or folder. Default: inputs/whole_imgs')
|
| 15 |
+
parser.add_argument('-o', '--output_path', type=str, default=None,
|
| 16 |
+
help='Output folder. Default: results/<input_name>_<w>')
|
| 17 |
+
parser.add_argument('-s', '--upscale', type=int, default=1,
|
| 18 |
+
help='The final upsampling scale of the image. Default: 1')
|
| 19 |
+
parser.add_argument('--detect_model', type=str, default='yolov5l-face.axmodel', help='face detection model path')
|
| 20 |
+
parser.add_argument('--restore_model', type=str, default='codeformer.axmodel', help='face restore model path')
|
| 21 |
+
parser.add_argument('--bg_model', type=str, default='realesrgan-x2.axmodel', help='background upsampler model path')
|
| 22 |
+
parser.add_argument('--has_aligned', action='store_true', help='Input are cropped and aligned faces. Default: False')
|
| 23 |
+
parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face. Default: False')
|
| 24 |
+
parser.add_argument('--draw_box', action='store_true', help='Draw the bounding box for the detected faces. Default: False')
|
| 25 |
+
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces. Default: None')
|
| 26 |
+
|
| 27 |
+
args = parser.parse_args()
|
| 28 |
+
|
| 29 |
+
# ------------------------ input & output ------------------------
|
| 30 |
+
if args.input_path.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path
|
| 31 |
+
input_img_list = [args.input_path]
|
| 32 |
+
result_root = f'results/test_img_{args.upscale}'
|
| 33 |
+
else: # input img folder
|
| 34 |
+
if args.input_path.endswith('/'): # solve when path ends with /
|
| 35 |
+
args.input_path = args.input_path[:-1]
|
| 36 |
+
# scan all the jpg and png images
|
| 37 |
+
input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[jpJP][pnPN]*[gG]')))
|
| 38 |
+
result_root = 'results'
|
| 39 |
+
|
| 40 |
+
if not args.output_path is None: # set output path
|
| 41 |
+
result_root = args.output_path
|
| 42 |
+
|
| 43 |
+
test_img_num = len(input_img_list)
|
| 44 |
+
if test_img_num == 0:
|
| 45 |
+
raise FileNotFoundError('No input image/video is found...\n'
|
| 46 |
+
'\tNote that --input_path for video should end with .mp4|.mov|.avi')
|
| 47 |
+
|
| 48 |
+
# ------------------ set up FaceRestoreHelper -------------------
|
| 49 |
+
restore_helper = RestoreHelper(
|
| 50 |
+
args.upscale,
|
| 51 |
+
face_size=512,
|
| 52 |
+
crop_ratio=(1, 1),
|
| 53 |
+
det_model=args.detect_model,
|
| 54 |
+
res_model=args.restore_model,
|
| 55 |
+
bg_model=args.bg_model,
|
| 56 |
+
save_ext='png',
|
| 57 |
+
use_parse=True
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# -------------------- start to processing ---------------------
|
| 61 |
+
for i, img_path in enumerate(input_img_list):
|
| 62 |
+
# clean all the intermediate results to process the next image
|
| 63 |
+
restore_helper.clean_all()
|
| 64 |
+
|
| 65 |
+
if isinstance(img_path, str):
|
| 66 |
+
img_name = os.path.basename(img_path)
|
| 67 |
+
basename, ext = os.path.splitext(img_name)
|
| 68 |
+
print(f'[{i+1}/{test_img_num}] Processing: {img_name}')
|
| 69 |
+
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
| 70 |
+
|
| 71 |
+
restore_helper.read_image(img)
|
| 72 |
+
# get face landmarks for each face
|
| 73 |
+
num_det_faces = restore_helper.get_face_landmarks_5(
|
| 74 |
+
only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5)
|
| 75 |
+
print(f'\tdetect {num_det_faces} faces')
|
| 76 |
+
# align and warp each face
|
| 77 |
+
restore_helper.align_warp_face()
|
| 78 |
+
# face restoration for each cropped face
|
| 79 |
+
for idx, cropped_face in enumerate(restore_helper.cropped_faces):
|
| 80 |
+
# prepare data
|
| 81 |
+
cropped_face_t = (cropped_face.astype(np.float32) / 255.0) * 2.0 - 1.0
|
| 82 |
+
cropped_face_t = np.transpose(
|
| 83 |
+
np.expand_dims(np.ascontiguousarray(cropped_face_t[...,::-1]), axis=0),
|
| 84 |
+
(0,3,1,2)
|
| 85 |
+
)
|
| 86 |
+
#print('cropped_face_t', cropped_face_t.shape)
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
ort_outs = restore_helper.rs_sessison.run(
|
| 90 |
+
restore_helper.rs_output,
|
| 91 |
+
{restore_helper.rs_input: cropped_face_t}
|
| 92 |
+
)
|
| 93 |
+
restored_face = ort_outs[0]
|
| 94 |
+
restored_face = (restored_face.squeeze().transpose(1, 2, 0) * 0.5 + 0.5) * 255
|
| 95 |
+
restored_face = np.clip(restored_face[...,::-1], 0, 255).astype(np.uint8)
|
| 96 |
+
except Exception as error:
|
| 97 |
+
print(f'\tFailed inference for CodeFormer: {error}')
|
| 98 |
+
restored_face = (cropped_face_t.squeeze().transpose(1, 2, 0) * 0.5 + 0.5) * 255
|
| 99 |
+
restored_face = np.clip(restored_face, 0, 255).astype(np.uint8)
|
| 100 |
+
|
| 101 |
+
restored_face = restored_face.astype('uint8')
|
| 102 |
+
restore_helper.add_restored_face(restored_face, cropped_face)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# paste_back
|
| 106 |
+
if not args.has_aligned:
|
| 107 |
+
# upsample the background
|
| 108 |
+
# Now only support RealESRGAN for upsampling background
|
| 109 |
+
bg_img = restore_helper.background_upsampling(img)
|
| 110 |
+
restore_helper.get_inverse_affine(None)
|
| 111 |
+
# paste each restored face to the input image
|
| 112 |
+
restored_img = restore_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box)
|
| 113 |
+
|
| 114 |
+
# save faces
|
| 115 |
+
# for idx, (cropped_face, restored_face) in enumerate(zip(face_helper.cropped_faces, face_helper.restored_faces)):
|
| 116 |
+
# # save cropped face
|
| 117 |
+
# if not args.has_aligned:
|
| 118 |
+
# save_crop_path = os.path.join(result_root, 'cropped_faces', f'{basename}_{idx:02d}.png')
|
| 119 |
+
# imwrite(cropped_face, save_crop_path)
|
| 120 |
+
# # save restored face
|
| 121 |
+
# if args.has_aligned:
|
| 122 |
+
# save_face_name = f'{basename}.png'
|
| 123 |
+
# else:
|
| 124 |
+
# save_face_name = f'{basename}_{idx:02d}.png'
|
| 125 |
+
# if args.suffix is not None:
|
| 126 |
+
# save_face_name = f'{save_face_name[:-4]}_{args.suffix}.png'
|
| 127 |
+
# save_restore_path = os.path.join(result_root, 'restored_faces', save_face_name)
|
| 128 |
+
# imwrite(restored_face, save_restore_path)
|
| 129 |
+
|
| 130 |
+
# save restored img
|
| 131 |
+
if not args.has_aligned and restored_img is not None:
|
| 132 |
+
if args.suffix is not None:
|
| 133 |
+
basename = f'{basename}_{args.suffix}'
|
| 134 |
+
save_restore_path = os.path.join(result_root, 'final_results', f'{basename}.png')
|
| 135 |
+
imwrite(restored_img, save_restore_path)
|
| 136 |
+
|
python/utils/face_detector.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import copy
|
| 3 |
+
import re
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import axengine as ort
|
| 7 |
+
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
#from lib.datasets import letterbox
|
| 10 |
+
from utils.general import (
|
| 11 |
+
non_max_suppression_face,
|
| 12 |
+
scale_coords,
|
| 13 |
+
scale_coords_landmarks,
|
| 14 |
+
letterbox,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def isListempty(inList):
|
| 19 |
+
if isinstance(inList, list): # Is a list
|
| 20 |
+
return all(map(isListempty, inList))
|
| 21 |
+
return False # Not a list
|
| 22 |
+
|
| 23 |
+
class YoloDetector:
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
model_path='yolov5l-face.onnx',
|
| 27 |
+
min_face=10,
|
| 28 |
+
target_size=None,
|
| 29 |
+
):
|
| 30 |
+
"""
|
| 31 |
+
model_path: path to the .onnx model file.
|
| 32 |
+
min_face : minimal face size in pixels.
|
| 33 |
+
target_size : target size of smaller image axis (choose lower for faster work). e.g. 480, 720, 1080.
|
| 34 |
+
None for original resolution.
|
| 35 |
+
"""
|
| 36 |
+
self._class_path = Path(__file__).parent.absolute()
|
| 37 |
+
self.target_size = target_size
|
| 38 |
+
self.min_face = min_face
|
| 39 |
+
self.session = ort.InferenceSession(model_path)
|
| 40 |
+
self.input_name = self.session.get_inputs()[0].name
|
| 41 |
+
self.output_names = [x.name for x in self.session.get_outputs()]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _preprocess(self, imgs):
|
| 45 |
+
"""
|
| 46 |
+
Preprocessing image before passing through the network. Resize and conversion to torch tensor.
|
| 47 |
+
"""
|
| 48 |
+
pp_imgs = []
|
| 49 |
+
for img in imgs:
|
| 50 |
+
h0, w0 = img.shape[:2] # orig hw
|
| 51 |
+
if self.target_size:
|
| 52 |
+
r = self.target_size / min(h0, w0) # resize image to img_size
|
| 53 |
+
if r < 1:
|
| 54 |
+
img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_LINEAR)
|
| 55 |
+
|
| 56 |
+
#imgsz = check_img_size(max(img.shape[:2]), s=self.detector.stride.max()) # check img_size
|
| 57 |
+
imgsz = (640, 640)
|
| 58 |
+
img = letterbox(img, new_shape=imgsz)[0]
|
| 59 |
+
pp_imgs.append(img)
|
| 60 |
+
pp_imgs = np.array(pp_imgs)
|
| 61 |
+
#pp_imgs = pp_imgs.transpose(0, 3, 1, 2)
|
| 62 |
+
|
| 63 |
+
pp_imgs = pp_imgs.astype(np.float32) # uint8 to fp16/32
|
| 64 |
+
return pp_imgs
|
| 65 |
+
|
| 66 |
+
def _postprocess(self, imgs, origimgs, pred, conf_thres, iou_thres):
|
| 67 |
+
"""
|
| 68 |
+
Postprocessing of raw pytorch model output.
|
| 69 |
+
Returns:
|
| 70 |
+
bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2.
|
| 71 |
+
points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners).
|
| 72 |
+
"""
|
| 73 |
+
bboxes = [[] for _ in range(len(origimgs))]
|
| 74 |
+
landmarks = [[] for _ in range(len(origimgs))]
|
| 75 |
+
|
| 76 |
+
pred = non_max_suppression_face(pred, conf_thres, iou_thres)
|
| 77 |
+
|
| 78 |
+
for image_id, origimg in enumerate(origimgs):
|
| 79 |
+
img_shape = origimg.shape
|
| 80 |
+
image_height, image_width = img_shape[:2]
|
| 81 |
+
gn = torch.tensor(img_shape)[[1, 0, 1, 0]] # normalization gain whwh
|
| 82 |
+
gn_lks = torch.tensor(img_shape)[[1, 0, 1, 0, 1, 0, 1, 0, 1, 0]] # normalization gain landmarks
|
| 83 |
+
det = pred[image_id].cpu()
|
| 84 |
+
scale_coords(imgs[image_id].shape[0:], det[:, :4], img_shape).round()
|
| 85 |
+
scale_coords_landmarks(imgs[image_id].shape[0:], det[:, 5:15], img_shape).round()
|
| 86 |
+
|
| 87 |
+
for j in range(det.size()[0]):
|
| 88 |
+
box = (det[j, :4].view(1, 4) / gn).view(-1).tolist()
|
| 89 |
+
box = list(
|
| 90 |
+
map(int, [box[0] * image_width, box[1] * image_height, box[2] * image_width, box[3] * image_height])
|
| 91 |
+
)
|
| 92 |
+
if box[3] - box[1] < self.min_face:
|
| 93 |
+
continue
|
| 94 |
+
lm = (det[j, 5:15].view(1, 10) / gn_lks).view(-1).tolist()
|
| 95 |
+
lm = list(map(int, [i * image_width if j % 2 == 0 else i * image_height for j, i in enumerate(lm)]))
|
| 96 |
+
lm = [lm[i : i + 2] for i in range(0, len(lm), 2)]
|
| 97 |
+
bboxes[image_id].append(box)
|
| 98 |
+
landmarks[image_id].append(lm)
|
| 99 |
+
return bboxes, landmarks
|
| 100 |
+
|
| 101 |
+
def detect_faces(self, imgs, conf_thres=0.7, iou_thres=0.5):
|
| 102 |
+
"""
|
| 103 |
+
Get bbox coordinates and keypoints of faces on original image.
|
| 104 |
+
Params:
|
| 105 |
+
imgs: image or list of images to detect faces on with BGR order (convert to RGB order for inference)
|
| 106 |
+
conf_thres: confidence threshold for each prediction
|
| 107 |
+
iou_thres: threshold for NMS (filter of intersecting bboxes)
|
| 108 |
+
Returns:
|
| 109 |
+
bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2.
|
| 110 |
+
points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners).
|
| 111 |
+
"""
|
| 112 |
+
# Pass input images through face detector
|
| 113 |
+
images = imgs if isinstance(imgs, list) else [imgs]
|
| 114 |
+
images = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images]
|
| 115 |
+
origimgs = copy.deepcopy(images)
|
| 116 |
+
|
| 117 |
+
images = self._preprocess(images)
|
| 118 |
+
|
| 119 |
+
# process ONNX model
|
| 120 |
+
pred = self.session.run(self.output_names, {self.input_name: images})[0]
|
| 121 |
+
pred = torch.from_numpy(pred)
|
| 122 |
+
|
| 123 |
+
# postprocess the output
|
| 124 |
+
bboxes, points = self._postprocess(images, origimgs, pred, conf_thres, iou_thres)
|
| 125 |
+
|
| 126 |
+
# return bboxes, points
|
| 127 |
+
if not isListempty(points):
|
| 128 |
+
bboxes = np.array(bboxes).reshape(-1,4)
|
| 129 |
+
points = np.array(points).reshape(-1,10)
|
| 130 |
+
padding = bboxes[:,0].reshape(-1,1)
|
| 131 |
+
return np.concatenate((bboxes, padding, points), axis=1)
|
| 132 |
+
else:
|
| 133 |
+
return None
|
| 134 |
+
|
| 135 |
+
def __call__(self, *args):
|
| 136 |
+
return self.predict(*args)
|
python/utils/general.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# General utils
|
| 2 |
+
|
| 3 |
+
import glob
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torchvision
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
# Settings
|
| 15 |
+
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
| 16 |
+
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
| 17 |
+
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
|
| 18 |
+
os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
|
| 19 |
+
|
| 20 |
+
def init_seeds(seed=0):
|
| 21 |
+
# Initialize random number generator (RNG) seeds
|
| 22 |
+
random.seed(seed)
|
| 23 |
+
np.random.seed(seed)
|
| 24 |
+
init_torch_seeds(seed)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_latest_run(search_dir='.'):
|
| 28 |
+
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
|
| 29 |
+
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
|
| 30 |
+
return max(last_list, key=os.path.getctime) if last_list else ''
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def imwrite(img, file_path, params=None, auto_mkdir=True):
|
| 34 |
+
"""Write image to file.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
img (ndarray): Image array to be written.
|
| 38 |
+
file_path (str): Image file path.
|
| 39 |
+
params (None or list): Same as opencv's :func:`imwrite` interface.
|
| 40 |
+
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
|
| 41 |
+
whether to create it automatically.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
bool: Successful or not.
|
| 45 |
+
"""
|
| 46 |
+
if auto_mkdir:
|
| 47 |
+
dir_name = os.path.abspath(os.path.dirname(file_path))
|
| 48 |
+
os.makedirs(dir_name, exist_ok=True)
|
| 49 |
+
return cv2.imwrite(file_path, img, params)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def img2tensor(imgs, bgr2rgb=True, float32=True):
|
| 53 |
+
"""Numpy array to tensor.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
imgs (list[ndarray] | ndarray): Input images.
|
| 57 |
+
bgr2rgb (bool): Whether to change bgr to rgb.
|
| 58 |
+
float32 (bool): Whether to change to float32.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
list[tensor] | tensor: Tensor images. If returned results only have
|
| 62 |
+
one element, just return tensor.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def _totensor(img, bgr2rgb, float32):
|
| 66 |
+
if img.shape[2] == 3 and bgr2rgb:
|
| 67 |
+
if img.dtype == 'float64':
|
| 68 |
+
img = img.astype('float32')
|
| 69 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 70 |
+
img = torch.from_numpy(img.transpose(2, 0, 1))
|
| 71 |
+
if float32:
|
| 72 |
+
img = img.float()
|
| 73 |
+
return img
|
| 74 |
+
|
| 75 |
+
if isinstance(imgs, list):
|
| 76 |
+
return [_totensor(img, bgr2rgb, float32) for img in imgs]
|
| 77 |
+
else:
|
| 78 |
+
return _totensor(imgs, bgr2rgb, float32)
|
| 79 |
+
|
| 80 |
+
def is_gray(img, threshold=10):
|
| 81 |
+
img = Image.fromarray(img)
|
| 82 |
+
if len(img.getbands()) == 1:
|
| 83 |
+
return True
|
| 84 |
+
img1 = np.asarray(img.getchannel(channel=0), dtype=np.int16)
|
| 85 |
+
img2 = np.asarray(img.getchannel(channel=1), dtype=np.int16)
|
| 86 |
+
img3 = np.asarray(img.getchannel(channel=2), dtype=np.int16)
|
| 87 |
+
diff1 = (img1 - img2).var()
|
| 88 |
+
diff2 = (img2 - img3).var()
|
| 89 |
+
diff3 = (img3 - img1).var()
|
| 90 |
+
diff_sum = (diff1 + diff2 + diff3) / 3.0
|
| 91 |
+
if diff_sum <= threshold:
|
| 92 |
+
return True
|
| 93 |
+
else:
|
| 94 |
+
return False
|
| 95 |
+
|
| 96 |
+
def rgb2gray(img, out_channel=3):
|
| 97 |
+
r, g, b = img[:,:,0], img[:,:,1], img[:,:,2]
|
| 98 |
+
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
|
| 99 |
+
if out_channel == 3:
|
| 100 |
+
gray = gray[:,:,np.newaxis].repeat(3, axis=2)
|
| 101 |
+
return gray
|
| 102 |
+
|
| 103 |
+
def bgr2gray(img, out_channel=3):
|
| 104 |
+
b, g, r = img[:,:,0], img[:,:,1], img[:,:,2]
|
| 105 |
+
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
|
| 106 |
+
if out_channel == 3:
|
| 107 |
+
gray = gray[:,:,np.newaxis].repeat(3, axis=2)
|
| 108 |
+
return gray
|
| 109 |
+
|
| 110 |
+
def calc_mean_std(feat, eps=1e-5):
|
| 111 |
+
"""
|
| 112 |
+
Args:
|
| 113 |
+
feat (numpy): 3D [w h c]s
|
| 114 |
+
"""
|
| 115 |
+
size = feat.shape
|
| 116 |
+
assert len(size) == 3, 'The input feature should be 3D tensor.'
|
| 117 |
+
c = size[2]
|
| 118 |
+
feat_var = feat.reshape(-1, c).var(axis=0) + eps
|
| 119 |
+
feat_std = np.sqrt(feat_var).reshape(1, 1, c)
|
| 120 |
+
feat_mean = feat.reshape(-1, c).mean(axis=0).reshape(1, 1, c)
|
| 121 |
+
return feat_mean, feat_std
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def adain_npy(content_feat, style_feat):
|
| 125 |
+
"""Adaptive instance normalization for numpy.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
content_feat (numpy): The input feature.
|
| 129 |
+
style_feat (numpy): The reference feature.
|
| 130 |
+
"""
|
| 131 |
+
size = content_feat.shape
|
| 132 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
| 133 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
| 134 |
+
normalized_feat = (content_feat - np.broadcast_to(content_mean, size)) / np.broadcast_to(content_std, size)
|
| 135 |
+
return normalized_feat * np.broadcast_to(style_std, size) + np.broadcast_to(style_mean, size)
|
| 136 |
+
|
| 137 |
+
def xyxy2xywh(x):
|
| 138 |
+
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
|
| 139 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
| 140 |
+
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
|
| 141 |
+
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
|
| 142 |
+
y[:, 2] = x[:, 2] - x[:, 0] # width
|
| 143 |
+
y[:, 3] = x[:, 3] - x[:, 1] # height
|
| 144 |
+
return y
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def xywh2xyxy(x):
|
| 148 |
+
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
| 149 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
| 150 |
+
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
| 151 |
+
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
| 152 |
+
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
| 153 |
+
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
|
| 154 |
+
return y
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def xywhn2xyxy(x, w=640, h=640, padw=32, padh=32):
|
| 158 |
+
# Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
| 159 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
| 160 |
+
y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
|
| 161 |
+
y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
|
| 162 |
+
y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
|
| 163 |
+
y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
|
| 164 |
+
return y
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
|
| 168 |
+
# Rescale coords (xyxy) from img1_shape to img0_shape
|
| 169 |
+
if ratio_pad is None: # calculate from img0_shape
|
| 170 |
+
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
| 171 |
+
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
|
| 172 |
+
else:
|
| 173 |
+
gain = ratio_pad[0][0]
|
| 174 |
+
pad = ratio_pad[1]
|
| 175 |
+
|
| 176 |
+
coords[:, [0, 2]] -= pad[0] # x padding
|
| 177 |
+
coords[:, [1, 3]] -= pad[1] # y padding
|
| 178 |
+
coords[:, :4] /= gain
|
| 179 |
+
clip_coords(coords, img0_shape)
|
| 180 |
+
return coords
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def clip_coords(boxes, img_shape):
|
| 184 |
+
# Clip bounding xyxy bounding boxes to image shape (height, width)
|
| 185 |
+
boxes[:, 0].clamp_(0, img_shape[1]) # x1
|
| 186 |
+
boxes[:, 1].clamp_(0, img_shape[0]) # y1
|
| 187 |
+
boxes[:, 2].clamp_(0, img_shape[1]) # x2
|
| 188 |
+
boxes[:, 3].clamp_(0, img_shape[0]) # y2
|
| 189 |
+
|
| 190 |
+
def box_iou(box1, box2):
|
| 191 |
+
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
|
| 192 |
+
"""
|
| 193 |
+
Return intersection-over-union (Jaccard index) of boxes.
|
| 194 |
+
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
| 195 |
+
Arguments:
|
| 196 |
+
box1 (Tensor[N, 4])
|
| 197 |
+
box2 (Tensor[M, 4])
|
| 198 |
+
Returns:
|
| 199 |
+
iou (Tensor[N, M]): the NxM matrix containing the pairwise
|
| 200 |
+
IoU values for every element in boxes1 and boxes2
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
def box_area(box):
|
| 204 |
+
# box = 4xn
|
| 205 |
+
return (box[2] - box[0]) * (box[3] - box[1])
|
| 206 |
+
|
| 207 |
+
area1 = box_area(box1.T)
|
| 208 |
+
area2 = box_area(box2.T)
|
| 209 |
+
|
| 210 |
+
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
| 211 |
+
inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) -
|
| 212 |
+
torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
|
| 213 |
+
# iou = inter / (area1 + area2 - inter)
|
| 214 |
+
return inter / (area1[:, None] + area2 - inter)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def wh_iou(wh1, wh2):
|
| 218 |
+
# Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
|
| 219 |
+
wh1 = wh1[:, None] # [N,1,2]
|
| 220 |
+
wh2 = wh2[None] # [1,M,2]
|
| 221 |
+
inter = torch.min(wh1, wh2).prod(2) # [N,M]
|
| 222 |
+
# iou = inter / (area1 + area2 - inter)
|
| 223 |
+
return inter / (wh1.prod(2) + wh2.prod(2) - inter)
|
| 224 |
+
|
| 225 |
+
def non_max_suppression_face(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
|
| 226 |
+
"""Performs Non-Maximum Suppression (NMS) on inference results
|
| 227 |
+
Returns:
|
| 228 |
+
detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
nc = prediction.shape[2] - 15 # number of classes
|
| 232 |
+
xc = prediction[..., 4] > conf_thres # candidates
|
| 233 |
+
|
| 234 |
+
# Settings
|
| 235 |
+
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
|
| 236 |
+
time_limit = 10.0 # seconds to quit after
|
| 237 |
+
redundant = True # require redundant detections
|
| 238 |
+
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
|
| 239 |
+
merge = False # use merge-NMS
|
| 240 |
+
|
| 241 |
+
t = time.time()
|
| 242 |
+
output = [torch.zeros((0, 16), device=prediction.device)] * prediction.shape[0]
|
| 243 |
+
for xi, x in enumerate(prediction): # image index, image inference
|
| 244 |
+
# Apply constraints
|
| 245 |
+
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
| 246 |
+
x = x[xc[xi]] # confidence
|
| 247 |
+
|
| 248 |
+
# Cat apriori labels if autolabelling
|
| 249 |
+
if labels and len(labels[xi]):
|
| 250 |
+
l = labels[xi]
|
| 251 |
+
v = torch.zeros((len(l), nc + 15), device=x.device)
|
| 252 |
+
v[:, :4] = l[:, 1:5] # box
|
| 253 |
+
v[:, 4] = 1.0 # conf
|
| 254 |
+
v[range(len(l)), l[:, 0].long() + 15] = 1.0 # cls
|
| 255 |
+
x = torch.cat((x, v), 0)
|
| 256 |
+
|
| 257 |
+
# If none remain process next image
|
| 258 |
+
if not x.shape[0]:
|
| 259 |
+
continue
|
| 260 |
+
|
| 261 |
+
# Compute conf
|
| 262 |
+
x[:, 15:] *= x[:, 4:5] # conf = obj_conf * cls_conf
|
| 263 |
+
|
| 264 |
+
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
| 265 |
+
box = xywh2xyxy(x[:, :4])
|
| 266 |
+
|
| 267 |
+
# Detections matrix nx6 (xyxy, conf, landmarks, cls)
|
| 268 |
+
if multi_label:
|
| 269 |
+
i, j = (x[:, 15:] > conf_thres).nonzero(as_tuple=False).T
|
| 270 |
+
x = torch.cat((box[i], x[i, j + 15, None], x[i, 5:15] ,j[:, None].float()), 1)
|
| 271 |
+
else: # best class only
|
| 272 |
+
conf, j = x[:, 15:].max(1, keepdim=True)
|
| 273 |
+
x = torch.cat((box, conf, x[:, 5:15], j.float()), 1)[conf.view(-1) > conf_thres]
|
| 274 |
+
|
| 275 |
+
# Filter by class
|
| 276 |
+
if classes is not None:
|
| 277 |
+
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
|
| 278 |
+
|
| 279 |
+
# If none remain process next image
|
| 280 |
+
n = x.shape[0] # number of boxes
|
| 281 |
+
if not n:
|
| 282 |
+
continue
|
| 283 |
+
|
| 284 |
+
# Batched NMS
|
| 285 |
+
c = x[:, 15:16] * (0 if agnostic else max_wh) # classes
|
| 286 |
+
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
| 287 |
+
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
| 288 |
+
#if i.shape[0] > max_det: # limit detections
|
| 289 |
+
# i = i[:max_det]
|
| 290 |
+
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
| 291 |
+
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
| 292 |
+
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
|
| 293 |
+
weights = iou * scores[None] # box weights
|
| 294 |
+
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
|
| 295 |
+
if redundant:
|
| 296 |
+
i = i[iou.sum(1) > 1] # require redundancy
|
| 297 |
+
|
| 298 |
+
output[xi] = x[i]
|
| 299 |
+
if (time.time() - t) > time_limit:
|
| 300 |
+
break # time limit exceeded
|
| 301 |
+
|
| 302 |
+
return output
|
| 303 |
+
|
| 304 |
+
def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None):
|
| 305 |
+
# Rescale coords (xyxy) from img1_shape to img0_shape
|
| 306 |
+
if ratio_pad is None: # calculate from img0_shape
|
| 307 |
+
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
| 308 |
+
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
|
| 309 |
+
else:
|
| 310 |
+
gain = ratio_pad[0][0]
|
| 311 |
+
pad = ratio_pad[1]
|
| 312 |
+
|
| 313 |
+
coords[:, [0, 2, 4, 6, 8]] -= pad[0] # x padding
|
| 314 |
+
coords[:, [1, 3, 5, 7, 9]] -= pad[1] # y padding
|
| 315 |
+
coords[:, :10] /= gain
|
| 316 |
+
coords[:, 0].clamp_(0, img0_shape[1]) # x1
|
| 317 |
+
coords[:, 1].clamp_(0, img0_shape[0]) # y1
|
| 318 |
+
coords[:, 2].clamp_(0, img0_shape[1]) # x2
|
| 319 |
+
coords[:, 3].clamp_(0, img0_shape[0]) # y2
|
| 320 |
+
coords[:, 4].clamp_(0, img0_shape[1]) # x3
|
| 321 |
+
coords[:, 5].clamp_(0, img0_shape[0]) # y3
|
| 322 |
+
coords[:, 6].clamp_(0, img0_shape[1]) # x4
|
| 323 |
+
coords[:, 7].clamp_(0, img0_shape[0]) # y4
|
| 324 |
+
coords[:, 8].clamp_(0, img0_shape[1]) # x5
|
| 325 |
+
coords[:, 9].clamp_(0, img0_shape[0]) # y5
|
| 326 |
+
return coords
|
| 327 |
+
|
| 328 |
+
def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=False, scaleFill=False, scaleup=True):
|
| 329 |
+
# Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
|
| 330 |
+
shape = img.shape[:2] # current shape [height, width]
|
| 331 |
+
if isinstance(new_shape, int):
|
| 332 |
+
new_shape = (new_shape, new_shape)
|
| 333 |
+
|
| 334 |
+
# Scale ratio (new / old)
|
| 335 |
+
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
| 336 |
+
if not scaleup: # only scale down, do not scale up (for better test mAP)
|
| 337 |
+
r = min(r, 1.0)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
# Compute padding
|
| 341 |
+
ratio = r, r # width, height ratios
|
| 342 |
+
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
| 343 |
+
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
| 344 |
+
#(dw, dh)
|
| 345 |
+
if auto: # minimum rectangle
|
| 346 |
+
dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding
|
| 347 |
+
#print(dw, dh)
|
| 348 |
+
elif scaleFill: # stretch
|
| 349 |
+
dw, dh = 0.0, 0.0
|
| 350 |
+
new_unpad = (new_shape[1], new_shape[0])
|
| 351 |
+
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
|
| 352 |
+
|
| 353 |
+
dw /= 2 # divide padding into 2 sides
|
| 354 |
+
dh /= 2
|
| 355 |
+
|
| 356 |
+
if shape[::-1] != new_unpad: # resize
|
| 357 |
+
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
| 358 |
+
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
| 359 |
+
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
| 360 |
+
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
|
| 361 |
+
return img, ratio, (dw, dh)
|
python/utils/restoration_helper.py
ADDED
|
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import axengine as ort
|
| 7 |
+
|
| 8 |
+
from utils.general import img2tensor, imwrite, is_gray, bgr2gray, adain_npy
|
| 9 |
+
from utils.face_detector import YoloDetector
|
| 10 |
+
|
| 11 |
+
def get_largest_face(det_faces, h, w):
|
| 12 |
+
def get_location(val, length):
|
| 13 |
+
if val < 0:
|
| 14 |
+
return 0
|
| 15 |
+
elif val > length:
|
| 16 |
+
return length
|
| 17 |
+
else:
|
| 18 |
+
return val
|
| 19 |
+
|
| 20 |
+
face_areas = []
|
| 21 |
+
for det_face in det_faces:
|
| 22 |
+
left = get_location(det_face[0], w)
|
| 23 |
+
right = get_location(det_face[2], w)
|
| 24 |
+
top = get_location(det_face[1], h)
|
| 25 |
+
bottom = get_location(det_face[3], h)
|
| 26 |
+
face_area = (right - left) * (bottom - top)
|
| 27 |
+
face_areas.append(face_area)
|
| 28 |
+
largest_idx = face_areas.index(max(face_areas))
|
| 29 |
+
return det_faces[largest_idx], largest_idx
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_center_face(det_faces, h=0, w=0, center=None):
|
| 33 |
+
if center is not None:
|
| 34 |
+
center = np.array(center)
|
| 35 |
+
else:
|
| 36 |
+
center = np.array([w / 2, h / 2])
|
| 37 |
+
center_dist = []
|
| 38 |
+
for det_face in det_faces:
|
| 39 |
+
face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
|
| 40 |
+
dist = np.linalg.norm(face_center - center)
|
| 41 |
+
center_dist.append(dist)
|
| 42 |
+
center_idx = center_dist.index(min(center_dist))
|
| 43 |
+
return det_faces[center_idx], center_idx
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class RestoreHelper(object):
|
| 47 |
+
"""Helper for the restoration pipeline (base class)."""
|
| 48 |
+
def __init__(self,
|
| 49 |
+
upscale_factor,
|
| 50 |
+
face_size=512,
|
| 51 |
+
crop_ratio=(1, 1),
|
| 52 |
+
det_model='yolov5l-face.onnx',
|
| 53 |
+
res_model='codeformer.onnx',
|
| 54 |
+
bg_model='realesrgan_x2.onnx' ,
|
| 55 |
+
save_ext='png',
|
| 56 |
+
template_3points=False,
|
| 57 |
+
pad_blur=False,
|
| 58 |
+
use_parse=False,
|
| 59 |
+
):
|
| 60 |
+
|
| 61 |
+
# face alignment params
|
| 62 |
+
self.template_3points = template_3points # improve robustness
|
| 63 |
+
self.upscale_factor = int(upscale_factor)
|
| 64 |
+
# the cropped face ratio based on the square face
|
| 65 |
+
self.crop_ratio = crop_ratio # (h, w)
|
| 66 |
+
assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
|
| 67 |
+
self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
|
| 68 |
+
|
| 69 |
+
# standard 5 landmarks for FFHQ faces with 512 x 512
|
| 70 |
+
# facexlib
|
| 71 |
+
self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
|
| 72 |
+
[201.26117, 371.41043], [313.08905, 371.15118]])
|
| 73 |
+
|
| 74 |
+
# dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
|
| 75 |
+
# self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
|
| 76 |
+
# [198.22603, 372.82502], [313.91018, 372.75659]])
|
| 77 |
+
|
| 78 |
+
self.face_template = self.face_template * (face_size / 512.0)
|
| 79 |
+
if self.crop_ratio[0] > 1:
|
| 80 |
+
self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
|
| 81 |
+
if self.crop_ratio[1] > 1:
|
| 82 |
+
self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
|
| 83 |
+
self.save_ext = save_ext
|
| 84 |
+
self.pad_blur = pad_blur
|
| 85 |
+
if self.pad_blur is True:
|
| 86 |
+
self.template_3points = False
|
| 87 |
+
|
| 88 |
+
self.all_landmarks_5 = []
|
| 89 |
+
self.det_faces = []
|
| 90 |
+
self.affine_matrices = []
|
| 91 |
+
self.inverse_affine_matrices = []
|
| 92 |
+
self.cropped_faces = []
|
| 93 |
+
self.restored_faces = []
|
| 94 |
+
self.pad_input_imgs = []
|
| 95 |
+
|
| 96 |
+
# init face detection model
|
| 97 |
+
self.face_detector = YoloDetector(model_path=det_model)
|
| 98 |
+
|
| 99 |
+
# init face parsing model
|
| 100 |
+
self.use_parse = use_parse
|
| 101 |
+
#self.face_parse = init_parsing_model(model_name='parsenet')
|
| 102 |
+
|
| 103 |
+
#init face restore model
|
| 104 |
+
self.res_model = res_model
|
| 105 |
+
self.rs_sessison, self.rs_input, self.rs_output = self.init_face_restoration()
|
| 106 |
+
|
| 107 |
+
# init background upsampling model
|
| 108 |
+
self.tile = 108
|
| 109 |
+
self.tile_pad = 10
|
| 110 |
+
self.scale = 2
|
| 111 |
+
self.bg_model = bg_model
|
| 112 |
+
self.bg_sessison, self.bg_input, self.bg_output = self.init_background_upsampling()
|
| 113 |
+
|
| 114 |
+
def set_upscale_factor(self, upscale_factor):
|
| 115 |
+
self.upscale_factor = upscale_factor
|
| 116 |
+
|
| 117 |
+
def read_image(self, img):
|
| 118 |
+
"""img can be image path or cv2 loaded image."""
|
| 119 |
+
# self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
|
| 120 |
+
if isinstance(img, str):
|
| 121 |
+
img = cv2.imread(img)
|
| 122 |
+
|
| 123 |
+
if np.max(img) > 256: # 16-bit image
|
| 124 |
+
img = img / 65535 * 255
|
| 125 |
+
if len(img.shape) == 2: # gray image
|
| 126 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
| 127 |
+
elif img.shape[2] == 4: # BGRA image with alpha channel
|
| 128 |
+
img = img[:, :, 0:3]
|
| 129 |
+
|
| 130 |
+
self.input_img = img
|
| 131 |
+
self.is_gray = is_gray(img, threshold=10)
|
| 132 |
+
if self.is_gray:
|
| 133 |
+
print('Grayscale input: True')
|
| 134 |
+
|
| 135 |
+
if min(self.input_img.shape[:2])<512:
|
| 136 |
+
f = 512.0/min(self.input_img.shape[:2])
|
| 137 |
+
self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
|
| 138 |
+
|
| 139 |
+
def get_face_landmarks_5(self,
|
| 140 |
+
only_keep_largest=False,
|
| 141 |
+
only_center_face=False,
|
| 142 |
+
resize=None,
|
| 143 |
+
blur_ratio=0.01,
|
| 144 |
+
eye_dist_threshold=None):
|
| 145 |
+
|
| 146 |
+
if resize is None:
|
| 147 |
+
scale = 1
|
| 148 |
+
input_img = self.input_img
|
| 149 |
+
else:
|
| 150 |
+
h, w = self.input_img.shape[0:2]
|
| 151 |
+
scale = resize / min(h, w)
|
| 152 |
+
# scale = max(1, scale) # always scale up; comment this out for HD images, e.g., AIGC faces.
|
| 153 |
+
h, w = int(h * scale), int(w * scale)
|
| 154 |
+
interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
|
| 155 |
+
input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
|
| 156 |
+
|
| 157 |
+
#with torch.no_grad():
|
| 158 |
+
bboxes = self.face_detector.detect_faces(input_img)
|
| 159 |
+
|
| 160 |
+
if bboxes is None or bboxes.shape[0] == 0:
|
| 161 |
+
return 0
|
| 162 |
+
else:
|
| 163 |
+
bboxes = bboxes / scale
|
| 164 |
+
|
| 165 |
+
for bbox in bboxes:
|
| 166 |
+
# remove faces with too small eye distance: side faces or too small faces
|
| 167 |
+
eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
|
| 168 |
+
if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
|
| 169 |
+
continue
|
| 170 |
+
|
| 171 |
+
if self.template_3points:
|
| 172 |
+
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
|
| 173 |
+
else:
|
| 174 |
+
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
|
| 175 |
+
self.all_landmarks_5.append(landmark)
|
| 176 |
+
self.det_faces.append(bbox[0:5])
|
| 177 |
+
|
| 178 |
+
if len(self.det_faces) == 0:
|
| 179 |
+
return 0
|
| 180 |
+
if only_keep_largest:
|
| 181 |
+
h, w, _ = self.input_img.shape
|
| 182 |
+
self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
|
| 183 |
+
self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
|
| 184 |
+
elif only_center_face:
|
| 185 |
+
h, w, _ = self.input_img.shape
|
| 186 |
+
self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
|
| 187 |
+
self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
|
| 188 |
+
|
| 189 |
+
# pad blurry images
|
| 190 |
+
if self.pad_blur:
|
| 191 |
+
self.pad_input_imgs = []
|
| 192 |
+
for landmarks in self.all_landmarks_5:
|
| 193 |
+
# get landmarks
|
| 194 |
+
eye_left = landmarks[0, :]
|
| 195 |
+
eye_right = landmarks[1, :]
|
| 196 |
+
eye_avg = (eye_left + eye_right) * 0.5
|
| 197 |
+
mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
|
| 198 |
+
eye_to_eye = eye_right - eye_left
|
| 199 |
+
eye_to_mouth = mouth_avg - eye_avg
|
| 200 |
+
|
| 201 |
+
# Get the oriented crop rectangle
|
| 202 |
+
# x: half width of the oriented crop rectangle
|
| 203 |
+
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
|
| 204 |
+
# - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
|
| 205 |
+
# norm with the hypotenuse: get the direction
|
| 206 |
+
x /= np.hypot(*x) # get the hypotenuse of a right triangle
|
| 207 |
+
rect_scale = 1.5
|
| 208 |
+
x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
|
| 209 |
+
# y: half height of the oriented crop rectangle
|
| 210 |
+
y = np.flipud(x) * [-1, 1]
|
| 211 |
+
|
| 212 |
+
# c: center
|
| 213 |
+
c = eye_avg + eye_to_mouth * 0.1
|
| 214 |
+
# quad: (left_top, left_bottom, right_bottom, right_top)
|
| 215 |
+
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
|
| 216 |
+
# qsize: side length of the square
|
| 217 |
+
qsize = np.hypot(*x) * 2
|
| 218 |
+
border = max(int(np.rint(qsize * 0.1)), 3)
|
| 219 |
+
|
| 220 |
+
# get pad
|
| 221 |
+
# pad: (width_left, height_top, width_right, height_bottom)
|
| 222 |
+
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
|
| 223 |
+
int(np.ceil(max(quad[:, 1]))))
|
| 224 |
+
pad = [
|
| 225 |
+
max(-pad[0] + border, 1),
|
| 226 |
+
max(-pad[1] + border, 1),
|
| 227 |
+
max(pad[2] - self.input_img.shape[0] + border, 1),
|
| 228 |
+
max(pad[3] - self.input_img.shape[1] + border, 1)
|
| 229 |
+
]
|
| 230 |
+
|
| 231 |
+
if max(pad) > 1:
|
| 232 |
+
# pad image
|
| 233 |
+
pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
|
| 234 |
+
# modify landmark coords
|
| 235 |
+
landmarks[:, 0] += pad[0]
|
| 236 |
+
landmarks[:, 1] += pad[1]
|
| 237 |
+
# blur pad images
|
| 238 |
+
h, w, _ = pad_img.shape
|
| 239 |
+
y, x, _ = np.ogrid[:h, :w, :1]
|
| 240 |
+
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
|
| 241 |
+
np.float32(w - 1 - x) / pad[2]),
|
| 242 |
+
1.0 - np.minimum(np.float32(y) / pad[1],
|
| 243 |
+
np.float32(h - 1 - y) / pad[3]))
|
| 244 |
+
blur = int(qsize * blur_ratio)
|
| 245 |
+
if blur % 2 == 0:
|
| 246 |
+
blur += 1
|
| 247 |
+
blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
|
| 248 |
+
# blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
|
| 249 |
+
|
| 250 |
+
pad_img = pad_img.astype('float32')
|
| 251 |
+
pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
|
| 252 |
+
pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
|
| 253 |
+
pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
|
| 254 |
+
self.pad_input_imgs.append(pad_img)
|
| 255 |
+
else:
|
| 256 |
+
self.pad_input_imgs.append(np.copy(self.input_img))
|
| 257 |
+
|
| 258 |
+
return len(self.all_landmarks_5)
|
| 259 |
+
|
| 260 |
+
def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
|
| 261 |
+
"""Align and warp faces with face template.
|
| 262 |
+
"""
|
| 263 |
+
if self.pad_blur:
|
| 264 |
+
assert len(self.pad_input_imgs) == len(
|
| 265 |
+
self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
|
| 266 |
+
for idx, landmark in enumerate(self.all_landmarks_5):
|
| 267 |
+
# use 5 landmarks to get affine matrix
|
| 268 |
+
# use cv2.LMEDS method for the equivalence to skimage transform
|
| 269 |
+
# ref: https://blog.csdn.net/yichxi/article/details/115827338
|
| 270 |
+
affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
|
| 271 |
+
self.affine_matrices.append(affine_matrix)
|
| 272 |
+
# warp and crop faces
|
| 273 |
+
if border_mode == 'constant':
|
| 274 |
+
border_mode = cv2.BORDER_CONSTANT
|
| 275 |
+
elif border_mode == 'reflect101':
|
| 276 |
+
border_mode = cv2.BORDER_REFLECT101
|
| 277 |
+
elif border_mode == 'reflect':
|
| 278 |
+
border_mode = cv2.BORDER_REFLECT
|
| 279 |
+
if self.pad_blur:
|
| 280 |
+
input_img = self.pad_input_imgs[idx]
|
| 281 |
+
else:
|
| 282 |
+
input_img = self.input_img
|
| 283 |
+
cropped_face = cv2.warpAffine(
|
| 284 |
+
input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray
|
| 285 |
+
self.cropped_faces.append(cropped_face)
|
| 286 |
+
# save the cropped face
|
| 287 |
+
if save_cropped_path is not None:
|
| 288 |
+
path = os.path.splitext(save_cropped_path)[0]
|
| 289 |
+
save_path = f'{path}_{idx:02d}.{self.save_ext}'
|
| 290 |
+
imwrite(cropped_face, save_path)
|
| 291 |
+
|
| 292 |
+
def get_inverse_affine(self, save_inverse_affine_path=None):
|
| 293 |
+
"""Get inverse affine matrix."""
|
| 294 |
+
for idx, affine_matrix in enumerate(self.affine_matrices):
|
| 295 |
+
inverse_affine = cv2.invertAffineTransform(affine_matrix)
|
| 296 |
+
inverse_affine *= self.upscale_factor
|
| 297 |
+
self.inverse_affine_matrices.append(inverse_affine)
|
| 298 |
+
# save inverse affine matrices
|
| 299 |
+
if save_inverse_affine_path is not None:
|
| 300 |
+
path, _ = os.path.splitext(save_inverse_affine_path)
|
| 301 |
+
save_path = f'{path}_{idx:02d}.pth'
|
| 302 |
+
torch.save(inverse_affine, save_path)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def add_restored_face(self, restored_face, input_face=None):
|
| 306 |
+
if self.is_gray:
|
| 307 |
+
restored_face = bgr2gray(restored_face) # convert img into grayscale
|
| 308 |
+
if input_face is not None:
|
| 309 |
+
restored_face = adain_npy(restored_face, input_face) # transfer the color
|
| 310 |
+
self.restored_faces.append(restored_face)
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
|
| 314 |
+
h, w, _ = self.input_img.shape
|
| 315 |
+
h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
|
| 316 |
+
|
| 317 |
+
if upsample_img is None:
|
| 318 |
+
# simply resize the background
|
| 319 |
+
# upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
|
| 320 |
+
upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR)
|
| 321 |
+
else:
|
| 322 |
+
upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
|
| 323 |
+
|
| 324 |
+
assert len(self.restored_faces) == len(
|
| 325 |
+
self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
|
| 326 |
+
|
| 327 |
+
inv_mask_borders = []
|
| 328 |
+
for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
|
| 329 |
+
if face_upsampler is not None:
|
| 330 |
+
restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0]
|
| 331 |
+
inverse_affine /= self.upscale_factor
|
| 332 |
+
inverse_affine[:, 2] *= self.upscale_factor
|
| 333 |
+
face_size = (self.face_size[0]*self.upscale_factor, self.face_size[1]*self.upscale_factor)
|
| 334 |
+
else:
|
| 335 |
+
# Add an offset to inverse affine matrix, for more precise back alignment
|
| 336 |
+
if self.upscale_factor > 1:
|
| 337 |
+
extra_offset = 0.5 * self.upscale_factor
|
| 338 |
+
else:
|
| 339 |
+
extra_offset = 0
|
| 340 |
+
inverse_affine[:, 2] += extra_offset
|
| 341 |
+
face_size = self.face_size
|
| 342 |
+
inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
|
| 343 |
+
|
| 344 |
+
# always use square mask
|
| 345 |
+
mask = np.ones(face_size, dtype=np.float32)
|
| 346 |
+
inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
|
| 347 |
+
# remove the black borders
|
| 348 |
+
inv_mask_erosion = cv2.erode(
|
| 349 |
+
inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
|
| 350 |
+
pasted_face = inv_mask_erosion[:, :, None] * inv_restored
|
| 351 |
+
total_face_area = np.sum(inv_mask_erosion) # // 3
|
| 352 |
+
# add border
|
| 353 |
+
if draw_box:
|
| 354 |
+
h, w = face_size
|
| 355 |
+
mask_border = np.ones((h, w, 3), dtype=np.float32)
|
| 356 |
+
border = int(1400/np.sqrt(total_face_area))
|
| 357 |
+
mask_border[border:h-border, border:w-border,:] = 0
|
| 358 |
+
inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
|
| 359 |
+
inv_mask_borders.append(inv_mask_border)
|
| 360 |
+
# compute the fusion edge based on the area of face
|
| 361 |
+
w_edge = int(total_face_area**0.5) // 20
|
| 362 |
+
erosion_radius = w_edge * 2
|
| 363 |
+
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
|
| 364 |
+
blur_size = w_edge * 2
|
| 365 |
+
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
|
| 366 |
+
if len(upsample_img.shape) == 2: # upsample_img is gray image
|
| 367 |
+
upsample_img = upsample_img[:, :, None]
|
| 368 |
+
inv_soft_mask = inv_soft_mask[:, :, None]
|
| 369 |
+
|
| 370 |
+
# parse mask
|
| 371 |
+
#if self.use_parse:
|
| 372 |
+
# if 0:
|
| 373 |
+
# # inference
|
| 374 |
+
# face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
|
| 375 |
+
# face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
|
| 376 |
+
# normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
| 377 |
+
# face_input = torch.unsqueeze(face_input, 0)
|
| 378 |
+
# with torch.no_grad():
|
| 379 |
+
# out = self.face_parse(face_input)[0]
|
| 380 |
+
# out = out.argmax(dim=1).squeeze().cpu().numpy()
|
| 381 |
+
|
| 382 |
+
# parse_mask = np.zeros(out.shape)
|
| 383 |
+
# MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
|
| 384 |
+
# for idx, color in enumerate(MASK_COLORMAP):
|
| 385 |
+
# parse_mask[out == idx] = color
|
| 386 |
+
# # blur the mask
|
| 387 |
+
# parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
|
| 388 |
+
# parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
|
| 389 |
+
# # remove the black borders
|
| 390 |
+
# thres = 10
|
| 391 |
+
# parse_mask[:thres, :] = 0
|
| 392 |
+
# parse_mask[-thres:, :] = 0
|
| 393 |
+
# parse_mask[:, :thres] = 0
|
| 394 |
+
# parse_mask[:, -thres:] = 0
|
| 395 |
+
# parse_mask = parse_mask / 255.
|
| 396 |
+
|
| 397 |
+
# parse_mask = cv2.resize(parse_mask, face_size)
|
| 398 |
+
# parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3)
|
| 399 |
+
# inv_soft_parse_mask = parse_mask[:, :, None]
|
| 400 |
+
# # pasted_face = inv_restored
|
| 401 |
+
# fuse_mask = (inv_soft_parse_mask<inv_soft_mask).astype('int')
|
| 402 |
+
# inv_soft_mask = inv_soft_parse_mask*fuse_mask + inv_soft_mask*(1-fuse_mask)
|
| 403 |
+
|
| 404 |
+
if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
|
| 405 |
+
alpha = upsample_img[:, :, 3:]
|
| 406 |
+
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
|
| 407 |
+
upsample_img = np.concatenate((upsample_img, alpha), axis=2)
|
| 408 |
+
else:
|
| 409 |
+
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
|
| 410 |
+
|
| 411 |
+
if np.max(upsample_img) > 256: # 16-bit image
|
| 412 |
+
upsample_img = upsample_img.astype(np.uint16)
|
| 413 |
+
else:
|
| 414 |
+
upsample_img = upsample_img.astype(np.uint8)
|
| 415 |
+
|
| 416 |
+
# draw bounding box
|
| 417 |
+
if draw_box:
|
| 418 |
+
# upsample_input_img = cv2.resize(input_img, (w_up, h_up))
|
| 419 |
+
img_color = np.ones([*upsample_img.shape], dtype=np.float32)
|
| 420 |
+
img_color[:,:,0] = 0
|
| 421 |
+
img_color[:,:,1] = 255
|
| 422 |
+
img_color[:,:,2] = 0
|
| 423 |
+
for inv_mask_border in inv_mask_borders:
|
| 424 |
+
upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img
|
| 425 |
+
# upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
|
| 426 |
+
|
| 427 |
+
if save_path is not None:
|
| 428 |
+
path = os.path.splitext(save_path)[0]
|
| 429 |
+
save_path = f'{path}.{self.save_ext}'
|
| 430 |
+
imwrite(upsample_img, save_path)
|
| 431 |
+
return upsample_img
|
| 432 |
+
|
| 433 |
+
def init_face_restoration(self):
|
| 434 |
+
session = ort.InferenceSession(self.res_model)
|
| 435 |
+
input_name = session.get_inputs()[0].name
|
| 436 |
+
output_names = [x.name for x in session.get_outputs()]
|
| 437 |
+
return session, input_name, output_names
|
| 438 |
+
|
| 439 |
+
def pre_process(self, img):
|
| 440 |
+
# mod tile_pad for divisible borders
|
| 441 |
+
tile_pad_h, tile_pad_w = 0, 0
|
| 442 |
+
h, w = img.shape[0:2]
|
| 443 |
+
|
| 444 |
+
if h % self.tile != 0:
|
| 445 |
+
tile_pad_h = (self.tile - h % self.tile)
|
| 446 |
+
if w % self.tile != 0:
|
| 447 |
+
tile_pad_w = (self.tile - w % self.tile)
|
| 448 |
+
img = np.pad(img, ((0, tile_pad_h), (0, tile_pad_w), (0, 0)), 'constant') #mode='reflect')
|
| 449 |
+
|
| 450 |
+
# boundary tile_pad
|
| 451 |
+
img = np.pad(img, ((self.tile_pad, self.tile_pad), (self.tile_pad, self.tile_pad), (0, 0)), 'constant')
|
| 452 |
+
|
| 453 |
+
# to CHW-Batch format
|
| 454 |
+
img = (img[..., ::-1] / 255).astype(np.float32)
|
| 455 |
+
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
|
| 456 |
+
|
| 457 |
+
return img
|
| 458 |
+
|
| 459 |
+
def init_background_upsampling(self):
|
| 460 |
+
session = ort.InferenceSession(self.bg_model)
|
| 461 |
+
input_name = session.get_inputs()[0].name
|
| 462 |
+
output_names = [x.name for x in session.get_outputs()]
|
| 463 |
+
return session, input_name, output_names
|
| 464 |
+
|
| 465 |
+
def tile_process(self, img, origin_shape):
|
| 466 |
+
"""It will first crop input images to tiles, and then process each tile.
|
| 467 |
+
Finally, all the processed tiles are merged into one images.
|
| 468 |
+
"""
|
| 469 |
+
# tile
|
| 470 |
+
batch, channel, height, width = img.shape
|
| 471 |
+
output_height = int(round(height * self.scale))
|
| 472 |
+
output_width = int(round(width * self.scale))
|
| 473 |
+
output_shape = (batch, channel, output_height, output_width)
|
| 474 |
+
origin_h, origin_w = origin_shape[0:2]
|
| 475 |
+
|
| 476 |
+
# start with black image
|
| 477 |
+
output = np.zeros(output_shape)
|
| 478 |
+
tiles_x = math.floor(width / self.tile)
|
| 479 |
+
tiles_y = math.floor(height / self.tile)
|
| 480 |
+
#print(f'Tile {tiles_x} x {tiles_y} for image {imgname}')
|
| 481 |
+
|
| 482 |
+
start_tile = int(round(self.tile_pad * self.scale))
|
| 483 |
+
end_tile = int(round(self.tile * self.scale)) + start_tile
|
| 484 |
+
|
| 485 |
+
# loop over all tiles
|
| 486 |
+
for y in range(tiles_y):
|
| 487 |
+
for x in range(tiles_x):
|
| 488 |
+
# extract tile from input image
|
| 489 |
+
ofs_x = x * self.tile
|
| 490 |
+
ofs_y = y * self.tile
|
| 491 |
+
# input tile area on total image
|
| 492 |
+
input_start_x = ofs_x
|
| 493 |
+
input_end_x = min(ofs_x + self.tile, width)
|
| 494 |
+
input_start_y = ofs_y
|
| 495 |
+
input_end_y = min(ofs_y + self.tile, height)
|
| 496 |
+
|
| 497 |
+
# input tile dimensions
|
| 498 |
+
input_tile = img[:, :, input_start_y:(input_end_y+2*self.tile_pad),
|
| 499 |
+
input_start_x:(input_end_x+2*self.tile_pad)]
|
| 500 |
+
|
| 501 |
+
# upscale tile
|
| 502 |
+
try:
|
| 503 |
+
output_tile = self.bg_sessison.run(self.bg_output, {self.bg_input: input_tile})
|
| 504 |
+
except RuntimeError as error:
|
| 505 |
+
print('Error', error)
|
| 506 |
+
|
| 507 |
+
# output tile area on total image
|
| 508 |
+
output_start_x = int(round(input_start_x * self.scale))
|
| 509 |
+
output_end_x = int(round(input_end_x * self.scale))
|
| 510 |
+
output_start_y = int(round(input_start_y * self.scale))
|
| 511 |
+
output_end_y = int(round(input_end_y * self.scale))
|
| 512 |
+
|
| 513 |
+
output[:, :, output_start_y:output_end_y,
|
| 514 |
+
output_start_x:output_end_x] = output_tile[0][:, :, start_tile:end_tile, start_tile:end_tile]
|
| 515 |
+
|
| 516 |
+
# remove extra tile_padding parts
|
| 517 |
+
output = output[:, :, :int(round(origin_h * self.scale)), :int(round(origin_w * self.scale))].squeeze(0)
|
| 518 |
+
output = np.transpose(output[::-1, ...], (1, 2, 0)).astype(np.float32)
|
| 519 |
+
output = np.clip(output*255.0, 0, 255).astype(np.uint8)
|
| 520 |
+
|
| 521 |
+
#resize origin shape
|
| 522 |
+
output = cv2.resize(output, (origin_w, origin_h), interpolation=cv2.INTER_LINEAR)
|
| 523 |
+
|
| 524 |
+
return output
|
| 525 |
+
|
| 526 |
+
def background_upsampling(self, img):
|
| 527 |
+
"""Background upsampling with Real-ESRGAN.
|
| 528 |
+
"""
|
| 529 |
+
# pre-process
|
| 530 |
+
img_input = self.pre_process(img)
|
| 531 |
+
|
| 532 |
+
# tile process
|
| 533 |
+
output = self.tile_process(img_input, img.shape)
|
| 534 |
+
|
| 535 |
+
return output
|
| 536 |
+
|
| 537 |
+
def clean_all(self):
|
| 538 |
+
self.all_landmarks_5 = []
|
| 539 |
+
self.restored_faces = []
|
| 540 |
+
self.affine_matrices = []
|
| 541 |
+
self.cropped_faces = []
|
| 542 |
+
self.inverse_affine_matrices = []
|
| 543 |
+
self.det_faces = []
|
| 544 |
+
self.pad_input_imgs = []
|