import os import torch import numpy as np from ..utils import log from accelerate import init_empty_weights from accelerate.utils import set_module_tensor_to_device import comfy.model_management as mm from comfy.utils import load_torch_file, ProgressBar import folder_paths script_directory = os.path.dirname(os.path.abspath(__file__)) device = mm.get_torch_device() offload_device = mm.unet_offload_device() class WanVideoAddSteadyDancerEmbeds: @classmethod def INPUT_TYPES(s): return {"required": { "embeds": ("WANVIDIMAGE_EMBEDS",), "pose_latents_positive": ("LATENT",), "pose_strength_spatial": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "Strength of the pose embedding"}), "pose_strength_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "Strength of the pose embedding"}), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage of the embedding application"}), "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage of the embedding application"}), }, "optional": { "pose_latents_negative": ("LATENT",), "clip_vision_embeds": ("WANVIDIMAGE_CLIPEMBEDS",), } } RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",) RETURN_NAMES = ("image_embeds",) FUNCTION = "add" CATEGORY = "WanVideoWrapper" def add(self, embeds, pose_latents_positive, pose_strength_spatial, pose_strength_temporal, start_percent=0.0, end_percent=1.0, pose_latents_negative=None, clip_vision_embeds=None): sdancer_embeds = { "cond_pos": pose_latents_positive["samples"][0], "cond_neg": pose_latents_negative["samples"][0] if pose_latents_negative else None, "pose_strength_spatial": pose_strength_spatial, "pose_strength_temporal": pose_strength_temporal, "start_percent": start_percent, "end_percent": end_percent, "clip_fea": clip_vision_embeds, } updated = dict(embeds) updated["sdancer_embeds"] = sdancer_embeds return (updated,) NODE_CLASS_MAPPINGS = { "WanVideoAddSteadyDancerEmbeds": WanVideoAddSteadyDancerEmbeds, } NODE_DISPLAY_NAME_MAPPINGS = { "WanVideoAddSteadyDancerEmbeds": "WanVideo Add SteadyDancer Embeds", }