| import os |
| import torch |
| from ..utils import log |
| import numpy as np |
|
|
| import comfy.model_management as mm |
| from comfy.utils import load_torch_file |
| import folder_paths |
|
|
| script_directory = os.path.dirname(os.path.abspath(__file__)) |
| device = mm.get_torch_device() |
| offload_device = mm.unet_offload_device() |
|
|
| from .resampler import Resampler |
|
|
| class LoadLynxResampler: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "model_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from 'ComfyUI/models/diffusion_models'"}), |
| "precision": (["fp32", "bf16", "fp16"], {"default": "fp16"}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("LYNXRESAMPLER",) |
| RETURN_NAMES = ("resampler", ) |
| FUNCTION = "loadmodel" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def loadmodel(self, model_name, precision): |
| dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] |
|
|
| model_path = folder_paths.get_full_path("diffusion_models", model_name) |
| resampler_sd = load_torch_file(model_path, safe_load=True) |
|
|
| output_dim = resampler_sd["proj_out.weight"].shape[0] |
|
|
| resampler = Resampler( |
| depth=4, |
| dim=1280, |
| dim_head=64, |
| embedding_dim=512, |
| ff_mult=4, |
| heads=20, |
| num_queries=16, |
| output_dim=output_dim, |
| dtype=dtype, |
| ).eval() |
| resampler.to(offload_device, dtype) |
| resampler.load_state_dict(resampler_sd, strict=True) |
|
|
| return resampler, |
|
|
|
|
| class LynxInsightFaceCrop: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "image": ("IMAGE", {"tooltip": "Input images for the model"}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("IMAGE", "IMAGE",) |
| RETURN_NAMES = ("ip_image", "ref_image") |
| FUNCTION = "encode" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def encode(self, image, image_size=112): |
| from .face.face_encoder import get_landmarks_from_image |
| from .face.face_utils import align_face |
| from insightface.utils import face_align |
|
|
| image_np = (image[0].numpy() * 255).astype(np.uint8) |
| landmarks = get_landmarks_from_image(image_np) |
|
|
| in_image = np.array(image_np) |
| landmark = np.array(landmarks) |
|
|
| ip_face_aligned = face_align.norm_crop(in_image, landmark=landmark, image_size=112) |
| ref_face_aligned = align_face(in_image, landmark, extend_face_crop=True, face_size=256) |
|
|
| ip_face_aligned = torch.from_numpy(ip_face_aligned).unsqueeze(0).float() / 255.0 |
| ref_face_aligned = torch.from_numpy(ref_face_aligned).unsqueeze(0).float() / 255.0 |
|
|
| ip_face_aligned = (ip_face_aligned - ip_face_aligned.min()) / (ip_face_aligned.max() - ip_face_aligned.min()) |
| ref_face_aligned = (ref_face_aligned - ref_face_aligned.min()) / (ref_face_aligned.max() - ref_face_aligned.min()) |
| ref_face_aligned = ref_face_aligned[:, :, :, [2, 1, 0]] |
|
|
| return ip_face_aligned, ref_face_aligned |
|
|
|
|
| class LynxEncodeFaceIP: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "resampler": ("LYNXRESAMPLER", {"tooltip": "lynx resampler model"}), |
| "ip_image": ("IMAGE", {"tooltip": "Input images for the model"}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("LYNXIP",) |
| RETURN_NAMES = ("lynx_face_embeds",) |
| FUNCTION = "encode" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def encode(self, resampler, ip_image): |
| from .face.face_encoder import FaceEncoderArcFace |
|
|
| image_in = ip_image.permute(0, 3, 1, 2).to(device) * 2 - 1 |
|
|
| |
| face_encoder = FaceEncoderArcFace() |
| face_encoder.init_encoder_model(device) |
| arcface_embed = face_encoder(image_in).to(device, resampler.dtype)[0] |
|
|
| arcface_embed = arcface_embed.reshape([1, -1, 512]) |
|
|
| resampler.to(device) |
| ip_x = resampler(arcface_embed) |
| ip_x_uncond = resampler(arcface_embed * 0) |
| resampler.to(offload_device) |
|
|
| ip_x= ip_x.to(resampler.dtype) |
|
|
| out_dict = { |
| 'ip_x': ip_x, |
| 'ip_x_uncond': ip_x_uncond, |
| } |
|
|
| return out_dict, |
|
|
| class DrawArcFaceLandmarks: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "lynx_face_embeds": ("LYNXIP", {"tooltip": "lynx resampler model"}), |
| "image": ("IMAGE", {"tooltip": "Input images for the model"}), |
| }, |
| "optional": { |
| "image": ("IMAGE",) |
| } |
| } |
|
|
| RETURN_TYPES = ("IMAGE",) |
| RETURN_NAMES = ("landmarked_image", ) |
| FUNCTION = "draw" |
| CATEGORY = "WanVideoWrapper" |
| DESCRIPTION = "Draw face landmarks on an image for visualization/debugging" |
|
|
| def draw(self, lynx_face_embeds, image): |
| import cv2 |
| landmarks = lynx_face_embeds['landmarks'] |
| image_np = image[0].numpy() * 255 |
|
|
| for (x, y) in landmarks: |
| cv2.circle(image_np, (int(x), int(y)), radius=3, color=(0, 255, 0), thickness=-1) |
|
|
| image_out = torch.from_numpy(image_np / 255).unsqueeze(0).float() |
|
|
| return image_out, |
|
|
| class WanVideoAddLynxEmbeds: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "embeds": ("WANVIDIMAGE_EMBEDS",), |
| "ip_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "Strength of the ip adapter face feature"}), |
| "ref_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "Strength of the reference feature"}), |
| "lynx_cfg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "If above 1.0 and main cfg_scale is above 1.0, run extra pass, default value 2.0"}), |
| "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent to apply the ref "}), |
| "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent to apply the ref "}), |
| }, |
| "optional": { |
| "vae": ("WANVAE", {"tooltip": "VAE model, only needed if ref_image is provided"}), |
| "lynx_ip_embeds": ("LYNXIP", {"tooltip": "lynx face embeddings"}), |
| "ref_image": ("IMAGE",), |
| "ref_text_embed": ("WANVIDEOTEXTEMBEDS",), |
| "ref_blocks_to_use": ("STRING", {"default": "", "forceInput": True, "tooltip": "Comma-separated list of block indices and ranges to use for reference feature, e.g. '0-20, 25, 28, 35-39'. If empty, use all blocks."}), |
| } |
| } |
|
|
| RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",) |
| RETURN_NAMES = ("image_embeds",) |
| FUNCTION = "add" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def add(self, embeds, ip_scale, ref_scale, start_percent, end_percent, lynx_cfg_scale, vae=None, lynx_ip_embeds=None, ref_image=None, ref_text_embed=None, ref_blocks_to_use=""): |
| if ref_image is not None and ref_text_embed is None: |
| raise ValueError("If ref_image is provided, ref_text_embed must also be provided.") |
| if ref_image is not None: |
| vae.to(device) |
| ref_image_in = (ref_image[..., :3].permute(3, 0, 1, 2) * 2 - 1).to(device, vae.dtype) |
| ref_latent = vae.encode([ref_image_in], device, tiled=False, sample=True) |
| ref_latent_uncond = vae.encode([torch.zeros_like(ref_image_in)], device, tiled=False, sample=True) |
| vae.to(offload_device) |
| if ref_blocks_to_use.strip() == "": |
| ref_blocks_to_use = None |
| else: |
| |
| blocks = [] |
| for item in ref_blocks_to_use.split(","): |
| item = item.strip() |
| if "-" in item and not item.startswith("-"): |
| |
| try: |
| start, end = item.split("-", 1) |
| start, end = int(start.strip()), int(end.strip()) |
| blocks.extend(list(range(start, end + 1))) |
| except ValueError: |
| print(f"Invalid range format: {item}") |
| elif item.isdigit(): |
| |
| blocks.append(int(item)) |
| else: |
| print(f"Invalid block specification: {item}") |
| ref_blocks_to_use = sorted(list(set(blocks))) |
| print("Using ref blocks:", ref_blocks_to_use) |
| |
| new_entry = { |
| "ip_x": lynx_ip_embeds["ip_x"] if lynx_ip_embeds is not None else None, |
| "ip_x_uncond": lynx_ip_embeds["ip_x_uncond"] if lynx_ip_embeds is not None else None, |
| "ref_latent": ref_latent if ref_image is not None else None, |
| "ref_latent_uncond": ref_latent_uncond if ref_image is not None else None, |
| "ref_text_embed": ref_text_embed if ref_text_embed is not None else None, |
| "ip_scale": ip_scale, |
| "ref_scale": ref_scale, |
| "cfg_scale": lynx_cfg_scale, |
| "start_percent": start_percent, |
| "end_percent": end_percent, |
| "ref_blocks_to_use": ref_blocks_to_use, |
| } |
|
|
| updated = dict(embeds) |
| updated["lynx_embeds"] = new_entry |
| return (updated,) |
|
|
| NODE_CLASS_MAPPINGS = { |
| "LoadLynxResampler": LoadLynxResampler, |
| "LynxEncodeFaceIP": LynxEncodeFaceIP, |
| "DrawArcFaceLandmarks": DrawArcFaceLandmarks, |
| "WanVideoAddLynxEmbeds": WanVideoAddLynxEmbeds, |
| "LynxInsightFaceCrop": LynxInsightFaceCrop, |
| } |
| NODE_DISPLAY_NAME_MAPPINGS = { |
| "LoadLynxResampler": "Load Lynx Resampler", |
| "LynxEncodeFaceIP": "Lynx Encode Face IP", |
| "DrawArcFaceLandmarks": "Draw ArcFace Landmarks", |
| "WanVideoAddLynxEmbeds": "WanVideo Add Lynx Embeds", |
| "LynxInsightFaceCrop": "Lynx InsightFace Crop", |
| } |
|
|