File size: 4,390 Bytes
c6535db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
from ..utils import log
import comfy.model_management as mm

device = mm.get_torch_device()
offload_device = mm.unet_offload_device()

class WanVideoAddSCAILReferenceEmbeds:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
                    "embeds": ("WANVIDIMAGE_EMBEDS",),
                    "vae": ("WANVAE", {"tooltip": "VAE model"}),
                    "ref_image": ("IMAGE",),
                    "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Strength of the reference embedding"}),
                    "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage of the embedding application"}),
                    "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage of the embedding application"}),
                },
                "optional": {
                    "clip_embeds": ("WANVIDIMAGE_CLIPEMBEDS", {"tooltip": "Clip vision encoded image"}),
                }
        }

    RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",)
    RETURN_NAMES = ("image_embeds",)
    FUNCTION = "add"
    CATEGORY = "WanVideoWrapper"

    def add(self, embeds, vae, ref_image, strength, start_percent, end_percent, clip_embeds=None):
        updated = dict(embeds)

        vae.to(device)
        ref_image_in = (ref_image[..., :3].permute(3, 0, 1, 2) * 2 - 1).to(device, vae.dtype)
        ref_latent = vae.encode([ref_image_in], device, tiled=False)[0]
        log.info(f"SCAIL ref_latent shape: {ref_latent.shape}")

        ref_mask = torch.ones_like(ref_latent[:4])
        ref_latent = torch.cat([ref_latent, ref_mask], dim=0)
        vae.to(offload_device)

        updated.setdefault("scail_embeds", {})
        updated["scail_embeds"]["ref_latent_pos"] = ref_latent * strength
        updated["scail_embeds"]["ref_latent_neg"] = torch.zeros_like(ref_latent)
        updated["scail_embeds"]["ref_start_percent"] = start_percent
        updated["scail_embeds"]["ref_end_percent"] = end_percent
        updated["clip_context"] = clip_embeds.get("clip_embeds", None) if clip_embeds is not None else None

        return (updated,)

class WanVideoAddSCAILPoseEmbeds:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
                    "embeds": ("WANVIDIMAGE_EMBEDS",),
                    "vae": ("WANVAE", {"tooltip": "VAE model"}),
                    "pose_images": ("IMAGE", {"tooltip": "Pose images for the entire video"}),
                    "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Strength of the pose control"}),
                    "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage of the pose control application"}),
                    "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage of the pose control application"}),
                },
        }

    RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",)
    RETURN_NAMES = ("image_embeds",)
    FUNCTION = "add"
    CATEGORY = "WanVideoWrapper"

    def add(self, embeds, vae, pose_images, strength, start_percent=0.0, end_percent=1.0):
        updated = dict(embeds)

        vae.to(device)
        pose_images_in = (pose_images[..., :3].permute(3, 0, 1, 2) * 2 - 1).to(device, vae.dtype)
        pose_latent = vae.encode([pose_images_in], device, tiled=False)[0]
        pose_mask = torch.ones_like(pose_latent[:4])
        pose_latent = torch.cat([pose_latent, pose_mask], dim=0)
        log.info(f"SCAIL pose_latent shape: {pose_latent.shape}")

        vae.to(offload_device)

        updated.setdefault("scail_embeds", {})
        updated["scail_embeds"]["pose_latent"] = pose_latent
        updated["scail_embeds"]["pose_strength"] = strength
        updated["scail_embeds"]["pose_start_percent"] = start_percent
        updated["scail_embeds"]["pose_end_percent"] = end_percent

        return (updated,)


NODE_CLASS_MAPPINGS = {
    "WanVideoAddSCAILPoseEmbeds": WanVideoAddSCAILPoseEmbeds,
    "WanVideoAddSCAILReferenceEmbeds": WanVideoAddSCAILReferenceEmbeds,
    }
NODE_DISPLAY_NAME_MAPPINGS = {
    "WanVideoAddSCAILReferenceEmbeds": "WanVideo Add SCAIL Reference Embeds",
    "WanVideoAddSCAILPoseEmbeds": "WanVideo Add SCAIL Pose Embeds",
    }