| import os |
| import torch |
| from tqdm import tqdm |
| import numpy as np |
| import folder_paths |
| import cv2 |
| import json |
| import logging |
| script_directory = os.path.dirname(os.path.abspath(__file__)) |
|
|
| from comfy import model_management as mm |
| from comfy.utils import ProgressBar |
| device = mm.get_torch_device() |
| offload_device = mm.unet_offload_device() |
|
|
| folder_paths.add_model_folder_path("detection", os.path.join(folder_paths.models_dir, "detection")) |
|
|
| from .models.onnx_models import ViTPose, Yolo |
| from .pose_utils.pose2d_utils import load_pose_metas_from_kp2ds_seq, crop, bbox_from_detector |
| from .utils import get_face_bboxes, padding_resize, resize_by_area, resize_to_bounds |
| from .pose_utils.human_visualization import AAPoseMeta, draw_aapose_by_meta_new |
| from .retarget_pose import get_retarget_pose |
|
|
| class OnnxDetectionModelLoader: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "vitpose_model": (folder_paths.get_filename_list("detection"), {"tooltip": "These models are loaded from the 'ComfyUI/models/detection' -folder",}), |
| "yolo_model": (folder_paths.get_filename_list("detection"), {"tooltip": "These models are loaded from the 'ComfyUI/models/detection' -folder",}), |
| "onnx_device": (["CUDAExecutionProvider", "CPUExecutionProvider"], {"default": "CUDAExecutionProvider", "tooltip": "Device to run the ONNX models on"}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("POSEMODEL",) |
| RETURN_NAMES = ("model", ) |
| FUNCTION = "loadmodel" |
| CATEGORY = "WanAnimatePreprocess" |
| DESCRIPTION = "Loads ONNX models for pose and face detection. ViTPose for pose estimation and YOLO for object detection." |
|
|
| def loadmodel(self, vitpose_model, yolo_model, onnx_device): |
|
|
| vitpose_model_path = folder_paths.get_full_path_or_raise("detection", vitpose_model) |
| yolo_model_path = folder_paths.get_full_path_or_raise("detection", yolo_model) |
|
|
| vitpose = ViTPose(vitpose_model_path, onnx_device) |
| yolo = Yolo(yolo_model_path, onnx_device) |
|
|
| model = { |
| "vitpose": vitpose, |
| "yolo": yolo, |
| } |
|
|
| return (model, ) |
|
|
| class PoseAndFaceDetection: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "model": ("POSEMODEL",), |
| "images": ("IMAGE",), |
| "width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 1, "tooltip": "Width of the generation"}), |
| "height": ("INT", {"default": 480, "min": 64, "max": 2048, "step": 1, "tooltip": "Height of the generation"}), |
| }, |
| "optional": { |
| "retarget_image": ("IMAGE", {"default": None, "tooltip": "Optional reference image for pose retargeting"}), |
| "face_padding": ("INT", {"default": 0, "min": 0, "max": 512, "step": 1, "tooltip": "When > 0, the detected face images are padded and resized to 512x512"}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("POSEDATA", "IMAGE", "STRING", "BBOX", "BBOX,") |
| RETURN_NAMES = ("pose_data", "face_images", "key_frame_body_points", "bboxes", "face_bboxes") |
| FUNCTION = "process" |
| CATEGORY = "WanAnimatePreprocess" |
| DESCRIPTION = "Detects human poses and face images from input images. Optionally retargets poses based on a reference image." |
|
|
| def process(self, model, images, width, height, retarget_image=None, face_padding=0): |
| detector = model["yolo"] |
| pose_model = model["vitpose"] |
| B, H, W, C = images.shape |
|
|
| shape = np.array([H, W])[None] |
| images_np = images.numpy() |
|
|
| IMG_NORM_MEAN = np.array([0.485, 0.456, 0.406]) |
| IMG_NORM_STD = np.array([0.229, 0.224, 0.225]) |
| input_resolution=(256, 192) |
| rescale = 1.25 |
|
|
| detector.reinit() |
| pose_model.reinit() |
| if retarget_image is not None: |
| refer_img = resize_by_area(retarget_image[0].numpy() * 255, width * height, divisor=16) / 255.0 |
| ref_bbox = (detector( |
| cv2.resize(refer_img.astype(np.float32), (640, 640)).transpose(2, 0, 1)[None], |
| shape |
| )[0][0]["bbox"]) |
|
|
| if ref_bbox is None or ref_bbox[-1] <= 0 or (ref_bbox[2] - ref_bbox[0]) < 10 or (ref_bbox[3] - ref_bbox[1]) < 10: |
| ref_bbox = np.array([0, 0, refer_img.shape[1], refer_img.shape[0]]) |
|
|
| center, scale = bbox_from_detector(ref_bbox, input_resolution, rescale=rescale) |
| refer_img = crop(refer_img, center, scale, (input_resolution[0], input_resolution[1]))[0] |
|
|
| img_norm = (refer_img - IMG_NORM_MEAN) / IMG_NORM_STD |
| img_norm = img_norm.transpose(2, 0, 1).astype(np.float32) |
|
|
| ref_keypoints = pose_model(img_norm[None], np.array(center)[None], np.array(scale)[None]) |
| refer_pose_meta = load_pose_metas_from_kp2ds_seq(ref_keypoints, width=retarget_image.shape[2], height=retarget_image.shape[1])[0] |
|
|
| comfy_pbar = ProgressBar(B*2) |
| progress = 0 |
| bboxes = [] |
| for img in tqdm(images_np, total=len(images_np), desc="Detecting bboxes"): |
| bboxes.append(detector( |
| cv2.resize(img, (640, 640)).transpose(2, 0, 1)[None], |
| shape |
| )[0][0]["bbox"]) |
| progress += 1 |
| if progress % 10 == 0: |
| comfy_pbar.update_absolute(progress) |
|
|
| detector.cleanup() |
|
|
| kp2ds = [] |
| for img, bbox in tqdm(zip(images_np, bboxes), total=len(images_np), desc="Extracting keypoints"): |
| if bbox is None or bbox[-1] <= 0 or (bbox[2] - bbox[0]) < 10 or (bbox[3] - bbox[1]) < 10: |
| bbox = np.array([0, 0, img.shape[1], img.shape[0]]) |
|
|
| bbox_xywh = bbox |
| center, scale = bbox_from_detector(bbox_xywh, input_resolution, rescale=rescale) |
| img = crop(img, center, scale, (input_resolution[0], input_resolution[1]))[0] |
|
|
| img_norm = (img - IMG_NORM_MEAN) / IMG_NORM_STD |
| img_norm = img_norm.transpose(2, 0, 1).astype(np.float32) |
|
|
| keypoints = pose_model(img_norm[None], np.array(center)[None], np.array(scale)[None]) |
| kp2ds.append(keypoints) |
| progress += 1 |
| if progress % 10 == 0: |
| comfy_pbar.update_absolute(progress) |
|
|
| pose_model.cleanup() |
|
|
| kp2ds = np.concatenate(kp2ds, 0) |
| pose_metas = load_pose_metas_from_kp2ds_seq(kp2ds, width=W, height=H) |
|
|
| face_images = [] |
| face_bboxes = [] |
| for idx, meta in enumerate(pose_metas): |
| face_bbox_for_image = get_face_bboxes(meta['keypoints_face'][:, :2], scale=1.3, image_shape=(H, W)) |
| x1, x2, y1, y2 = face_bbox_for_image |
| if face_padding > 0: |
| x1 = max(0, x1 - face_padding) |
| y1 = max(0, y1 - face_padding) |
| x2 = min(W, x2 + face_padding) |
| y2 = min(H, y2 + face_padding) |
| face_bboxes.append((x1, y1, x2, y2)) |
| face_image = images_np[idx][y1:y2, x1:x2] |
| |
| if face_image.size == 0 or face_image.shape[0] == 0 or face_image.shape[1] == 0: |
| logging.warning(f"Empty face crop on frame {idx}, creating fallback image.") |
| |
| fallback_size = int(min(H, W) * 0.3) |
| fallback_x1 = (W - fallback_size) // 2 |
| fallback_x2 = fallback_x1 + fallback_size |
| fallback_y1 = int(H * 0.1) |
| fallback_y2 = fallback_y1 + fallback_size |
| face_image = images_np[idx][fallback_y1:fallback_y2, fallback_x1:fallback_x2] |
|
|
| |
| if face_image.size == 0: |
| face_image = np.zeros((fallback_size, fallback_size, C), dtype=images_np.dtype) |
| face_image = cv2.resize(face_image, (512, 512)) |
| face_images.append(face_image) |
|
|
| face_images_np = np.stack(face_images, 0) |
| face_images_tensor = torch.from_numpy(face_images_np) |
|
|
| if retarget_image is not None and refer_pose_meta is not None: |
| retarget_pose_metas = get_retarget_pose(pose_metas[0], refer_pose_meta, pose_metas, None, None) |
| else: |
| retarget_pose_metas = [AAPoseMeta.from_humanapi_meta(meta) for meta in pose_metas] |
|
|
| bbox = np.array(bboxes[0]).flatten() |
| if bbox.shape[0] >= 4: |
| bbox_ints = tuple(int(v) for v in bbox[:4]) |
| else: |
| bbox_ints = (0, 0, 0, 0) |
|
|
| key_frame_num = 4 if B >= 4 else 1 |
| key_frame_step = len(pose_metas) // key_frame_num |
| key_frame_index_list = list(range(0, len(pose_metas), key_frame_step)) |
|
|
| key_points_index = [0, 1, 2, 5, 8, 11, 10, 13] |
|
|
| for key_frame_index in key_frame_index_list: |
| keypoints_body_list = [] |
| body_key_points = pose_metas[key_frame_index]['keypoints_body'] |
| for each_index in key_points_index: |
| each_keypoint = body_key_points[each_index] |
| if None is each_keypoint: |
| continue |
| keypoints_body_list.append(each_keypoint) |
|
|
| keypoints_body = np.array(keypoints_body_list)[:, :2] |
| wh = np.array([[pose_metas[0]['width'], pose_metas[0]['height']]]) |
| points = (keypoints_body * wh).astype(np.int32) |
| points_dict_list = [] |
| for point in points: |
| points_dict_list.append({"x": int(point[0]), "y": int(point[1])}) |
|
|
| pose_data = { |
| "retarget_image": refer_img if retarget_image is not None else None, |
| "pose_metas": retarget_pose_metas, |
| "refer_pose_meta": refer_pose_meta if retarget_image is not None else None, |
| "pose_metas_original": pose_metas, |
| } |
|
|
| return (pose_data, face_images_tensor, json.dumps(points_dict_list), [bbox_ints], face_bboxes) |
|
|
| class DrawViTPose: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "pose_data": ("POSEDATA",), |
| "width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 1, "tooltip": "Width of the generation"}), |
| "height": ("INT", {"default": 480, "min": 64, "max": 2048, "step": 1, "tooltip": "Height of the generation"}), |
| "retarget_padding": ("INT", {"default": 16, "min": 0, "max": 512, "step": 1, "tooltip": "When > 0, the retargeted pose image is padded and resized to the target size"}), |
| "body_stick_width": ("INT", {"default": -1, "min": -1, "max": 20, "step": 1, "tooltip": "Width of the body sticks. Set to 0 to disable body drawing, -1 for auto"}), |
| "hand_stick_width": ("INT", {"default": -1, "min": -1, "max": 20, "step": 1, "tooltip": "Width of the hand sticks. Set to 0 to disable hand drawing, -1 for auto"}), |
| "draw_head": ("BOOLEAN", {"default": "True", "tooltip": "Whether to draw head keypoints"}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("IMAGE", ) |
| RETURN_NAMES = ("pose_images", ) |
| FUNCTION = "process" |
| CATEGORY = "WanAnimatePreprocess" |
| DESCRIPTION = "Draws pose images from pose data." |
|
|
| def process(self, pose_data, width, height, body_stick_width, hand_stick_width, draw_head, retarget_padding=64): |
|
|
| retarget_image = pose_data.get("retarget_image", None) |
| pose_metas = pose_data["pose_metas"] |
|
|
| draw_hand = hand_stick_width != 0 |
| use_retarget_resize = retarget_padding > 0 and retarget_image is not None |
|
|
| comfy_pbar = ProgressBar(len(pose_metas)) |
| progress = 0 |
| crop_target_image = None |
| pose_images = [] |
|
|
| for idx, meta in enumerate(tqdm(pose_metas, desc="Drawing pose images")): |
| canvas = np.zeros((height, width, 3), dtype=np.uint8) |
| pose_image = draw_aapose_by_meta_new(canvas, meta, draw_hand=draw_hand, draw_head=draw_head, body_stick_width=body_stick_width, hand_stick_width=hand_stick_width) |
|
|
| if crop_target_image is None: |
| crop_target_image = pose_image |
|
|
| if use_retarget_resize: |
| pose_image = resize_to_bounds(pose_image, height, width, crop_target_image=crop_target_image, extra_padding=retarget_padding) |
| else: |
| pose_image = padding_resize(pose_image, height, width) |
|
|
| pose_images.append(pose_image) |
| progress += 1 |
| if progress % 10 == 0: |
| comfy_pbar.update_absolute(progress) |
|
|
| pose_images_np = np.stack(pose_images, 0) |
| pose_images_tensor = torch.from_numpy(pose_images_np).float() / 255.0 |
|
|
| return (pose_images_tensor, ) |
|
|
| class PoseRetargetPromptHelper: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "pose_data": ("POSEDATA",), |
| }, |
| } |
|
|
| RETURN_TYPES = ("STRING", "STRING", ) |
| RETURN_NAMES = ("prompt", "retarget_prompt", ) |
| FUNCTION = "process" |
| CATEGORY = "WanAnimatePreprocess" |
| DESCRIPTION = "Generates text prompts for pose retargeting based on visibility of arms and legs in the template pose. Originally used for Flux Kontext" |
|
|
| def process(self, pose_data): |
| refer_pose_meta = pose_data.get("refer_pose_meta", None) |
| if refer_pose_meta is None: |
| return ("Change the person to face forward.", "Change the person to face forward.", ) |
| tpl_pose_metas = pose_data["pose_metas_original"] |
| arm_visible = False |
| leg_visible = False |
|
|
| for tpl_pose_meta in tpl_pose_metas: |
| tpl_keypoints = tpl_pose_meta['keypoints_body'] |
| tpl_keypoints = np.array(tpl_keypoints) |
| if np.any(tpl_keypoints[3]) != 0 or np.any(tpl_keypoints[4]) != 0 or np.any(tpl_keypoints[6]) != 0 or np.any(tpl_keypoints[7]) != 0: |
| if (tpl_keypoints[3][0] <= 1 and tpl_keypoints[3][1] <= 1 and tpl_keypoints[3][2] >= 0.75) or (tpl_keypoints[4][0] <= 1 and tpl_keypoints[4][1] <= 1 and tpl_keypoints[4][2] >= 0.75) or \ |
| (tpl_keypoints[6][0] <= 1 and tpl_keypoints[6][1] <= 1 and tpl_keypoints[6][2] >= 0.75) or (tpl_keypoints[7][0] <= 1 and tpl_keypoints[7][1] <= 1 and tpl_keypoints[7][2] >= 0.75): |
| arm_visible = True |
| if np.any(tpl_keypoints[9]) != 0 or np.any(tpl_keypoints[12]) != 0 or np.any(tpl_keypoints[10]) != 0 or np.any(tpl_keypoints[13]) != 0: |
| if (tpl_keypoints[9][0] <= 1 and tpl_keypoints[9][1] <= 1 and tpl_keypoints[9][2] >= 0.75) or (tpl_keypoints[12][0] <= 1 and tpl_keypoints[12][1] <= 1 and tpl_keypoints[12][2] >= 0.75) or \ |
| (tpl_keypoints[10][0] <= 1 and tpl_keypoints[10][1] <= 1 and tpl_keypoints[10][2] >= 0.75) or (tpl_keypoints[13][0] <= 1 and tpl_keypoints[13][1] <= 1 and tpl_keypoints[13][2] >= 0.75): |
| leg_visible = True |
| if arm_visible and leg_visible: |
| break |
|
|
| if leg_visible: |
| if tpl_pose_meta['width'] > tpl_pose_meta['height']: |
| tpl_prompt = "Change the person to a standard T-pose (facing forward with arms extended). The person is standing. Feet and Hands are visible in the image." |
| else: |
| tpl_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. The person is standing. Feet and Hands are visible in the image." |
|
|
| if refer_pose_meta['width'] > refer_pose_meta['height']: |
| refer_prompt = "Change the person to a standard T-pose (facing forward with arms extended). The person is standing. Feet and Hands are visible in the image." |
| else: |
| refer_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. The person is standing. Feet and Hands are visible in the image." |
| elif arm_visible: |
| if tpl_pose_meta['width'] > tpl_pose_meta['height']: |
| tpl_prompt = "Change the person to a standard T-pose (facing forward with arms extended). Hands are visible in the image." |
| else: |
| tpl_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. Hands are visible in the image." |
|
|
| if refer_pose_meta['width'] > refer_pose_meta['height']: |
| refer_prompt = "Change the person to a standard T-pose (facing forward with arms extended). Hands are visible in the image." |
| else: |
| refer_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. Hands are visible in the image." |
| else: |
| tpl_prompt = "Change the person to face forward." |
| refer_prompt = "Change the person to face forward." |
|
|
| return (tpl_prompt, refer_prompt, ) |
|
|
| class PoseDetectionOneToAllAnimation: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "model": ("POSEMODEL",), |
| "images": ("IMAGE",), |
| "width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 2, "tooltip": "Width of the generation"}), |
| "height": ("INT", {"default": 480, "min": 64, "max": 2048, "step": 2, "tooltip": "Height of the generation"}), |
| "align_to": (["ref", "pose", "none"], {"default": "ref", "tooltip": "Alignment mode for poses"}), |
| "draw_face_points": (["full", "weak", "none"], {"default": "full", "tooltip": "Whether to draw face keypoints on the pose images"}), |
| "draw_head": (["full", "weak", "none"], {"default": "full", "tooltip": "Whether to draw head keypoints on the pose images"}), |
| }, |
| "optional": { |
| "ref_image": ("IMAGE", {"default": None, "tooltip": "Optional reference image for pose retargeting"}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE", "MASK",) |
| RETURN_NAMES = ("pose_images", "ref_pose_image", "ref_image", "ref_mask") |
| FUNCTION = "process" |
| CATEGORY = "WanAnimatePreprocess" |
| DESCRIPTION = "Specialized pose detection and alignment for OneToAllAnimation model https://github.com/ssj9596/One-to-All-Animation. Detects poses from input images and aligns them based on a reference image if provided." |
|
|
| def process(self, model, images, width, height, align_to, draw_face_points, draw_head, ref_image=None): |
| from .onetoall.infer_function import aaposemeta_to_dwpose, align_to_reference, align_to_pose |
| from .onetoall.utils import draw_pose_aligned, warp_ref_to_pose |
| detector = model["yolo"] |
| pose_model = model["vitpose"] |
| B, H, W, C = images.shape |
|
|
| shape = np.array([H, W])[None] |
| images_np = images.numpy() |
|
|
| IMG_NORM_MEAN = np.array([0.485, 0.456, 0.406]) |
| IMG_NORM_STD = np.array([0.229, 0.224, 0.225]) |
| input_resolution=(256, 192) |
| rescale = 1.25 |
|
|
| detector.reinit() |
| pose_model.reinit() |
|
|
| if ref_image is not None: |
| refer_img_np = ref_image[0].numpy() * 255 |
| refer_img = resize_by_area(refer_img_np, width * height, divisor=16) / 255.0 |
| ref_bbox = (detector( |
| cv2.resize(refer_img.astype(np.float32), (640, 640)).transpose(2, 0, 1)[None], |
| shape |
| )[0][0]["bbox"]) |
|
|
| if ref_bbox is None or ref_bbox[-1] <= 0 or (ref_bbox[2] - ref_bbox[0]) < 10 or (ref_bbox[3] - ref_bbox[1]) < 10: |
| ref_bbox = np.array([0, 0, refer_img.shape[1], refer_img.shape[0]]) |
|
|
| center, scale = bbox_from_detector(ref_bbox, input_resolution, rescale=rescale) |
| refer_img = crop(refer_img, center, scale, (input_resolution[0], input_resolution[1]))[0] |
|
|
| img_norm = (refer_img - IMG_NORM_MEAN) / IMG_NORM_STD |
| img_norm = img_norm.transpose(2, 0, 1).astype(np.float32) |
|
|
| ref_keypoints = pose_model(img_norm[None], np.array(center)[None], np.array(scale)[None]) |
| refer_pose_meta = load_pose_metas_from_kp2ds_seq(ref_keypoints, width=ref_image.shape[2], height=ref_image.shape[1])[0] |
|
|
| ref_dwpose = aaposemeta_to_dwpose(refer_pose_meta) |
|
|
| comfy_pbar = ProgressBar(B*2) |
| progress = 0 |
| bboxes = [] |
| for img in tqdm(images_np, total=len(images_np), desc="Detecting bboxes"): |
| bboxes.append(detector( |
| cv2.resize(img, (640, 640)).transpose(2, 0, 1)[None], |
| shape |
| )[0][0]["bbox"]) |
| progress += 1 |
| if progress % 10 == 0: |
| comfy_pbar.update_absolute(progress) |
|
|
| detector.cleanup() |
|
|
| kp2ds = [] |
| for img, bbox in tqdm(zip(images_np, bboxes), total=len(images_np), desc="Extracting keypoints"): |
| if bbox is None or bbox[-1] <= 0 or (bbox[2] - bbox[0]) < 10 or (bbox[3] - bbox[1]) < 10: |
| bbox = np.array([0, 0, img.shape[1], img.shape[0]]) |
|
|
| bbox_xywh = bbox |
| center, scale = bbox_from_detector(bbox_xywh, input_resolution, rescale=rescale) |
| img = crop(img, center, scale, (input_resolution[0], input_resolution[1]))[0] |
|
|
| img_norm = (img - IMG_NORM_MEAN) / IMG_NORM_STD |
| img_norm = img_norm.transpose(2, 0, 1).astype(np.float32) |
|
|
| keypoints = pose_model(img_norm[None], np.array(center)[None], np.array(scale)[None]) |
| kp2ds.append(keypoints) |
| progress += 1 |
| if progress % 10 == 0: |
| comfy_pbar.update_absolute(progress) |
|
|
| pose_model.cleanup() |
|
|
| kp2ds = np.concatenate(kp2ds, 0) |
| pose_metas = load_pose_metas_from_kp2ds_seq(kp2ds, width=W, height=H) |
| tpl_dwposes = [aaposemeta_to_dwpose(meta) for meta in pose_metas] |
|
|
| ref_pose_image_tensor = None |
| if ref_image is not None: |
| if align_to == "ref": |
| ref_pose_image = draw_pose_aligned(ref_dwpose, height, width, without_face=True) |
| ref_pose_image_np = np.stack(ref_pose_image, 0) |
| ref_pose_image_tensor = torch.from_numpy(ref_pose_image_np).unsqueeze(0).float() / 255.0 |
| tpl_dwposes = align_to_reference(refer_pose_meta, pose_metas, tpl_dwposes, anchor_idx=0) |
| image_input_tensor = ref_image |
| image_mask_tensor = torch.zeros(1, ref_image.shape[1], ref_image.shape[2], dtype=torch.float32, device="cpu") |
| elif align_to == "pose": |
| image_input, ref_pose_image_np, image_mask = warp_ref_to_pose(refer_img_np, tpl_dwposes[0], ref_dwpose) |
| ref_pose_image_np = np.stack(ref_pose_image_np, 0) |
| ref_pose_image_tensor = torch.from_numpy(ref_pose_image_np).unsqueeze(0).float() / 255.0 |
| tpl_dwposes = align_to_pose(ref_dwpose, tpl_dwposes, anchor_idx=0) |
| image_input_tensor = torch.from_numpy(image_input).unsqueeze(0).float() / 255.0 |
| image_mask_tensor = torch.from_numpy(image_mask).unsqueeze(0).float() / 255.0 |
| elif align_to == "none": |
| ref_pose_image = draw_pose_aligned(ref_dwpose, height, width, without_face=True) |
| ref_pose_image_np = np.stack(ref_pose_image, 0) |
| ref_pose_image_tensor = torch.from_numpy(ref_pose_image_np).unsqueeze(0).float() / 255.0 |
| image_input_tensor = ref_image |
| image_mask_tensor = torch.zeros(1, ref_image.shape[1], ref_image.shape[2], dtype=torch.float32, device="cpu") |
| else: |
| ref_pose_image_tensor = torch.zeros((1, height, width, 3), dtype=torch.float32, device="cpu") |
| image_input_tensor = torch.zeros((1, height, width, 3), dtype=torch.float32, device="cpu") |
| image_mask_tensor = torch.zeros(1, height, width, dtype=torch.float32, device="cpu") |
|
|
| pose_imgs = [] |
| for pose_np in tpl_dwposes: |
| pose_img = draw_pose_aligned(pose_np, height, width, without_face=(draw_face_points=="none"), face_change=(draw_face_points=="weak"), head_strength=draw_head) |
| pose_img = torch.from_numpy(np.array(pose_img)) |
| pose_imgs.append(pose_img) |
|
|
| pose_tensor = torch.stack(pose_imgs).cpu().float() / 255.0 |
|
|
| return (pose_tensor, ref_pose_image_tensor, image_input_tensor, image_mask_tensor) |
|
|
| NODE_CLASS_MAPPINGS = { |
| "OnnxDetectionModelLoader": OnnxDetectionModelLoader, |
| "PoseAndFaceDetection": PoseAndFaceDetection, |
| "DrawViTPose": DrawViTPose, |
| "PoseRetargetPromptHelper": PoseRetargetPromptHelper, |
| "PoseDetectionOneToAllAnimation": PoseDetectionOneToAllAnimation, |
| } |
| NODE_DISPLAY_NAME_MAPPINGS = { |
| "OnnxDetectionModelLoader": "ONNX Detection Model Loader", |
| "PoseAndFaceDetection": "Pose and Face Detection", |
| "DrawViTPose": "Draw ViT Pose", |
| "PoseRetargetPromptHelper": "Pose Retarget Prompt Helper", |
| "PoseDetectionOneToAllAnimation": "Pose Detection OneToAll Animation", |
| } |
|
|