Eji-Sensei14's picture
Upload folder using huggingface_hub
c6535db verified
import os
import torch
from tqdm import tqdm
import numpy as np
import folder_paths
import cv2
import json
import logging
script_directory = os.path.dirname(os.path.abspath(__file__))
from comfy import model_management as mm
from comfy.utils import ProgressBar
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
folder_paths.add_model_folder_path("detection", os.path.join(folder_paths.models_dir, "detection"))
from .models.onnx_models import ViTPose, Yolo
from .pose_utils.pose2d_utils import load_pose_metas_from_kp2ds_seq, crop, bbox_from_detector
from .utils import get_face_bboxes, padding_resize, resize_by_area, resize_to_bounds
from .pose_utils.human_visualization import AAPoseMeta, draw_aapose_by_meta_new
from .retarget_pose import get_retarget_pose
class OnnxDetectionModelLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"vitpose_model": (folder_paths.get_filename_list("detection"), {"tooltip": "These models are loaded from the 'ComfyUI/models/detection' -folder",}),
"yolo_model": (folder_paths.get_filename_list("detection"), {"tooltip": "These models are loaded from the 'ComfyUI/models/detection' -folder",}),
"onnx_device": (["CUDAExecutionProvider", "CPUExecutionProvider"], {"default": "CUDAExecutionProvider", "tooltip": "Device to run the ONNX models on"}),
},
}
RETURN_TYPES = ("POSEMODEL",)
RETURN_NAMES = ("model", )
FUNCTION = "loadmodel"
CATEGORY = "WanAnimatePreprocess"
DESCRIPTION = "Loads ONNX models for pose and face detection. ViTPose for pose estimation and YOLO for object detection."
def loadmodel(self, vitpose_model, yolo_model, onnx_device):
vitpose_model_path = folder_paths.get_full_path_or_raise("detection", vitpose_model)
yolo_model_path = folder_paths.get_full_path_or_raise("detection", yolo_model)
vitpose = ViTPose(vitpose_model_path, onnx_device)
yolo = Yolo(yolo_model_path, onnx_device)
model = {
"vitpose": vitpose,
"yolo": yolo,
}
return (model, )
class PoseAndFaceDetection:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("POSEMODEL",),
"images": ("IMAGE",),
"width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 1, "tooltip": "Width of the generation"}),
"height": ("INT", {"default": 480, "min": 64, "max": 2048, "step": 1, "tooltip": "Height of the generation"}),
},
"optional": {
"retarget_image": ("IMAGE", {"default": None, "tooltip": "Optional reference image for pose retargeting"}),
"face_padding": ("INT", {"default": 0, "min": 0, "max": 512, "step": 1, "tooltip": "When > 0, the detected face images are padded and resized to 512x512"}),
},
}
RETURN_TYPES = ("POSEDATA", "IMAGE", "STRING", "BBOX", "BBOX,")
RETURN_NAMES = ("pose_data", "face_images", "key_frame_body_points", "bboxes", "face_bboxes")
FUNCTION = "process"
CATEGORY = "WanAnimatePreprocess"
DESCRIPTION = "Detects human poses and face images from input images. Optionally retargets poses based on a reference image."
def process(self, model, images, width, height, retarget_image=None, face_padding=0):
detector = model["yolo"]
pose_model = model["vitpose"]
B, H, W, C = images.shape
shape = np.array([H, W])[None]
images_np = images.numpy()
IMG_NORM_MEAN = np.array([0.485, 0.456, 0.406])
IMG_NORM_STD = np.array([0.229, 0.224, 0.225])
input_resolution=(256, 192)
rescale = 1.25
detector.reinit()
pose_model.reinit()
if retarget_image is not None:
refer_img = resize_by_area(retarget_image[0].numpy() * 255, width * height, divisor=16) / 255.0
ref_bbox = (detector(
cv2.resize(refer_img.astype(np.float32), (640, 640)).transpose(2, 0, 1)[None],
shape
)[0][0]["bbox"])
if ref_bbox is None or ref_bbox[-1] <= 0 or (ref_bbox[2] - ref_bbox[0]) < 10 or (ref_bbox[3] - ref_bbox[1]) < 10:
ref_bbox = np.array([0, 0, refer_img.shape[1], refer_img.shape[0]])
center, scale = bbox_from_detector(ref_bbox, input_resolution, rescale=rescale)
refer_img = crop(refer_img, center, scale, (input_resolution[0], input_resolution[1]))[0]
img_norm = (refer_img - IMG_NORM_MEAN) / IMG_NORM_STD
img_norm = img_norm.transpose(2, 0, 1).astype(np.float32)
ref_keypoints = pose_model(img_norm[None], np.array(center)[None], np.array(scale)[None])
refer_pose_meta = load_pose_metas_from_kp2ds_seq(ref_keypoints, width=retarget_image.shape[2], height=retarget_image.shape[1])[0]
comfy_pbar = ProgressBar(B*2)
progress = 0
bboxes = []
for img in tqdm(images_np, total=len(images_np), desc="Detecting bboxes"):
bboxes.append(detector(
cv2.resize(img, (640, 640)).transpose(2, 0, 1)[None],
shape
)[0][0]["bbox"])
progress += 1
if progress % 10 == 0:
comfy_pbar.update_absolute(progress)
detector.cleanup()
kp2ds = []
for img, bbox in tqdm(zip(images_np, bboxes), total=len(images_np), desc="Extracting keypoints"):
if bbox is None or bbox[-1] <= 0 or (bbox[2] - bbox[0]) < 10 or (bbox[3] - bbox[1]) < 10:
bbox = np.array([0, 0, img.shape[1], img.shape[0]])
bbox_xywh = bbox
center, scale = bbox_from_detector(bbox_xywh, input_resolution, rescale=rescale)
img = crop(img, center, scale, (input_resolution[0], input_resolution[1]))[0]
img_norm = (img - IMG_NORM_MEAN) / IMG_NORM_STD
img_norm = img_norm.transpose(2, 0, 1).astype(np.float32)
keypoints = pose_model(img_norm[None], np.array(center)[None], np.array(scale)[None])
kp2ds.append(keypoints)
progress += 1
if progress % 10 == 0:
comfy_pbar.update_absolute(progress)
pose_model.cleanup()
kp2ds = np.concatenate(kp2ds, 0)
pose_metas = load_pose_metas_from_kp2ds_seq(kp2ds, width=W, height=H)
face_images = []
face_bboxes = []
for idx, meta in enumerate(pose_metas):
face_bbox_for_image = get_face_bboxes(meta['keypoints_face'][:, :2], scale=1.3, image_shape=(H, W))
x1, x2, y1, y2 = face_bbox_for_image
if face_padding > 0:
x1 = max(0, x1 - face_padding)
y1 = max(0, y1 - face_padding)
x2 = min(W, x2 + face_padding)
y2 = min(H, y2 + face_padding)
face_bboxes.append((x1, y1, x2, y2))
face_image = images_np[idx][y1:y2, x1:x2]
# Check if face_image is valid before resizing
if face_image.size == 0 or face_image.shape[0] == 0 or face_image.shape[1] == 0:
logging.warning(f"Empty face crop on frame {idx}, creating fallback image.")
# Create a fallback image (black or use center crop)
fallback_size = int(min(H, W) * 0.3)
fallback_x1 = (W - fallback_size) // 2
fallback_x2 = fallback_x1 + fallback_size
fallback_y1 = int(H * 0.1)
fallback_y2 = fallback_y1 + fallback_size
face_image = images_np[idx][fallback_y1:fallback_y2, fallback_x1:fallback_x2]
# If still empty, create a black image
if face_image.size == 0:
face_image = np.zeros((fallback_size, fallback_size, C), dtype=images_np.dtype)
face_image = cv2.resize(face_image, (512, 512))
face_images.append(face_image)
face_images_np = np.stack(face_images, 0)
face_images_tensor = torch.from_numpy(face_images_np)
if retarget_image is not None and refer_pose_meta is not None:
retarget_pose_metas = get_retarget_pose(pose_metas[0], refer_pose_meta, pose_metas, None, None)
else:
retarget_pose_metas = [AAPoseMeta.from_humanapi_meta(meta) for meta in pose_metas]
bbox = np.array(bboxes[0]).flatten()
if bbox.shape[0] >= 4:
bbox_ints = tuple(int(v) for v in bbox[:4])
else:
bbox_ints = (0, 0, 0, 0)
key_frame_num = 4 if B >= 4 else 1
key_frame_step = len(pose_metas) // key_frame_num
key_frame_index_list = list(range(0, len(pose_metas), key_frame_step))
key_points_index = [0, 1, 2, 5, 8, 11, 10, 13]
for key_frame_index in key_frame_index_list:
keypoints_body_list = []
body_key_points = pose_metas[key_frame_index]['keypoints_body']
for each_index in key_points_index:
each_keypoint = body_key_points[each_index]
if None is each_keypoint:
continue
keypoints_body_list.append(each_keypoint)
keypoints_body = np.array(keypoints_body_list)[:, :2]
wh = np.array([[pose_metas[0]['width'], pose_metas[0]['height']]])
points = (keypoints_body * wh).astype(np.int32)
points_dict_list = []
for point in points:
points_dict_list.append({"x": int(point[0]), "y": int(point[1])})
pose_data = {
"retarget_image": refer_img if retarget_image is not None else None,
"pose_metas": retarget_pose_metas,
"refer_pose_meta": refer_pose_meta if retarget_image is not None else None,
"pose_metas_original": pose_metas,
}
return (pose_data, face_images_tensor, json.dumps(points_dict_list), [bbox_ints], face_bboxes)
class DrawViTPose:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"pose_data": ("POSEDATA",),
"width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 1, "tooltip": "Width of the generation"}),
"height": ("INT", {"default": 480, "min": 64, "max": 2048, "step": 1, "tooltip": "Height of the generation"}),
"retarget_padding": ("INT", {"default": 16, "min": 0, "max": 512, "step": 1, "tooltip": "When > 0, the retargeted pose image is padded and resized to the target size"}),
"body_stick_width": ("INT", {"default": -1, "min": -1, "max": 20, "step": 1, "tooltip": "Width of the body sticks. Set to 0 to disable body drawing, -1 for auto"}),
"hand_stick_width": ("INT", {"default": -1, "min": -1, "max": 20, "step": 1, "tooltip": "Width of the hand sticks. Set to 0 to disable hand drawing, -1 for auto"}),
"draw_head": ("BOOLEAN", {"default": "True", "tooltip": "Whether to draw head keypoints"}),
},
}
RETURN_TYPES = ("IMAGE", )
RETURN_NAMES = ("pose_images", )
FUNCTION = "process"
CATEGORY = "WanAnimatePreprocess"
DESCRIPTION = "Draws pose images from pose data."
def process(self, pose_data, width, height, body_stick_width, hand_stick_width, draw_head, retarget_padding=64):
retarget_image = pose_data.get("retarget_image", None)
pose_metas = pose_data["pose_metas"]
draw_hand = hand_stick_width != 0
use_retarget_resize = retarget_padding > 0 and retarget_image is not None
comfy_pbar = ProgressBar(len(pose_metas))
progress = 0
crop_target_image = None
pose_images = []
for idx, meta in enumerate(tqdm(pose_metas, desc="Drawing pose images")):
canvas = np.zeros((height, width, 3), dtype=np.uint8)
pose_image = draw_aapose_by_meta_new(canvas, meta, draw_hand=draw_hand, draw_head=draw_head, body_stick_width=body_stick_width, hand_stick_width=hand_stick_width)
if crop_target_image is None:
crop_target_image = pose_image
if use_retarget_resize:
pose_image = resize_to_bounds(pose_image, height, width, crop_target_image=crop_target_image, extra_padding=retarget_padding)
else:
pose_image = padding_resize(pose_image, height, width)
pose_images.append(pose_image)
progress += 1
if progress % 10 == 0:
comfy_pbar.update_absolute(progress)
pose_images_np = np.stack(pose_images, 0)
pose_images_tensor = torch.from_numpy(pose_images_np).float() / 255.0
return (pose_images_tensor, )
class PoseRetargetPromptHelper:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"pose_data": ("POSEDATA",),
},
}
RETURN_TYPES = ("STRING", "STRING", )
RETURN_NAMES = ("prompt", "retarget_prompt", )
FUNCTION = "process"
CATEGORY = "WanAnimatePreprocess"
DESCRIPTION = "Generates text prompts for pose retargeting based on visibility of arms and legs in the template pose. Originally used for Flux Kontext"
def process(self, pose_data):
refer_pose_meta = pose_data.get("refer_pose_meta", None)
if refer_pose_meta is None:
return ("Change the person to face forward.", "Change the person to face forward.", )
tpl_pose_metas = pose_data["pose_metas_original"]
arm_visible = False
leg_visible = False
for tpl_pose_meta in tpl_pose_metas:
tpl_keypoints = tpl_pose_meta['keypoints_body']
tpl_keypoints = np.array(tpl_keypoints)
if np.any(tpl_keypoints[3]) != 0 or np.any(tpl_keypoints[4]) != 0 or np.any(tpl_keypoints[6]) != 0 or np.any(tpl_keypoints[7]) != 0:
if (tpl_keypoints[3][0] <= 1 and tpl_keypoints[3][1] <= 1 and tpl_keypoints[3][2] >= 0.75) or (tpl_keypoints[4][0] <= 1 and tpl_keypoints[4][1] <= 1 and tpl_keypoints[4][2] >= 0.75) or \
(tpl_keypoints[6][0] <= 1 and tpl_keypoints[6][1] <= 1 and tpl_keypoints[6][2] >= 0.75) or (tpl_keypoints[7][0] <= 1 and tpl_keypoints[7][1] <= 1 and tpl_keypoints[7][2] >= 0.75):
arm_visible = True
if np.any(tpl_keypoints[9]) != 0 or np.any(tpl_keypoints[12]) != 0 or np.any(tpl_keypoints[10]) != 0 or np.any(tpl_keypoints[13]) != 0:
if (tpl_keypoints[9][0] <= 1 and tpl_keypoints[9][1] <= 1 and tpl_keypoints[9][2] >= 0.75) or (tpl_keypoints[12][0] <= 1 and tpl_keypoints[12][1] <= 1 and tpl_keypoints[12][2] >= 0.75) or \
(tpl_keypoints[10][0] <= 1 and tpl_keypoints[10][1] <= 1 and tpl_keypoints[10][2] >= 0.75) or (tpl_keypoints[13][0] <= 1 and tpl_keypoints[13][1] <= 1 and tpl_keypoints[13][2] >= 0.75):
leg_visible = True
if arm_visible and leg_visible:
break
if leg_visible:
if tpl_pose_meta['width'] > tpl_pose_meta['height']:
tpl_prompt = "Change the person to a standard T-pose (facing forward with arms extended). The person is standing. Feet and Hands are visible in the image."
else:
tpl_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. The person is standing. Feet and Hands are visible in the image."
if refer_pose_meta['width'] > refer_pose_meta['height']:
refer_prompt = "Change the person to a standard T-pose (facing forward with arms extended). The person is standing. Feet and Hands are visible in the image."
else:
refer_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. The person is standing. Feet and Hands are visible in the image."
elif arm_visible:
if tpl_pose_meta['width'] > tpl_pose_meta['height']:
tpl_prompt = "Change the person to a standard T-pose (facing forward with arms extended). Hands are visible in the image."
else:
tpl_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. Hands are visible in the image."
if refer_pose_meta['width'] > refer_pose_meta['height']:
refer_prompt = "Change the person to a standard T-pose (facing forward with arms extended). Hands are visible in the image."
else:
refer_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. Hands are visible in the image."
else:
tpl_prompt = "Change the person to face forward."
refer_prompt = "Change the person to face forward."
return (tpl_prompt, refer_prompt, )
class PoseDetectionOneToAllAnimation:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("POSEMODEL",),
"images": ("IMAGE",),
"width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 2, "tooltip": "Width of the generation"}),
"height": ("INT", {"default": 480, "min": 64, "max": 2048, "step": 2, "tooltip": "Height of the generation"}),
"align_to": (["ref", "pose", "none"], {"default": "ref", "tooltip": "Alignment mode for poses"}),
"draw_face_points": (["full", "weak", "none"], {"default": "full", "tooltip": "Whether to draw face keypoints on the pose images"}),
"draw_head": (["full", "weak", "none"], {"default": "full", "tooltip": "Whether to draw head keypoints on the pose images"}),
},
"optional": {
"ref_image": ("IMAGE", {"default": None, "tooltip": "Optional reference image for pose retargeting"}),
},
}
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE", "MASK",)
RETURN_NAMES = ("pose_images", "ref_pose_image", "ref_image", "ref_mask")
FUNCTION = "process"
CATEGORY = "WanAnimatePreprocess"
DESCRIPTION = "Specialized pose detection and alignment for OneToAllAnimation model https://github.com/ssj9596/One-to-All-Animation. Detects poses from input images and aligns them based on a reference image if provided."
def process(self, model, images, width, height, align_to, draw_face_points, draw_head, ref_image=None):
from .onetoall.infer_function import aaposemeta_to_dwpose, align_to_reference, align_to_pose
from .onetoall.utils import draw_pose_aligned, warp_ref_to_pose
detector = model["yolo"]
pose_model = model["vitpose"]
B, H, W, C = images.shape
shape = np.array([H, W])[None]
images_np = images.numpy()
IMG_NORM_MEAN = np.array([0.485, 0.456, 0.406])
IMG_NORM_STD = np.array([0.229, 0.224, 0.225])
input_resolution=(256, 192)
rescale = 1.25
detector.reinit()
pose_model.reinit()
if ref_image is not None:
refer_img_np = ref_image[0].numpy() * 255
refer_img = resize_by_area(refer_img_np, width * height, divisor=16) / 255.0
ref_bbox = (detector(
cv2.resize(refer_img.astype(np.float32), (640, 640)).transpose(2, 0, 1)[None],
shape
)[0][0]["bbox"])
if ref_bbox is None or ref_bbox[-1] <= 0 or (ref_bbox[2] - ref_bbox[0]) < 10 or (ref_bbox[3] - ref_bbox[1]) < 10:
ref_bbox = np.array([0, 0, refer_img.shape[1], refer_img.shape[0]])
center, scale = bbox_from_detector(ref_bbox, input_resolution, rescale=rescale)
refer_img = crop(refer_img, center, scale, (input_resolution[0], input_resolution[1]))[0]
img_norm = (refer_img - IMG_NORM_MEAN) / IMG_NORM_STD
img_norm = img_norm.transpose(2, 0, 1).astype(np.float32)
ref_keypoints = pose_model(img_norm[None], np.array(center)[None], np.array(scale)[None])
refer_pose_meta = load_pose_metas_from_kp2ds_seq(ref_keypoints, width=ref_image.shape[2], height=ref_image.shape[1])[0]
ref_dwpose = aaposemeta_to_dwpose(refer_pose_meta)
comfy_pbar = ProgressBar(B*2)
progress = 0
bboxes = []
for img in tqdm(images_np, total=len(images_np), desc="Detecting bboxes"):
bboxes.append(detector(
cv2.resize(img, (640, 640)).transpose(2, 0, 1)[None],
shape
)[0][0]["bbox"])
progress += 1
if progress % 10 == 0:
comfy_pbar.update_absolute(progress)
detector.cleanup()
kp2ds = []
for img, bbox in tqdm(zip(images_np, bboxes), total=len(images_np), desc="Extracting keypoints"):
if bbox is None or bbox[-1] <= 0 or (bbox[2] - bbox[0]) < 10 or (bbox[3] - bbox[1]) < 10:
bbox = np.array([0, 0, img.shape[1], img.shape[0]])
bbox_xywh = bbox
center, scale = bbox_from_detector(bbox_xywh, input_resolution, rescale=rescale)
img = crop(img, center, scale, (input_resolution[0], input_resolution[1]))[0]
img_norm = (img - IMG_NORM_MEAN) / IMG_NORM_STD
img_norm = img_norm.transpose(2, 0, 1).astype(np.float32)
keypoints = pose_model(img_norm[None], np.array(center)[None], np.array(scale)[None])
kp2ds.append(keypoints)
progress += 1
if progress % 10 == 0:
comfy_pbar.update_absolute(progress)
pose_model.cleanup()
kp2ds = np.concatenate(kp2ds, 0)
pose_metas = load_pose_metas_from_kp2ds_seq(kp2ds, width=W, height=H)
tpl_dwposes = [aaposemeta_to_dwpose(meta) for meta in pose_metas]
ref_pose_image_tensor = None
if ref_image is not None:
if align_to == "ref":
ref_pose_image = draw_pose_aligned(ref_dwpose, height, width, without_face=True)
ref_pose_image_np = np.stack(ref_pose_image, 0)
ref_pose_image_tensor = torch.from_numpy(ref_pose_image_np).unsqueeze(0).float() / 255.0
tpl_dwposes = align_to_reference(refer_pose_meta, pose_metas, tpl_dwposes, anchor_idx=0)
image_input_tensor = ref_image
image_mask_tensor = torch.zeros(1, ref_image.shape[1], ref_image.shape[2], dtype=torch.float32, device="cpu")
elif align_to == "pose":
image_input, ref_pose_image_np, image_mask = warp_ref_to_pose(refer_img_np, tpl_dwposes[0], ref_dwpose)
ref_pose_image_np = np.stack(ref_pose_image_np, 0)
ref_pose_image_tensor = torch.from_numpy(ref_pose_image_np).unsqueeze(0).float() / 255.0
tpl_dwposes = align_to_pose(ref_dwpose, tpl_dwposes, anchor_idx=0)
image_input_tensor = torch.from_numpy(image_input).unsqueeze(0).float() / 255.0
image_mask_tensor = torch.from_numpy(image_mask).unsqueeze(0).float() / 255.0
elif align_to == "none":
ref_pose_image = draw_pose_aligned(ref_dwpose, height, width, without_face=True)
ref_pose_image_np = np.stack(ref_pose_image, 0)
ref_pose_image_tensor = torch.from_numpy(ref_pose_image_np).unsqueeze(0).float() / 255.0
image_input_tensor = ref_image
image_mask_tensor = torch.zeros(1, ref_image.shape[1], ref_image.shape[2], dtype=torch.float32, device="cpu")
else:
ref_pose_image_tensor = torch.zeros((1, height, width, 3), dtype=torch.float32, device="cpu")
image_input_tensor = torch.zeros((1, height, width, 3), dtype=torch.float32, device="cpu")
image_mask_tensor = torch.zeros(1, height, width, dtype=torch.float32, device="cpu")
pose_imgs = []
for pose_np in tpl_dwposes:
pose_img = draw_pose_aligned(pose_np, height, width, without_face=(draw_face_points=="none"), face_change=(draw_face_points=="weak"), head_strength=draw_head)
pose_img = torch.from_numpy(np.array(pose_img))
pose_imgs.append(pose_img)
pose_tensor = torch.stack(pose_imgs).cpu().float() / 255.0
return (pose_tensor, ref_pose_image_tensor, image_input_tensor, image_mask_tensor)
NODE_CLASS_MAPPINGS = {
"OnnxDetectionModelLoader": OnnxDetectionModelLoader,
"PoseAndFaceDetection": PoseAndFaceDetection,
"DrawViTPose": DrawViTPose,
"PoseRetargetPromptHelper": PoseRetargetPromptHelper,
"PoseDetectionOneToAllAnimation": PoseDetectionOneToAllAnimation,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"OnnxDetectionModelLoader": "ONNX Detection Model Loader",
"PoseAndFaceDetection": "Pose and Face Detection",
"DrawViTPose": "Draw ViT Pose",
"PoseRetargetPromptHelper": "Pose Retarget Prompt Helper",
"PoseDetectionOneToAllAnimation": "Pose Detection OneToAll Animation",
}