import torch from ..utils import log import comfy.model_management as mm device = mm.get_torch_device() offload_device = mm.unet_offload_device() class WanVideoAddSCAILReferenceEmbeds: @classmethod def INPUT_TYPES(s): return {"required": { "embeds": ("WANVIDIMAGE_EMBEDS",), "vae": ("WANVAE", {"tooltip": "VAE model"}), "ref_image": ("IMAGE",), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Strength of the reference embedding"}), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage of the embedding application"}), "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage of the embedding application"}), }, "optional": { "clip_embeds": ("WANVIDIMAGE_CLIPEMBEDS", {"tooltip": "Clip vision encoded image"}), } } RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",) RETURN_NAMES = ("image_embeds",) FUNCTION = "add" CATEGORY = "WanVideoWrapper" def add(self, embeds, vae, ref_image, strength, start_percent, end_percent, clip_embeds=None): updated = dict(embeds) vae.to(device) ref_image_in = (ref_image[..., :3].permute(3, 0, 1, 2) * 2 - 1).to(device, vae.dtype) ref_latent = vae.encode([ref_image_in], device, tiled=False)[0] log.info(f"SCAIL ref_latent shape: {ref_latent.shape}") ref_mask = torch.ones_like(ref_latent[:4]) ref_latent = torch.cat([ref_latent, ref_mask], dim=0) vae.to(offload_device) updated.setdefault("scail_embeds", {}) updated["scail_embeds"]["ref_latent_pos"] = ref_latent * strength updated["scail_embeds"]["ref_latent_neg"] = torch.zeros_like(ref_latent) updated["scail_embeds"]["ref_start_percent"] = start_percent updated["scail_embeds"]["ref_end_percent"] = end_percent updated["clip_context"] = clip_embeds.get("clip_embeds", None) if clip_embeds is not None else None return (updated,) class WanVideoAddSCAILPoseEmbeds: @classmethod def INPUT_TYPES(s): return {"required": { "embeds": ("WANVIDIMAGE_EMBEDS",), "vae": ("WANVAE", {"tooltip": "VAE model"}), "pose_images": ("IMAGE", {"tooltip": "Pose images for the entire video"}), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Strength of the pose control"}), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage of the pose control application"}), "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage of the pose control application"}), }, } RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",) RETURN_NAMES = ("image_embeds",) FUNCTION = "add" CATEGORY = "WanVideoWrapper" def add(self, embeds, vae, pose_images, strength, start_percent=0.0, end_percent=1.0): updated = dict(embeds) vae.to(device) pose_images_in = (pose_images[..., :3].permute(3, 0, 1, 2) * 2 - 1).to(device, vae.dtype) pose_latent = vae.encode([pose_images_in], device, tiled=False)[0] pose_mask = torch.ones_like(pose_latent[:4]) pose_latent = torch.cat([pose_latent, pose_mask], dim=0) log.info(f"SCAIL pose_latent shape: {pose_latent.shape}") vae.to(offload_device) updated.setdefault("scail_embeds", {}) updated["scail_embeds"]["pose_latent"] = pose_latent updated["scail_embeds"]["pose_strength"] = strength updated["scail_embeds"]["pose_start_percent"] = start_percent updated["scail_embeds"]["pose_end_percent"] = end_percent return (updated,) NODE_CLASS_MAPPINGS = { "WanVideoAddSCAILPoseEmbeds": WanVideoAddSCAILPoseEmbeds, "WanVideoAddSCAILReferenceEmbeds": WanVideoAddSCAILReferenceEmbeds, } NODE_DISPLAY_NAME_MAPPINGS = { "WanVideoAddSCAILReferenceEmbeds": "WanVideo Add SCAIL Reference Embeds", "WanVideoAddSCAILPoseEmbeds": "WanVideo Add SCAIL Pose Embeds", }