|
|
import os |
|
|
import torch |
|
|
from ..utils import log |
|
|
import numpy as np |
|
|
|
|
|
import comfy.model_management as mm |
|
|
from comfy.utils import load_torch_file |
|
|
import folder_paths |
|
|
|
|
|
script_directory = os.path.dirname(os.path.abspath(__file__)) |
|
|
device = mm.get_torch_device() |
|
|
offload_device = mm.unet_offload_device() |
|
|
|
|
|
from .resampler import Resampler |
|
|
|
|
|
class LoadLynxResampler: |
|
|
@classmethod |
|
|
def INPUT_TYPES(s): |
|
|
return { |
|
|
"required": { |
|
|
"model_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from 'ComfyUI/models/diffusion_models'"}), |
|
|
"precision": (["fp32", "bf16", "fp16"], {"default": "fp16"}), |
|
|
}, |
|
|
} |
|
|
|
|
|
RETURN_TYPES = ("LYNXRESAMPLER",) |
|
|
RETURN_NAMES = ("resampler", ) |
|
|
FUNCTION = "loadmodel" |
|
|
CATEGORY = "WanVideoWrapper" |
|
|
|
|
|
def loadmodel(self, model_name, precision): |
|
|
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] |
|
|
|
|
|
model_path = folder_paths.get_full_path("diffusion_models", model_name) |
|
|
resampler_sd = load_torch_file(model_path, safe_load=True) |
|
|
|
|
|
output_dim = resampler_sd["proj_out.weight"].shape[0] |
|
|
|
|
|
resampler = Resampler( |
|
|
depth=4, |
|
|
dim=1280, |
|
|
dim_head=64, |
|
|
embedding_dim=512, |
|
|
ff_mult=4, |
|
|
heads=20, |
|
|
num_queries=16, |
|
|
output_dim=output_dim, |
|
|
dtype=dtype, |
|
|
).eval() |
|
|
resampler.to(offload_device, dtype) |
|
|
resampler.load_state_dict(resampler_sd, strict=True) |
|
|
|
|
|
return resampler, |
|
|
|
|
|
|
|
|
class LynxInsightFaceCrop: |
|
|
@classmethod |
|
|
def INPUT_TYPES(s): |
|
|
return { |
|
|
"required": { |
|
|
"image": ("IMAGE", {"tooltip": "Input images for the model"}), |
|
|
}, |
|
|
} |
|
|
|
|
|
RETURN_TYPES = ("IMAGE", "IMAGE",) |
|
|
RETURN_NAMES = ("ip_image", "ref_image") |
|
|
FUNCTION = "encode" |
|
|
CATEGORY = "WanVideoWrapper" |
|
|
|
|
|
def encode(self, image, image_size=112): |
|
|
from .face.face_encoder import get_landmarks_from_image |
|
|
from .face.face_utils import align_face |
|
|
from insightface.utils import face_align |
|
|
|
|
|
image_np = (image[0].numpy() * 255).astype(np.uint8) |
|
|
landmarks = get_landmarks_from_image(image_np) |
|
|
|
|
|
in_image = np.array(image_np) |
|
|
landmark = np.array(landmarks) |
|
|
|
|
|
ip_face_aligned = face_align.norm_crop(in_image, landmark=landmark, image_size=112) |
|
|
ref_face_aligned = align_face(in_image, landmark, extend_face_crop=True, face_size=256) |
|
|
|
|
|
ip_face_aligned = torch.from_numpy(ip_face_aligned).unsqueeze(0).float() / 255.0 |
|
|
ref_face_aligned = torch.from_numpy(ref_face_aligned).unsqueeze(0).float() / 255.0 |
|
|
|
|
|
ip_face_aligned = (ip_face_aligned - ip_face_aligned.min()) / (ip_face_aligned.max() - ip_face_aligned.min()) |
|
|
ref_face_aligned = (ref_face_aligned - ref_face_aligned.min()) / (ref_face_aligned.max() - ref_face_aligned.min()) |
|
|
ref_face_aligned = ref_face_aligned[:, :, :, [2, 1, 0]] |
|
|
|
|
|
return ip_face_aligned, ref_face_aligned |
|
|
|
|
|
|
|
|
class LynxEncodeFaceIP: |
|
|
@classmethod |
|
|
def INPUT_TYPES(s): |
|
|
return { |
|
|
"required": { |
|
|
"resampler": ("LYNXRESAMPLER", {"tooltip": "lynx resampler model"}), |
|
|
"ip_image": ("IMAGE", {"tooltip": "Input images for the model"}), |
|
|
}, |
|
|
} |
|
|
|
|
|
RETURN_TYPES = ("LYNXIP",) |
|
|
RETURN_NAMES = ("lynx_face_embeds",) |
|
|
FUNCTION = "encode" |
|
|
CATEGORY = "WanVideoWrapper" |
|
|
|
|
|
def encode(self, resampler, ip_image): |
|
|
from .face.face_encoder import FaceEncoderArcFace |
|
|
|
|
|
image_in = ip_image.permute(0, 3, 1, 2).to(device) * 2 - 1 |
|
|
|
|
|
|
|
|
face_encoder = FaceEncoderArcFace() |
|
|
face_encoder.init_encoder_model(device) |
|
|
arcface_embed = face_encoder(image_in).to(device, resampler.dtype)[0] |
|
|
|
|
|
arcface_embed = arcface_embed.reshape([1, -1, 512]) |
|
|
|
|
|
resampler.to(device) |
|
|
ip_x = resampler(arcface_embed) |
|
|
ip_x_uncond = resampler(arcface_embed * 0) |
|
|
resampler.to(offload_device) |
|
|
|
|
|
ip_x= ip_x.to(resampler.dtype) |
|
|
|
|
|
out_dict = { |
|
|
'ip_x': ip_x, |
|
|
'ip_x_uncond': ip_x_uncond, |
|
|
} |
|
|
|
|
|
return out_dict, |
|
|
|
|
|
class DrawArcFaceLandmarks: |
|
|
@classmethod |
|
|
def INPUT_TYPES(s): |
|
|
return { |
|
|
"required": { |
|
|
"lynx_face_embeds": ("LYNXIP", {"tooltip": "lynx resampler model"}), |
|
|
"image": ("IMAGE", {"tooltip": "Input images for the model"}), |
|
|
}, |
|
|
"optional": { |
|
|
"image": ("IMAGE",) |
|
|
} |
|
|
} |
|
|
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
|
RETURN_NAMES = ("landmarked_image", ) |
|
|
FUNCTION = "draw" |
|
|
CATEGORY = "WanVideoWrapper" |
|
|
DESCRIPTION = "Draw face landmarks on an image for visualization/debugging" |
|
|
|
|
|
def draw(self, lynx_face_embeds, image): |
|
|
import cv2 |
|
|
landmarks = lynx_face_embeds['landmarks'] |
|
|
image_np = image[0].numpy() * 255 |
|
|
|
|
|
for (x, y) in landmarks: |
|
|
cv2.circle(image_np, (int(x), int(y)), radius=3, color=(0, 255, 0), thickness=-1) |
|
|
|
|
|
image_out = torch.from_numpy(image_np / 255).unsqueeze(0).float() |
|
|
|
|
|
return image_out, |
|
|
|
|
|
class WanVideoAddLynxEmbeds: |
|
|
@classmethod |
|
|
def INPUT_TYPES(s): |
|
|
return {"required": { |
|
|
"embeds": ("WANVIDIMAGE_EMBEDS",), |
|
|
"ip_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "Strength of the ip adapter face feature"}), |
|
|
"ref_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "Strength of the reference feature"}), |
|
|
"lynx_cfg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "If above 1.0 and main cfg_scale is above 1.0, run extra pass, default value 2.0"}), |
|
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent to apply the ref "}), |
|
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent to apply the ref "}), |
|
|
}, |
|
|
"optional": { |
|
|
"vae": ("WANVAE", {"tooltip": "VAE model, only needed if ref_image is provided"}), |
|
|
"lynx_ip_embeds": ("LYNXIP", {"tooltip": "lynx face embeddings"}), |
|
|
"ref_image": ("IMAGE",), |
|
|
"ref_text_embed": ("WANVIDEOTEXTEMBEDS",), |
|
|
"ref_blocks_to_use": ("STRING", {"default": "", "forceInput": True, "tooltip": "Comma-separated list of block indices and ranges to use for reference feature, e.g. '0-20, 25, 28, 35-39'. If empty, use all blocks."}), |
|
|
} |
|
|
} |
|
|
|
|
|
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",) |
|
|
RETURN_NAMES = ("image_embeds",) |
|
|
FUNCTION = "add" |
|
|
CATEGORY = "WanVideoWrapper" |
|
|
|
|
|
def add(self, embeds, ip_scale, ref_scale, start_percent, end_percent, lynx_cfg_scale, vae=None, lynx_ip_embeds=None, ref_image=None, ref_text_embed=None, ref_blocks_to_use=""): |
|
|
if ref_image is not None and ref_text_embed is None: |
|
|
raise ValueError("If ref_image is provided, ref_text_embed must also be provided.") |
|
|
if ref_image is not None: |
|
|
vae.to(device) |
|
|
ref_image_in = (ref_image[..., :3].permute(3, 0, 1, 2) * 2 - 1).to(device, vae.dtype) |
|
|
ref_latent = vae.encode([ref_image_in], device, tiled=False, sample=True) |
|
|
ref_latent_uncond = vae.encode([torch.zeros_like(ref_image_in)], device, tiled=False, sample=True) |
|
|
vae.to(offload_device) |
|
|
if ref_blocks_to_use.strip() == "": |
|
|
ref_blocks_to_use = None |
|
|
else: |
|
|
|
|
|
blocks = [] |
|
|
for item in ref_blocks_to_use.split(","): |
|
|
item = item.strip() |
|
|
if "-" in item and not item.startswith("-"): |
|
|
|
|
|
try: |
|
|
start, end = item.split("-", 1) |
|
|
start, end = int(start.strip()), int(end.strip()) |
|
|
blocks.extend(list(range(start, end + 1))) |
|
|
except ValueError: |
|
|
print(f"Invalid range format: {item}") |
|
|
elif item.isdigit(): |
|
|
|
|
|
blocks.append(int(item)) |
|
|
else: |
|
|
print(f"Invalid block specification: {item}") |
|
|
ref_blocks_to_use = sorted(list(set(blocks))) |
|
|
print("Using ref blocks:", ref_blocks_to_use) |
|
|
|
|
|
new_entry = { |
|
|
"ip_x": lynx_ip_embeds["ip_x"] if lynx_ip_embeds is not None else None, |
|
|
"ip_x_uncond": lynx_ip_embeds["ip_x_uncond"] if lynx_ip_embeds is not None else None, |
|
|
"ref_latent": ref_latent if ref_image is not None else None, |
|
|
"ref_latent_uncond": ref_latent_uncond if ref_image is not None else None, |
|
|
"ref_text_embed": ref_text_embed if ref_text_embed is not None else None, |
|
|
"ip_scale": ip_scale, |
|
|
"ref_scale": ref_scale, |
|
|
"cfg_scale": lynx_cfg_scale, |
|
|
"start_percent": start_percent, |
|
|
"end_percent": end_percent, |
|
|
"ref_blocks_to_use": ref_blocks_to_use, |
|
|
} |
|
|
|
|
|
updated = dict(embeds) |
|
|
updated["lynx_embeds"] = new_entry |
|
|
return (updated,) |
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
|
"LoadLynxResampler": LoadLynxResampler, |
|
|
"LynxEncodeFaceIP": LynxEncodeFaceIP, |
|
|
"DrawArcFaceLandmarks": DrawArcFaceLandmarks, |
|
|
"WanVideoAddLynxEmbeds": WanVideoAddLynxEmbeds, |
|
|
"LynxInsightFaceCrop": LynxInsightFaceCrop, |
|
|
} |
|
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
|
"LoadLynxResampler": "Load Lynx Resampler", |
|
|
"LynxEncodeFaceIP": "Lynx Encode Face IP", |
|
|
"DrawArcFaceLandmarks": "Draw ArcFace Landmarks", |
|
|
"WanVideoAddLynxEmbeds": "WanVideo Add Lynx Embeds", |
|
|
"LynxInsightFaceCrop": "Lynx InsightFace Crop", |
|
|
} |
|
|
|