|
|
""" |
|
|
input: image_path |
|
|
output: save a masked image and resized image |
|
|
""" |
|
|
import os |
|
|
import sys |
|
|
import urllib.request |
|
|
import numpy as np |
|
|
import torch |
|
|
import cv2 |
|
|
from PIL import Image |
|
|
from omegaconf import OmegaConf |
|
|
from torchvision import transforms |
|
|
from utils.face_detector import FaceDetector |
|
|
from pathlib import Path |
|
|
|
|
|
def generate_crop_bounding_box(h, w, center, size=512): |
|
|
""" |
|
|
Crop a region of a specified size from the given center point, |
|
|
filling the area outside the image boundary with zeros. |
|
|
|
|
|
:param image: The input image in NumPy array form, shape (H, W, C) |
|
|
:param center: The center point (y, x) to start cropping from |
|
|
:param size: The size of the cropped region (default is 512) |
|
|
:return: The cropped region with padding, shape (size, size, C) |
|
|
""" |
|
|
half_size = size // 2 |
|
|
|
|
|
|
|
|
y1 = max(center[0] - half_size, 0) |
|
|
x1 = max(center[1] - half_size, 0) |
|
|
y2 = min(center[0] + half_size, h) |
|
|
x2 = min(center[1] + half_size, w) |
|
|
return [x1, y1, x2, y2] |
|
|
|
|
|
def crop_from_bbox(image, center, bbox, size=512): |
|
|
""" |
|
|
Crop a region of a specified size from the given center point, |
|
|
filling the area outside the image boundary with zeros. |
|
|
|
|
|
:param image: The input image in NumPy array form, shape (H, W, C) |
|
|
:param center: The center point (y, x) to start cropping from |
|
|
:param size: The size of the cropped region (default is 512) |
|
|
:return: The cropped region with padding, shape (size, size, C) |
|
|
""" |
|
|
h, w = image.shape[:2] |
|
|
x1, y1, x2, y2 = bbox |
|
|
half_size = size // 2 |
|
|
|
|
|
cropped = np.zeros((size, size, image.shape[2]), dtype=image.dtype) |
|
|
|
|
|
|
|
|
cropped[(y1 - (center[0] - half_size)):(y2 - (center[0] - half_size)), |
|
|
(x1 - (center[1] - half_size)):(x2 - (center[1] - half_size))] = image[y1:y2, x1:x2] |
|
|
|
|
|
return cropped |
|
|
|
|
|
face_detector = None |
|
|
model_path = "./utils/face_landmarker.task" |
|
|
if not os.path.exists(model_path): |
|
|
print("Downloading face landmarker model...") |
|
|
url = "https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task" |
|
|
urllib.request.urlretrieve(url, model_path) |
|
|
|
|
|
def initialize_face_detector(): |
|
|
global face_detector |
|
|
if face_detector is None: |
|
|
face_detector = FaceDetector( |
|
|
mediapipe_model_asset_path=model_path, |
|
|
face_detection_confidence=0.5, |
|
|
num_faces=1, |
|
|
) |
|
|
initialize_face_detector() |
|
|
|
|
|
def augmentation(images, transform, state=None): |
|
|
if state is not None: |
|
|
torch.set_rng_state(state) |
|
|
if isinstance(images, list): |
|
|
transformed = [transforms.functional.to_tensor(img) for img in images] |
|
|
return transform(torch.stack(transformed, dim=0)) |
|
|
return transform(transforms.functional.to_tensor(images)) |
|
|
|
|
|
def scale_bbox(bbox, h, w, scale=1.8): |
|
|
sw = (bbox[2] - bbox[0]) / 2 |
|
|
sh = (bbox[3] - bbox[1]) / 2 |
|
|
cx = (bbox[0] + bbox[2]) / 2 |
|
|
cy = (bbox[1] + bbox[3]) / 2 |
|
|
sw *= scale |
|
|
sh *= scale |
|
|
scaled = [cx - sw, cy - sh, cx + sw, cy + sh] |
|
|
scaled[0] = np.clip(scaled[0], 0, w) |
|
|
scaled[2] = np.clip(scaled[2], 0, w) |
|
|
scaled[1] = np.clip(scaled[1], 0, h) |
|
|
scaled[3] = np.clip(scaled[3], 0, h) |
|
|
return scaled |
|
|
|
|
|
def get_mask(bbox, hd, wd, scale=1.0, return_pil=True): |
|
|
if min(bbox) < 0: |
|
|
raise Exception("Invalid mask") |
|
|
bbox = scale_bbox(bbox, hd, wd, scale=scale) |
|
|
x0, y0, x1, y1 = [int(v) for v in bbox] |
|
|
mask = np.zeros((hd, wd, 3), dtype=np.uint8) |
|
|
mask[y0:y1, x0:x1, :] = 255 |
|
|
if return_pil: |
|
|
return Image.fromarray(mask) |
|
|
return mask |
|
|
|
|
|
def generate_masked_image( |
|
|
image_path="./test_case/test_img.png", |
|
|
save_path="./test_case/test_img.png", |
|
|
crop=False, |
|
|
union_bbox_scale=1.3): |
|
|
cfg = OmegaConf.load("./configs/audio_head_animator.yaml") |
|
|
pixel_transform = transforms.Compose([ |
|
|
transforms.Resize(512, interpolation=transforms.InterpolationMode.BICUBIC), |
|
|
transforms.Normalize([0.5], [0.5]), |
|
|
]) |
|
|
resize_transform = transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BICUBIC) |
|
|
|
|
|
img = Image.open(image_path).convert("RGB") |
|
|
state = torch.get_rng_state() |
|
|
|
|
|
|
|
|
det_res = face_detector.get_face_xy_rotation_and_keypoints( |
|
|
np.array(img), cfg.data.mouth_bbox_scale, cfg.data.eye_bbox_scale |
|
|
) |
|
|
|
|
|
person_id = 0 |
|
|
mouth_bbox = np.array(det_res[6][person_id]) |
|
|
eye_bbox = det_res[7][person_id] |
|
|
face_contour = np.array(det_res[8][person_id]) |
|
|
left_eye_bbox = eye_bbox["left_eye"] |
|
|
right_eye_bbox = eye_bbox["right_eye"] |
|
|
|
|
|
|
|
|
if crop: |
|
|
|
|
|
face_bbox = det_res[5][person_id] |
|
|
|
|
|
x1, y1 = face_bbox[0] |
|
|
x2, y2 = face_bbox[1] |
|
|
center = [(y1 + y2) // 2, (x1 + x2) // 2] |
|
|
|
|
|
|
|
|
width = x2 - x1 |
|
|
height = y2 - y1 |
|
|
max_size = int(max(width, height) * union_bbox_scale) |
|
|
|
|
|
|
|
|
hd, wd = img.size[1], img.size[0] |
|
|
|
|
|
|
|
|
crop_bbox = generate_crop_bounding_box(hd, wd, center, max_size) |
|
|
|
|
|
|
|
|
img_array = np.array(img) |
|
|
cropped_img = crop_from_bbox(img_array, center, crop_bbox, size=max_size) |
|
|
img = Image.fromarray(cropped_img) |
|
|
|
|
|
|
|
|
det_res = face_detector.get_face_xy_rotation_and_keypoints( |
|
|
cropped_img, cfg.data.mouth_bbox_scale, cfg.data.eye_bbox_scale |
|
|
) |
|
|
mouth_bbox = np.array(det_res[6][person_id]) |
|
|
eye_bbox = det_res[7][person_id] |
|
|
face_contour = np.array(det_res[8][person_id]) |
|
|
left_eye_bbox = eye_bbox["left_eye"] |
|
|
right_eye_bbox = eye_bbox["right_eye"] |
|
|
|
|
|
pixel_values_ref = augmentation([img], pixel_transform, state) |
|
|
pixel_values_ref = (pixel_values_ref + 1) / 2 |
|
|
new_hd, new_wd = img.size[1], img.size[0] |
|
|
|
|
|
mouth_mask = resize_transform(get_mask(mouth_bbox, new_hd, new_wd, scale=1.0)) |
|
|
left_eye_mask = resize_transform(get_mask(left_eye_bbox, new_hd, new_wd, scale=1.0)) |
|
|
right_eye_mask = resize_transform(get_mask(right_eye_bbox, new_hd, new_wd, scale=1.0)) |
|
|
face_contour = resize_transform(Image.fromarray(face_contour)) |
|
|
|
|
|
eye_mask = np.bitwise_or(np.array(left_eye_mask), np.array(right_eye_mask)) |
|
|
combined_mask = np.bitwise_or(eye_mask, np.array(mouth_mask)) |
|
|
|
|
|
combined_mask_tensor = torch.from_numpy(combined_mask / 255.0).permute(2, 0, 1).unsqueeze(0) |
|
|
face_contour_tensor = torch.from_numpy(np.array(face_contour) / 255.0).permute(2, 0, 1).unsqueeze(0) |
|
|
|
|
|
masked_ref = pixel_values_ref * combined_mask_tensor + face_contour_tensor * (1 - combined_mask_tensor) |
|
|
masked_ref = masked_ref.clamp(0, 1) |
|
|
masked_ref_np = (masked_ref.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) |
|
|
|
|
|
base, _ = os.path.splitext(save_path) |
|
|
resized_img = (pixel_values_ref.squeeze(0).permute(1, 2, 0).cpu().numpy().clip(0, 1) * 255).astype(np.uint8) |
|
|
Image.fromarray(resized_img).save(f"{base}_resize.png") |
|
|
Image.fromarray(masked_ref_np).save(f"{base}_masked.png") |
|
|
|
|
|
if __name__ == '__main__': |
|
|
import fire |
|
|
fire.Fire(generate_masked_image) |
|
|
|
|
|
|
|
|
|