import os import torch from ..utils import log import numpy as np import comfy.model_management as mm from comfy.utils import load_torch_file import folder_paths script_directory = os.path.dirname(os.path.abspath(__file__)) device = mm.get_torch_device() offload_device = mm.unet_offload_device() from .resampler import Resampler class LoadLynxResampler: @classmethod def INPUT_TYPES(s): return { "required": { "model_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from 'ComfyUI/models/diffusion_models'"}), "precision": (["fp32", "bf16", "fp16"], {"default": "fp16"}), }, } RETURN_TYPES = ("LYNXRESAMPLER",) RETURN_NAMES = ("resampler", ) FUNCTION = "loadmodel" CATEGORY = "WanVideoWrapper" def loadmodel(self, model_name, precision): dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] model_path = folder_paths.get_full_path("diffusion_models", model_name) resampler_sd = load_torch_file(model_path, safe_load=True) output_dim = resampler_sd["proj_out.weight"].shape[0] resampler = Resampler( depth=4, dim=1280, dim_head=64, embedding_dim=512, ff_mult=4, heads=20, num_queries=16, output_dim=output_dim, dtype=dtype, ).eval() resampler.to(offload_device, dtype) resampler.load_state_dict(resampler_sd, strict=True) return resampler, class LynxInsightFaceCrop: @classmethod def INPUT_TYPES(s): return { "required": { "image": ("IMAGE", {"tooltip": "Input images for the model"}), }, } RETURN_TYPES = ("IMAGE", "IMAGE",) RETURN_NAMES = ("ip_image", "ref_image") FUNCTION = "encode" CATEGORY = "WanVideoWrapper" def encode(self, image, image_size=112): from .face.face_encoder import get_landmarks_from_image from .face.face_utils import align_face from insightface.utils import face_align image_np = (image[0].numpy() * 255).astype(np.uint8) landmarks = get_landmarks_from_image(image_np) in_image = np.array(image_np) landmark = np.array(landmarks) ip_face_aligned = face_align.norm_crop(in_image, landmark=landmark, image_size=112) ref_face_aligned = align_face(in_image, landmark, extend_face_crop=True, face_size=256) ip_face_aligned = torch.from_numpy(ip_face_aligned).unsqueeze(0).float() / 255.0 ref_face_aligned = torch.from_numpy(ref_face_aligned).unsqueeze(0).float() / 255.0 ip_face_aligned = (ip_face_aligned - ip_face_aligned.min()) / (ip_face_aligned.max() - ip_face_aligned.min()) ref_face_aligned = (ref_face_aligned - ref_face_aligned.min()) / (ref_face_aligned.max() - ref_face_aligned.min()) ref_face_aligned = ref_face_aligned[:, :, :, [2, 1, 0]] # BGR to RGB return ip_face_aligned, ref_face_aligned class LynxEncodeFaceIP: @classmethod def INPUT_TYPES(s): return { "required": { "resampler": ("LYNXRESAMPLER", {"tooltip": "lynx resampler model"}), "ip_image": ("IMAGE", {"tooltip": "Input images for the model"}), }, } RETURN_TYPES = ("LYNXIP",) RETURN_NAMES = ("lynx_face_embeds",) FUNCTION = "encode" CATEGORY = "WanVideoWrapper" def encode(self, resampler, ip_image): from .face.face_encoder import FaceEncoderArcFace image_in = ip_image.permute(0, 3, 1, 2).to(device) * 2 - 1 # to [-1, 1] # Face embedding via ArcFace face_encoder = FaceEncoderArcFace() face_encoder.init_encoder_model(device) arcface_embed = face_encoder(image_in).to(device, resampler.dtype)[0] arcface_embed = arcface_embed.reshape([1, -1, 512]) resampler.to(device) ip_x = resampler(arcface_embed) ip_x_uncond = resampler(arcface_embed * 0) resampler.to(offload_device) ip_x= ip_x.to(resampler.dtype) out_dict = { 'ip_x': ip_x, 'ip_x_uncond': ip_x_uncond, } return out_dict, class DrawArcFaceLandmarks: @classmethod def INPUT_TYPES(s): return { "required": { "lynx_face_embeds": ("LYNXIP", {"tooltip": "lynx resampler model"}), "image": ("IMAGE", {"tooltip": "Input images for the model"}), }, "optional": { "image": ("IMAGE",) } } RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("landmarked_image", ) FUNCTION = "draw" CATEGORY = "WanVideoWrapper" DESCRIPTION = "Draw face landmarks on an image for visualization/debugging" def draw(self, lynx_face_embeds, image): import cv2 landmarks = lynx_face_embeds['landmarks'] image_np = image[0].numpy() * 255 for (x, y) in landmarks: cv2.circle(image_np, (int(x), int(y)), radius=3, color=(0, 255, 0), thickness=-1) image_out = torch.from_numpy(image_np / 255).unsqueeze(0).float() return image_out, class WanVideoAddLynxEmbeds: @classmethod def INPUT_TYPES(s): return {"required": { "embeds": ("WANVIDIMAGE_EMBEDS",), "ip_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "Strength of the ip adapter face feature"}), "ref_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "Strength of the reference feature"}), "lynx_cfg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "If above 1.0 and main cfg_scale is above 1.0, run extra pass, default value 2.0"}), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent to apply the ref "}), "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent to apply the ref "}), }, "optional": { "vae": ("WANVAE", {"tooltip": "VAE model, only needed if ref_image is provided"}), "lynx_ip_embeds": ("LYNXIP", {"tooltip": "lynx face embeddings"}), "ref_image": ("IMAGE",), "ref_text_embed": ("WANVIDEOTEXTEMBEDS",), "ref_blocks_to_use": ("STRING", {"default": "", "forceInput": True, "tooltip": "Comma-separated list of block indices and ranges to use for reference feature, e.g. '0-20, 25, 28, 35-39'. If empty, use all blocks."}), } } RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",) RETURN_NAMES = ("image_embeds",) FUNCTION = "add" CATEGORY = "WanVideoWrapper" def add(self, embeds, ip_scale, ref_scale, start_percent, end_percent, lynx_cfg_scale, vae=None, lynx_ip_embeds=None, ref_image=None, ref_text_embed=None, ref_blocks_to_use=""): if ref_image is not None and ref_text_embed is None: raise ValueError("If ref_image is provided, ref_text_embed must also be provided.") if ref_image is not None: vae.to(device) ref_image_in = (ref_image[..., :3].permute(3, 0, 1, 2) * 2 - 1).to(device, vae.dtype) ref_latent = vae.encode([ref_image_in], device, tiled=False, sample=True) ref_latent_uncond = vae.encode([torch.zeros_like(ref_image_in)], device, tiled=False, sample=True) vae.to(offload_device) if ref_blocks_to_use.strip() == "": ref_blocks_to_use = None else: # Parse comma-separated blocks and ranges blocks = [] for item in ref_blocks_to_use.split(","): item = item.strip() if "-" in item and not item.startswith("-"): # Handle range like "0-20" or "35-39" try: start, end = item.split("-", 1) start, end = int(start.strip()), int(end.strip()) blocks.extend(list(range(start, end + 1))) except ValueError: print(f"Invalid range format: {item}") elif item.isdigit(): # Handle single number blocks.append(int(item)) else: print(f"Invalid block specification: {item}") ref_blocks_to_use = sorted(list(set(blocks))) # Remove duplicates and sort print("Using ref blocks:", ref_blocks_to_use) new_entry = { "ip_x": lynx_ip_embeds["ip_x"] if lynx_ip_embeds is not None else None, "ip_x_uncond": lynx_ip_embeds["ip_x_uncond"] if lynx_ip_embeds is not None else None, "ref_latent": ref_latent if ref_image is not None else None, "ref_latent_uncond": ref_latent_uncond if ref_image is not None else None, "ref_text_embed": ref_text_embed if ref_text_embed is not None else None, "ip_scale": ip_scale, "ref_scale": ref_scale, "cfg_scale": lynx_cfg_scale, "start_percent": start_percent, "end_percent": end_percent, "ref_blocks_to_use": ref_blocks_to_use, } updated = dict(embeds) updated["lynx_embeds"] = new_entry return (updated,) NODE_CLASS_MAPPINGS = { "LoadLynxResampler": LoadLynxResampler, "LynxEncodeFaceIP": LynxEncodeFaceIP, "DrawArcFaceLandmarks": DrawArcFaceLandmarks, "WanVideoAddLynxEmbeds": WanVideoAddLynxEmbeds, "LynxInsightFaceCrop": LynxInsightFaceCrop, } NODE_DISPLAY_NAME_MAPPINGS = { "LoadLynxResampler": "Load Lynx Resampler", "LynxEncodeFaceIP": "Lynx Encode Face IP", "DrawArcFaceLandmarks": "Draw ArcFace Landmarks", "WanVideoAddLynxEmbeds": "WanVideo Add Lynx Embeds", "LynxInsightFaceCrop": "Lynx InsightFace Crop", }