| | |
| | |
| |
|
| | import torch |
| | import numpy as np |
| | import skimage.io as io |
| |
|
| | |
| | import matplotlib.pyplot as plt |
| | from matplotlib.patches import Rectangle |
| | from skimage.transform import SimilarityTransform |
| | from skimage.transform import warp |
| | from PIL import Image, ImageFilter |
| | import torch.nn.functional as F |
| | import torchvision as tv |
| | import torchvision.utils as vutils |
| | import time |
| | import cv2 |
| | import os |
| | from skimage import img_as_ubyte |
| | import json |
| | import argparse |
| | import dlib |
| |
|
| |
|
| | def calculate_cdf(histogram): |
| | """ |
| | This method calculates the cumulative distribution function |
| | :param array histogram: The values of the histogram |
| | :return: normalized_cdf: The normalized cumulative distribution function |
| | :rtype: array |
| | """ |
| | |
| | cdf = histogram.cumsum() |
| |
|
| | |
| | normalized_cdf = cdf / float(cdf.max()) |
| |
|
| | return normalized_cdf |
| |
|
| |
|
| | def calculate_lookup(src_cdf, ref_cdf): |
| | """ |
| | This method creates the lookup table |
| | :param array src_cdf: The cdf for the source image |
| | :param array ref_cdf: The cdf for the reference image |
| | :return: lookup_table: The lookup table |
| | :rtype: array |
| | """ |
| | lookup_table = np.zeros(256) |
| | lookup_val = 0 |
| | for src_pixel_val in range(len(src_cdf)): |
| | lookup_val |
| | for ref_pixel_val in range(len(ref_cdf)): |
| | if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]: |
| | lookup_val = ref_pixel_val |
| | break |
| | lookup_table[src_pixel_val] = lookup_val |
| | return lookup_table |
| |
|
| |
|
| | def match_histograms(src_image, ref_image): |
| | """ |
| | This method matches the source image histogram to the |
| | reference signal |
| | :param image src_image: The original source image |
| | :param image ref_image: The reference image |
| | :return: image_after_matching |
| | :rtype: image (array) |
| | """ |
| | |
| | |
| | src_b, src_g, src_r = cv2.split(src_image) |
| | ref_b, ref_g, ref_r = cv2.split(ref_image) |
| |
|
| | |
| | |
| | |
| | src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256]) |
| | src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256]) |
| | src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256]) |
| | ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256]) |
| | ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256]) |
| | ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256]) |
| |
|
| | |
| | src_cdf_blue = calculate_cdf(src_hist_blue) |
| | src_cdf_green = calculate_cdf(src_hist_green) |
| | src_cdf_red = calculate_cdf(src_hist_red) |
| | ref_cdf_blue = calculate_cdf(ref_hist_blue) |
| | ref_cdf_green = calculate_cdf(ref_hist_green) |
| | ref_cdf_red = calculate_cdf(ref_hist_red) |
| |
|
| | |
| | blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue) |
| | green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green) |
| | red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red) |
| |
|
| | |
| | |
| | blue_after_transform = cv2.LUT(src_b, blue_lookup_table) |
| | green_after_transform = cv2.LUT(src_g, green_lookup_table) |
| | red_after_transform = cv2.LUT(src_r, red_lookup_table) |
| |
|
| | |
| | image_after_matching = cv2.merge([blue_after_transform, green_after_transform, red_after_transform]) |
| | image_after_matching = cv2.convertScaleAbs(image_after_matching) |
| |
|
| | return image_after_matching |
| |
|
| |
|
| | def _standard_face_pts(): |
| | pts = ( |
| | np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0 |
| | - 1.0 |
| | ) |
| |
|
| | return np.reshape(pts, (5, 2)) |
| |
|
| |
|
| | def _origin_face_pts(): |
| | pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) |
| |
|
| | return np.reshape(pts, (5, 2)) |
| |
|
| |
|
| | def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0): |
| |
|
| | std_pts = _standard_face_pts() |
| | target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0 |
| |
|
| | |
| |
|
| | h, w, c = img.shape |
| | if normalize == True: |
| | landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0 |
| | landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0 |
| |
|
| | |
| |
|
| | affine = SimilarityTransform() |
| |
|
| | affine.estimate(target_pts, landmark) |
| |
|
| | return affine |
| |
|
| |
|
| | def compute_inverse_transformation_matrix(img, landmark, normalize, target_face_scale=1.0): |
| |
|
| | std_pts = _standard_face_pts() |
| | target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0 |
| |
|
| | |
| |
|
| | h, w, c = img.shape |
| | if normalize == True: |
| | landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0 |
| | landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0 |
| |
|
| | |
| |
|
| | affine = SimilarityTransform() |
| |
|
| | affine.estimate(landmark, target_pts) |
| |
|
| | return affine |
| |
|
| |
|
| | def show_detection(image, box, landmark): |
| | plt.imshow(image) |
| | print(box[2] - box[0]) |
| | plt.gca().add_patch( |
| | Rectangle( |
| | (box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none" |
| | ) |
| | ) |
| | plt.scatter(landmark[0][0], landmark[0][1]) |
| | plt.scatter(landmark[1][0], landmark[1][1]) |
| | plt.scatter(landmark[2][0], landmark[2][1]) |
| | plt.scatter(landmark[3][0], landmark[3][1]) |
| | plt.scatter(landmark[4][0], landmark[4][1]) |
| | plt.show() |
| |
|
| |
|
| | def affine2theta(affine, input_w, input_h, target_w, target_h): |
| | |
| | param = affine |
| | theta = np.zeros([2, 3]) |
| | theta[0, 0] = param[0, 0] * input_h / target_h |
| | theta[0, 1] = param[0, 1] * input_w / target_h |
| | theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1 |
| | theta[1, 0] = param[1, 0] * input_h / target_w |
| | theta[1, 1] = param[1, 1] * input_w / target_w |
| | theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1 |
| | return theta |
| |
|
| |
|
| | def blur_blending(im1, im2, mask): |
| |
|
| | mask *= 255.0 |
| |
|
| | kernel = np.ones((10, 10), np.uint8) |
| | mask = cv2.erode(mask, kernel, iterations=1) |
| |
|
| | mask = Image.fromarray(mask.astype("uint8")).convert("L") |
| | im1 = Image.fromarray(im1.astype("uint8")) |
| | im2 = Image.fromarray(im2.astype("uint8")) |
| |
|
| | mask_blur = mask.filter(ImageFilter.GaussianBlur(20)) |
| | im = Image.composite(im1, im2, mask) |
| |
|
| | im = Image.composite(im, im2, mask_blur) |
| |
|
| | return np.array(im) / 255.0 |
| |
|
| |
|
| | def blur_blending_cv2(im1, im2, mask): |
| |
|
| | mask *= 255.0 |
| |
|
| | kernel = np.ones((9, 9), np.uint8) |
| | mask = cv2.erode(mask, kernel, iterations=3) |
| |
|
| | mask_blur = cv2.GaussianBlur(mask, (25, 25), 0) |
| | mask_blur /= 255.0 |
| |
|
| | im = im1 * mask_blur + (1 - mask_blur) * im2 |
| |
|
| | im /= 255.0 |
| | im = np.clip(im, 0.0, 1.0) |
| |
|
| | return im |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| | def Poisson_blending(im1, im2, mask): |
| |
|
| | |
| | mask *= 255 |
| | kernel = np.ones((10, 10), np.uint8) |
| | mask = cv2.erode(mask, kernel, iterations=1) |
| | mask /= 255 |
| | mask = 1 - mask |
| | mask *= 255 |
| |
|
| | mask = mask[:, :, 0] |
| | width, height, channels = im1.shape |
| | center = (int(height / 2), int(width / 2)) |
| | result = cv2.seamlessClone( |
| | im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.MIXED_CLONE |
| | ) |
| |
|
| | return result / 255.0 |
| |
|
| |
|
| | def Poisson_B(im1, im2, mask, center): |
| |
|
| | mask *= 255 |
| |
|
| | result = cv2.seamlessClone( |
| | im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.NORMAL_CLONE |
| | ) |
| |
|
| | return result / 255 |
| |
|
| |
|
| | def seamless_clone(old_face, new_face, raw_mask): |
| |
|
| | height, width, _ = old_face.shape |
| | height = height // 2 |
| | width = width // 2 |
| |
|
| | y_indices, x_indices, _ = np.nonzero(raw_mask) |
| | y_crop = slice(np.min(y_indices), np.max(y_indices)) |
| | x_crop = slice(np.min(x_indices), np.max(x_indices)) |
| | y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2 + height)) |
| | x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2 + width)) |
| |
|
| | insertion = np.rint(new_face[y_crop, x_crop] * 255.0).astype("uint8") |
| | insertion_mask = np.rint(raw_mask[y_crop, x_crop] * 255.0).astype("uint8") |
| | insertion_mask[insertion_mask != 0] = 255 |
| | prior = np.rint(np.pad(old_face * 255.0, ((height, height), (width, width), (0, 0)), "constant")).astype( |
| | "uint8" |
| | ) |
| | |
| | n_mask = insertion_mask[1:-1, 1:-1, :] |
| | n_mask = cv2.copyMakeBorder(n_mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0) |
| | print(n_mask.shape) |
| | x, y, w, h = cv2.boundingRect(n_mask[:, :, 0]) |
| | if w < 4 or h < 4: |
| | blended = prior |
| | else: |
| | blended = cv2.seamlessClone( |
| | insertion, |
| | prior, |
| | insertion_mask, |
| | (x_center, y_center), |
| | cv2.NORMAL_CLONE, |
| | ) |
| |
|
| | blended = blended[height:-height, width:-width] |
| |
|
| | return blended.astype("float32") / 255.0 |
| |
|
| |
|
| | def get_landmark(face_landmarks, id): |
| | part = face_landmarks.part(id) |
| | x = part.x |
| | y = part.y |
| |
|
| | return (x, y) |
| |
|
| |
|
| | def search(face_landmarks): |
| |
|
| | x1, y1 = get_landmark(face_landmarks, 36) |
| | x2, y2 = get_landmark(face_landmarks, 39) |
| | x3, y3 = get_landmark(face_landmarks, 42) |
| | x4, y4 = get_landmark(face_landmarks, 45) |
| |
|
| | x_nose, y_nose = get_landmark(face_landmarks, 30) |
| |
|
| | x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48) |
| | x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54) |
| |
|
| | x_left_eye = int((x1 + x2) / 2) |
| | y_left_eye = int((y1 + y2) / 2) |
| | x_right_eye = int((x3 + x4) / 2) |
| | y_right_eye = int((y3 + y4) / 2) |
| |
|
| | results = np.array( |
| | [ |
| | [x_left_eye, y_left_eye], |
| | [x_right_eye, y_right_eye], |
| | [x_nose, y_nose], |
| | [x_left_mouth, y_left_mouth], |
| | [x_right_mouth, y_right_mouth], |
| | ] |
| | ) |
| |
|
| | return results |
| |
|
| |
|
| | if __name__ == "__main__": |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--origin_url", type=str, default="./", help="origin images") |
| | parser.add_argument("--replace_url", type=str, default="./", help="restored faces") |
| | parser.add_argument("--save_url", type=str, default="./save") |
| | opts = parser.parse_args() |
| |
|
| | origin_url = opts.origin_url |
| | replace_url = opts.replace_url |
| | save_url = opts.save_url |
| |
|
| | if not os.path.exists(save_url): |
| | os.makedirs(save_url) |
| |
|
| | face_detector = dlib.get_frontal_face_detector() |
| | landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") |
| |
|
| | count = 0 |
| |
|
| | for x in os.listdir(origin_url): |
| | img_url = os.path.join(origin_url, x) |
| | pil_img = Image.open(img_url).convert("RGB") |
| |
|
| | origin_width, origin_height = pil_img.size |
| | image = np.array(pil_img) |
| |
|
| | start = time.time() |
| | faces = face_detector(image) |
| | done = time.time() |
| |
|
| | if len(faces) == 0: |
| | print("Warning: There is no face in %s" % (x)) |
| | continue |
| |
|
| | blended = image |
| | for face_id in range(len(faces)): |
| |
|
| | current_face = faces[face_id] |
| | face_landmarks = landmark_locator(image, current_face) |
| | current_fl = search(face_landmarks) |
| |
|
| | forward_mask = np.ones_like(image).astype("uint8") |
| | affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3) |
| | aligned_face = warp(image, affine, output_shape=(256, 256, 3), preserve_range=True) |
| | forward_mask = warp( |
| | forward_mask, affine, output_shape=(256, 256, 3), order=0, preserve_range=True |
| | ) |
| |
|
| | affine_inverse = affine.inverse |
| | cur_face = aligned_face |
| | if replace_url != "": |
| |
|
| | face_name = x[:-4] + "_" + str(face_id + 1) + ".png" |
| | cur_url = os.path.join(replace_url, face_name) |
| | restored_face = Image.open(cur_url).convert("RGB") |
| | restored_face = np.array(restored_face) |
| | cur_face = restored_face |
| |
|
| | |
| | A = cv2.cvtColor(aligned_face.astype("uint8"), cv2.COLOR_RGB2BGR) |
| | B = cv2.cvtColor(cur_face.astype("uint8"), cv2.COLOR_RGB2BGR) |
| | B = match_histograms(B, A) |
| | cur_face = cv2.cvtColor(B.astype("uint8"), cv2.COLOR_BGR2RGB) |
| |
|
| | warped_back = warp( |
| | cur_face, |
| | affine_inverse, |
| | output_shape=(origin_height, origin_width, 3), |
| | order=3, |
| | preserve_range=True, |
| | ) |
| |
|
| | backward_mask = warp( |
| | forward_mask, |
| | affine_inverse, |
| | output_shape=(origin_height, origin_width, 3), |
| | order=0, |
| | preserve_range=True, |
| | ) |
| |
|
| | blended = blur_blending_cv2(warped_back, blended, backward_mask) |
| | blended *= 255.0 |
| |
|
| | io.imsave(os.path.join(save_url, x), img_as_ubyte(blended / 255.0)) |
| |
|
| | count += 1 |
| |
|
| | if count % 1000 == 0: |
| | print("%d have finished ..." % (count)) |
| |
|
| |
|