File size: 8,516 Bytes
872b1a7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 | """
input: image_path
output: save a masked image and resized image
"""
import os
import sys
import urllib.request
import numpy as np
import torch
import cv2
from PIL import Image
from omegaconf import OmegaConf
from torchvision import transforms
from utils.face_detector import FaceDetector
from pathlib import Path
def generate_crop_bounding_box(h, w, center, size=512):
"""
Crop a region of a specified size from the given center point,
filling the area outside the image boundary with zeros.
:param image: The input image in NumPy array form, shape (H, W, C)
:param center: The center point (y, x) to start cropping from
:param size: The size of the cropped region (default is 512)
:return: The cropped region with padding, shape (size, size, C)
"""
half_size = size // 2 # Half the size for the cropping region
# Calculate the top-left and bottom-right coordinates of the cropping region
y1 = max(center[0] - half_size, 0) # Ensure the y1 index is not less than 0
x1 = max(center[1] - half_size, 0) # Ensure the x1 index is not less than 0
y2 = min(center[0] + half_size, h) # Ensure the y2 index does not exceed the image height
x2 = min(center[1] + half_size, w) # Ensure the x2 index does not exceed the image width
return [x1, y1, x2, y2]
def crop_from_bbox(image, center, bbox, size=512):
"""
Crop a region of a specified size from the given center point,
filling the area outside the image boundary with zeros.
:param image: The input image in NumPy array form, shape (H, W, C)
:param center: The center point (y, x) to start cropping from
:param size: The size of the cropped region (default is 512)
:return: The cropped region with padding, shape (size, size, C)
"""
h, w = image.shape[:2] # Get the height and width of the image
x1, y1, x2, y2 = bbox
half_size = size // 2 # Half the size for the cropping region
# Create a zero-filled array for padding
cropped = np.zeros((size, size, image.shape[2]), dtype=image.dtype)
# Copy the valid region from the original image to the cropped region
cropped[(y1 - (center[0] - half_size)):(y2 - (center[0] - half_size)),
(x1 - (center[1] - half_size)):(x2 - (center[1] - half_size))] = image[y1:y2, x1:x2]
return cropped
face_detector = None
model_path = "./utils/face_landmarker.task"
if not os.path.exists(model_path):
print("Downloading face landmarker model...")
url = "https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task"
urllib.request.urlretrieve(url, model_path)
def initialize_face_detector():
global face_detector
if face_detector is None:
face_detector = FaceDetector(
mediapipe_model_asset_path=model_path,
face_detection_confidence=0.5,
num_faces=1,
)
initialize_face_detector()
def augmentation(images, transform, state=None):
if state is not None:
torch.set_rng_state(state)
if isinstance(images, list):
transformed = [transforms.functional.to_tensor(img) for img in images]
return transform(torch.stack(transformed, dim=0))
return transform(transforms.functional.to_tensor(images))
def scale_bbox(bbox, h, w, scale=1.8):
sw = (bbox[2] - bbox[0]) / 2
sh = (bbox[3] - bbox[1]) / 2
cx = (bbox[0] + bbox[2]) / 2
cy = (bbox[1] + bbox[3]) / 2
sw *= scale
sh *= scale
scaled = [cx - sw, cy - sh, cx + sw, cy + sh]
scaled[0] = np.clip(scaled[0], 0, w)
scaled[2] = np.clip(scaled[2], 0, w)
scaled[1] = np.clip(scaled[1], 0, h)
scaled[3] = np.clip(scaled[3], 0, h)
return scaled
def get_mask(bbox, hd, wd, scale=1.0, return_pil=True):
if min(bbox) < 0:
raise Exception("Invalid mask")
bbox = scale_bbox(bbox, hd, wd, scale=scale)
x0, y0, x1, y1 = [int(v) for v in bbox]
mask = np.zeros((hd, wd, 3), dtype=np.uint8)
mask[y0:y1, x0:x1, :] = 255
if return_pil:
return Image.fromarray(mask)
return mask
def generate_masked_image(
image_path="./test_case/test_img.png",
save_path="./test_case/test_img.png",
crop=False,
union_bbox_scale=1.3):
cfg = OmegaConf.load("./configs/audio_head_animator.yaml")
pixel_transform = transforms.Compose([
transforms.Resize(512, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.Normalize([0.5], [0.5]),
])
resize_transform = transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BICUBIC)
img = Image.open(image_path).convert("RGB")
state = torch.get_rng_state()
# Get face detection results first
det_res = face_detector.get_face_xy_rotation_and_keypoints(
np.array(img), cfg.data.mouth_bbox_scale, cfg.data.eye_bbox_scale
)
person_id = 0
mouth_bbox = np.array(det_res[6][person_id])
eye_bbox = det_res[7][person_id]
face_contour = np.array(det_res[8][person_id])
left_eye_bbox = eye_bbox["left_eye"]
right_eye_bbox = eye_bbox["right_eye"]
# If crop is True, crop the face region first
if crop:
# Get the face bounding box and calculate center
face_bbox = det_res[5][person_id] # Get the face bounding box from det_res[5]
# face_bbox is [(x1, y1), (x2, y2)]
x1, y1 = face_bbox[0]
x2, y2 = face_bbox[1]
center = [(y1 + y2) // 2, (x1 + x2) // 2]
# Calculate the size for cropping
width = x2 - x1
height = y2 - y1
max_size = int(max(width, height) * union_bbox_scale)
# Get the image dimensions
hd, wd = img.size[1], img.size[0]
# Generate the crop bounding box
crop_bbox = generate_crop_bounding_box(hd, wd, center, max_size)
# Crop the image
img_array = np.array(img)
cropped_img = crop_from_bbox(img_array, center, crop_bbox, size=max_size)
img = Image.fromarray(cropped_img)
# Update the face detection results for the cropped image
det_res = face_detector.get_face_xy_rotation_and_keypoints(
cropped_img, cfg.data.mouth_bbox_scale, cfg.data.eye_bbox_scale
)
mouth_bbox = np.array(det_res[6][person_id])
eye_bbox = det_res[7][person_id]
face_contour = np.array(det_res[8][person_id])
left_eye_bbox = eye_bbox["left_eye"]
right_eye_bbox = eye_bbox["right_eye"]
pixel_values_ref = augmentation([img], pixel_transform, state)
pixel_values_ref = (pixel_values_ref + 1) / 2
new_hd, new_wd = img.size[1], img.size[0]
mouth_mask = resize_transform(get_mask(mouth_bbox, new_hd, new_wd, scale=1.0))
left_eye_mask = resize_transform(get_mask(left_eye_bbox, new_hd, new_wd, scale=1.0))
right_eye_mask = resize_transform(get_mask(right_eye_bbox, new_hd, new_wd, scale=1.0))
face_contour = resize_transform(Image.fromarray(face_contour))
eye_mask = np.bitwise_or(np.array(left_eye_mask), np.array(right_eye_mask))
combined_mask = np.bitwise_or(eye_mask, np.array(mouth_mask))
combined_mask_tensor = torch.from_numpy(combined_mask / 255.0).permute(2, 0, 1).unsqueeze(0)
face_contour_tensor = torch.from_numpy(np.array(face_contour) / 255.0).permute(2, 0, 1).unsqueeze(0)
masked_ref = pixel_values_ref * combined_mask_tensor + face_contour_tensor * (1 - combined_mask_tensor)
masked_ref = masked_ref.clamp(0, 1)
masked_ref_np = (masked_ref.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
base, _ = os.path.splitext(save_path)
resized_img = (pixel_values_ref.squeeze(0).permute(1, 2, 0).cpu().numpy().clip(0, 1) * 255).astype(np.uint8)
Image.fromarray(resized_img).save(f"{base}_resize.png")
Image.fromarray(masked_ref_np).save(f"{base}_masked.png")
if __name__ == '__main__':
import fire
fire.Fire(generate_masked_image)
# python img_to_mask.py --image_path /mnt/weka/haiyang_workspace/ckpts/good_train_case/image_example/KristiNoem2-Scene-001.png --save_path /mnt/weka/haiyang_workspace/ckpts/good_train_case/image_example/KristiNoem2-Scene-001.png --crop True --union_bbox_scale 1.6
# python img_to_latent.py --mask_image_path ./test_case/ChrisVanHollen0-Scene-003_masked.png --save_npz_path ./test_case/ChrisVanHollen0-Scene-003_resize.npz
# python latent_two_video.py --npz_path ./test_case/ChrisVanHollen0-Scene-003_resize.npz --save_dir ./test_case/ |