aliensmn's picture
Mirror from https://github.com/kijai/ComfyUI-WanVideoWrapper
cf812a0 verified
import os
import torch
from ..utils import log
import numpy as np
import comfy.model_management as mm
from comfy.utils import load_torch_file
import folder_paths
script_directory = os.path.dirname(os.path.abspath(__file__))
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
from .resampler import Resampler
class LoadLynxResampler:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from 'ComfyUI/models/diffusion_models'"}),
"precision": (["fp32", "bf16", "fp16"], {"default": "fp16"}),
},
}
RETURN_TYPES = ("LYNXRESAMPLER",)
RETURN_NAMES = ("resampler", )
FUNCTION = "loadmodel"
CATEGORY = "WanVideoWrapper"
def loadmodel(self, model_name, precision):
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
model_path = folder_paths.get_full_path("diffusion_models", model_name)
resampler_sd = load_torch_file(model_path, safe_load=True)
output_dim = resampler_sd["proj_out.weight"].shape[0]
resampler = Resampler(
depth=4,
dim=1280,
dim_head=64,
embedding_dim=512,
ff_mult=4,
heads=20,
num_queries=16,
output_dim=output_dim,
dtype=dtype,
).eval()
resampler.to(offload_device, dtype)
resampler.load_state_dict(resampler_sd, strict=True)
return resampler,
class LynxInsightFaceCrop:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE", {"tooltip": "Input images for the model"}),
},
}
RETURN_TYPES = ("IMAGE", "IMAGE",)
RETURN_NAMES = ("ip_image", "ref_image")
FUNCTION = "encode"
CATEGORY = "WanVideoWrapper"
def encode(self, image, image_size=112):
from .face.face_encoder import get_landmarks_from_image
from .face.face_utils import align_face
from insightface.utils import face_align
image_np = (image[0].numpy() * 255).astype(np.uint8)
landmarks = get_landmarks_from_image(image_np)
in_image = np.array(image_np)
landmark = np.array(landmarks)
ip_face_aligned = face_align.norm_crop(in_image, landmark=landmark, image_size=112)
ref_face_aligned = align_face(in_image, landmark, extend_face_crop=True, face_size=256)
ip_face_aligned = torch.from_numpy(ip_face_aligned).unsqueeze(0).float() / 255.0
ref_face_aligned = torch.from_numpy(ref_face_aligned).unsqueeze(0).float() / 255.0
ip_face_aligned = (ip_face_aligned - ip_face_aligned.min()) / (ip_face_aligned.max() - ip_face_aligned.min())
ref_face_aligned = (ref_face_aligned - ref_face_aligned.min()) / (ref_face_aligned.max() - ref_face_aligned.min())
ref_face_aligned = ref_face_aligned[:, :, :, [2, 1, 0]] # BGR to RGB
return ip_face_aligned, ref_face_aligned
class LynxEncodeFaceIP:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"resampler": ("LYNXRESAMPLER", {"tooltip": "lynx resampler model"}),
"ip_image": ("IMAGE", {"tooltip": "Input images for the model"}),
},
}
RETURN_TYPES = ("LYNXIP",)
RETURN_NAMES = ("lynx_face_embeds",)
FUNCTION = "encode"
CATEGORY = "WanVideoWrapper"
def encode(self, resampler, ip_image):
from .face.face_encoder import FaceEncoderArcFace
image_in = ip_image.permute(0, 3, 1, 2).to(device) * 2 - 1 # to [-1, 1]
# Face embedding via ArcFace
face_encoder = FaceEncoderArcFace()
face_encoder.init_encoder_model(device)
arcface_embed = face_encoder(image_in).to(device, resampler.dtype)[0]
arcface_embed = arcface_embed.reshape([1, -1, 512])
resampler.to(device)
ip_x = resampler(arcface_embed)
ip_x_uncond = resampler(arcface_embed * 0)
resampler.to(offload_device)
ip_x= ip_x.to(resampler.dtype)
out_dict = {
'ip_x': ip_x,
'ip_x_uncond': ip_x_uncond,
}
return out_dict,
class DrawArcFaceLandmarks:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"lynx_face_embeds": ("LYNXIP", {"tooltip": "lynx resampler model"}),
"image": ("IMAGE", {"tooltip": "Input images for the model"}),
},
"optional": {
"image": ("IMAGE",)
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("landmarked_image", )
FUNCTION = "draw"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Draw face landmarks on an image for visualization/debugging"
def draw(self, lynx_face_embeds, image):
import cv2
landmarks = lynx_face_embeds['landmarks']
image_np = image[0].numpy() * 255
for (x, y) in landmarks:
cv2.circle(image_np, (int(x), int(y)), radius=3, color=(0, 255, 0), thickness=-1)
image_out = torch.from_numpy(image_np / 255).unsqueeze(0).float()
return image_out,
class WanVideoAddLynxEmbeds:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"embeds": ("WANVIDIMAGE_EMBEDS",),
"ip_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "Strength of the ip adapter face feature"}),
"ref_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "Strength of the reference feature"}),
"lynx_cfg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "If above 1.0 and main cfg_scale is above 1.0, run extra pass, default value 2.0"}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent to apply the ref "}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent to apply the ref "}),
},
"optional": {
"vae": ("WANVAE", {"tooltip": "VAE model, only needed if ref_image is provided"}),
"lynx_ip_embeds": ("LYNXIP", {"tooltip": "lynx face embeddings"}),
"ref_image": ("IMAGE",),
"ref_text_embed": ("WANVIDEOTEXTEMBEDS",),
"ref_blocks_to_use": ("STRING", {"default": "", "forceInput": True, "tooltip": "Comma-separated list of block indices and ranges to use for reference feature, e.g. '0-20, 25, 28, 35-39'. If empty, use all blocks."}),
}
}
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",)
RETURN_NAMES = ("image_embeds",)
FUNCTION = "add"
CATEGORY = "WanVideoWrapper"
def add(self, embeds, ip_scale, ref_scale, start_percent, end_percent, lynx_cfg_scale, vae=None, lynx_ip_embeds=None, ref_image=None, ref_text_embed=None, ref_blocks_to_use=""):
if ref_image is not None and ref_text_embed is None:
raise ValueError("If ref_image is provided, ref_text_embed must also be provided.")
if ref_image is not None:
vae.to(device)
ref_image_in = (ref_image[..., :3].permute(3, 0, 1, 2) * 2 - 1).to(device, vae.dtype)
ref_latent = vae.encode([ref_image_in], device, tiled=False, sample=True)
ref_latent_uncond = vae.encode([torch.zeros_like(ref_image_in)], device, tiled=False, sample=True)
vae.to(offload_device)
if ref_blocks_to_use.strip() == "":
ref_blocks_to_use = None
else:
# Parse comma-separated blocks and ranges
blocks = []
for item in ref_blocks_to_use.split(","):
item = item.strip()
if "-" in item and not item.startswith("-"):
# Handle range like "0-20" or "35-39"
try:
start, end = item.split("-", 1)
start, end = int(start.strip()), int(end.strip())
blocks.extend(list(range(start, end + 1)))
except ValueError:
print(f"Invalid range format: {item}")
elif item.isdigit():
# Handle single number
blocks.append(int(item))
else:
print(f"Invalid block specification: {item}")
ref_blocks_to_use = sorted(list(set(blocks))) # Remove duplicates and sort
print("Using ref blocks:", ref_blocks_to_use)
new_entry = {
"ip_x": lynx_ip_embeds["ip_x"] if lynx_ip_embeds is not None else None,
"ip_x_uncond": lynx_ip_embeds["ip_x_uncond"] if lynx_ip_embeds is not None else None,
"ref_latent": ref_latent if ref_image is not None else None,
"ref_latent_uncond": ref_latent_uncond if ref_image is not None else None,
"ref_text_embed": ref_text_embed if ref_text_embed is not None else None,
"ip_scale": ip_scale,
"ref_scale": ref_scale,
"cfg_scale": lynx_cfg_scale,
"start_percent": start_percent,
"end_percent": end_percent,
"ref_blocks_to_use": ref_blocks_to_use,
}
updated = dict(embeds)
updated["lynx_embeds"] = new_entry
return (updated,)
NODE_CLASS_MAPPINGS = {
"LoadLynxResampler": LoadLynxResampler,
"LynxEncodeFaceIP": LynxEncodeFaceIP,
"DrawArcFaceLandmarks": DrawArcFaceLandmarks,
"WanVideoAddLynxEmbeds": WanVideoAddLynxEmbeds,
"LynxInsightFaceCrop": LynxInsightFaceCrop,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LoadLynxResampler": "Load Lynx Resampler",
"LynxEncodeFaceIP": "Lynx Encode Face IP",
"DrawArcFaceLandmarks": "Draw ArcFace Landmarks",
"WanVideoAddLynxEmbeds": "WanVideo Add Lynx Embeds",
"LynxInsightFaceCrop": "Lynx InsightFace Crop",
}