File size: 10,175 Bytes
cf812a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
import os
import torch
from ..utils import log
import numpy as np
import comfy.model_management as mm
from comfy.utils import load_torch_file
import folder_paths
script_directory = os.path.dirname(os.path.abspath(__file__))
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
from .resampler import Resampler
class LoadLynxResampler:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from 'ComfyUI/models/diffusion_models'"}),
"precision": (["fp32", "bf16", "fp16"], {"default": "fp16"}),
},
}
RETURN_TYPES = ("LYNXRESAMPLER",)
RETURN_NAMES = ("resampler", )
FUNCTION = "loadmodel"
CATEGORY = "WanVideoWrapper"
def loadmodel(self, model_name, precision):
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
model_path = folder_paths.get_full_path("diffusion_models", model_name)
resampler_sd = load_torch_file(model_path, safe_load=True)
output_dim = resampler_sd["proj_out.weight"].shape[0]
resampler = Resampler(
depth=4,
dim=1280,
dim_head=64,
embedding_dim=512,
ff_mult=4,
heads=20,
num_queries=16,
output_dim=output_dim,
dtype=dtype,
).eval()
resampler.to(offload_device, dtype)
resampler.load_state_dict(resampler_sd, strict=True)
return resampler,
class LynxInsightFaceCrop:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE", {"tooltip": "Input images for the model"}),
},
}
RETURN_TYPES = ("IMAGE", "IMAGE",)
RETURN_NAMES = ("ip_image", "ref_image")
FUNCTION = "encode"
CATEGORY = "WanVideoWrapper"
def encode(self, image, image_size=112):
from .face.face_encoder import get_landmarks_from_image
from .face.face_utils import align_face
from insightface.utils import face_align
image_np = (image[0].numpy() * 255).astype(np.uint8)
landmarks = get_landmarks_from_image(image_np)
in_image = np.array(image_np)
landmark = np.array(landmarks)
ip_face_aligned = face_align.norm_crop(in_image, landmark=landmark, image_size=112)
ref_face_aligned = align_face(in_image, landmark, extend_face_crop=True, face_size=256)
ip_face_aligned = torch.from_numpy(ip_face_aligned).unsqueeze(0).float() / 255.0
ref_face_aligned = torch.from_numpy(ref_face_aligned).unsqueeze(0).float() / 255.0
ip_face_aligned = (ip_face_aligned - ip_face_aligned.min()) / (ip_face_aligned.max() - ip_face_aligned.min())
ref_face_aligned = (ref_face_aligned - ref_face_aligned.min()) / (ref_face_aligned.max() - ref_face_aligned.min())
ref_face_aligned = ref_face_aligned[:, :, :, [2, 1, 0]] # BGR to RGB
return ip_face_aligned, ref_face_aligned
class LynxEncodeFaceIP:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"resampler": ("LYNXRESAMPLER", {"tooltip": "lynx resampler model"}),
"ip_image": ("IMAGE", {"tooltip": "Input images for the model"}),
},
}
RETURN_TYPES = ("LYNXIP",)
RETURN_NAMES = ("lynx_face_embeds",)
FUNCTION = "encode"
CATEGORY = "WanVideoWrapper"
def encode(self, resampler, ip_image):
from .face.face_encoder import FaceEncoderArcFace
image_in = ip_image.permute(0, 3, 1, 2).to(device) * 2 - 1 # to [-1, 1]
# Face embedding via ArcFace
face_encoder = FaceEncoderArcFace()
face_encoder.init_encoder_model(device)
arcface_embed = face_encoder(image_in).to(device, resampler.dtype)[0]
arcface_embed = arcface_embed.reshape([1, -1, 512])
resampler.to(device)
ip_x = resampler(arcface_embed)
ip_x_uncond = resampler(arcface_embed * 0)
resampler.to(offload_device)
ip_x= ip_x.to(resampler.dtype)
out_dict = {
'ip_x': ip_x,
'ip_x_uncond': ip_x_uncond,
}
return out_dict,
class DrawArcFaceLandmarks:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"lynx_face_embeds": ("LYNXIP", {"tooltip": "lynx resampler model"}),
"image": ("IMAGE", {"tooltip": "Input images for the model"}),
},
"optional": {
"image": ("IMAGE",)
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("landmarked_image", )
FUNCTION = "draw"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Draw face landmarks on an image for visualization/debugging"
def draw(self, lynx_face_embeds, image):
import cv2
landmarks = lynx_face_embeds['landmarks']
image_np = image[0].numpy() * 255
for (x, y) in landmarks:
cv2.circle(image_np, (int(x), int(y)), radius=3, color=(0, 255, 0), thickness=-1)
image_out = torch.from_numpy(image_np / 255).unsqueeze(0).float()
return image_out,
class WanVideoAddLynxEmbeds:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"embeds": ("WANVIDIMAGE_EMBEDS",),
"ip_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "Strength of the ip adapter face feature"}),
"ref_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "Strength of the reference feature"}),
"lynx_cfg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "If above 1.0 and main cfg_scale is above 1.0, run extra pass, default value 2.0"}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent to apply the ref "}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent to apply the ref "}),
},
"optional": {
"vae": ("WANVAE", {"tooltip": "VAE model, only needed if ref_image is provided"}),
"lynx_ip_embeds": ("LYNXIP", {"tooltip": "lynx face embeddings"}),
"ref_image": ("IMAGE",),
"ref_text_embed": ("WANVIDEOTEXTEMBEDS",),
"ref_blocks_to_use": ("STRING", {"default": "", "forceInput": True, "tooltip": "Comma-separated list of block indices and ranges to use for reference feature, e.g. '0-20, 25, 28, 35-39'. If empty, use all blocks."}),
}
}
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",)
RETURN_NAMES = ("image_embeds",)
FUNCTION = "add"
CATEGORY = "WanVideoWrapper"
def add(self, embeds, ip_scale, ref_scale, start_percent, end_percent, lynx_cfg_scale, vae=None, lynx_ip_embeds=None, ref_image=None, ref_text_embed=None, ref_blocks_to_use=""):
if ref_image is not None and ref_text_embed is None:
raise ValueError("If ref_image is provided, ref_text_embed must also be provided.")
if ref_image is not None:
vae.to(device)
ref_image_in = (ref_image[..., :3].permute(3, 0, 1, 2) * 2 - 1).to(device, vae.dtype)
ref_latent = vae.encode([ref_image_in], device, tiled=False, sample=True)
ref_latent_uncond = vae.encode([torch.zeros_like(ref_image_in)], device, tiled=False, sample=True)
vae.to(offload_device)
if ref_blocks_to_use.strip() == "":
ref_blocks_to_use = None
else:
# Parse comma-separated blocks and ranges
blocks = []
for item in ref_blocks_to_use.split(","):
item = item.strip()
if "-" in item and not item.startswith("-"):
# Handle range like "0-20" or "35-39"
try:
start, end = item.split("-", 1)
start, end = int(start.strip()), int(end.strip())
blocks.extend(list(range(start, end + 1)))
except ValueError:
print(f"Invalid range format: {item}")
elif item.isdigit():
# Handle single number
blocks.append(int(item))
else:
print(f"Invalid block specification: {item}")
ref_blocks_to_use = sorted(list(set(blocks))) # Remove duplicates and sort
print("Using ref blocks:", ref_blocks_to_use)
new_entry = {
"ip_x": lynx_ip_embeds["ip_x"] if lynx_ip_embeds is not None else None,
"ip_x_uncond": lynx_ip_embeds["ip_x_uncond"] if lynx_ip_embeds is not None else None,
"ref_latent": ref_latent if ref_image is not None else None,
"ref_latent_uncond": ref_latent_uncond if ref_image is not None else None,
"ref_text_embed": ref_text_embed if ref_text_embed is not None else None,
"ip_scale": ip_scale,
"ref_scale": ref_scale,
"cfg_scale": lynx_cfg_scale,
"start_percent": start_percent,
"end_percent": end_percent,
"ref_blocks_to_use": ref_blocks_to_use,
}
updated = dict(embeds)
updated["lynx_embeds"] = new_entry
return (updated,)
NODE_CLASS_MAPPINGS = {
"LoadLynxResampler": LoadLynxResampler,
"LynxEncodeFaceIP": LynxEncodeFaceIP,
"DrawArcFaceLandmarks": DrawArcFaceLandmarks,
"WanVideoAddLynxEmbeds": WanVideoAddLynxEmbeds,
"LynxInsightFaceCrop": LynxInsightFaceCrop,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LoadLynxResampler": "Load Lynx Resampler",
"LynxEncodeFaceIP": "Lynx Encode Face IP",
"DrawArcFaceLandmarks": "Draw ArcFace Landmarks",
"WanVideoAddLynxEmbeds": "WanVideo Add Lynx Embeds",
"LynxInsightFaceCrop": "Lynx InsightFace Crop",
}
|