""" input: image_path output: save a masked image and resized image """ import os import sys import urllib.request import numpy as np import torch import cv2 from PIL import Image from omegaconf import OmegaConf from torchvision import transforms from utils.face_detector import FaceDetector from pathlib import Path def generate_crop_bounding_box(h, w, center, size=512): """ Crop a region of a specified size from the given center point, filling the area outside the image boundary with zeros. :param image: The input image in NumPy array form, shape (H, W, C) :param center: The center point (y, x) to start cropping from :param size: The size of the cropped region (default is 512) :return: The cropped region with padding, shape (size, size, C) """ half_size = size // 2 # Half the size for the cropping region # Calculate the top-left and bottom-right coordinates of the cropping region y1 = max(center[0] - half_size, 0) # Ensure the y1 index is not less than 0 x1 = max(center[1] - half_size, 0) # Ensure the x1 index is not less than 0 y2 = min(center[0] + half_size, h) # Ensure the y2 index does not exceed the image height x2 = min(center[1] + half_size, w) # Ensure the x2 index does not exceed the image width return [x1, y1, x2, y2] def crop_from_bbox(image, center, bbox, size=512): """ Crop a region of a specified size from the given center point, filling the area outside the image boundary with zeros. :param image: The input image in NumPy array form, shape (H, W, C) :param center: The center point (y, x) to start cropping from :param size: The size of the cropped region (default is 512) :return: The cropped region with padding, shape (size, size, C) """ h, w = image.shape[:2] # Get the height and width of the image x1, y1, x2, y2 = bbox half_size = size // 2 # Half the size for the cropping region # Create a zero-filled array for padding cropped = np.zeros((size, size, image.shape[2]), dtype=image.dtype) # Copy the valid region from the original image to the cropped region cropped[(y1 - (center[0] - half_size)):(y2 - (center[0] - half_size)), (x1 - (center[1] - half_size)):(x2 - (center[1] - half_size))] = image[y1:y2, x1:x2] return cropped face_detector = None model_path = "./utils/face_landmarker.task" if not os.path.exists(model_path): print("Downloading face landmarker model...") url = "https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task" urllib.request.urlretrieve(url, model_path) def initialize_face_detector(): global face_detector if face_detector is None: face_detector = FaceDetector( mediapipe_model_asset_path=model_path, face_detection_confidence=0.5, num_faces=1, ) initialize_face_detector() def augmentation(images, transform, state=None): if state is not None: torch.set_rng_state(state) if isinstance(images, list): transformed = [transforms.functional.to_tensor(img) for img in images] return transform(torch.stack(transformed, dim=0)) return transform(transforms.functional.to_tensor(images)) def scale_bbox(bbox, h, w, scale=1.8): sw = (bbox[2] - bbox[0]) / 2 sh = (bbox[3] - bbox[1]) / 2 cx = (bbox[0] + bbox[2]) / 2 cy = (bbox[1] + bbox[3]) / 2 sw *= scale sh *= scale scaled = [cx - sw, cy - sh, cx + sw, cy + sh] scaled[0] = np.clip(scaled[0], 0, w) scaled[2] = np.clip(scaled[2], 0, w) scaled[1] = np.clip(scaled[1], 0, h) scaled[3] = np.clip(scaled[3], 0, h) return scaled def get_mask(bbox, hd, wd, scale=1.0, return_pil=True): if min(bbox) < 0: raise Exception("Invalid mask") bbox = scale_bbox(bbox, hd, wd, scale=scale) x0, y0, x1, y1 = [int(v) for v in bbox] mask = np.zeros((hd, wd, 3), dtype=np.uint8) mask[y0:y1, x0:x1, :] = 255 if return_pil: return Image.fromarray(mask) return mask def generate_masked_image( image_path="./test_case/test_img.png", save_path="./test_case/test_img.png", crop=False, union_bbox_scale=1.3): cfg = OmegaConf.load("./configs/audio_head_animator.yaml") pixel_transform = transforms.Compose([ transforms.Resize(512, interpolation=transforms.InterpolationMode.BICUBIC), transforms.Normalize([0.5], [0.5]), ]) resize_transform = transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BICUBIC) img = Image.open(image_path).convert("RGB") state = torch.get_rng_state() # Get face detection results first det_res = face_detector.get_face_xy_rotation_and_keypoints( np.array(img), cfg.data.mouth_bbox_scale, cfg.data.eye_bbox_scale ) person_id = 0 mouth_bbox = np.array(det_res[6][person_id]) eye_bbox = det_res[7][person_id] face_contour = np.array(det_res[8][person_id]) left_eye_bbox = eye_bbox["left_eye"] right_eye_bbox = eye_bbox["right_eye"] # If crop is True, crop the face region first if crop: # Get the face bounding box and calculate center face_bbox = det_res[5][person_id] # Get the face bounding box from det_res[5] # face_bbox is [(x1, y1), (x2, y2)] x1, y1 = face_bbox[0] x2, y2 = face_bbox[1] center = [(y1 + y2) // 2, (x1 + x2) // 2] # Calculate the size for cropping width = x2 - x1 height = y2 - y1 max_size = int(max(width, height) * union_bbox_scale) # Get the image dimensions hd, wd = img.size[1], img.size[0] # Generate the crop bounding box crop_bbox = generate_crop_bounding_box(hd, wd, center, max_size) # Crop the image img_array = np.array(img) cropped_img = crop_from_bbox(img_array, center, crop_bbox, size=max_size) img = Image.fromarray(cropped_img) # Update the face detection results for the cropped image det_res = face_detector.get_face_xy_rotation_and_keypoints( cropped_img, cfg.data.mouth_bbox_scale, cfg.data.eye_bbox_scale ) mouth_bbox = np.array(det_res[6][person_id]) eye_bbox = det_res[7][person_id] face_contour = np.array(det_res[8][person_id]) left_eye_bbox = eye_bbox["left_eye"] right_eye_bbox = eye_bbox["right_eye"] pixel_values_ref = augmentation([img], pixel_transform, state) pixel_values_ref = (pixel_values_ref + 1) / 2 new_hd, new_wd = img.size[1], img.size[0] mouth_mask = resize_transform(get_mask(mouth_bbox, new_hd, new_wd, scale=1.0)) left_eye_mask = resize_transform(get_mask(left_eye_bbox, new_hd, new_wd, scale=1.0)) right_eye_mask = resize_transform(get_mask(right_eye_bbox, new_hd, new_wd, scale=1.0)) face_contour = resize_transform(Image.fromarray(face_contour)) eye_mask = np.bitwise_or(np.array(left_eye_mask), np.array(right_eye_mask)) combined_mask = np.bitwise_or(eye_mask, np.array(mouth_mask)) combined_mask_tensor = torch.from_numpy(combined_mask / 255.0).permute(2, 0, 1).unsqueeze(0) face_contour_tensor = torch.from_numpy(np.array(face_contour) / 255.0).permute(2, 0, 1).unsqueeze(0) masked_ref = pixel_values_ref * combined_mask_tensor + face_contour_tensor * (1 - combined_mask_tensor) masked_ref = masked_ref.clamp(0, 1) masked_ref_np = (masked_ref.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) base, _ = os.path.splitext(save_path) resized_img = (pixel_values_ref.squeeze(0).permute(1, 2, 0).cpu().numpy().clip(0, 1) * 255).astype(np.uint8) Image.fromarray(resized_img).save(f"{base}_resize.png") Image.fromarray(masked_ref_np).save(f"{base}_masked.png") if __name__ == '__main__': import fire fire.Fire(generate_masked_image) # python img_to_mask.py --image_path /mnt/weka/haiyang_workspace/ckpts/good_train_case/image_example/KristiNoem2-Scene-001.png --save_path /mnt/weka/haiyang_workspace/ckpts/good_train_case/image_example/KristiNoem2-Scene-001.png --crop True --union_bbox_scale 1.6 # python img_to_latent.py --mask_image_path ./test_case/ChrisVanHollen0-Scene-003_masked.png --save_npz_path ./test_case/ChrisVanHollen0-Scene-003_resize.npz # python latent_two_video.py --npz_path ./test_case/ChrisVanHollen0-Scene-003_resize.npz --save_dir ./test_case/