"""Modified from https://github.com/kijai/ComfyUI-EasyAnimateWrapper/blob/main/nodes.py """ import copy import gc import json import os import cv2 import numpy as np import torch from diffusers import FlowMatchEulerDiscreteScheduler from einops import rearrange from omegaconf import OmegaConf from PIL import Image import comfy.model_management as mm import folder_paths from comfy.utils import ProgressBar, load_torch_file from ...videox_fun.data.bucket_sampler import (ASPECT_RATIO_512, get_closest_ratio) from ...videox_fun.data.dataset_image_video import process_pose_params from ...videox_fun.models import (AutoencoderKLWan, AutoencoderKLWan3_8, AutoTokenizer, CLIPModel, Wan2_2Transformer3DModel, WanT5EncoderModel) from ...videox_fun.models.cache_utils import get_teacache_coefficients from ...videox_fun.pipeline import (Wan2_2FunControlPipeline, Wan2_2FunInpaintPipeline, Wan2_2FunPipeline) from ...videox_fun.ui.controller import all_cheduler_dict from ...videox_fun.utils.fp8_optimization import ( convert_model_weight_to_float8, convert_weight_dtype_wrapper, undo_convert_weight_dtype_wrapper, replace_parameters_by_name) from ...videox_fun.utils.lora_utils import merge_lora, unmerge_lora from ...videox_fun.utils.utils import (filter_kwargs, get_image_latent, get_image_to_video_latent, get_video_to_video_latent, save_videos_grid) from ..wan2_1.nodes import get_wan_scheduler from ..comfyui_utils import (eas_cache_dir, script_directory, search_model_in_possible_folders, to_pil) # Used in lora cache transformer_cpu_cache = {} transformer_high_cpu_cache = {} # lora path before lora_path_before = "" lora_high_path_before = "" class LoadWan2_2FunModel: @classmethod def INPUT_TYPES(s): return { "required": { "model": ( [ 'Wan2.2-Fun-A14B-InP', 'Wan2.2-Fun-A14B-Control', 'Wan2.2-Fun-A14B-Control-Camera', 'Wan2.2-Fun-5B-InP', 'Wan2.2-Fun-5B-Control', 'Wan2.2-Fun-5B-Control-Camera', ], { "default": 'Wan2.2-Fun-A14B-InP', } ), "model_type": ( ["Inpaint", "Control"], { "default": "Inpaint", } ), "GPU_memory_mode":( ["model_full_load", "model_full_load_and_qfloat8","model_cpu_offload", "model_cpu_offload_and_qfloat8", "sequential_cpu_offload"], { "default": "model_cpu_offload", } ), "config": ( [ "wan2.2/wan_civitai_i2v.yaml", "wan2.2/wan_civitai_5b.yaml", ], { "default": "wan2.2/wan_civitai_i2v.yaml", } ), "precision": ( ['fp16', 'bf16'], { "default": 'fp16' } ), }, } RETURN_TYPES = ("FunModels",) RETURN_NAMES = ("funmodels",) FUNCTION = "loadmodel" CATEGORY = "CogVideoXFUNWrapper" def loadmodel(self, GPU_memory_mode, model_type, model, precision, config): # Init weight_dtype and device device = mm.get_torch_device() offload_device = mm.unet_offload_device() weight_dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] mm.unload_all_models() mm.cleanup_models() mm.soft_empty_cache() # Init processbar pbar = ProgressBar(5) # Load config config_path = f"{script_directory}/config/{config}" config = OmegaConf.load(config_path) # Detect model is existing or not possible_folders = ["CogVideoX_Fun", "Fun_Models", "VideoX_Fun", "Wan-AI"] + \ [os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "models/Diffusion_Transformer")] # Possible folder names to check # Initialize model_name as None model_name = search_model_in_possible_folders(possible_folders, model) Chosen_AutoencoderKL = { "AutoencoderKLWan": AutoencoderKLWan, "AutoencoderKLWan3_8": AutoencoderKLWan3_8 }[config['vae_kwargs'].get('vae_type', 'AutoencoderKLWan')] vae = Chosen_AutoencoderKL.from_pretrained( os.path.join(model_name, config['vae_kwargs'].get('vae_subpath', 'vae')), additional_kwargs=OmegaConf.to_container(config['vae_kwargs']), ).to(weight_dtype) # Update pbar pbar.update(1) # Load Sampler print("Load Sampler.") scheduler = FlowMatchEulerDiscreteScheduler( **filter_kwargs(FlowMatchEulerDiscreteScheduler, OmegaConf.to_container(config['scheduler_kwargs'])) ) # Update pbar pbar.update(1) # Get Transformer transformer = Wan2_2Transformer3DModel.from_pretrained( os.path.join(model_name, config['transformer_additional_kwargs'].get('transformer_low_noise_model_subpath', 'transformer')), transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), low_cpu_mem_usage=True, torch_dtype=weight_dtype, ) if config['transformer_additional_kwargs'].get('transformer_combination_type', 'single') == "moe": transformer_2 = Wan2_2Transformer3DModel.from_pretrained( os.path.join(model_name, config['transformer_additional_kwargs'].get('transformer_high_noise_model_subpath', 'transformer')), transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), low_cpu_mem_usage=True, torch_dtype=weight_dtype, ) else: transformer_2 = None # Update pbar pbar.update(1) # Get tokenizer and text_encoder tokenizer = AutoTokenizer.from_pretrained( os.path.join(model_name, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), ) pbar.update(1) text_encoder = WanT5EncoderModel.from_pretrained( os.path.join(model_name, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')), additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']), low_cpu_mem_usage=True, torch_dtype=weight_dtype, ) pbar.update(1) # Get pipeline if model_type == "Inpaint": if transformer.config.in_channels != vae.config.latent_channels: pipeline = Wan2_2FunInpaintPipeline( vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, transformer_2=transformer_2, scheduler=scheduler, ) else: pipeline = Wan2_2FunPipeline( vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, transformer_2=transformer_2, scheduler=scheduler, ) else: pipeline = Wan2_2FunControlPipeline( vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, transformer_2=transformer_2, scheduler=scheduler, ) pipeline.remove_all_hooks() undo_convert_weight_dtype_wrapper(transformer) if GPU_memory_mode == "sequential_cpu_offload": replace_parameters_by_name(transformer, ["modulation",], device=device) transformer.freqs = transformer.freqs.to(device=device) if transformer_2 is not None: replace_parameters_by_name(transformer_2, ["modulation",], device=device) transformer_2.freqs = transformer_2.freqs.to(device=device) pipeline.enable_sequential_cpu_offload(device=device) elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",], device=device) convert_weight_dtype_wrapper(transformer, weight_dtype) if transformer_2 is not None: convert_model_weight_to_float8(transformer_2, exclude_module_name=["modulation",], device=device) convert_weight_dtype_wrapper(transformer_2, weight_dtype) pipeline.enable_model_cpu_offload(device=device) elif GPU_memory_mode == "model_cpu_offload": pipeline.enable_model_cpu_offload(device=device) elif GPU_memory_mode == "model_full_load_and_qfloat8": convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",], device=device) convert_weight_dtype_wrapper(transformer, weight_dtype) if transformer_2 is not None: convert_model_weight_to_float8(transformer_2, exclude_module_name=["modulation",], device=device) convert_weight_dtype_wrapper(transformer_2, weight_dtype) pipeline.to(device=device) else: pipeline.to(device=device) funmodels = { 'pipeline': pipeline, 'dtype': weight_dtype, 'model_name': model_name, 'model_type': model_type, 'loras': [], 'strength_model': [], 'config': config, } return (funmodels,) class LoadWan2_2FunLora: @classmethod def INPUT_TYPES(s): return { "required": { "funmodels": ("FunModels",), "lora_name": (folder_paths.get_filename_list("loras"), {"default": None,}), "lora_high_name": (folder_paths.get_filename_list("loras"), {"default": None,}), "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}), "lora_cache":([False, True], {"default": False,}), } } RETURN_TYPES = ("FunModels",) RETURN_NAMES = ("funmodels",) FUNCTION = "load_lora" CATEGORY = "CogVideoXFUNWrapper" def load_lora(self, funmodels, lora_name, lora_high_name, strength_model, lora_cache): new_funmodels = dict(funmodels) if lora_name is not None: loras = list(new_funmodels.get("loras", [])) + [folder_paths.get_full_path("loras", lora_name)] loras_high = list(new_funmodels.get("loras_high", [])) + [folder_paths.get_full_path("loras", lora_high_name)] strength_models = list(new_funmodels.get("strength_model", [])) + [strength_model] new_funmodels['loras'] = loras new_funmodels['loras_high'] = loras_high new_funmodels['strength_model'] = strength_models new_funmodels['lora_cache'] = lora_cache return (new_funmodels,) class Wan2_2FunT2VSampler: @classmethod def INPUT_TYPES(s): return { "required": { "funmodels": ( "FunModels", ), "prompt": ( "STRING_PROMPT", ), "negative_prompt": ( "STRING_PROMPT", ), "video_length": ( "INT", {"default": 81, "min": 5, "max": 161, "step": 4} ), "width": ( "INT", {"default": 832, "min": 64, "max": 2048, "step": 16} ), "height": ( "INT", {"default": 480, "min": 64, "max": 2048, "step": 16} ), "is_image":( [ False, True ], { "default": False, } ), "seed": ( "INT", {"default": 43, "min": 0, "max": 0xffffffffffffffff} ), "steps": ( "INT", {"default": 50, "min": 1, "max": 200, "step": 1} ), "cfg": ( "FLOAT", {"default": 6.0, "min": 1.0, "max": 20.0, "step": 0.01} ), "scheduler": ( ["Flow", "Flow_Unipc", "Flow_DPM++"], { "default": 'Flow' } ), "shift": ( "INT", {"default": 5, "min": 1, "max": 100, "step": 1} ), "boundary": ( "FLOAT", {"default": 0.875, "min": 0.00, "max": 1.00, "step": 0.001} ), "teacache_threshold": ( "FLOAT", {"default": 0.10, "min": 0.00, "max": 1.00, "step": 0.005} ), "enable_teacache":( [False, True], {"default": True,} ), "num_skip_start_steps": ( "INT", {"default": 5, "min": 0, "max": 50, "step": 1} ), "teacache_offload":( [False, True], {"default": True,} ), "cfg_skip_ratio":( "FLOAT", {"default": 0, "min": 0, "max": 1, "step": 0.01} ), }, "optional": { "riflex_k": ("RIFLEXT_ARGS",), }, } RETURN_TYPES = ("IMAGE",) RETURN_NAMES =("images",) FUNCTION = "process" CATEGORY = "CogVideoXFUNWrapper" def process(self, funmodels, prompt, negative_prompt, video_length, width, height, is_image, seed, steps, cfg, scheduler, shift, boundary, teacache_threshold, enable_teacache, num_skip_start_steps, teacache_offload, cfg_skip_ratio, riflex_k=0): global transformer_cpu_cache global transformer_high_cpu_cache global lora_path_before global lora_high_path_before device = mm.get_torch_device() offload_device = mm.unet_offload_device() mm.soft_empty_cache() gc.collect() # Get Pipeline pipeline = funmodels['pipeline'] model_name = funmodels['model_name'] weight_dtype = funmodels['dtype'] # Load Sampler pipeline.scheduler = get_wan_scheduler(scheduler, shift) coefficients = get_teacache_coefficients(model_name) if enable_teacache else None if coefficients is not None: print(f"Enable TeaCache with threshold {teacache_threshold} and skip the first {num_skip_start_steps} steps.") pipeline.transformer.enable_teacache( coefficients, steps, teacache_threshold, num_skip_start_steps=num_skip_start_steps, offload=teacache_offload ) if pipeline.transformer_2 is not None: pipeline.transformer_2.share_teacache(transformer=pipeline.transformer) else: pipeline.transformer.disable_teacache() if pipeline.transformer_2 is not None: pipeline.transformer_2.disable_teacache() if cfg_skip_ratio is not None: print(f"Enable cfg_skip_ratio {cfg_skip_ratio}.") pipeline.transformer.enable_cfg_skip(cfg_skip_ratio, steps) if pipeline.transformer_2 is not None: pipeline.transformer_2.share_cfg_skip(transformer=pipeline.transformer) generator= torch.Generator(device).manual_seed(seed) video_length = 1 if is_image else video_length with torch.no_grad(): video_length = int((video_length - 1) // pipeline.vae.config.temporal_compression_ratio * pipeline.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 if riflex_k > 0: latent_frames = (video_length - 1) // pipeline.vae.config.temporal_compression_ratio + 1 pipeline.transformer.enable_riflex(k = riflex_k, L_test = latent_frames) if pipeline.transformer_2 is not None: pipeline.transformer_2.enable_riflex(k = riflex_k, L_test = latent_frames) input_video, input_video_mask, clip_image = get_image_to_video_latent(None, None, video_length=video_length, sample_size=(height, width)) # Apply lora if funmodels.get("lora_cache", False): if len(funmodels.get("loras", [])) != 0: # Save the original weights to cpu if len(transformer_cpu_cache) == 0: print('Save transformer state_dict to cpu memory') transformer_state_dict = pipeline.transformer.state_dict() for key in transformer_state_dict: transformer_cpu_cache[key] = transformer_state_dict[key].clone().cpu() lora_path_now = str(funmodels.get("loras", []) + funmodels.get("strength_model", [])) if lora_path_now != lora_path_before: print('Merge Lora with Cache') lora_path_before = copy.deepcopy(lora_path_now) pipeline.transformer.load_state_dict(transformer_cpu_cache) for _lora_path, _lora_weight in zip(funmodels.get("loras", []), funmodels.get("strength_model", [])): pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype) if pipeline.transformer_2 is not None: # Save the original weights to cpu if len(transformer_high_cpu_cache) == 0: print('Save transformer high state_dict to cpu memory') transformer_high_state_dict = pipeline.transformer_2.state_dict() for key in transformer_high_state_dict: transformer_high_cpu_cache[key] = transformer_high_state_dict[key].clone().cpu() lora_high_path_now = str(funmodels.get("loras_high", []) + funmodels.get("strength_model", [])) if lora_high_path_now != lora_high_path_before: print('Merge Lora High with Cache') lora_high_path_before = copy.deepcopy(lora_high_path_now) pipeline.transformer_2.load_state_dict(transformer_cpu_cache) for _lora_path, _lora_weight in zip(funmodels.get("loras_high", []), funmodels.get("strength_model", [])): pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") else: print('Merge Lora') # Clear lora when switch from lora_cache=True to lora_cache=False. if len(transformer_cpu_cache) != 0: pipeline.transformer.load_state_dict(transformer_cpu_cache) transformer_cpu_cache = {} lora_path_before = "" gc.collect() for _lora_path, _lora_weight in zip(funmodels.get("loras", []), funmodels.get("strength_model", [])): pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype) # Clear lora when switch from lora_cache=True to lora_cache=False. if pipeline.transformer_2 is not None: if len(transformer_high_cpu_cache) != 0: pipeline.transformer_2.load_state_dict(transformer_high_cpu_cache) transformer_high_cpu_cache = {} lora_high_path_before = "" gc.collect() for _lora_path, _lora_weight in zip(funmodels.get("loras_high", []), funmodels.get("strength_model", [])): pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") sample = pipeline( prompt, num_frames = video_length, negative_prompt = negative_prompt, height = height, width = width, generator = generator, guidance_scale = cfg, num_inference_steps = steps, video = input_video, mask_video = input_video_mask, boundary = boundary, comfyui_progressbar = True, ).videos videos = rearrange(sample, "b c t h w -> (b t) h w c") if not funmodels.get("lora_cache", False): print('Unmerge Lora') for _lora_path, _lora_weight in zip(funmodels.get("loras", []), funmodels.get("strength_model", [])): pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype) if pipeline.transformer_2 is not None: for _lora_path, _lora_weight in zip(funmodels.get("loras_high", []), funmodels.get("strength_model", [])): pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") return (videos,) class Wan2_2FunInpaintSampler: @classmethod def INPUT_TYPES(s): return { "required": { "funmodels": ( "FunModels", ), "prompt": ( "STRING_PROMPT", ), "negative_prompt": ( "STRING_PROMPT", ), "video_length": ( "INT", {"default": 81, "min": 5, "max": 161, "step": 4} ), "base_resolution": ( [ 512, 640, 768, 896, 960, 1024, ], {"default": 640} ), "seed": ( "INT", {"default": 43, "min": 0, "max": 0xffffffffffffffff} ), "steps": ( "INT", {"default": 50, "min": 1, "max": 200, "step": 1} ), "cfg": ( "FLOAT", {"default": 6.0, "min": 1.0, "max": 20.0, "step": 0.01} ), "scheduler": ( ["Flow", "Flow_Unipc", "Flow_DPM++"], { "default": 'Flow' } ), "shift": ( "INT", {"default": 5, "min": 1, "max": 100, "step": 1} ), "boundary": ( "FLOAT", {"default": 0.900, "min": 0.00, "max": 1.00, "step": 0.001} ), "teacache_threshold": ( "FLOAT", {"default": 0.10, "min": 0.00, "max": 1.00, "step": 0.005} ), "enable_teacache":( [False, True], {"default": True,} ), "num_skip_start_steps": ( "INT", {"default": 5, "min": 0, "max": 50, "step": 1} ), "teacache_offload":( [False, True], {"default": True,} ), "cfg_skip_ratio":( "FLOAT", {"default": 0, "min": 0, "max": 1, "step": 0.01} ), }, "optional": { "start_img": ("IMAGE",), "end_img": ("IMAGE",), "riflex_k": ("RIFLEXT_ARGS",), }, } RETURN_TYPES = ("IMAGE",) RETURN_NAMES =("images",) FUNCTION = "process" CATEGORY = "CogVideoXFUNWrapper" def process(self, funmodels, prompt, negative_prompt, video_length, base_resolution, seed, steps, cfg, scheduler, shift, boundary, teacache_threshold, enable_teacache, num_skip_start_steps, teacache_offload, cfg_skip_ratio, start_img=None, end_img=None, riflex_k=0): global transformer_cpu_cache global transformer_high_cpu_cache global lora_path_before global lora_high_path_before device = mm.get_torch_device() offload_device = mm.unet_offload_device() mm.soft_empty_cache() gc.collect() # Get Pipeline pipeline = funmodels['pipeline'] model_name = funmodels['model_name'] weight_dtype = funmodels['dtype'] start_img = [to_pil(_start_img) for _start_img in start_img] if start_img is not None else None end_img = [to_pil(_end_img) for _end_img in end_img] if end_img is not None else None # Count most suitable height and width aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} original_width, original_height = start_img[0].size if type(start_img) is list else Image.open(start_img).size closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size) height, width = [int(x / 16) * 16 for x in closest_size] # Load Sampler pipeline.scheduler = get_wan_scheduler(scheduler, shift) coefficients = get_teacache_coefficients(model_name) if enable_teacache else None if coefficients is not None: print(f"Enable TeaCache with threshold {teacache_threshold} and skip the first {num_skip_start_steps} steps.") pipeline.transformer.enable_teacache( coefficients, steps, teacache_threshold, num_skip_start_steps=num_skip_start_steps, offload=teacache_offload ) if pipeline.transformer_2 is not None: pipeline.transformer_2.share_teacache(transformer=pipeline.transformer) else: pipeline.transformer.disable_teacache() if pipeline.transformer_2 is not None: pipeline.transformer_2.disable_teacache() if cfg_skip_ratio is not None: print(f"Enable cfg_skip_ratio {cfg_skip_ratio}.") pipeline.transformer.enable_cfg_skip(cfg_skip_ratio, steps) if pipeline.transformer_2 is not None: pipeline.transformer_2.share_cfg_skip(transformer=pipeline.transformer) generator= torch.Generator(device).manual_seed(seed) with torch.no_grad(): video_length = int((video_length - 1) // pipeline.vae.config.temporal_compression_ratio * pipeline.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 if riflex_k > 0: latent_frames = (video_length - 1) // pipeline.vae.config.temporal_compression_ratio + 1 pipeline.transformer.enable_riflex(k = riflex_k, L_test = latent_frames) if pipeline.transformer_2 is not None: pipeline.transformer_2.enable_riflex(k = riflex_k, L_test = latent_frames) input_video, input_video_mask, clip_image = get_image_to_video_latent(start_img, end_img, video_length=video_length, sample_size=(height, width)) # Apply lora if funmodels.get("lora_cache", False): if len(funmodels.get("loras", [])) != 0: # Save the original weights to cpu if len(transformer_cpu_cache) == 0: print('Save transformer state_dict to cpu memory') transformer_state_dict = pipeline.transformer.state_dict() for key in transformer_state_dict: transformer_cpu_cache[key] = transformer_state_dict[key].clone().cpu() lora_path_now = str(funmodels.get("loras", []) + funmodels.get("strength_model", [])) if lora_path_now != lora_path_before: print('Merge Lora with Cache') lora_path_before = copy.deepcopy(lora_path_now) pipeline.transformer.load_state_dict(transformer_cpu_cache) for _lora_path, _lora_weight in zip(funmodels.get("loras", []), funmodels.get("strength_model", [])): pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype) if pipeline.transformer_2 is not None: # Save the original weights to cpu if len(transformer_high_cpu_cache) == 0: print('Save transformer high state_dict to cpu memory') transformer_high_state_dict = pipeline.transformer_2.state_dict() for key in transformer_high_state_dict: transformer_high_cpu_cache[key] = transformer_high_state_dict[key].clone().cpu() lora_high_path_now = str(funmodels.get("loras_high", []) + funmodels.get("strength_model", [])) if lora_high_path_now != lora_high_path_before: print('Merge Lora High with Cache') lora_high_path_before = copy.deepcopy(lora_high_path_now) pipeline.transformer_2.load_state_dict(transformer_cpu_cache) for _lora_path, _lora_weight in zip(funmodels.get("loras_high", []), funmodels.get("strength_model", [])): pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") else: print('Merge Lora') # Clear lora when switch from lora_cache=True to lora_cache=False. if len(transformer_cpu_cache) != 0: pipeline.transformer.load_state_dict(transformer_cpu_cache) transformer_cpu_cache = {} lora_path_before = "" gc.collect() for _lora_path, _lora_weight in zip(funmodels.get("loras", []), funmodels.get("strength_model", [])): pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype) # Clear lora when switch from lora_cache=True to lora_cache=False. if pipeline.transformer_2 is not None: if len(transformer_high_cpu_cache) != 0: pipeline.transformer_2.load_state_dict(transformer_high_cpu_cache) transformer_high_cpu_cache = {} lora_high_path_before = "" gc.collect() for _lora_path, _lora_weight in zip(funmodels.get("loras_high", []), funmodels.get("strength_model", [])): pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") sample = pipeline( prompt, num_frames = video_length, negative_prompt = negative_prompt, height = height, width = width, generator = generator, guidance_scale = cfg, num_inference_steps = steps, video = input_video, mask_video = input_video_mask, boundary = boundary, comfyui_progressbar = True, ).videos videos = rearrange(sample, "b c t h w -> (b t) h w c") if not funmodels.get("lora_cache", False): print('Unmerge Lora') for _lora_path, _lora_weight in zip(funmodels.get("loras", []), funmodels.get("strength_model", [])): pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype) if pipeline.transformer_2 is not None: for _lora_path, _lora_weight in zip(funmodels.get("loras_high", []), funmodels.get("strength_model", [])): pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") return (videos,) class Wan2_2FunV2VSampler: @classmethod def INPUT_TYPES(s): return { "required": { "funmodels": ( "FunModels", ), "prompt": ( "STRING_PROMPT", ), "negative_prompt": ( "STRING_PROMPT", ), "video_length": ( "INT", {"default": 81, "min": 1, "max": 161, "step": 4} ), "base_resolution": ( [ 512, 640, 768, 896, 960, 1024, ], {"default": 640} ), "seed": ( "INT", {"default": 43, "min": 0, "max": 0xffffffffffffffff} ), "steps": ( "INT", {"default": 50, "min": 1, "max": 200, "step": 1} ), "cfg": ( "FLOAT", {"default": 6.0, "min": 1.0, "max": 20.0, "step": 0.01} ), "denoise_strength": ( "FLOAT", {"default": 1.00, "min": 0.05, "max": 1.00, "step": 0.01} ), "scheduler": ( ["Flow", "Flow_Unipc", "Flow_DPM++"], { "default": 'Flow' } ), "shift": ( "INT", {"default": 5, "min": 1, "max": 100, "step": 1} ), "boundary": ( "FLOAT", {"default": 0.900, "min": 0.00, "max": 1.00, "step": 0.001} ), "teacache_threshold": ( "FLOAT", {"default": 0.10, "min": 0.00, "max": 1.00, "step": 0.005} ), "enable_teacache":( [False, True], {"default": True,} ), "num_skip_start_steps": ( "INT", {"default": 5, "min": 0, "max": 50, "step": 1} ), "teacache_offload":( [False, True], {"default": True,} ), "cfg_skip_ratio":( "FLOAT", {"default": 0, "min": 0, "max": 1, "step": 0.01} ), }, "optional": { "validation_video": ("IMAGE",), "control_video": ("IMAGE",), "start_image": ("IMAGE",), "end_image": ("IMAGE",), "ref_image": ("IMAGE",), "camera_conditions": ("STRING", {"forceInput": True}), "riflex_k": ("RIFLEXT_ARGS",), }, } RETURN_TYPES = ("IMAGE",) RETURN_NAMES =("images",) FUNCTION = "process" CATEGORY = "CogVideoXFUNWrapper" def process(self, funmodels, prompt, negative_prompt, video_length, base_resolution, seed, steps, cfg, denoise_strength, scheduler, shift, boundary, teacache_threshold, enable_teacache, num_skip_start_steps, teacache_offload, cfg_skip_ratio, validation_video=None, control_video=None, start_image=None, end_image=None, ref_image=None, camera_conditions=None, riflex_k=0): global transformer_cpu_cache global transformer_high_cpu_cache global lora_path_before global lora_high_path_before device = mm.get_torch_device() offload_device = mm.unet_offload_device() mm.soft_empty_cache() gc.collect() # Get Pipeline pipeline = funmodels['pipeline'] model_name = funmodels['model_name'] weight_dtype = funmodels['dtype'] model_type = funmodels['model_type'] # Count most suitable height and width aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} if model_type == "Inpaint": if type(validation_video) is str: original_width, original_height = Image.fromarray(cv2.VideoCapture(validation_video).read()[1]).size else: validation_video = np.array(validation_video.cpu().numpy() * 255, np.uint8) original_width, original_height = Image.fromarray(validation_video[0]).size else: if control_video is not None and type(control_video) is str: original_width, original_height = Image.fromarray(cv2.VideoCapture(control_video).read()[1]).size elif control_video is not None: control_video = np.array(control_video.cpu().numpy() * 255, np.uint8) original_width, original_height = Image.fromarray(control_video[0]).size else: original_width, original_height = 384 / 512 * base_resolution, 672 / 512 * base_resolution if ref_image is not None: ref_image = [to_pil(_ref_image) for _ref_image in ref_image] original_width, original_height = ref_image[0].size if type(ref_image) is list else Image.open(ref_image).size if start_image is not None: start_image = [to_pil(_start_image) for _start_image in start_image] original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size if end_image is not None: end_image = [to_pil(_end_image) for _end_image in end_image] closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size) height, width = [int(x / 16) * 16 for x in closest_size] # Load Sampler pipeline.scheduler = get_wan_scheduler(scheduler, shift) coefficients = get_teacache_coefficients(model_name) if enable_teacache else None if coefficients is not None: print(f"Enable TeaCache with threshold {teacache_threshold} and skip the first {num_skip_start_steps} steps.") pipeline.transformer.enable_teacache( coefficients, steps, teacache_threshold, num_skip_start_steps=num_skip_start_steps, offload=teacache_offload ) if pipeline.transformer_2 is not None: pipeline.transformer_2.share_teacache(transformer=pipeline.transformer) else: pipeline.transformer.disable_teacache() if pipeline.transformer_2 is not None: pipeline.transformer_2.disable_teacache() if cfg_skip_ratio is not None: print(f"Enable cfg_skip_ratio {cfg_skip_ratio}.") pipeline.transformer.enable_cfg_skip(cfg_skip_ratio, steps) if pipeline.transformer_2 is not None: pipeline.transformer_2.share_cfg_skip(transformer=pipeline.transformer) generator= torch.Generator(device).manual_seed(seed) with torch.no_grad(): video_length = int((video_length - 1) // pipeline.vae.config.temporal_compression_ratio * pipeline.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 if riflex_k > 0: latent_frames = (video_length - 1) // pipeline.vae.config.temporal_compression_ratio + 1 pipeline.transformer.enable_riflex(k = riflex_k, L_test = latent_frames) if pipeline.transformer_2 is not None: pipeline.transformer_2.enable_riflex(k = riflex_k, L_test = latent_frames) if model_type == "Inpaint": input_video, input_video_mask, ref_image, clip_image = get_video_to_video_latent(validation_video, video_length=video_length, sample_size=(height, width), fps=16, ref_image=ref_image[0] if ref_image is not None else ref_image) else: if ref_image is not None: clip_image = ref_image[0].convert("RGB") elif start_image is not None: clip_image = start_image[0].convert("RGB") else: clip_image = None inpaint_video, inpaint_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=video_length, sample_size=(height, width)) if ref_image is not None: ref_image = get_image_latent(ref_image[0] if ref_image is not None else ref_image, sample_size=(height, width)) if camera_conditions is not None and len(camera_conditions) > 0: poses = json.loads(camera_conditions) cam_params = np.array([[float(x) for x in pose] for pose in poses]) cam_params = np.concatenate([np.zeros_like(cam_params[:, :1]), cam_params], 1) control_camera_video = process_pose_params(cam_params, width=width, height=height) control_camera_video = control_camera_video[:video_length].permute([3, 0, 1, 2]).unsqueeze(0) input_video, input_video_mask = None, None else: control_camera_video = None input_video, input_video_mask, _, _ = get_video_to_video_latent(control_video, video_length=video_length, sample_size=(height, width), fps=16, ref_image=None) # Apply lora if funmodels.get("lora_cache", False): if len(funmodels.get("loras", [])) != 0: # Save the original weights to cpu if len(transformer_cpu_cache) == 0: print('Save transformer state_dict to cpu memory') transformer_state_dict = pipeline.transformer.state_dict() for key in transformer_state_dict: transformer_cpu_cache[key] = transformer_state_dict[key].clone().cpu() lora_path_now = str(funmodels.get("loras", []) + funmodels.get("strength_model", [])) if lora_path_now != lora_path_before: print('Merge Lora with Cache') lora_path_before = copy.deepcopy(lora_path_now) pipeline.transformer.load_state_dict(transformer_cpu_cache) for _lora_path, _lora_weight in zip(funmodels.get("loras", []), funmodels.get("strength_model", [])): pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype) if pipeline.transformer_2 is not None: # Save the original weights to cpu if len(transformer_high_cpu_cache) == 0: print('Save transformer high state_dict to cpu memory') transformer_high_state_dict = pipeline.transformer_2.state_dict() for key in transformer_high_state_dict: transformer_high_cpu_cache[key] = transformer_high_state_dict[key].clone().cpu() lora_high_path_now = str(funmodels.get("loras_high", []) + funmodels.get("strength_model", [])) if lora_high_path_now != lora_high_path_before: print('Merge Lora High with Cache') lora_high_path_before = copy.deepcopy(lora_high_path_now) pipeline.transformer_2.load_state_dict(transformer_cpu_cache) for _lora_path, _lora_weight in zip(funmodels.get("loras_high", []), funmodels.get("strength_model", [])): pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") else: print('Merge Lora') # Clear lora when switch from lora_cache=True to lora_cache=False. if len(transformer_cpu_cache) != 0: pipeline.transformer.load_state_dict(transformer_cpu_cache) transformer_cpu_cache = {} lora_path_before = "" gc.collect() for _lora_path, _lora_weight in zip(funmodels.get("loras", []), funmodels.get("strength_model", [])): pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype) # Clear lora when switch from lora_cache=True to lora_cache=False. if pipeline.transformer_2 is not None: if len(transformer_high_cpu_cache) != 0: pipeline.transformer_2.load_state_dict(transformer_high_cpu_cache) transformer_high_cpu_cache = {} lora_high_path_before = "" gc.collect() for _lora_path, _lora_weight in zip(funmodels.get("loras_high", []), funmodels.get("strength_model", [])): pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") if model_type == "Inpaint": sample = pipeline( prompt, num_frames = video_length, negative_prompt = negative_prompt, height = height, width = width, generator = generator, guidance_scale = cfg, num_inference_steps = steps, video = input_video, mask_video = input_video_mask, clip_image = clip_image, strength = float(denoise_strength), comfyui_progressbar = True, ).videos else: sample = pipeline( prompt, num_frames = video_length, negative_prompt = negative_prompt, height = height, width = width, generator = generator, guidance_scale = cfg, num_inference_steps = steps, video = inpaint_video, mask_video = inpaint_video_mask, control_video = input_video, control_camera_video = control_camera_video, ref_image = ref_image, boundary = boundary, comfyui_progressbar = True, ).videos videos = rearrange(sample, "b c t h w -> (b t) h w c") if not funmodels.get("lora_cache", False): print('Unmerge Lora') for _lora_path, _lora_weight in zip(funmodels.get("loras", []), funmodels.get("strength_model", [])): pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype) if pipeline.transformer_2 is not None: for _lora_path, _lora_weight in zip(funmodels.get("loras_high", []), funmodels.get("strength_model", [])): pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") return (videos,)