| | """Modified from https://github.com/kijai/ComfyUI-EasyAnimateWrapper/blob/main/nodes.py |
| | """ |
| | import copy |
| | import gc |
| | import json |
| | import os |
| |
|
| | import cv2 |
| | import numpy as np |
| | import torch |
| | from diffusers import FlowMatchEulerDiscreteScheduler |
| | from einops import rearrange |
| | from omegaconf import OmegaConf |
| | from PIL import Image |
| |
|
| | import comfy.model_management as mm |
| | import folder_paths |
| | from comfy.utils import ProgressBar, load_torch_file |
| |
|
| | from ...videox_fun.data.bucket_sampler import (ASPECT_RATIO_512, |
| | get_closest_ratio) |
| | from ...videox_fun.data.dataset_image_video import process_pose_params |
| | from ...videox_fun.models import (AutoencoderKLWan, AutoencoderKLWan3_8, |
| | AutoTokenizer, CLIPModel, |
| | Wan2_2Transformer3DModel, WanT5EncoderModel) |
| | from ...videox_fun.models.cache_utils import get_teacache_coefficients |
| | from ...videox_fun.pipeline import (Wan2_2FunControlPipeline, |
| | Wan2_2FunInpaintPipeline, |
| | Wan2_2FunPipeline) |
| | from ...videox_fun.ui.controller import all_cheduler_dict |
| | from ...videox_fun.utils.fp8_optimization import ( |
| | convert_model_weight_to_float8, convert_weight_dtype_wrapper, undo_convert_weight_dtype_wrapper, |
| | replace_parameters_by_name) |
| | from ...videox_fun.utils.lora_utils import merge_lora, unmerge_lora |
| | from ...videox_fun.utils.utils import (filter_kwargs, get_image_latent, |
| | get_image_to_video_latent, |
| | get_video_to_video_latent, |
| | save_videos_grid) |
| | from ..wan2_1.nodes import get_wan_scheduler |
| | from ..comfyui_utils import (eas_cache_dir, script_directory, |
| | search_model_in_possible_folders, to_pil) |
| |
|
| |
|
| | |
| | transformer_cpu_cache = {} |
| | transformer_high_cpu_cache = {} |
| | |
| | lora_path_before = "" |
| | lora_high_path_before = "" |
| |
|
| |
|
| | class LoadWan2_2FunModel: |
| | @classmethod |
| | def INPUT_TYPES(s): |
| | return { |
| | "required": { |
| | "model": ( |
| | [ |
| | 'Wan2.2-Fun-A14B-InP', |
| | 'Wan2.2-Fun-A14B-Control', |
| | 'Wan2.2-Fun-A14B-Control-Camera', |
| | 'Wan2.2-Fun-5B-InP', |
| | 'Wan2.2-Fun-5B-Control', |
| | 'Wan2.2-Fun-5B-Control-Camera', |
| | ], |
| | { |
| | "default": 'Wan2.2-Fun-A14B-InP', |
| | } |
| | ), |
| | "model_type": ( |
| | ["Inpaint", "Control"], |
| | { |
| | "default": "Inpaint", |
| | } |
| | ), |
| | "GPU_memory_mode":( |
| | ["model_full_load", "model_full_load_and_qfloat8","model_cpu_offload", "model_cpu_offload_and_qfloat8", "sequential_cpu_offload"], |
| | { |
| | "default": "model_cpu_offload", |
| | } |
| | ), |
| | "config": ( |
| | [ |
| | "wan2.2/wan_civitai_i2v.yaml", |
| | "wan2.2/wan_civitai_5b.yaml", |
| | ], |
| | { |
| | "default": "wan2.2/wan_civitai_i2v.yaml", |
| | } |
| | ), |
| | "precision": ( |
| | ['fp16', 'bf16'], |
| | { |
| | "default": 'fp16' |
| | } |
| | ), |
| | }, |
| | } |
| | RETURN_TYPES = ("FunModels",) |
| | RETURN_NAMES = ("funmodels",) |
| | FUNCTION = "loadmodel" |
| | CATEGORY = "CogVideoXFUNWrapper" |
| |
|
| | def loadmodel(self, GPU_memory_mode, model_type, model, precision, config): |
| | |
| | device = mm.get_torch_device() |
| | offload_device = mm.unet_offload_device() |
| | weight_dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] |
| |
|
| | mm.unload_all_models() |
| | mm.cleanup_models() |
| | mm.soft_empty_cache() |
| |
|
| | |
| | pbar = ProgressBar(5) |
| |
|
| | |
| | config_path = f"{script_directory}/config/{config}" |
| | config = OmegaConf.load(config_path) |
| |
|
| | |
| | possible_folders = ["CogVideoX_Fun", "Fun_Models", "VideoX_Fun", "Wan-AI"] + \ |
| | [os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "models/Diffusion_Transformer")] |
| | |
| | model_name = search_model_in_possible_folders(possible_folders, model) |
| |
|
| | Chosen_AutoencoderKL = { |
| | "AutoencoderKLWan": AutoencoderKLWan, |
| | "AutoencoderKLWan3_8": AutoencoderKLWan3_8 |
| | }[config['vae_kwargs'].get('vae_type', 'AutoencoderKLWan')] |
| | vae = Chosen_AutoencoderKL.from_pretrained( |
| | os.path.join(model_name, config['vae_kwargs'].get('vae_subpath', 'vae')), |
| | additional_kwargs=OmegaConf.to_container(config['vae_kwargs']), |
| | ).to(weight_dtype) |
| | |
| | pbar.update(1) |
| |
|
| | |
| | print("Load Sampler.") |
| | scheduler = FlowMatchEulerDiscreteScheduler( |
| | **filter_kwargs(FlowMatchEulerDiscreteScheduler, OmegaConf.to_container(config['scheduler_kwargs'])) |
| | ) |
| | |
| | pbar.update(1) |
| | |
| | |
| | transformer = Wan2_2Transformer3DModel.from_pretrained( |
| | os.path.join(model_name, config['transformer_additional_kwargs'].get('transformer_low_noise_model_subpath', 'transformer')), |
| | transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), |
| | low_cpu_mem_usage=True, |
| | torch_dtype=weight_dtype, |
| | ) |
| |
|
| | if config['transformer_additional_kwargs'].get('transformer_combination_type', 'single') == "moe": |
| | transformer_2 = Wan2_2Transformer3DModel.from_pretrained( |
| | os.path.join(model_name, config['transformer_additional_kwargs'].get('transformer_high_noise_model_subpath', 'transformer')), |
| | transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), |
| | low_cpu_mem_usage=True, |
| | torch_dtype=weight_dtype, |
| | ) |
| | else: |
| | transformer_2 = None |
| | |
| | pbar.update(1) |
| |
|
| | |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | os.path.join(model_name, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), |
| | ) |
| | pbar.update(1) |
| |
|
| | text_encoder = WanT5EncoderModel.from_pretrained( |
| | os.path.join(model_name, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')), |
| | additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']), |
| | low_cpu_mem_usage=True, |
| | torch_dtype=weight_dtype, |
| | ) |
| | pbar.update(1) |
| |
|
| | |
| | if model_type == "Inpaint": |
| | if transformer.config.in_channels != vae.config.latent_channels: |
| | pipeline = Wan2_2FunInpaintPipeline( |
| | vae=vae, |
| | tokenizer=tokenizer, |
| | text_encoder=text_encoder, |
| | transformer=transformer, |
| | transformer_2=transformer_2, |
| | scheduler=scheduler, |
| | ) |
| | else: |
| | pipeline = Wan2_2FunPipeline( |
| | vae=vae, |
| | tokenizer=tokenizer, |
| | text_encoder=text_encoder, |
| | transformer=transformer, |
| | transformer_2=transformer_2, |
| | scheduler=scheduler, |
| | ) |
| | else: |
| | pipeline = Wan2_2FunControlPipeline( |
| | vae=vae, |
| | tokenizer=tokenizer, |
| | text_encoder=text_encoder, |
| | transformer=transformer, |
| | transformer_2=transformer_2, |
| | scheduler=scheduler, |
| | ) |
| |
|
| | pipeline.remove_all_hooks() |
| | undo_convert_weight_dtype_wrapper(transformer) |
| |
|
| | if GPU_memory_mode == "sequential_cpu_offload": |
| | replace_parameters_by_name(transformer, ["modulation",], device=device) |
| | transformer.freqs = transformer.freqs.to(device=device) |
| | if transformer_2 is not None: |
| | replace_parameters_by_name(transformer_2, ["modulation",], device=device) |
| | transformer_2.freqs = transformer_2.freqs.to(device=device) |
| | pipeline.enable_sequential_cpu_offload(device=device) |
| | elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": |
| | convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",], device=device) |
| | convert_weight_dtype_wrapper(transformer, weight_dtype) |
| | if transformer_2 is not None: |
| | convert_model_weight_to_float8(transformer_2, exclude_module_name=["modulation",], device=device) |
| | convert_weight_dtype_wrapper(transformer_2, weight_dtype) |
| | pipeline.enable_model_cpu_offload(device=device) |
| | elif GPU_memory_mode == "model_cpu_offload": |
| | pipeline.enable_model_cpu_offload(device=device) |
| | elif GPU_memory_mode == "model_full_load_and_qfloat8": |
| | convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",], device=device) |
| | convert_weight_dtype_wrapper(transformer, weight_dtype) |
| | if transformer_2 is not None: |
| | convert_model_weight_to_float8(transformer_2, exclude_module_name=["modulation",], device=device) |
| | convert_weight_dtype_wrapper(transformer_2, weight_dtype) |
| | pipeline.to(device=device) |
| | else: |
| | pipeline.to(device=device) |
| |
|
| | funmodels = { |
| | 'pipeline': pipeline, |
| | 'dtype': weight_dtype, |
| | 'model_name': model_name, |
| | 'model_type': model_type, |
| | 'loras': [], |
| | 'strength_model': [], |
| | 'config': config, |
| | } |
| | return (funmodels,) |
| |
|
| | class LoadWan2_2FunLora: |
| | @classmethod |
| | def INPUT_TYPES(s): |
| | return { |
| | "required": { |
| | "funmodels": ("FunModels",), |
| | "lora_name": (folder_paths.get_filename_list("loras"), {"default": None,}), |
| | "lora_high_name": (folder_paths.get_filename_list("loras"), {"default": None,}), |
| | "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}), |
| | "lora_cache":([False, True], {"default": False,}), |
| | } |
| | } |
| | RETURN_TYPES = ("FunModels",) |
| | RETURN_NAMES = ("funmodels",) |
| | FUNCTION = "load_lora" |
| | CATEGORY = "CogVideoXFUNWrapper" |
| |
|
| | def load_lora(self, funmodels, lora_name, lora_high_name, strength_model, lora_cache): |
| | new_funmodels = dict(funmodels) |
| | if lora_name is not None: |
| | loras = list(new_funmodels.get("loras", [])) + [folder_paths.get_full_path("loras", lora_name)] |
| | loras_high = list(new_funmodels.get("loras_high", [])) + [folder_paths.get_full_path("loras", lora_high_name)] |
| | strength_models = list(new_funmodels.get("strength_model", [])) + [strength_model] |
| | new_funmodels['loras'] = loras |
| | new_funmodels['loras_high'] = loras_high |
| | new_funmodels['strength_model'] = strength_models |
| | new_funmodels['lora_cache'] = lora_cache |
| | return (new_funmodels,) |
| |
|
| | class Wan2_2FunT2VSampler: |
| | @classmethod |
| | def INPUT_TYPES(s): |
| | return { |
| | "required": { |
| | "funmodels": ( |
| | "FunModels", |
| | ), |
| | "prompt": ( |
| | "STRING_PROMPT", |
| | ), |
| | "negative_prompt": ( |
| | "STRING_PROMPT", |
| | ), |
| | "video_length": ( |
| | "INT", {"default": 81, "min": 5, "max": 161, "step": 4} |
| | ), |
| | "width": ( |
| | "INT", {"default": 832, "min": 64, "max": 2048, "step": 16} |
| | ), |
| | "height": ( |
| | "INT", {"default": 480, "min": 64, "max": 2048, "step": 16} |
| | ), |
| | "is_image":( |
| | [ |
| | False, |
| | True |
| | ], |
| | { |
| | "default": False, |
| | } |
| | ), |
| | "seed": ( |
| | "INT", {"default": 43, "min": 0, "max": 0xffffffffffffffff} |
| | ), |
| | "steps": ( |
| | "INT", {"default": 50, "min": 1, "max": 200, "step": 1} |
| | ), |
| | "cfg": ( |
| | "FLOAT", {"default": 6.0, "min": 1.0, "max": 20.0, "step": 0.01} |
| | ), |
| | "scheduler": ( |
| | ["Flow", "Flow_Unipc", "Flow_DPM++"], |
| | { |
| | "default": 'Flow' |
| | } |
| | ), |
| | "shift": ( |
| | "INT", {"default": 5, "min": 1, "max": 100, "step": 1} |
| | ), |
| | "boundary": ( |
| | "FLOAT", {"default": 0.875, "min": 0.00, "max": 1.00, "step": 0.001} |
| | ), |
| | "teacache_threshold": ( |
| | "FLOAT", {"default": 0.10, "min": 0.00, "max": 1.00, "step": 0.005} |
| | ), |
| | "enable_teacache":( |
| | [False, True], {"default": True,} |
| | ), |
| | "num_skip_start_steps": ( |
| | "INT", {"default": 5, "min": 0, "max": 50, "step": 1} |
| | ), |
| | "teacache_offload":( |
| | [False, True], {"default": True,} |
| | ), |
| | "cfg_skip_ratio":( |
| | "FLOAT", {"default": 0, "min": 0, "max": 1, "step": 0.01} |
| | ), |
| | }, |
| | "optional": { |
| | "riflex_k": ("RIFLEXT_ARGS",), |
| | }, |
| | } |
| | |
| | RETURN_TYPES = ("IMAGE",) |
| | RETURN_NAMES =("images",) |
| | FUNCTION = "process" |
| | CATEGORY = "CogVideoXFUNWrapper" |
| |
|
| | def process(self, funmodels, prompt, negative_prompt, video_length, width, height, is_image, seed, steps, cfg, scheduler, shift, boundary, teacache_threshold, enable_teacache, num_skip_start_steps, teacache_offload, cfg_skip_ratio, riflex_k=0): |
| | global transformer_cpu_cache |
| | global transformer_high_cpu_cache |
| | global lora_path_before |
| | global lora_high_path_before |
| | device = mm.get_torch_device() |
| | offload_device = mm.unet_offload_device() |
| |
|
| | mm.soft_empty_cache() |
| | gc.collect() |
| |
|
| | |
| | pipeline = funmodels['pipeline'] |
| | model_name = funmodels['model_name'] |
| | weight_dtype = funmodels['dtype'] |
| |
|
| | |
| | pipeline.scheduler = get_wan_scheduler(scheduler, shift) |
| | coefficients = get_teacache_coefficients(model_name) if enable_teacache else None |
| | if coefficients is not None: |
| | print(f"Enable TeaCache with threshold {teacache_threshold} and skip the first {num_skip_start_steps} steps.") |
| | pipeline.transformer.enable_teacache( |
| | coefficients, steps, teacache_threshold, num_skip_start_steps=num_skip_start_steps, offload=teacache_offload |
| | ) |
| | if pipeline.transformer_2 is not None: |
| | pipeline.transformer_2.share_teacache(transformer=pipeline.transformer) |
| | else: |
| | pipeline.transformer.disable_teacache() |
| | if pipeline.transformer_2 is not None: |
| | pipeline.transformer_2.disable_teacache() |
| |
|
| | if cfg_skip_ratio is not None: |
| | print(f"Enable cfg_skip_ratio {cfg_skip_ratio}.") |
| | pipeline.transformer.enable_cfg_skip(cfg_skip_ratio, steps) |
| | if pipeline.transformer_2 is not None: |
| | pipeline.transformer_2.share_cfg_skip(transformer=pipeline.transformer) |
| |
|
| | generator= torch.Generator(device).manual_seed(seed) |
| |
|
| | video_length = 1 if is_image else video_length |
| | with torch.no_grad(): |
| | video_length = int((video_length - 1) // pipeline.vae.config.temporal_compression_ratio * pipeline.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 |
| |
|
| | if riflex_k > 0: |
| | latent_frames = (video_length - 1) // pipeline.vae.config.temporal_compression_ratio + 1 |
| | pipeline.transformer.enable_riflex(k = riflex_k, L_test = latent_frames) |
| | if pipeline.transformer_2 is not None: |
| | pipeline.transformer_2.enable_riflex(k = riflex_k, L_test = latent_frames) |
| |
|
| | input_video, input_video_mask, clip_image = get_image_to_video_latent(None, None, video_length=video_length, sample_size=(height, width)) |
| |
|
| | |
| | if funmodels.get("lora_cache", False): |
| | if len(funmodels.get("loras", [])) != 0: |
| | |
| | if len(transformer_cpu_cache) == 0: |
| | print('Save transformer state_dict to cpu memory') |
| | transformer_state_dict = pipeline.transformer.state_dict() |
| | for key in transformer_state_dict: |
| | transformer_cpu_cache[key] = transformer_state_dict[key].clone().cpu() |
| | |
| | lora_path_now = str(funmodels.get("loras", []) + funmodels.get("strength_model", [])) |
| | if lora_path_now != lora_path_before: |
| | print('Merge Lora with Cache') |
| | lora_path_before = copy.deepcopy(lora_path_now) |
| | pipeline.transformer.load_state_dict(transformer_cpu_cache) |
| | for _lora_path, _lora_weight in zip(funmodels.get("loras", []), funmodels.get("strength_model", [])): |
| | pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype) |
| | |
| | if pipeline.transformer_2 is not None: |
| | |
| | if len(transformer_high_cpu_cache) == 0: |
| | print('Save transformer high state_dict to cpu memory') |
| | transformer_high_state_dict = pipeline.transformer_2.state_dict() |
| | for key in transformer_high_state_dict: |
| | transformer_high_cpu_cache[key] = transformer_high_state_dict[key].clone().cpu() |
| |
|
| | lora_high_path_now = str(funmodels.get("loras_high", []) + funmodels.get("strength_model", [])) |
| | if lora_high_path_now != lora_high_path_before: |
| | print('Merge Lora High with Cache') |
| | lora_high_path_before = copy.deepcopy(lora_high_path_now) |
| | pipeline.transformer_2.load_state_dict(transformer_cpu_cache) |
| | for _lora_path, _lora_weight in zip(funmodels.get("loras_high", []), funmodels.get("strength_model", [])): |
| | pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") |
| | else: |
| | print('Merge Lora') |
| | |
| | if len(transformer_cpu_cache) != 0: |
| | pipeline.transformer.load_state_dict(transformer_cpu_cache) |
| | transformer_cpu_cache = {} |
| | lora_path_before = "" |
| | gc.collect() |
| | |
| | for _lora_path, _lora_weight in zip(funmodels.get("loras", []), funmodels.get("strength_model", [])): |
| | pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype) |
| |
|
| | |
| | if pipeline.transformer_2 is not None: |
| | if len(transformer_high_cpu_cache) != 0: |
| | pipeline.transformer_2.load_state_dict(transformer_high_cpu_cache) |
| | transformer_high_cpu_cache = {} |
| | lora_high_path_before = "" |
| | gc.collect() |
| |
|
| | for _lora_path, _lora_weight in zip(funmodels.get("loras_high", []), funmodels.get("strength_model", [])): |
| | pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") |
| |
|
| | sample = pipeline( |
| | prompt, |
| | num_frames = video_length, |
| | negative_prompt = negative_prompt, |
| | height = height, |
| | width = width, |
| | generator = generator, |
| | guidance_scale = cfg, |
| | num_inference_steps = steps, |
| | |
| | video = input_video, |
| | mask_video = input_video_mask, |
| | boundary = boundary, |
| | comfyui_progressbar = True, |
| | ).videos |
| | videos = rearrange(sample, "b c t h w -> (b t) h w c") |
| |
|
| | if not funmodels.get("lora_cache", False): |
| | print('Unmerge Lora') |
| | for _lora_path, _lora_weight in zip(funmodels.get("loras", []), funmodels.get("strength_model", [])): |
| | pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype) |
| | if pipeline.transformer_2 is not None: |
| | for _lora_path, _lora_weight in zip(funmodels.get("loras_high", []), funmodels.get("strength_model", [])): |
| | pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") |
| | return (videos,) |
| |
|
| |
|
| | class Wan2_2FunInpaintSampler: |
| | @classmethod |
| | def INPUT_TYPES(s): |
| | return { |
| | "required": { |
| | "funmodels": ( |
| | "FunModels", |
| | ), |
| | "prompt": ( |
| | "STRING_PROMPT", |
| | ), |
| | "negative_prompt": ( |
| | "STRING_PROMPT", |
| | ), |
| | "video_length": ( |
| | "INT", {"default": 81, "min": 5, "max": 161, "step": 4} |
| | ), |
| | "base_resolution": ( |
| | [ |
| | 512, |
| | 640, |
| | 768, |
| | 896, |
| | 960, |
| | 1024, |
| | ], {"default": 640} |
| | ), |
| | "seed": ( |
| | "INT", {"default": 43, "min": 0, "max": 0xffffffffffffffff} |
| | ), |
| | "steps": ( |
| | "INT", {"default": 50, "min": 1, "max": 200, "step": 1} |
| | ), |
| | "cfg": ( |
| | "FLOAT", {"default": 6.0, "min": 1.0, "max": 20.0, "step": 0.01} |
| | ), |
| | "scheduler": ( |
| | ["Flow", "Flow_Unipc", "Flow_DPM++"], |
| | { |
| | "default": 'Flow' |
| | } |
| | ), |
| | "shift": ( |
| | "INT", {"default": 5, "min": 1, "max": 100, "step": 1} |
| | ), |
| | "boundary": ( |
| | "FLOAT", {"default": 0.900, "min": 0.00, "max": 1.00, "step": 0.001} |
| | ), |
| | "teacache_threshold": ( |
| | "FLOAT", {"default": 0.10, "min": 0.00, "max": 1.00, "step": 0.005} |
| | ), |
| | "enable_teacache":( |
| | [False, True], {"default": True,} |
| | ), |
| | "num_skip_start_steps": ( |
| | "INT", {"default": 5, "min": 0, "max": 50, "step": 1} |
| | ), |
| | "teacache_offload":( |
| | [False, True], {"default": True,} |
| | ), |
| | "cfg_skip_ratio":( |
| | "FLOAT", {"default": 0, "min": 0, "max": 1, "step": 0.01} |
| | ), |
| | }, |
| | "optional": { |
| | "start_img": ("IMAGE",), |
| | "end_img": ("IMAGE",), |
| | "riflex_k": ("RIFLEXT_ARGS",), |
| | }, |
| | } |
| | |
| | RETURN_TYPES = ("IMAGE",) |
| | RETURN_NAMES =("images",) |
| | FUNCTION = "process" |
| | CATEGORY = "CogVideoXFUNWrapper" |
| |
|
| | def process(self, funmodels, prompt, negative_prompt, video_length, base_resolution, seed, steps, cfg, scheduler, shift, boundary, teacache_threshold, enable_teacache, num_skip_start_steps, teacache_offload, cfg_skip_ratio, start_img=None, end_img=None, riflex_k=0): |
| | global transformer_cpu_cache |
| | global transformer_high_cpu_cache |
| | global lora_path_before |
| | global lora_high_path_before |
| |
|
| | device = mm.get_torch_device() |
| | offload_device = mm.unet_offload_device() |
| |
|
| | mm.soft_empty_cache() |
| | gc.collect() |
| | |
| | |
| | pipeline = funmodels['pipeline'] |
| | model_name = funmodels['model_name'] |
| | weight_dtype = funmodels['dtype'] |
| |
|
| | start_img = [to_pil(_start_img) for _start_img in start_img] if start_img is not None else None |
| | end_img = [to_pil(_end_img) for _end_img in end_img] if end_img is not None else None |
| | |
| | aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} |
| | original_width, original_height = start_img[0].size if type(start_img) is list else Image.open(start_img).size |
| | closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size) |
| | height, width = [int(x / 16) * 16 for x in closest_size] |
| |
|
| | |
| | pipeline.scheduler = get_wan_scheduler(scheduler, shift) |
| | coefficients = get_teacache_coefficients(model_name) if enable_teacache else None |
| | if coefficients is not None: |
| | print(f"Enable TeaCache with threshold {teacache_threshold} and skip the first {num_skip_start_steps} steps.") |
| | pipeline.transformer.enable_teacache( |
| | coefficients, steps, teacache_threshold, num_skip_start_steps=num_skip_start_steps, offload=teacache_offload |
| | ) |
| | if pipeline.transformer_2 is not None: |
| | pipeline.transformer_2.share_teacache(transformer=pipeline.transformer) |
| | else: |
| | pipeline.transformer.disable_teacache() |
| | if pipeline.transformer_2 is not None: |
| | pipeline.transformer_2.disable_teacache() |
| |
|
| | if cfg_skip_ratio is not None: |
| | print(f"Enable cfg_skip_ratio {cfg_skip_ratio}.") |
| | pipeline.transformer.enable_cfg_skip(cfg_skip_ratio, steps) |
| | if pipeline.transformer_2 is not None: |
| | pipeline.transformer_2.share_cfg_skip(transformer=pipeline.transformer) |
| |
|
| | generator= torch.Generator(device).manual_seed(seed) |
| |
|
| | with torch.no_grad(): |
| | video_length = int((video_length - 1) // pipeline.vae.config.temporal_compression_ratio * pipeline.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 |
| |
|
| | if riflex_k > 0: |
| | latent_frames = (video_length - 1) // pipeline.vae.config.temporal_compression_ratio + 1 |
| | pipeline.transformer.enable_riflex(k = riflex_k, L_test = latent_frames) |
| | if pipeline.transformer_2 is not None: |
| | pipeline.transformer_2.enable_riflex(k = riflex_k, L_test = latent_frames) |
| |
|
| | input_video, input_video_mask, clip_image = get_image_to_video_latent(start_img, end_img, video_length=video_length, sample_size=(height, width)) |
| |
|
| | |
| | if funmodels.get("lora_cache", False): |
| | if len(funmodels.get("loras", [])) != 0: |
| | |
| | if len(transformer_cpu_cache) == 0: |
| | print('Save transformer state_dict to cpu memory') |
| | transformer_state_dict = pipeline.transformer.state_dict() |
| | for key in transformer_state_dict: |
| | transformer_cpu_cache[key] = transformer_state_dict[key].clone().cpu() |
| | |
| | lora_path_now = str(funmodels.get("loras", []) + funmodels.get("strength_model", [])) |
| | if lora_path_now != lora_path_before: |
| | print('Merge Lora with Cache') |
| | lora_path_before = copy.deepcopy(lora_path_now) |
| | pipeline.transformer.load_state_dict(transformer_cpu_cache) |
| | for _lora_path, _lora_weight in zip(funmodels.get("loras", []), funmodels.get("strength_model", [])): |
| | pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype) |
| | |
| | if pipeline.transformer_2 is not None: |
| | |
| | if len(transformer_high_cpu_cache) == 0: |
| | print('Save transformer high state_dict to cpu memory') |
| | transformer_high_state_dict = pipeline.transformer_2.state_dict() |
| | for key in transformer_high_state_dict: |
| | transformer_high_cpu_cache[key] = transformer_high_state_dict[key].clone().cpu() |
| |
|
| | lora_high_path_now = str(funmodels.get("loras_high", []) + funmodels.get("strength_model", [])) |
| | if lora_high_path_now != lora_high_path_before: |
| | print('Merge Lora High with Cache') |
| | lora_high_path_before = copy.deepcopy(lora_high_path_now) |
| | pipeline.transformer_2.load_state_dict(transformer_cpu_cache) |
| | for _lora_path, _lora_weight in zip(funmodels.get("loras_high", []), funmodels.get("strength_model", [])): |
| | pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") |
| | else: |
| | print('Merge Lora') |
| | |
| | if len(transformer_cpu_cache) != 0: |
| | pipeline.transformer.load_state_dict(transformer_cpu_cache) |
| | transformer_cpu_cache = {} |
| | lora_path_before = "" |
| | gc.collect() |
| | |
| | for _lora_path, _lora_weight in zip(funmodels.get("loras", []), funmodels.get("strength_model", [])): |
| | pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype) |
| |
|
| | |
| | if pipeline.transformer_2 is not None: |
| | if len(transformer_high_cpu_cache) != 0: |
| | pipeline.transformer_2.load_state_dict(transformer_high_cpu_cache) |
| | transformer_high_cpu_cache = {} |
| | lora_high_path_before = "" |
| | gc.collect() |
| |
|
| | for _lora_path, _lora_weight in zip(funmodels.get("loras_high", []), funmodels.get("strength_model", [])): |
| | pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") |
| |
|
| | sample = pipeline( |
| | prompt, |
| | num_frames = video_length, |
| | negative_prompt = negative_prompt, |
| | height = height, |
| | width = width, |
| | generator = generator, |
| | guidance_scale = cfg, |
| | num_inference_steps = steps, |
| |
|
| | video = input_video, |
| | mask_video = input_video_mask, |
| | boundary = boundary, |
| | comfyui_progressbar = True, |
| | ).videos |
| | videos = rearrange(sample, "b c t h w -> (b t) h w c") |
| |
|
| | if not funmodels.get("lora_cache", False): |
| | print('Unmerge Lora') |
| | for _lora_path, _lora_weight in zip(funmodels.get("loras", []), funmodels.get("strength_model", [])): |
| | pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype) |
| | if pipeline.transformer_2 is not None: |
| | for _lora_path, _lora_weight in zip(funmodels.get("loras_high", []), funmodels.get("strength_model", [])): |
| | pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") |
| | return (videos,) |
| |
|
| |
|
| | class Wan2_2FunV2VSampler: |
| | @classmethod |
| | def INPUT_TYPES(s): |
| | return { |
| | "required": { |
| | "funmodels": ( |
| | "FunModels", |
| | ), |
| | "prompt": ( |
| | "STRING_PROMPT", |
| | ), |
| | "negative_prompt": ( |
| | "STRING_PROMPT", |
| | ), |
| | "video_length": ( |
| | "INT", {"default": 81, "min": 1, "max": 161, "step": 4} |
| | ), |
| | "base_resolution": ( |
| | [ |
| | 512, |
| | 640, |
| | 768, |
| | 896, |
| | 960, |
| | 1024, |
| | ], {"default": 640} |
| | ), |
| | "seed": ( |
| | "INT", {"default": 43, "min": 0, "max": 0xffffffffffffffff} |
| | ), |
| | "steps": ( |
| | "INT", {"default": 50, "min": 1, "max": 200, "step": 1} |
| | ), |
| | "cfg": ( |
| | "FLOAT", {"default": 6.0, "min": 1.0, "max": 20.0, "step": 0.01} |
| | ), |
| | "denoise_strength": ( |
| | "FLOAT", {"default": 1.00, "min": 0.05, "max": 1.00, "step": 0.01} |
| | ), |
| | "scheduler": ( |
| | ["Flow", "Flow_Unipc", "Flow_DPM++"], |
| | { |
| | "default": 'Flow' |
| | } |
| | ), |
| | "shift": ( |
| | "INT", {"default": 5, "min": 1, "max": 100, "step": 1} |
| | ), |
| | "boundary": ( |
| | "FLOAT", {"default": 0.900, "min": 0.00, "max": 1.00, "step": 0.001} |
| | ), |
| | "teacache_threshold": ( |
| | "FLOAT", {"default": 0.10, "min": 0.00, "max": 1.00, "step": 0.005} |
| | ), |
| | "enable_teacache":( |
| | [False, True], {"default": True,} |
| | ), |
| | "num_skip_start_steps": ( |
| | "INT", {"default": 5, "min": 0, "max": 50, "step": 1} |
| | ), |
| | "teacache_offload":( |
| | [False, True], {"default": True,} |
| | ), |
| | "cfg_skip_ratio":( |
| | "FLOAT", {"default": 0, "min": 0, "max": 1, "step": 0.01} |
| | ), |
| | }, |
| | "optional": { |
| | "validation_video": ("IMAGE",), |
| | "control_video": ("IMAGE",), |
| | "start_image": ("IMAGE",), |
| | "end_image": ("IMAGE",), |
| | "ref_image": ("IMAGE",), |
| | "camera_conditions": ("STRING", {"forceInput": True}), |
| | "riflex_k": ("RIFLEXT_ARGS",), |
| | }, |
| | } |
| | |
| | RETURN_TYPES = ("IMAGE",) |
| | RETURN_NAMES =("images",) |
| | FUNCTION = "process" |
| | CATEGORY = "CogVideoXFUNWrapper" |
| |
|
| | def process(self, funmodels, prompt, negative_prompt, video_length, base_resolution, seed, steps, cfg, denoise_strength, scheduler, shift, boundary, teacache_threshold, enable_teacache, num_skip_start_steps, teacache_offload, cfg_skip_ratio, validation_video=None, control_video=None, start_image=None, end_image=None, ref_image=None, camera_conditions=None, riflex_k=0): |
| | global transformer_cpu_cache |
| | global transformer_high_cpu_cache |
| | global lora_path_before |
| | global lora_high_path_before |
| |
|
| | device = mm.get_torch_device() |
| | offload_device = mm.unet_offload_device() |
| |
|
| | mm.soft_empty_cache() |
| | gc.collect() |
| | |
| | |
| | pipeline = funmodels['pipeline'] |
| | model_name = funmodels['model_name'] |
| | weight_dtype = funmodels['dtype'] |
| | model_type = funmodels['model_type'] |
| |
|
| | |
| | aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} |
| | if model_type == "Inpaint": |
| | if type(validation_video) is str: |
| | original_width, original_height = Image.fromarray(cv2.VideoCapture(validation_video).read()[1]).size |
| | else: |
| | validation_video = np.array(validation_video.cpu().numpy() * 255, np.uint8) |
| | original_width, original_height = Image.fromarray(validation_video[0]).size |
| | else: |
| | if control_video is not None and type(control_video) is str: |
| | original_width, original_height = Image.fromarray(cv2.VideoCapture(control_video).read()[1]).size |
| | elif control_video is not None: |
| | control_video = np.array(control_video.cpu().numpy() * 255, np.uint8) |
| | original_width, original_height = Image.fromarray(control_video[0]).size |
| | else: |
| | original_width, original_height = 384 / 512 * base_resolution, 672 / 512 * base_resolution |
| |
|
| | if ref_image is not None: |
| | ref_image = [to_pil(_ref_image) for _ref_image in ref_image] |
| | original_width, original_height = ref_image[0].size if type(ref_image) is list else Image.open(ref_image).size |
| | |
| | if start_image is not None: |
| | start_image = [to_pil(_start_image) for _start_image in start_image] |
| | original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size |
| | |
| | if end_image is not None: |
| | end_image = [to_pil(_end_image) for _end_image in end_image] |
| |
|
| | closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size) |
| | height, width = [int(x / 16) * 16 for x in closest_size] |
| |
|
| | |
| | pipeline.scheduler = get_wan_scheduler(scheduler, shift) |
| | coefficients = get_teacache_coefficients(model_name) if enable_teacache else None |
| | if coefficients is not None: |
| | print(f"Enable TeaCache with threshold {teacache_threshold} and skip the first {num_skip_start_steps} steps.") |
| | pipeline.transformer.enable_teacache( |
| | coefficients, steps, teacache_threshold, num_skip_start_steps=num_skip_start_steps, offload=teacache_offload |
| | ) |
| | if pipeline.transformer_2 is not None: |
| | pipeline.transformer_2.share_teacache(transformer=pipeline.transformer) |
| | else: |
| | pipeline.transformer.disable_teacache() |
| | if pipeline.transformer_2 is not None: |
| | pipeline.transformer_2.disable_teacache() |
| |
|
| | if cfg_skip_ratio is not None: |
| | print(f"Enable cfg_skip_ratio {cfg_skip_ratio}.") |
| | pipeline.transformer.enable_cfg_skip(cfg_skip_ratio, steps) |
| | if pipeline.transformer_2 is not None: |
| | pipeline.transformer_2.share_cfg_skip(transformer=pipeline.transformer) |
| |
|
| | generator= torch.Generator(device).manual_seed(seed) |
| |
|
| | with torch.no_grad(): |
| | video_length = int((video_length - 1) // pipeline.vae.config.temporal_compression_ratio * pipeline.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 |
| |
|
| | if riflex_k > 0: |
| | latent_frames = (video_length - 1) // pipeline.vae.config.temporal_compression_ratio + 1 |
| | pipeline.transformer.enable_riflex(k = riflex_k, L_test = latent_frames) |
| | if pipeline.transformer_2 is not None: |
| | pipeline.transformer_2.enable_riflex(k = riflex_k, L_test = latent_frames) |
| |
|
| | if model_type == "Inpaint": |
| | input_video, input_video_mask, ref_image, clip_image = get_video_to_video_latent(validation_video, video_length=video_length, sample_size=(height, width), fps=16, ref_image=ref_image[0] if ref_image is not None else ref_image) |
| | else: |
| | if ref_image is not None: |
| | clip_image = ref_image[0].convert("RGB") |
| | elif start_image is not None: |
| | clip_image = start_image[0].convert("RGB") |
| | else: |
| | clip_image = None |
| |
|
| | inpaint_video, inpaint_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=video_length, sample_size=(height, width)) |
| |
|
| | if ref_image is not None: |
| | ref_image = get_image_latent(ref_image[0] if ref_image is not None else ref_image, sample_size=(height, width)) |
| | |
| | if camera_conditions is not None and len(camera_conditions) > 0: |
| | poses = json.loads(camera_conditions) |
| | cam_params = np.array([[float(x) for x in pose] for pose in poses]) |
| | cam_params = np.concatenate([np.zeros_like(cam_params[:, :1]), cam_params], 1) |
| | control_camera_video = process_pose_params(cam_params, width=width, height=height) |
| | control_camera_video = control_camera_video[:video_length].permute([3, 0, 1, 2]).unsqueeze(0) |
| | input_video, input_video_mask = None, None |
| | else: |
| | control_camera_video = None |
| | input_video, input_video_mask, _, _ = get_video_to_video_latent(control_video, video_length=video_length, sample_size=(height, width), fps=16, ref_image=None) |
| |
|
| | |
| | if funmodels.get("lora_cache", False): |
| | if len(funmodels.get("loras", [])) != 0: |
| | |
| | if len(transformer_cpu_cache) == 0: |
| | print('Save transformer state_dict to cpu memory') |
| | transformer_state_dict = pipeline.transformer.state_dict() |
| | for key in transformer_state_dict: |
| | transformer_cpu_cache[key] = transformer_state_dict[key].clone().cpu() |
| | |
| | lora_path_now = str(funmodels.get("loras", []) + funmodels.get("strength_model", [])) |
| | if lora_path_now != lora_path_before: |
| | print('Merge Lora with Cache') |
| | lora_path_before = copy.deepcopy(lora_path_now) |
| | pipeline.transformer.load_state_dict(transformer_cpu_cache) |
| | for _lora_path, _lora_weight in zip(funmodels.get("loras", []), funmodels.get("strength_model", [])): |
| | pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype) |
| | |
| | if pipeline.transformer_2 is not None: |
| | |
| | if len(transformer_high_cpu_cache) == 0: |
| | print('Save transformer high state_dict to cpu memory') |
| | transformer_high_state_dict = pipeline.transformer_2.state_dict() |
| | for key in transformer_high_state_dict: |
| | transformer_high_cpu_cache[key] = transformer_high_state_dict[key].clone().cpu() |
| |
|
| | lora_high_path_now = str(funmodels.get("loras_high", []) + funmodels.get("strength_model", [])) |
| | if lora_high_path_now != lora_high_path_before: |
| | print('Merge Lora High with Cache') |
| | lora_high_path_before = copy.deepcopy(lora_high_path_now) |
| | pipeline.transformer_2.load_state_dict(transformer_cpu_cache) |
| | for _lora_path, _lora_weight in zip(funmodels.get("loras_high", []), funmodels.get("strength_model", [])): |
| | pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") |
| | else: |
| | print('Merge Lora') |
| | |
| | if len(transformer_cpu_cache) != 0: |
| | pipeline.transformer.load_state_dict(transformer_cpu_cache) |
| | transformer_cpu_cache = {} |
| | lora_path_before = "" |
| | gc.collect() |
| | |
| | for _lora_path, _lora_weight in zip(funmodels.get("loras", []), funmodels.get("strength_model", [])): |
| | pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype) |
| |
|
| | |
| | if pipeline.transformer_2 is not None: |
| | if len(transformer_high_cpu_cache) != 0: |
| | pipeline.transformer_2.load_state_dict(transformer_high_cpu_cache) |
| | transformer_high_cpu_cache = {} |
| | lora_high_path_before = "" |
| | gc.collect() |
| |
|
| | for _lora_path, _lora_weight in zip(funmodels.get("loras_high", []), funmodels.get("strength_model", [])): |
| | pipeline = merge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") |
| |
|
| | if model_type == "Inpaint": |
| | sample = pipeline( |
| | prompt, |
| | num_frames = video_length, |
| | negative_prompt = negative_prompt, |
| | height = height, |
| | width = width, |
| | generator = generator, |
| | guidance_scale = cfg, |
| | num_inference_steps = steps, |
| |
|
| | video = input_video, |
| | mask_video = input_video_mask, |
| | clip_image = clip_image, |
| | strength = float(denoise_strength), |
| | comfyui_progressbar = True, |
| | ).videos |
| | else: |
| | sample = pipeline( |
| | prompt, |
| | num_frames = video_length, |
| | negative_prompt = negative_prompt, |
| | height = height, |
| | width = width, |
| | generator = generator, |
| | guidance_scale = cfg, |
| | num_inference_steps = steps, |
| |
|
| | video = inpaint_video, |
| | mask_video = inpaint_video_mask, |
| | control_video = input_video, |
| | control_camera_video = control_camera_video, |
| | ref_image = ref_image, |
| | boundary = boundary, |
| | comfyui_progressbar = True, |
| | ).videos |
| | videos = rearrange(sample, "b c t h w -> (b t) h w c") |
| |
|
| | if not funmodels.get("lora_cache", False): |
| | print('Unmerge Lora') |
| | for _lora_path, _lora_weight in zip(funmodels.get("loras", []), funmodels.get("strength_model", [])): |
| | pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype) |
| | if pipeline.transformer_2 is not None: |
| | for _lora_path, _lora_weight in zip(funmodels.get("loras_high", []), funmodels.get("strength_model", [])): |
| | pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") |
| | return (videos,) |