aliensmn's picture
Mirror from https://github.com/kijai/ComfyUI-WanVideoWrapper
cf812a0 verified
import os, gc, math
import torch
import torch.nn.functional as F
import numpy as np
import hashlib
from .wanvideo.schedulers import get_scheduler, scheduler_list
from .utils import(log, clip_encode_image_tiled, add_noise_to_reference_video, set_module_tensor_to_device)
from .taehv import TAEHV
from comfy import model_management as mm
from comfy.utils import ProgressBar, common_upscale
from comfy.clip_vision import clip_preprocess, ClipVisionModel
import folder_paths
script_directory = os.path.dirname(os.path.abspath(__file__))
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
VAE_STRIDE = (4, 8, 8)
PATCH_SIZE = (1, 2, 2)
class WanVideoEnhanceAVideo:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"weight": ("FLOAT", {"default": 2.0, "min": 0, "max": 100, "step": 0.01, "tooltip": "The feta Weight of the Enhance-A-Video"}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage of the steps to apply Enhance-A-Video"}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage of the steps to apply Enhance-A-Video"}),
},
}
RETURN_TYPES = ("FETAARGS",)
RETURN_NAMES = ("feta_args",)
FUNCTION = "setargs"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video"
def setargs(self, **kwargs):
return (kwargs, )
class WanVideoSetBlockSwap:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("WANVIDEOMODEL", ),
},
"optional": {
"block_swap_args": ("BLOCKSWAPARGS", ),
}
}
RETURN_TYPES = ("WANVIDEOMODEL",)
RETURN_NAMES = ("model", )
FUNCTION = "loadmodel"
CATEGORY = "WanVideoWrapper"
def loadmodel(self, model, block_swap_args=None):
if block_swap_args is None:
return (model,)
patcher = model.clone()
if 'transformer_options' not in patcher.model_options:
patcher.model_options['transformer_options'] = {}
patcher.model_options["transformer_options"]["block_swap_args"] = block_swap_args
return (patcher,)
class WanVideoSetRadialAttention:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("WANVIDEOMODEL", ),
"dense_attention_mode": ([
"sdpa",
"flash_attn_2",
"flash_attn_3",
"sageattn",
"sparse_sage_attention",
], {"default": "sageattn", "tooltip": "The attention mode for dense attention"}),
"dense_blocks": ("INT", {"default": 1, "min": 0, "max": 40, "step": 1, "tooltip": "Number of blocks to apply normal attention to"}),
"dense_vace_blocks": ("INT", {"default": 1, "min": 0, "max": 15, "step": 1, "tooltip": "Number of vace blocks to apply normal attention to"}),
"dense_timesteps": ("INT", {"default": 2, "min": 0, "max": 100, "step": 1, "tooltip": "The step to start applying sparse attention"}),
"decay_factor": ("FLOAT", {"default": 0.2, "min": 0, "max": 1, "step": 0.01, "tooltip": "Controls how quickly the attention window shrinks as the distance between frames increases in the sparse attention mask."}),
"block_size":([128, 64], {"default": 128, "tooltip": "Radial attention block size, larger blocks are faster but restricts usable dimensions more."}),
}
}
RETURN_TYPES = ("WANVIDEOMODEL",)
RETURN_NAMES = ("model", )
FUNCTION = "loadmodel"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Sets radial attention parameters, dense attention refers to normal attention"
def loadmodel(self, model, dense_attention_mode, dense_blocks, dense_vace_blocks, dense_timesteps, decay_factor, block_size):
if "radial" not in model.model.diffusion_model.attention_mode:
raise Exception("Enable radial attention first in the model loader.")
patcher = model.clone()
if 'transformer_options' not in patcher.model_options:
patcher.model_options['transformer_options'] = {}
patcher.model_options["transformer_options"]["dense_attention_mode"] = dense_attention_mode
patcher.model_options["transformer_options"]["dense_blocks"] = dense_blocks
patcher.model_options["transformer_options"]["dense_vace_blocks"] = dense_vace_blocks
patcher.model_options["transformer_options"]["dense_timesteps"] = dense_timesteps
patcher.model_options["transformer_options"]["decay_factor"] = decay_factor
patcher.model_options["transformer_options"]["block_size"] = block_size
return (patcher,)
class WanVideoBlockList:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"blocks": ("STRING", {"default": "1", "multiline":True}),
}
}
RETURN_TYPES = ("INT",)
RETURN_NAMES = ("block_list", )
FUNCTION = "create_list"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Comma separated list of blocks to apply block swap to, can also use ranges like '0-5' or '0,2,3-5' etc., can be connected to the dense_blocks input of 'WanVideoSetRadialAttention' node"
def create_list(self, blocks):
block_list = []
for line in blocks.splitlines():
for part in line.split(","):
part = part.strip()
if not part:
continue
if "-" in part:
try:
start, end = map(int, part.split("-", 1))
block_list.extend(range(start, end + 1))
except Exception:
raise ValueError(f"Invalid range: '{part}'")
else:
try:
block_list.append(int(part))
except Exception:
raise ValueError(f"Invalid integer: '{part}'")
return (block_list,)
# In-memory cache for prompt extender output
_extender_cache = {}
cache_dir = os.path.join(script_directory, 'text_embed_cache')
def get_cache_path(prompt):
cache_key = prompt.strip()
cache_hash = hashlib.sha256(cache_key.encode('utf-8')).hexdigest()
return os.path.join(cache_dir, f"{cache_hash}.pt")
def get_cached_text_embeds(positive_prompt, negative_prompt):
os.makedirs(cache_dir, exist_ok=True)
context = None
context_null = None
pos_cache_path = get_cache_path(positive_prompt)
neg_cache_path = get_cache_path(negative_prompt)
# Try to load positive prompt embeds
if os.path.exists(pos_cache_path):
try:
log.info(f"Loading prompt embeds from cache: {pos_cache_path}")
context = torch.load(pos_cache_path)
except Exception as e:
log.warning(f"Failed to load cache: {e}, will re-encode.")
# Try to load negative prompt embeds
if os.path.exists(neg_cache_path):
try:
log.info(f"Loading prompt embeds from cache: {neg_cache_path}")
context_null = torch.load(neg_cache_path)
except Exception as e:
log.warning(f"Failed to load cache: {e}, will re-encode.")
return context, context_null
class WanVideoTextEncodeCached:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model_name": (folder_paths.get_filename_list("text_encoders"), {"tooltip": "These models are loaded from 'ComfyUI/models/text_encoders'"}),
"precision": (["fp32", "bf16"],
{"default": "bf16"}
),
"positive_prompt": ("STRING", {"default": "", "multiline": True} ),
"negative_prompt": ("STRING", {"default": "", "multiline": True} ),
"quantization": (['disabled', 'fp8_e4m3fn'], {"default": 'disabled', "tooltip": "optional quantization method"}),
"use_disk_cache": ("BOOLEAN", {"default": True, "tooltip": "Cache the text embeddings to disk for faster re-use, under the custom_nodes/ComfyUI-WanVideoWrapper/text_embed_cache directory"}),
"device": (["gpu", "cpu"], {"default": "gpu", "tooltip": "Device to run the text encoding on."}),
},
"optional": {
"extender_args": ("WANVIDEOPROMPTEXTENDER_ARGS", {"tooltip": "Use this node to extend the prompt with additional text."}),
}
}
RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", "WANVIDEOTEXTEMBEDS", "STRING")
RETURN_NAMES = ("text_embeds", "negative_text_embeds", "positive_prompt")
OUTPUT_TOOLTIPS = ("The text embeddings for both prompts", "The text embeddings for the negative prompt only (for NAG)", "Positive prompt to display prompt extender results")
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = """Encodes text prompts into text embeddings. This node loads and completely unloads the T5 after done,
leaving no VRAM or RAM imprint. If prompts have been cached before T5 is not loaded at all.
negative output is meant to be used with NAG, it contains only negative prompt embeddings.
Additionally you can provide a Qwen LLM model to extend the positive prompt with either one
of the original Wan templates or a custom system prompt.
"""
def process(self, model_name, precision, positive_prompt, negative_prompt, quantization='disabled', use_disk_cache=True, device="gpu", extender_args=None):
from .nodes_model_loading import LoadWanVideoT5TextEncoder
pbar = ProgressBar(3)
echoshot = True if "[1]" in positive_prompt else False
# Handle prompt extension with in-memory cache
orig_prompt = positive_prompt
if extender_args is not None:
extender_key = (orig_prompt, str(extender_args))
if extender_key in _extender_cache:
positive_prompt = _extender_cache[extender_key]
log.info(f"Loaded extended prompt from in-memory cache: {positive_prompt}")
else:
from .qwen.qwen import QwenLoader, WanVideoPromptExtender
log.info("Using WanVideoPromptExtender to process prompts")
qwen, = QwenLoader().load(
extender_args["model"],
load_device="main_device" if device == "gpu" else "cpu",
precision=precision)
positive_prompt, = WanVideoPromptExtender().generate(
qwen=qwen,
max_new_tokens=extender_args["max_new_tokens"],
prompt=orig_prompt,
device=device,
force_offload=False,
custom_system_prompt=extender_args["system_prompt"],
seed=extender_args["seed"]
)
log.info(f"Extended positive prompt: {positive_prompt}")
_extender_cache[extender_key] = positive_prompt
del qwen
pbar.update(1)
# Now check disk cache using the (possibly extended) prompt
if use_disk_cache:
context, context_null = get_cached_text_embeds(positive_prompt, negative_prompt)
if context is not None and context_null is not None:
return{
"prompt_embeds": context,
"negative_prompt_embeds": context_null,
"echoshot": echoshot,
},{"prompt_embeds": context_null}, positive_prompt
t5, = LoadWanVideoT5TextEncoder().loadmodel(model_name, precision, "main_device", quantization)
pbar.update(1)
prompt_embeds_dict, = WanVideoTextEncode().process(
positive_prompt=positive_prompt,
negative_prompt=negative_prompt,
t5=t5,
force_offload=False,
model_to_offload=None,
use_disk_cache=use_disk_cache,
device=device
)
pbar.update(1)
del t5
mm.soft_empty_cache()
gc.collect()
return (prompt_embeds_dict, {"prompt_embeds": prompt_embeds_dict["negative_prompt_embeds"]}, positive_prompt)
#region TextEncode
class WanVideoTextEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"positive_prompt": ("STRING", {"default": "", "multiline": True} ),
"negative_prompt": ("STRING", {"default": "", "multiline": True} ),
},
"optional": {
"t5": ("WANTEXTENCODER",),
"force_offload": ("BOOLEAN", {"default": True}),
"model_to_offload": ("WANVIDEOMODEL", {"tooltip": "Model to move to offload_device before encoding"}),
"use_disk_cache": ("BOOLEAN", {"default": False, "tooltip": "Cache the text embeddings to disk for faster re-use, under the custom_nodes/ComfyUI-WanVideoWrapper/text_embed_cache directory"}),
"device": (["gpu", "cpu"], {"default": "gpu", "tooltip": "Device to run the text encoding on."}),
}
}
RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", )
RETURN_NAMES = ("text_embeds",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Encodes text prompts into text embeddings. For rudimentary prompt travel you can input multiple prompts separated by '|', they will be equally spread over the video length"
def process(self, positive_prompt, negative_prompt, t5=None, force_offload=True, model_to_offload=None, use_disk_cache=False, device="gpu"):
if t5 is None and not use_disk_cache:
raise ValueError("T5 encoder is required for text encoding. Please provide a valid T5 encoder or enable disk cache.")
echoshot = True if "[1]" in positive_prompt else False
if use_disk_cache:
context, context_null = get_cached_text_embeds(positive_prompt, negative_prompt)
if context is not None and context_null is not None:
return{
"prompt_embeds": context,
"negative_prompt_embeds": context_null,
"echoshot": echoshot,
},
if t5 is None:
raise ValueError("No cached text embeds found for prompts, please provide a T5 encoder.")
if model_to_offload is not None and device == "gpu":
try:
log.info(f"Moving video model to {offload_device}")
model_to_offload.model.to(offload_device)
except:
pass
encoder = t5["model"]
dtype = t5["dtype"]
positive_prompts = []
all_weights = []
# Split positive prompts and process each with weights
if "|" in positive_prompt:
log.info("Multiple positive prompts detected, splitting by '|'")
positive_prompts_raw = [p.strip() for p in positive_prompt.split('|')]
elif "[1]" in positive_prompt:
log.info("Multiple positive prompts detected, splitting by [#] and enabling EchoShot")
import re
segments = re.split(r'\[\d+\]', positive_prompt)
positive_prompts_raw = [segment.strip() for segment in segments if segment.strip()]
assert len(positive_prompts_raw) > 1 and len(positive_prompts_raw) < 7, 'Input shot num must between 2~6 !'
else:
positive_prompts_raw = [positive_prompt.strip()]
for p in positive_prompts_raw:
cleaned_prompt, weights = self.parse_prompt_weights(p)
positive_prompts.append(cleaned_prompt)
all_weights.append(weights)
mm.soft_empty_cache()
if device == "gpu":
device_to = mm.get_torch_device()
else:
device_to = torch.device("cpu")
if encoder.quantization == "fp8_e4m3fn":
cast_dtype = torch.float8_e4m3fn
else:
cast_dtype = encoder.dtype
params_to_keep = {'norm', 'pos_embedding', 'token_embedding'}
for name, param in encoder.model.named_parameters():
dtype_to_use = dtype if any(keyword in name for keyword in params_to_keep) else cast_dtype
value = encoder.state_dict[name] if hasattr(encoder, 'state_dict') else encoder.model.state_dict()[name]
set_module_tensor_to_device(encoder.model, name, device=device_to, dtype=dtype_to_use, value=value)
if hasattr(encoder, 'state_dict'):
del encoder.state_dict
mm.soft_empty_cache()
gc.collect()
with torch.autocast(device_type=mm.get_autocast_device(device_to), dtype=encoder.dtype, enabled=encoder.quantization != 'disabled'):
# Encode positive if not loaded from cache
if use_disk_cache and context is not None:
pass
else:
context = encoder(positive_prompts, device_to)
# Apply weights to embeddings if any were extracted
for i, weights in enumerate(all_weights):
for text, weight in weights.items():
log.info(f"Applying weight {weight} to prompt: {text}")
if len(weights) > 0:
context[i] = context[i] * weight
# Encode negative if not loaded from cache
if use_disk_cache and context_null is not None:
pass
else:
context_null = encoder([negative_prompt], device_to)
if force_offload:
encoder.model.to(offload_device)
mm.soft_empty_cache()
gc.collect()
prompt_embeds_dict = {
"prompt_embeds": context,
"negative_prompt_embeds": context_null,
"echoshot": echoshot,
}
# Save each part to its own cache file if needed
if use_disk_cache:
pos_cache_path = get_cache_path(positive_prompt)
neg_cache_path = get_cache_path(negative_prompt)
try:
if not os.path.exists(pos_cache_path):
torch.save(context, pos_cache_path)
log.info(f"Saved prompt embeds to cache: {pos_cache_path}")
except Exception as e:
log.warning(f"Failed to save cache: {e}")
try:
if not os.path.exists(neg_cache_path):
torch.save(context_null, neg_cache_path)
log.info(f"Saved prompt embeds to cache: {neg_cache_path}")
except Exception as e:
log.warning(f"Failed to save cache: {e}")
return (prompt_embeds_dict,)
def parse_prompt_weights(self, prompt):
"""Extract text and weights from prompts with (text:weight) format"""
import re
# Parse all instances of (text:weight) in the prompt
pattern = r'\((.*?):([\d\.]+)\)'
matches = re.findall(pattern, prompt)
# Replace each match with just the text part
cleaned_prompt = prompt
weights = {}
for match in matches:
text, weight = match
orig_text = f"({text}:{weight})"
cleaned_prompt = cleaned_prompt.replace(orig_text, text)
weights[text] = float(weight)
return cleaned_prompt, weights
class WanVideoTextEncodeSingle:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"prompt": ("STRING", {"default": "", "multiline": True} ),
},
"optional": {
"t5": ("WANTEXTENCODER",),
"force_offload": ("BOOLEAN", {"default": True}),
"model_to_offload": ("WANVIDEOMODEL", {"tooltip": "Model to move to offload_device before encoding"}),
"use_disk_cache": ("BOOLEAN", {"default": False, "tooltip": "Cache the text embeddings to disk for faster re-use, under the custom_nodes/ComfyUI-WanVideoWrapper/text_embed_cache directory"}),
"device": (["gpu", "cpu"], {"default": "gpu", "tooltip": "Device to run the text encoding on."}),
}
}
RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", )
RETURN_NAMES = ("text_embeds",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Encodes text prompt into text embedding."
def process(self, prompt, t5=None, force_offload=True, model_to_offload=None, use_disk_cache=False, device="gpu"):
# Unified cache logic: use a single cache file per unique prompt
encoded = None
echoshot = True if "[1]" in prompt else False
if use_disk_cache:
cache_dir = os.path.join(script_directory, 'text_embed_cache')
os.makedirs(cache_dir, exist_ok=True)
def get_cache_path(prompt):
cache_key = prompt.strip()
cache_hash = hashlib.sha256(cache_key.encode('utf-8')).hexdigest()
return os.path.join(cache_dir, f"{cache_hash}.pt")
cache_path = get_cache_path(prompt)
if os.path.exists(cache_path):
try:
log.info(f"Loading prompt embeds from cache: {cache_path}")
encoded = torch.load(cache_path)
except Exception as e:
log.warning(f"Failed to load cache: {e}, will re-encode.")
if t5 is None and encoded is None:
raise ValueError("No cached text embeds found for prompts, please provide a T5 encoder.")
if encoded is None:
try:
if model_to_offload is not None and device == "gpu":
log.info(f"Moving video model to {offload_device}")
model_to_offload.model.to(offload_device)
mm.soft_empty_cache()
except:
pass
encoder = t5["model"]
dtype = t5["dtype"]
if device == "gpu":
device_to = mm.get_torch_device()
else:
device_to = torch.device("cpu")
if encoder.quantization == "fp8_e4m3fn":
cast_dtype = torch.float8_e4m3fn
else:
cast_dtype = encoder.dtype
params_to_keep = {'norm', 'pos_embedding', 'token_embedding'}
for name, param in encoder.model.named_parameters():
dtype_to_use = dtype if any(keyword in name for keyword in params_to_keep) else cast_dtype
value = encoder.state_dict[name] if hasattr(encoder, 'state_dict') else encoder.model.state_dict()[name]
set_module_tensor_to_device(encoder.model, name, device=device_to, dtype=dtype_to_use, value=value)
if hasattr(encoder, 'state_dict'):
del encoder.state_dict
mm.soft_empty_cache()
gc.collect()
with torch.autocast(device_type=mm.get_autocast_device(device_to), dtype=encoder.dtype, enabled=encoder.quantization != 'disabled'):
encoded = encoder([prompt], device_to)
if force_offload:
encoder.model.to(offload_device)
mm.soft_empty_cache()
# Save to cache if enabled
if use_disk_cache:
try:
if not os.path.exists(cache_path):
torch.save(encoded, cache_path)
log.info(f"Saved prompt embeds to cache: {cache_path}")
except Exception as e:
log.warning(f"Failed to save cache: {e}")
prompt_embeds_dict = {
"prompt_embeds": encoded,
"negative_prompt_embeds": None,
"echoshot": echoshot
}
return (prompt_embeds_dict,)
class WanVideoApplyNAG:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"original_text_embeds": ("WANVIDEOTEXTEMBEDS",),
"nag_text_embeds": ("WANVIDEOTEXTEMBEDS",),
"nag_scale": ("FLOAT", {"default": 11.0, "min": 0.0, "max": 100.0, "step": 0.1}),
"nag_tau": ("FLOAT", {"default": 2.5, "min": 0.0, "max": 10.0, "step": 0.1}),
"nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}),
},
}
RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", )
RETURN_NAMES = ("text_embeds",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Adds NAG prompt embeds to original prompt embeds: 'https://github.com/ChenDarYen/Normalized-Attention-Guidance'"
def process(self, original_text_embeds, nag_text_embeds, nag_scale, nag_tau, nag_alpha):
prompt_embeds_dict_copy = original_text_embeds.copy()
prompt_embeds_dict_copy.update({
"nag_prompt_embeds": nag_text_embeds["prompt_embeds"],
"nag_params": {
"nag_scale": nag_scale,
"nag_tau": nag_tau,
"nag_alpha": nag_alpha,
}
})
return (prompt_embeds_dict_copy,)
class WanVideoTextEmbedBridge:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"positive": ("CONDITIONING",),
},
"optional": {
"negative": ("CONDITIONING",),
}
}
RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", )
RETURN_NAMES = ("text_embeds",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Bridge between ComfyUI native text embedding and WanVideoWrapper text embedding"
def process(self, positive, negative=None):
prompt_embeds_dict = {
"prompt_embeds": positive[0][0].to(device),
"negative_prompt_embeds": negative[0][0].to(device) if negative is not None else None,
}
return (prompt_embeds_dict,)
#region clip vision
class WanVideoClipVisionEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip_vision": ("CLIP_VISION",),
"image_1": ("IMAGE", {"tooltip": "Image to encode"}),
"strength_1": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional clip embed multiplier"}),
"strength_2": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional clip embed multiplier"}),
"crop": (["center", "disabled"], {"default": "center", "tooltip": "Crop image to 224x224 before encoding"}),
"combine_embeds": (["average", "sum", "concat", "batch"], {"default": "average", "tooltip": "Method to combine multiple clip embeds"}),
"force_offload": ("BOOLEAN", {"default": True}),
},
"optional": {
"image_2": ("IMAGE", ),
"negative_image": ("IMAGE", {"tooltip": "image to use for uncond"}),
"tiles": ("INT", {"default": 0, "min": 0, "max": 16, "step": 2, "tooltip": "Use matteo's tiled image encoding for improved accuracy"}),
"ratio": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Ratio of the tile average"}),
}
}
RETURN_TYPES = ("WANVIDIMAGE_CLIPEMBEDS",)
RETURN_NAMES = ("image_embeds",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
def process(self, clip_vision, image_1, strength_1, strength_2, force_offload, crop, combine_embeds, image_2=None, negative_image=None, tiles=0, ratio=1.0):
image_mean = [0.48145466, 0.4578275, 0.40821073]
image_std = [0.26862954, 0.26130258, 0.27577711]
if image_2 is not None:
image = torch.cat([image_1, image_2], dim=0)
else:
image = image_1
clip_vision.model.to(device)
negative_clip_embeds = None
if tiles > 0:
log.info("Using tiled image encoding")
clip_embeds = clip_encode_image_tiled(clip_vision, image.to(device), tiles=tiles, ratio=ratio)
if negative_image is not None:
negative_clip_embeds = clip_encode_image_tiled(clip_vision, negative_image.to(device), tiles=tiles, ratio=ratio)
else:
if isinstance(clip_vision, ClipVisionModel):
clip_embeds = clip_vision.encode_image(image).penultimate_hidden_states.to(device)
if negative_image is not None:
negative_clip_embeds = clip_vision.encode_image(negative_image).penultimate_hidden_states.to(device)
else:
pixel_values = clip_preprocess(image.to(device), size=224, mean=image_mean, std=image_std, crop=(not crop == "disabled")).float()
clip_embeds = clip_vision.visual(pixel_values)
if negative_image is not None:
pixel_values = clip_preprocess(negative_image.to(device), size=224, mean=image_mean, std=image_std, crop=(not crop == "disabled")).float()
negative_clip_embeds = clip_vision.visual(pixel_values)
log.info(f"Clip embeds shape: {clip_embeds.shape}, dtype: {clip_embeds.dtype}")
weighted_embeds = []
weighted_embeds.append(clip_embeds[0:1] * strength_1)
# Handle all additional embeddings
if clip_embeds.shape[0] > 1:
weighted_embeds.append(clip_embeds[1:2] * strength_2)
if clip_embeds.shape[0] > 2:
for i in range(2, clip_embeds.shape[0]):
weighted_embeds.append(clip_embeds[i:i+1]) # Add as-is without strength modifier
# Combine all weighted embeddings
if combine_embeds == "average":
clip_embeds = torch.mean(torch.stack(weighted_embeds), dim=0)
elif combine_embeds == "sum":
clip_embeds = torch.sum(torch.stack(weighted_embeds), dim=0)
elif combine_embeds == "concat":
clip_embeds = torch.cat(weighted_embeds, dim=1)
elif combine_embeds == "batch":
clip_embeds = torch.cat(weighted_embeds, dim=0)
else:
clip_embeds = weighted_embeds[0]
log.info(f"Combined clip embeds shape: {clip_embeds.shape}")
if force_offload:
clip_vision.model.to(offload_device)
mm.soft_empty_cache()
clip_embeds_dict = {
"clip_embeds": clip_embeds,
"negative_clip_embeds": negative_clip_embeds
}
return (clip_embeds_dict,)
class WanVideoRealisDanceLatents:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"ref_latent": ("LATENT", {"tooltip": "Reference image to encode"}),
"pose_cond_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the SMPL model"}),
"pose_cond_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the SMPL model"}),
},
"optional": {
"smpl_latent": ("LATENT", {"tooltip": "SMPL pose image to encode"}),
"hamer_latent": ("LATENT", {"tooltip": "Hamer hand pose image to encode"}),
},
}
RETURN_TYPES = ("ADD_COND_LATENTS",)
RETURN_NAMES = ("add_cond_latents",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
def process(self, ref_latent, pose_cond_start_percent, pose_cond_end_percent, hamer_latent=None, smpl_latent=None):
if smpl_latent is None and hamer_latent is None:
raise Exception("At least one of smpl_latent or hamer_latent must be provided")
if smpl_latent is None:
smpl = torch.zeros_like(hamer_latent["samples"])
else:
smpl = smpl_latent["samples"]
if hamer_latent is None:
hamer = torch.zeros_like(smpl_latent["samples"])
else:
hamer = hamer_latent["samples"]
pose_latent = torch.cat((smpl, hamer), dim=1)
add_cond_latents = {
"ref_latent": ref_latent["samples"],
"pose_latent": pose_latent,
"pose_cond_start_percent": pose_cond_start_percent,
"pose_cond_end_percent": pose_cond_end_percent,
}
return (add_cond_latents,)
class WanVideoAddStandInLatent:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"embeds": ("WANVIDIMAGE_EMBEDS",),
"ip_image_latent": ("LATENT", {"tooltip": "Reference image to encode"}),
"freq_offset": ("INT", {"default": 1, "min": 0, "max": 100, "step": 1, "tooltip": "EXPERIMENTAL: RoPE frequency offset between the reference and rest of the sequence"}),
#"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent to apply the ref "}),
#"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent to apply the ref "}),
}
}
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",)
RETURN_NAMES = ("image_embeds",)
FUNCTION = "add"
CATEGORY = "WanVideoWrapper"
def add(self, embeds, ip_image_latent, freq_offset):
# Prepare the new extra latent entry
new_entry = {
"ip_image_latent": ip_image_latent["samples"],
"freq_offset": freq_offset,
#"ip_start_percent": start_percent,
#"ip_end_percent": end_percent,
}
# Return a new dict with updated extra_latents
updated = dict(embeds)
updated["standin_input"] = new_entry
return (updated,)
class WanVideoAddMTVMotion:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"embeds": ("WANVIDIMAGE_EMBEDS",),
"mtv_crafter_motion": ("MTVCRAFTERMOTION",),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "Strength of the MTV motion"}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent to apply the ref "}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent to apply the ref "}),
}
}
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",)
RETURN_NAMES = ("image_embeds",)
FUNCTION = "add"
CATEGORY = "WanVideoWrapper"
def add(self, embeds, mtv_crafter_motion, strength, start_percent, end_percent):
# Prepare the new extra latent entry
new_entry = {
"mtv_motion_tokens": mtv_crafter_motion["mtv_motion_tokens"],
"strength": strength,
"start_percent": start_percent,
"end_percent": end_percent,
"global_mean": mtv_crafter_motion["global_mean"],
"global_std": mtv_crafter_motion["global_std"]
}
# Return a new dict with updated extra_latents
updated = dict(embeds)
updated["mtv_crafter_motion"] = new_entry
return (updated,)
#region I2V encode
class WanVideoImageToVideoEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"width": ("INT", {"default": 832, "min": 64, "max": 8096, "step": 8, "tooltip": "Width of the image to encode"}),
"height": ("INT", {"default": 480, "min": 64, "max": 8096, "step": 8, "tooltip": "Height of the image to encode"}),
"num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}),
"noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Strength of noise augmentation, helpful for I2V where some noise can add motion and give sharper results"}),
"start_latent_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional latent multiplier, helpful for I2V where lower values allow for more motion"}),
"end_latent_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional latent multiplier, helpful for I2V where lower values allow for more motion"}),
"force_offload": ("BOOLEAN", {"default": True}),
},
"optional": {
"vae": ("WANVAE",),
"clip_embeds": ("WANVIDIMAGE_CLIPEMBEDS", {"tooltip": "Clip vision encoded image"}),
"start_image": ("IMAGE", {"tooltip": "Image to encode"}),
"end_image": ("IMAGE", {"tooltip": "end frame"}),
"control_embeds": ("WANVIDIMAGE_EMBEDS", {"tooltip": "Control signal for the Fun -model"}),
"fun_or_fl2v_model": ("BOOLEAN", {"default": True, "tooltip": "Enable when using official FLF2V or Fun model"}),
"temporal_mask": ("MASK", {"tooltip": "mask"}),
"extra_latents": ("LATENT", {"tooltip": "Extra latents to add to the input front, used for Skyreels A2 reference images"}),
"tiled_vae": ("BOOLEAN", {"default": False, "tooltip": "Use tiled VAE encoding for reduced memory use"}),
"add_cond_latents": ("ADD_COND_LATENTS", {"advanced": True, "tooltip": "Additional cond latents WIP"}),
}
}
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",)
RETURN_NAMES = ("image_embeds",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
def process(self, width, height, num_frames, force_offload, noise_aug_strength,
start_latent_strength, end_latent_strength, start_image=None, end_image=None, control_embeds=None, fun_or_fl2v_model=False,
temporal_mask=None, extra_latents=None, clip_embeds=None, tiled_vae=False, add_cond_latents=None, vae=None):
if start_image is None and end_image is None and add_cond_latents is None:
return WanVideoEmptyEmbeds().process(
num_frames, width, height, control_embeds=control_embeds, extra_latents=extra_latents,
)
if vae is None:
raise ValueError("VAE is required for image encoding.")
H = height
W = width
lat_h = H // vae.upsampling_factor
lat_w = W // vae.upsampling_factor
num_frames = ((num_frames - 1) // 4) * 4 + 1
two_ref_images = start_image is not None and end_image is not None
if start_image is None and end_image is not None:
fun_or_fl2v_model = True # end image alone only works with this option
base_frames = num_frames + (1 if two_ref_images and not fun_or_fl2v_model else 0)
if temporal_mask is None:
mask = torch.zeros(1, base_frames, lat_h, lat_w, device=device, dtype=vae.dtype)
if start_image is not None:
mask[:, 0:start_image.shape[0]] = 1 # First frame
if end_image is not None:
mask[:, -end_image.shape[0]:] = 1 # End frame if exists
else:
mask = common_upscale(temporal_mask.unsqueeze(1).to(device), lat_w, lat_h, "nearest", "disabled").squeeze(1)
if mask.shape[0] > base_frames:
mask = mask[:base_frames]
elif mask.shape[0] < base_frames:
mask = torch.cat([mask, torch.zeros(base_frames - mask.shape[0], lat_h, lat_w, device=device)])
mask = mask.unsqueeze(0).to(device, vae.dtype)
# Repeat first frame and optionally end frame
start_mask_repeated = torch.repeat_interleave(mask[:, 0:1], repeats=4, dim=1) # T, C, H, W
if end_image is not None and not fun_or_fl2v_model:
end_mask_repeated = torch.repeat_interleave(mask[:, -1:], repeats=4, dim=1) # T, C, H, W
mask = torch.cat([start_mask_repeated, mask[:, 1:-1], end_mask_repeated], dim=1)
else:
mask = torch.cat([start_mask_repeated, mask[:, 1:]], dim=1)
# Reshape mask into groups of 4 frames
mask = mask.view(1, mask.shape[1] // 4, 4, lat_h, lat_w) # 1, T, C, H, W
mask = mask.movedim(1, 2)[0]# C, T, H, W
# Resize and rearrange the input image dimensions
if start_image is not None:
start_image = start_image[..., :3]
if start_image.shape[1] != H or start_image.shape[2] != W:
resized_start_image = common_upscale(start_image.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(0, 1)
else:
resized_start_image = start_image.permute(3, 0, 1, 2) # C, T, H, W
resized_start_image = resized_start_image * 2 - 1
if noise_aug_strength > 0.0:
resized_start_image = add_noise_to_reference_video(resized_start_image, ratio=noise_aug_strength)
if end_image is not None:
end_image = end_image[..., :3]
if end_image.shape[1] != H or end_image.shape[2] != W:
resized_end_image = common_upscale(end_image.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(0, 1)
else:
resized_end_image = end_image.permute(3, 0, 1, 2) # C, T, H, W
resized_end_image = resized_end_image * 2 - 1
if noise_aug_strength > 0.0:
resized_end_image = add_noise_to_reference_video(resized_end_image, ratio=noise_aug_strength)
# Concatenate image with zero frames and encode
if temporal_mask is None:
if start_image is not None and end_image is None:
zero_frames = torch.zeros(3, num_frames-start_image.shape[0], H, W, device=device, dtype=vae.dtype)
concatenated = torch.cat([resized_start_image.to(device, dtype=vae.dtype), zero_frames], dim=1)
del resized_start_image, zero_frames
elif start_image is None and end_image is not None:
zero_frames = torch.zeros(3, num_frames-end_image.shape[0], H, W, device=device, dtype=vae.dtype)
concatenated = torch.cat([zero_frames, resized_end_image.to(device, dtype=vae.dtype)], dim=1)
del zero_frames
elif start_image is None and end_image is None:
concatenated = torch.zeros(3, num_frames, H, W, device=device, dtype=vae.dtype)
else:
if fun_or_fl2v_model:
zero_frames = torch.zeros(3, num_frames-(start_image.shape[0]+end_image.shape[0]), H, W, device=device, dtype=vae.dtype)
else:
zero_frames = torch.zeros(3, num_frames-1, H, W, device=device, dtype=vae.dtype)
concatenated = torch.cat([resized_start_image.to(device, dtype=vae.dtype), zero_frames, resized_end_image.to(device, dtype=vae.dtype)], dim=1)
del resized_start_image, zero_frames
else:
temporal_mask = common_upscale(temporal_mask.unsqueeze(1), W, H, "nearest", "disabled").squeeze(1)
concatenated = resized_start_image[:,:num_frames].to(vae.dtype) * temporal_mask[:num_frames].unsqueeze(0).to(vae.dtype)
del resized_start_image, temporal_mask
mm.soft_empty_cache()
gc.collect()
vae.to(device)
y = vae.encode([concatenated], device, end_=(end_image is not None and not fun_or_fl2v_model),tiled=tiled_vae)[0]
del concatenated
has_ref = False
if extra_latents is not None:
samples = extra_latents["samples"].squeeze(0)
y = torch.cat([samples, y], dim=1)
mask = torch.cat([torch.ones_like(mask[:, 0:samples.shape[1]]), mask], dim=1)
num_frames += samples.shape[1] * 4
has_ref = True
y[:, :1] *= start_latent_strength
y[:, -1:] *= end_latent_strength
# Calculate maximum sequence length
patches_per_frame = lat_h * lat_w // (PATCH_SIZE[1] * PATCH_SIZE[2])
frames_per_stride = (num_frames - 1) // 4 + (2 if end_image is not None and not fun_or_fl2v_model else 1)
max_seq_len = frames_per_stride * patches_per_frame
if add_cond_latents is not None:
add_cond_latents["ref_latent_neg"] = vae.encode(torch.zeros(1, 3, 1, H, W, device=device, dtype=vae.dtype), device)
if force_offload:
vae.model.to(offload_device)
mm.soft_empty_cache()
gc.collect()
image_embeds = {
"image_embeds": y,
"clip_context": clip_embeds.get("clip_embeds", None) if clip_embeds is not None else None,
"negative_clip_context": clip_embeds.get("negative_clip_embeds", None) if clip_embeds is not None else None,
"max_seq_len": max_seq_len,
"num_frames": num_frames,
"lat_h": lat_h,
"lat_w": lat_w,
"control_embeds": control_embeds["control_embeds"] if control_embeds is not None else None,
"end_image": resized_end_image if end_image is not None else None,
"fun_or_fl2v_model": fun_or_fl2v_model,
"has_ref": has_ref,
"add_cond_latents": add_cond_latents,
"mask": mask
}
return (image_embeds,)
# region WanAnimate
class WanVideoAnimateEmbeds:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"vae": ("WANVAE",),
"width": ("INT", {"default": 832, "min": 64, "max": 8096, "step": 8, "tooltip": "Width of the image to encode"}),
"height": ("INT", {"default": 480, "min": 64, "max": 8096, "step": 8, "tooltip": "Height of the image to encode"}),
"num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}),
"force_offload": ("BOOLEAN", {"default": True}),
"frame_window_size": ("INT", {"default": 77, "min": 1, "max": 1000, "step": 1, "tooltip": "Number of frames to use for temporal attention window"}),
"colormatch": (
[
'disabled',
'mkl',
'hm',
'reinhard',
'mvgd',
'hm-mvgd-hm',
'hm-mkl-hm',
], {
"default": 'disabled', "tooltip": "Color matching method to use between the windows"
},),
"pose_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional multiplier for the pose"}),
"face_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional multiplier for the face"}),
},
"optional": {
"clip_embeds": ("WANVIDIMAGE_CLIPEMBEDS", {"tooltip": "Clip vision encoded image"}),
"ref_images": ("IMAGE", {"tooltip": "Image to encode"}),
"pose_images": ("IMAGE", {"tooltip": "end frame"}),
"face_images": ("IMAGE", {"tooltip": "end frame"}),
"bg_images": ("IMAGE", {"tooltip": "background images"}),
"mask": ("MASK", {"tooltip": "mask"}),
"tiled_vae": ("BOOLEAN", {"default": False, "tooltip": "Use tiled VAE encoding for reduced memory use"}),
}
}
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",)
RETURN_NAMES = ("image_embeds",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
def process(self, vae, width, height, num_frames, force_offload, frame_window_size, colormatch, pose_strength, face_strength,
ref_images=None, pose_images=None, face_images=None, clip_embeds=None, tiled_vae=False, bg_images=None, mask=None):
H = height
W = width
lat_h = H // vae.upsampling_factor
lat_w = W // vae.upsampling_factor
num_refs = ref_images.shape[0] if ref_images is not None else 0
num_frames = ((num_frames - 1) // 4) * 4 + 1
looping = num_frames > frame_window_size
if num_frames < frame_window_size:
frame_window_size = num_frames
target_shape = (16, (num_frames - 1) // 4 + 1 + num_refs, lat_h, lat_w)
latent_window_size = ((frame_window_size - 1) // 4)
if not looping:
num_frames = num_frames + num_refs * 4
else:
latent_window_size = latent_window_size + 1
mm.soft_empty_cache()
gc.collect()
vae.to(device)
# Resize and rearrange the input image dimensions
pose_latents = ref_latents = ref_latent = None
if pose_images is not None:
pose_images = pose_images[..., :3]
if pose_images.shape[1] != H or pose_images.shape[2] != W:
resized_pose_images = common_upscale(pose_images.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(0, 1)
else:
resized_pose_images = pose_images.permute(3, 0, 1, 2) # C, T, H, W
resized_pose_images = resized_pose_images * 2 - 1
if not looping:
pose_latents = vae.encode([resized_pose_images.to(device, vae.dtype)], device,tiled=tiled_vae)
pose_latents = pose_latents.to(offload_device)
if pose_latents.shape[2] < latent_window_size:
log.info(f"WanAnimate: Padding pose latents from {pose_latents.shape} to length {latent_window_size}")
pad_len = latent_window_size - pose_latents.shape[2]
pad = torch.zeros(pose_latents.shape[0], pose_latents.shape[1], pad_len, pose_latents.shape[3], pose_latents.shape[4], device=pose_latents.device, dtype=pose_latents.dtype)
pose_latents = torch.cat([pose_latents, pad], dim=2)
del resized_pose_images
else:
resized_pose_images = resized_pose_images.to(offload_device, dtype=vae.dtype)
bg_latents = None
if bg_images is not None:
if bg_images.shape[1] != H or bg_images.shape[2] != W:
resized_bg_images = common_upscale(bg_images.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(0, 1)
else:
resized_bg_images = bg_images.permute(3, 0, 1, 2) # C, T, H, W
resized_bg_images = (resized_bg_images[:3] * 2 - 1)
if not looping:
if bg_images is None:
resized_bg_images = torch.zeros(3, num_frames - num_refs, H, W, device=device, dtype=vae.dtype)
bg_latents = vae.encode([resized_bg_images.to(device, vae.dtype)], device,tiled=tiled_vae)[0].to(offload_device)
del resized_bg_images
elif bg_images is not None:
resized_bg_images = resized_bg_images.to(offload_device, dtype=vae.dtype)
if ref_images is not None:
if ref_images.shape[1] != H or ref_images.shape[2] != W:
resized_ref_images = common_upscale(ref_images.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(0, 1)
else:
resized_ref_images = ref_images.permute(3, 0, 1, 2) # C, T, H, W
resized_ref_images = resized_ref_images[:3] * 2 - 1
ref_latent = vae.encode([resized_ref_images.to(device, vae.dtype)], device,tiled=tiled_vae)[0]
msk = torch.zeros(4, 1, lat_h, lat_w, device=device, dtype=vae.dtype)
msk[:, :num_refs] = 1
ref_latent_masked = torch.cat([msk, ref_latent], dim=0).to(offload_device) # 4+C 1 H W
if mask is None:
bg_mask = torch.zeros(1, num_frames, lat_h, lat_w, device=offload_device, dtype=vae.dtype)
else:
bg_mask = 1 - mask[:num_frames]
if bg_mask.shape[0] < num_frames and not looping:
bg_mask = torch.cat([bg_mask, bg_mask[-1:].repeat(num_frames - bg_mask.shape[0], 1, 1)], dim=0)
bg_mask = common_upscale(bg_mask.unsqueeze(1), lat_w, lat_h, "nearest", "disabled").squeeze(1)
bg_mask = bg_mask.unsqueeze(-1).permute(3, 0, 1, 2).to(offload_device, vae.dtype) # C, T, H, W
if bg_images is None and looping:
bg_mask[:, :num_refs] = 1
bg_mask_mask_repeated = torch.repeat_interleave(bg_mask[:, 0:1], repeats=4, dim=1) # T, C, H, W
bg_mask = torch.cat([bg_mask_mask_repeated, bg_mask[:, 1:]], dim=1)
bg_mask = bg_mask.view(1, bg_mask.shape[1] // 4, 4, lat_h, lat_w) # 1, T, C, H, W
bg_mask = bg_mask.movedim(1, 2)[0]# C, T, H, W
if not looping:
bg_latents_masked = torch.cat([bg_mask[:, :bg_latents.shape[1]], bg_latents], dim=0)
del bg_mask, bg_latents
ref_latent = torch.cat([ref_latent_masked, bg_latents_masked], dim=1)
else:
ref_latent = ref_latent_masked
if face_images is not None:
face_images = face_images[..., :3]
if face_images.shape[1] != 512 or face_images.shape[2] != 512:
resized_face_images = common_upscale(face_images.movedim(-1, 1), 512, 512, "lanczos", "center").movedim(0, 1)
else:
resized_face_images = face_images.permute(3, 0, 1, 2) # B, C, T, H, W
resized_face_images = (resized_face_images * 2 - 1).unsqueeze(0)
resized_face_images = resized_face_images.to(offload_device, dtype=vae.dtype)
seq_len = math.ceil((target_shape[2] * target_shape[3]) / 4 * target_shape[1])
if force_offload:
vae.model.to(offload_device)
mm.soft_empty_cache()
gc.collect()
image_embeds = {
"clip_context": clip_embeds.get("clip_embeds", None) if clip_embeds is not None else None,
"negative_clip_context": clip_embeds.get("negative_clip_embeds", None) if clip_embeds is not None else None,
"max_seq_len": seq_len,
"pose_latents": pose_latents,
"pose_images": resized_pose_images if pose_images is not None and looping else None,
"bg_images": resized_bg_images if bg_images is not None and looping else None,
"ref_masks": bg_mask if mask is not None and looping else None,
"is_masked": mask is not None,
"ref_latent": ref_latent,
"ref_image": resized_ref_images if ref_images is not None else None,
"face_pixels": resized_face_images if face_images is not None else None,
"num_frames": num_frames,
"target_shape": target_shape,
"frame_window_size": frame_window_size,
"lat_h": lat_h,
"lat_w": lat_w,
"vae": vae,
"colormatch": colormatch,
"looping": looping,
"pose_strength": pose_strength,
"face_strength": face_strength,
}
return (image_embeds,)
class WanVideoEmptyEmbeds:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"width": ("INT", {"default": 832, "min": 64, "max": 8096, "step": 8, "tooltip": "Width of the image to encode"}),
"height": ("INT", {"default": 480, "min": 64, "max": 8096, "step": 8, "tooltip": "Height of the image to encode"}),
"num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}),
},
"optional": {
"control_embeds": ("WANVIDIMAGE_EMBEDS", {"tooltip": "control signal for the Fun -model"}),
"extra_latents": ("LATENT", {"tooltip": "First latent to use for the Pusa -model"}),
}
}
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", )
RETURN_NAMES = ("image_embeds",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
def process(self, num_frames, width, height, control_embeds=None, extra_latents=None):
target_shape = (16, (num_frames - 1) // VAE_STRIDE[0] + 1,
height // VAE_STRIDE[1],
width // VAE_STRIDE[2])
embeds = {
"target_shape": target_shape,
"num_frames": num_frames,
"control_embeds": control_embeds["control_embeds"] if control_embeds is not None else None,
}
if extra_latents is not None:
embeds["extra_latents"] = [{
"samples": extra_latents["samples"],
"index": 0,
}]
return (embeds,)
class WanVideoAddExtraLatent:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"embeds": ("WANVIDIMAGE_EMBEDS",),
"extra_latents": ("LATENT",),
"latent_index": ("INT", {"default": 0, "min": -1000, "max": 1000, "step": 1, "tooltip": "Index to insert the extra latents at in latent space"}),
}
}
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",)
RETURN_NAMES = ("image_embeds",)
FUNCTION = "add"
CATEGORY = "WanVideoWrapper"
def add(self, embeds, extra_latents, latent_index):
# Prepare the new extra latent entry
new_entry = {
"samples": extra_latents["samples"],
"index": latent_index,
}
# Get previous extra_latents list, or start a new one
prev_extra_latents = embeds.get("extra_latents", None)
if prev_extra_latents is None:
extra_latents_list = [new_entry]
elif isinstance(prev_extra_latents, list):
extra_latents_list = prev_extra_latents + [new_entry]
else:
extra_latents_list = [prev_extra_latents, new_entry]
# Return a new dict with updated extra_latents
updated = dict(embeds)
updated["extra_latents"] = extra_latents_list
return (updated,)
class WanVideoAddLucyEditLatents:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"embeds": ("WANVIDIMAGE_EMBEDS",),
"extra_latents": ("LATENT",),
}
}
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",)
RETURN_NAMES = ("image_embeds",)
FUNCTION = "add"
CATEGORY = "WanVideoWrapper"
def add(self, embeds, extra_latents):
updated = dict(embeds)
updated["extra_channel_latents"] = extra_latents["samples"]
return (updated,)
class WanVideoMiniMaxRemoverEmbeds:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"width": ("INT", {"default": 832, "min": 64, "max": 8096, "step": 8, "tooltip": "Width of the image to encode"}),
"height": ("INT", {"default": 480, "min": 64, "max": 8096, "step": 8, "tooltip": "Height of the image to encode"}),
"num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}),
"latents": ("LATENT", {"tooltip": "Encoded latents to use as control signals"}),
"mask_latents": ("LATENT", {"tooltip": "Encoded latents to use as mask"}),
},
}
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", )
RETURN_NAMES = ("image_embeds",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
def process(self, num_frames, width, height, latents, mask_latents):
target_shape = (16, (num_frames - 1) // VAE_STRIDE[0] + 1,
height // VAE_STRIDE[1],
width // VAE_STRIDE[2])
embeds = {
"target_shape": target_shape,
"num_frames": num_frames,
"minimax_latents": latents["samples"].squeeze(0),
"minimax_mask_latents": mask_latents["samples"].squeeze(0),
}
return (embeds,)
# region phantom
class WanVideoPhantomEmbeds:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}),
"phantom_latent_1": ("LATENT", {"tooltip": "reference latents for the phantom model"}),
"phantom_cfg_scale": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "CFG scale for the extra phantom cond pass"}),
"phantom_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the phantom model"}),
"phantom_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the phantom model"}),
},
"optional": {
"phantom_latent_2": ("LATENT", {"tooltip": "reference latents for the phantom model"}),
"phantom_latent_3": ("LATENT", {"tooltip": "reference latents for the phantom model"}),
"phantom_latent_4": ("LATENT", {"tooltip": "reference latents for the phantom model"}),
"vace_embeds": ("WANVIDIMAGE_EMBEDS", {"tooltip": "VACE embeds"}),
}
}
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", )
RETURN_NAMES = ("image_embeds",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
def process(self, num_frames, phantom_cfg_scale, phantom_start_percent, phantom_end_percent, phantom_latent_1, phantom_latent_2=None, phantom_latent_3=None, phantom_latent_4=None, vace_embeds=None):
samples = phantom_latent_1["samples"].squeeze(0)
if phantom_latent_2 is not None:
samples = torch.cat([samples, phantom_latent_2["samples"].squeeze(0)], dim=1)
if phantom_latent_3 is not None:
samples = torch.cat([samples, phantom_latent_3["samples"].squeeze(0)], dim=1)
if phantom_latent_4 is not None:
samples = torch.cat([samples, phantom_latent_4["samples"].squeeze(0)], dim=1)
C, T, H, W = samples.shape
log.info(f"Phantom latents shape: {samples.shape}")
target_shape = (16, (num_frames - 1) // VAE_STRIDE[0] + 1,
H * 8 // VAE_STRIDE[1],
W * 8 // VAE_STRIDE[2])
embeds = {
"target_shape": target_shape,
"num_frames": num_frames,
"phantom_latents": samples,
"phantom_cfg_scale": phantom_cfg_scale,
"phantom_start_percent": phantom_start_percent,
"phantom_end_percent": phantom_end_percent,
}
if vace_embeds is not None:
vace_input = {
"vace_context": vace_embeds["vace_context"],
"vace_scale": vace_embeds["vace_scale"],
"has_ref": vace_embeds["has_ref"],
"vace_start_percent": vace_embeds["vace_start_percent"],
"vace_end_percent": vace_embeds["vace_end_percent"],
"vace_seq_len": vace_embeds["vace_seq_len"],
"additional_vace_inputs": vace_embeds["additional_vace_inputs"],
}
embeds.update(vace_input)
return (embeds,)
class WanVideoControlEmbeds:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the control signal"}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the control signal"}),
"latents": ("LATENT", {"tooltip": "Encoded latents to use as control signals"}),
},
"optional": {
"fun_ref_image": ("LATENT", {"tooltip": "Reference latent for the Fun 1.1 -model"}),
}
}
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", )
RETURN_NAMES = ("image_embeds",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
def process(self, latents, start_percent, end_percent, fun_ref_image=None):
samples = latents["samples"].squeeze(0)
C, T, H, W = samples.shape
num_frames = (T - 1) * 4 + 1
seq_len = math.ceil((H * W) / 4 * ((num_frames - 1) // 4 + 1))
embeds = {
"max_seq_len": seq_len,
"target_shape": samples.shape,
"num_frames": num_frames,
"control_embeds": {
"control_images": samples,
"start_percent": start_percent,
"end_percent": end_percent,
"fun_ref_image": fun_ref_image["samples"][:,:, 0] if fun_ref_image is not None else None,
}
}
return (embeds,)
class WanVideoAddControlEmbeds:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"embeds": ("WANVIDIMAGE_EMBEDS",),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the control signal"}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the control signal"}),
},
"optional": {
"latents": ("LATENT", {"tooltip": "Encoded latents to use as control signals"}),
"fun_ref_image": ("LATENT", {"tooltip": "Reference latent for the Fun 1.1 -model"}),
}
}
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", )
RETURN_NAMES = ("image_embeds",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
def process(self, embeds, start_percent, end_percent, fun_ref_image=None, latents=None):
new_entry = {
"control_images": latents["samples"].squeeze(0) if latents is not None else None,
"start_percent": start_percent,
"end_percent": end_percent,
"fun_ref_image": fun_ref_image["samples"][:,:, 0] if fun_ref_image is not None else None,
}
updated = dict(embeds)
updated["control_embeds"] = new_entry
return (updated,)
class WanVideoAddPusaNoise:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"embeds": ("WANVIDIMAGE_EMBEDS",),
"noise_multipliers": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "Noise multipliers for Pusa, can be a list of floats"}),
"noisy_steps": ("INT", {"default": -1, "min": -1, "max": 1000, "tooltip": "Number steps to apply the extra noise"}),
},
}
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", )
RETURN_NAMES = ("image_embeds",)
FUNCTION = "add"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Adds latent and timestep noise multipliers when using flowmatch_pusa"
def add(self, embeds, noise_multipliers, noisy_steps):
updated = dict(embeds)
updated["pusa_noise_multipliers"] = noise_multipliers
updated["pusa_noisy_steps"] = noisy_steps
return (updated,)
class WanVideoSLG:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"blocks": ("STRING", {"default": "10", "tooltip": "Blocks to skip uncond on, separated by comma, index starts from 0"}),
"start_percent": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the control signal"}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the control signal"}),
},
}
RETURN_TYPES = ("SLGARGS", )
RETURN_NAMES = ("slg_args",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Skips uncond on the selected blocks"
def process(self, blocks, start_percent, end_percent):
slg_block_list = [int(x.strip()) for x in blocks.split(",")]
slg_args = {
"blocks": slg_block_list,
"start_percent": start_percent,
"end_percent": end_percent,
}
return (slg_args,)
#region VACE
class WanVideoVACEEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"vae": ("WANVAE",),
"width": ("INT", {"default": 832, "min": 64, "max": 8096, "step": 8, "tooltip": "Width of the image to encode"}),
"height": ("INT", {"default": 480, "min": 64, "max": 8096, "step": 8, "tooltip": "Height of the image to encode"}),
"num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}),
"vace_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the steps to apply VACE"}),
"vace_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the steps to apply VACE"}),
},
"optional": {
"input_frames": ("IMAGE",),
"ref_images": ("IMAGE",),
"input_masks": ("MASK",),
"prev_vace_embeds": ("WANVIDIMAGE_EMBEDS",),
"tiled_vae": ("BOOLEAN", {"default": False, "tooltip": "Use tiled VAE encoding for reduced memory use"}),
},
}
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", )
RETURN_NAMES = ("vace_embeds",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
def process(self, vae, width, height, num_frames, strength, vace_start_percent, vace_end_percent, input_frames=None, ref_images=None, input_masks=None, prev_vace_embeds=None, tiled_vae=False):
width = (width // 16) * 16
height = (height // 16) * 16
target_shape = (16, (num_frames - 1) // VAE_STRIDE[0] + 1,
height // VAE_STRIDE[1],
width // VAE_STRIDE[2])
# vace context encode
if input_frames is None:
input_frames = torch.zeros((1, 3, num_frames, height, width), device=device, dtype=vae.dtype)
else:
input_frames = input_frames.clone()[:num_frames, :, :, :3]
input_frames = common_upscale(input_frames.movedim(-1, 1), width, height, "lanczos", "disabled").movedim(1, -1)
input_frames = input_frames.to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W
input_frames = input_frames * 2 - 1
if input_masks is None:
input_masks = torch.ones_like(input_frames, device=device)
else:
log.info(f"input_masks shape: {input_masks.shape}")
input_masks = input_masks[:num_frames]
input_masks = common_upscale(input_masks.clone().unsqueeze(1), width, height, "nearest-exact", "disabled").squeeze(1)
input_masks = input_masks.to(vae.dtype).to(device)
input_masks = input_masks.unsqueeze(-1).unsqueeze(0).permute(0, 4, 1, 2, 3).repeat(1, 3, 1, 1, 1) # B, C, T, H, W
if ref_images is not None:
ref_images = ref_images.clone()[..., :3]
# Create padded image
if ref_images.shape[0] > 1:
ref_images = torch.cat([ref_images[i] for i in range(ref_images.shape[0])], dim=1).unsqueeze(0)
B, H, W, C = ref_images.shape
current_aspect = W / H
target_aspect = width / height
if current_aspect > target_aspect:
# Image is wider than target, pad height
new_h = int(W / target_aspect)
pad_h = (new_h - H) // 2
padded = torch.ones(ref_images.shape[0], new_h, W, ref_images.shape[3], device=ref_images.device, dtype=ref_images.dtype)
padded[:, pad_h:pad_h+H, :, :] = ref_images
ref_images = padded
elif current_aspect < target_aspect:
# Image is taller than target, pad width
new_w = int(H * target_aspect)
pad_w = (new_w - W) // 2
padded = torch.ones(ref_images.shape[0], H, new_w, ref_images.shape[3], device=ref_images.device, dtype=ref_images.dtype)
padded[:, :, pad_w:pad_w+W, :] = ref_images
ref_images = padded
ref_images = common_upscale(ref_images.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
ref_images = ref_images.to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3).unsqueeze(0)
ref_images = ref_images * 2 - 1
vae = vae.to(device)
z0 = self.vace_encode_frames(vae, input_frames, ref_images, masks=input_masks, tiled_vae=tiled_vae)
m0 = self.vace_encode_masks(input_masks, ref_images)
z = self.vace_latent(z0, m0)
vae.to(offload_device)
vace_input = {
"vace_context": z,
"vace_scale": strength,
"has_ref": ref_images is not None,
"num_frames": num_frames,
"target_shape": target_shape,
"vace_start_percent": vace_start_percent,
"vace_end_percent": vace_end_percent,
"vace_seq_len": math.ceil((z[0].shape[2] * z[0].shape[3]) / 4 * z[0].shape[1]),
"additional_vace_inputs": [],
}
if prev_vace_embeds is not None:
if "additional_vace_inputs" in prev_vace_embeds and prev_vace_embeds["additional_vace_inputs"]:
vace_input["additional_vace_inputs"] = prev_vace_embeds["additional_vace_inputs"].copy()
vace_input["additional_vace_inputs"].append(prev_vace_embeds)
return (vace_input,)
def vace_encode_frames(self, vae, frames, ref_images, masks=None, tiled_vae=False):
if ref_images is None:
ref_images = [None] * len(frames)
else:
assert len(frames) == len(ref_images)
pbar = ProgressBar(len(frames))
if masks is None:
latents = vae.encode(frames, device=device, tiled=tiled_vae)
else:
inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
del frames
inactive = vae.encode(inactive, device=device, tiled=tiled_vae)
reactive = vae.encode(reactive, device=device, tiled=tiled_vae)
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
del inactive, reactive
cat_latents = []
for latent, refs in zip(latents, ref_images):
if refs is not None:
if masks is None:
ref_latent = vae.encode(refs, device=device, tiled=tiled_vae)
else:
ref_latent = vae.encode(refs, device=device, tiled=tiled_vae)
ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
assert all([x.shape[1] == 1 for x in ref_latent])
latent = torch.cat([*ref_latent, latent], dim=1)
cat_latents.append(latent)
pbar.update(1)
return cat_latents
def vace_encode_masks(self, masks, ref_images=None):
if ref_images is None:
ref_images = [None] * len(masks)
else:
assert len(masks) == len(ref_images)
result_masks = []
pbar = ProgressBar(len(masks))
for mask, refs in zip(masks, ref_images):
_c, depth, height, width = mask.shape
new_depth = int((depth + 3) // VAE_STRIDE[0])
height = 2 * (int(height) // (VAE_STRIDE[1] * 2))
width = 2 * (int(width) // (VAE_STRIDE[2] * 2))
# reshape
mask = mask[0, :, :, :]
mask = mask.view(
depth, height, VAE_STRIDE[1], width, VAE_STRIDE[1]
) # depth, height, 8, width, 8
mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
mask = mask.reshape(
VAE_STRIDE[1] * VAE_STRIDE[2], depth, height, width
) # 8*8, depth, height, width
# interpolation
mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0)
if refs is not None:
length = len(refs)
mask_pad = torch.zeros_like(mask[:, :length, :, :])
mask = torch.cat((mask_pad, mask), dim=1)
result_masks.append(mask)
pbar.update(1)
return result_masks
def vace_latent(self, z, m):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
#region context options
class WanVideoContextOptions:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"context_schedule": (["uniform_standard", "uniform_looped", "static_standard"],),
"context_frames": ("INT", {"default": 81, "min": 2, "max": 1000, "step": 1, "tooltip": "Number of pixel frames in the context, NOTE: the latent space has 4 frames in 1"} ),
"context_stride": ("INT", {"default": 4, "min": 4, "max": 100, "step": 1, "tooltip": "Context stride as pixel frames, NOTE: the latent space has 4 frames in 1"} ),
"context_overlap": ("INT", {"default": 16, "min": 4, "max": 100, "step": 1, "tooltip": "Context overlap as pixel frames, NOTE: the latent space has 4 frames in 1"} ),
"freenoise": ("BOOLEAN", {"default": True, "tooltip": "Shuffle the noise"}),
"verbose": ("BOOLEAN", {"default": False, "tooltip": "Print debug output"}),
},
"optional": {
"fuse_method": (["linear", "pyramid"], {"default": "linear", "tooltip": "Window weight function: linear=ramps at edges only, pyramid=triangular weights peaking in middle"}),
"reference_latent": ("LATENT", {"tooltip": "Image to be used as init for I2V models for windows where first frame is not the actual first frame. Mostly useful with MAGREF model"}),
}
}
RETURN_TYPES = ("WANVIDCONTEXT", )
RETURN_NAMES = ("context_options",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Context options for WanVideo, allows splitting the video into context windows and attemps blending them for longer generations than the model and memory otherwise would allow."
def process(self, context_schedule, context_frames, context_stride, context_overlap, freenoise, verbose, image_cond_start_step=6, image_cond_window_count=2, vae=None, fuse_method="linear", reference_latent=None):
context_options = {
"context_schedule":context_schedule,
"context_frames":context_frames,
"context_stride":context_stride,
"context_overlap":context_overlap,
"freenoise":freenoise,
"verbose":verbose,
"fuse_method":fuse_method,
"reference_latent":reference_latent["samples"] if reference_latent is not None else None,
}
return (context_options,)
class WanVideoFlowEdit:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"source_embeds": ("WANVIDEOTEXTEMBEDS", ),
"skip_steps": ("INT", {"default": 4, "min": 0}),
"drift_steps": ("INT", {"default": 0, "min": 0}),
"drift_flow_shift": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 30.0, "step": 0.01}),
"source_cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
"drift_cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
},
"optional": {
"source_image_embeds": ("WANVIDIMAGE_EMBEDS", ),
}
}
RETURN_TYPES = ("FLOWEDITARGS", )
RETURN_NAMES = ("flowedit_args",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Flowedit options for WanVideo"
def process(self, **kwargs):
return (kwargs,)
class WanVideoLoopArgs:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"shift_skip": ("INT", {"default": 6, "min": 0, "tooltip": "Skip step of latent shift"}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the looping effect"}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the looping effect"}),
},
}
RETURN_TYPES = ("LOOPARGS", )
RETURN_NAMES = ("loop_args",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Looping through latent shift as shown in https://github.com/YisuiTT/Mobius/"
def process(self, **kwargs):
return (kwargs,)
class WanVideoExperimentalArgs:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"video_attention_split_steps": ("STRING", {"default": "", "tooltip": "Steps to split self attention when using multiple prompts"}),
"cfg_zero_star": ("BOOLEAN", {"default": False, "tooltip": "https://github.com/WeichenFan/CFG-Zero-star"}),
"use_zero_init": ("BOOLEAN", {"default": False}),
"zero_star_steps": ("INT", {"default": 0, "min": 0, "tooltip": "Steps to split self attention when using multiple prompts"}),
"use_fresca": ("BOOLEAN", {"default": False, "tooltip": "https://github.com/WikiChao/FreSca"}),
"fresca_scale_low": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"fresca_scale_high": ("FLOAT", {"default": 1.25, "min": 0.0, "max": 10.0, "step": 0.01}),
"fresca_freq_cutoff": ("INT", {"default": 20, "min": 0, "max": 10000, "step": 1}),
"use_tcfg": ("BOOLEAN", {"default": False, "tooltip": "https://arxiv.org/abs/2503.18137 TCFG: Tangential Damping Classifier-free Guidance. CFG artifacts reduction."}),
"raag_alpha": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Alpha value for RAAG, 1.0 is default, 0.0 is disabled."}),
"bidirectional_sampling": ("BOOLEAN", {"default": False, "tooltip": "Enable bidirectional sampling, based on https://github.com/ff2416/WanFM"}),
"temporal_score_rescaling": ("BOOLEAN", {"default": False, "tooltip": "Enable temporal score rescaling: https://github.com/temporalscorerescaling/TSR/"}),
"tsr_k": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "The sampling temperature"}),
"tsr_sigma": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "How early TSR steer the sampling process"}),
},
}
RETURN_TYPES = ("EXPERIMENTALARGS", )
RETURN_NAMES = ("exp_args",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Experimental stuff"
EXPERIMENTAL = True
def process(self, **kwargs):
return (kwargs,)
class WanVideoFreeInitArgs:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"freeinit_num_iters": ("INT", {"default": 3, "min": 1, "max": 10, "tooltip": "Number of FreeInit iterations"}),
"freeinit_method": (["butterworth", "ideal", "gaussian", "none"], {"default": "ideal", "tooltip": "Frequency filter type"}),
"freeinit_n": ("INT", {"default": 4, "min": 1, "max": 10, "tooltip": "Butterworth filter order (only for butterworth)"}),
"freeinit_d_s": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Spatial filter cutoff"}),
"freeinit_d_t": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Temporal filter cutoff"}),
},
}
RETURN_TYPES = ("FREEINITARGS", )
RETURN_NAMES = ("freeinit_args",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "https://github.com/TianxingWu/FreeInit; FreeInit, a concise yet effective method to improve temporal consistency of videos generated by diffusion models"
EXPERIMENTAL = True
def process(self, **kwargs):
return (kwargs,)
class WanVideoScheduler: #WIP
@classmethod
def INPUT_TYPES(s):
return {"required": {
"scheduler": (scheduler_list, {"default": "unipc"}),
"steps": ("INT", {"default": 30, "min": 1, "tooltip": "Number of steps for the scheduler"}),
"shift": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
"start_step": ("INT", {"default": 0, "min": 0, "tooltip": "Starting step for the scheduler"}),
"end_step": ("INT", {"default": -1, "min": -1, "tooltip": "Ending step for the scheduler"})
},
"optional": {
"sigmas": ("SIGMAS", ),
},
"hidden": {
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("SIGMAS", "INT", "FLOAT", scheduler_list, "INT", "INT",)
RETURN_NAMES = ("sigmas", "steps", "shift", "scheduler", "start_step", "end_step")
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
EXPERIMENTAL = True
def process(self, scheduler, steps, start_step, end_step, shift, unique_id, sigmas=None):
sample_scheduler, timesteps, start_idx, end_idx = get_scheduler(
scheduler,
steps,
start_step, end_step, shift,
device,
sigmas=sigmas,
log_timesteps=True)
scheduler_dict = {
"sample_scheduler": sample_scheduler,
"timesteps": timesteps,
}
try:
from server import PromptServer
import io
import base64
import matplotlib.pyplot as plt
except:
PromptServer = None
if unique_id and PromptServer is not None:
try:
# Plot sigmas and save to a buffer
sigmas_np = sample_scheduler.full_sigmas.cpu().numpy()
if not np.isclose(sigmas_np[-1], 0.0, atol=1e-6):
sigmas_np = np.append(sigmas_np, 0.0)
buf = io.BytesIO()
fig = plt.figure(facecolor='#353535')
ax = fig.add_subplot(111)
ax.set_facecolor('#353535') # Set axes background color
x_values = range(0, len(sigmas_np))
ax.plot(x_values, sigmas_np)
# Annotate each sigma value
ax.scatter(x_values, sigmas_np, color='white', s=20, zorder=3) # Small dots at each sigma
for x, y in zip(x_values, sigmas_np):
if len(sigmas_np) <= 10: # Only annotate if few steps
ax.annotate(f"{y:.3f}", (x, y), textcoords="offset points", xytext=(10, 1), ha='center', color='orange', fontsize=12)
ax.set_xticks(x_values)
ax.set_title("Sigmas", color='white') # Title font color
ax.set_xlabel("Step", color='white') # X label font color
ax.set_ylabel("Sigma Value", color='white') # Y label font color
ax.tick_params(axis='x', colors='white', labelsize=10) # X tick color
ax.tick_params(axis='y', colors='white', labelsize=10) # Y tick color
# Add split point if end_step is defined
end_idx += 1
if end_idx != -1 and 0 <= end_idx < len(sigmas_np) - 1:
ax.axvline(end_idx, color='red', linestyle='--', linewidth=2, label='end_step split')
# Add split point if start_step is defined
if start_idx > 0 and 0 <= start_idx < len(sigmas_np):
ax.axvline(start_idx, color='green', linestyle='--', linewidth=2, label='start_step split')
if (end_idx != -1 and 0 <= end_idx < len(sigmas_np)) or (start_idx > 0 and 0 <= start_idx < len(sigmas_np)):
ax.legend()
if start_idx < end_idx and 0 <= start_idx < len(sigmas_np) and 0 < end_idx < len(sigmas_np):
ax.axvspan(start_idx, end_idx, color='lightblue', alpha=0.1, label='Sampled Range')
plt.tight_layout()
plt.savefig(buf, format='png')
plt.close(fig)
buf.seek(0)
img_base64 = base64.b64encode(buf.read()).decode('utf-8')
buf.close()
# Send as HTML img tag with base64 data
html_img = f"<img src='data:image/png;base64,{img_base64}' alt='Sigmas Plot' style='max-width:100%; height:100%; overflow:hidden; display:block;'>"
PromptServer.instance.send_progress_text(html_img, unique_id)
except Exception as e:
print("Failed to send sigmas plot:", e)
pass
return (sigmas, steps, shift, scheduler_dict, start_step, end_step)
class WanVideoSchedulerSA_ODE:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"use_adaptive_order": ("BOOLEAN", {"default": False, "tooltip": "Use adaptive order"}),
"use_velocity_smoothing": ("BOOLEAN", {"default": True, "tooltip": "Use velocity smoothing"}),
"convergence_threshold": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Convergence threshold for velocity smoothing"}),
"smoothing_factor": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Smoothing factor for velocity smoothing"}),
"steps": ("INT", {"default": 30, "min": 1, "tooltip": "Number of steps for the scheduler"}),
"shift": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
"start_step": ("INT", {"default": 0, "min": 0, "tooltip": "Starting step for the scheduler"}),
"end_step": ("INT", {"default": -1, "min": -1, "tooltip": "Ending step for the scheduler"})
},
"optional": {
"sigmas": ("SIGMAS", ),
},
}
RETURN_TYPES = ("SIGMAS", "INT", "FLOAT", scheduler_list, "INT", "INT",)
RETURN_NAMES = ("sigmas", "steps", "shift", "scheduler", "start_step", "end_step")
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
EXPERIMENTAL = True
def process(self, steps, start_step, end_step, shift, use_adaptive_order, use_velocity_smoothing, convergence_threshold, smoothing_factor, sigmas=None):
sample_scheduler, timesteps, _, _ = get_scheduler(
scheduler="sa_ode_stable/lowstep",
steps=steps,
start_step=start_step, end_step=end_step, shift=shift,
device=device,
sigmas=sigmas,
log_timesteps=True,
use_adaptive_order=use_adaptive_order,
use_velocity_smoothing=use_velocity_smoothing,
convergence_threshold=convergence_threshold,
smoothing_factor=smoothing_factor
)
scheduler_dict = {
"sample_scheduler": sample_scheduler,
"timesteps": timesteps,
}
return (sigmas, steps, shift, scheduler_dict, start_step, end_step)
rope_functions = ["default", "comfy", "comfy_chunked"]
class WanVideoRoPEFunction:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"rope_function": (rope_functions, {"default": "comfy"}),
"ntk_scale_f": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
"ntk_scale_h": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
"ntk_scale_w": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
},
}
RETURN_TYPES = (rope_functions, )
RETURN_NAMES = ("rope_function",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
EXPERIMENTAL = True
def process(self, rope_function, ntk_scale_f, ntk_scale_h, ntk_scale_w):
if ntk_scale_f != 1.0 or ntk_scale_h != 1.0 or ntk_scale_w != 1.0:
rope_func_dict = {
"rope_function": rope_function,
"ntk_scale_f": ntk_scale_f,
"ntk_scale_h": ntk_scale_h,
"ntk_scale_w": ntk_scale_w,
}
return (rope_func_dict,)
return (rope_function,)
#region VideoDecode
class WanVideoDecode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"vae": ("WANVAE",),
"samples": ("LATENT",),
"enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": (
"Drastically reduces memory use but will introduce seams at tile stride boundaries. "
"The location and number of seams is dictated by the tile stride size. "
"The visibility of seams can be controlled by increasing the tile size. "
"Seams become less obvious at 1.5x stride and are barely noticeable at 2x stride size. "
"Which is to say if you use a stride width of 160, the seams are barely noticeable with a tile width of 320."
)}),
"tile_x": ("INT", {"default": 272, "min": 40, "max": 2048, "step": 8, "tooltip": "Tile width in pixels. Smaller values use less VRAM but will make seams more obvious."}),
"tile_y": ("INT", {"default": 272, "min": 40, "max": 2048, "step": 8, "tooltip": "Tile height in pixels. Smaller values use less VRAM but will make seams more obvious."}),
"tile_stride_x": ("INT", {"default": 144, "min": 32, "max": 2040, "step": 8, "tooltip": "Tile stride width in pixels. Smaller values use less VRAM but will introduce more seams."}),
"tile_stride_y": ("INT", {"default": 128, "min": 32, "max": 2040, "step": 8, "tooltip": "Tile stride height in pixels. Smaller values use less VRAM but will introduce more seams."}),
},
"optional": {
"normalization": (["default", "minmax"], {"advanced": True}),
}
}
@classmethod
def VALIDATE_INPUTS(s, tile_x, tile_y, tile_stride_x, tile_stride_y):
if tile_x <= tile_stride_x:
return "Tile width must be larger than the tile stride width."
if tile_y <= tile_stride_y:
return "Tile height must be larger than the tile stride height."
return True
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "decode"
CATEGORY = "WanVideoWrapper"
def decode(self, vae, samples, enable_vae_tiling, tile_x, tile_y, tile_stride_x, tile_stride_y, normalization="default"):
mm.soft_empty_cache()
video = samples.get("video", None)
if video is not None:
video.clamp_(-1.0, 1.0)
video.add_(1.0).div_(2.0)
return video.cpu().float(),
latents = samples["samples"]
end_image = samples.get("end_image", None)
has_ref = samples.get("has_ref", False)
drop_last = samples.get("drop_last", False)
is_looped = samples.get("looped", False)
vae.to(device)
latents = latents.to(device = device, dtype = vae.dtype)
mm.soft_empty_cache()
if has_ref:
latents = latents[:, :, 1:]
if drop_last:
latents = latents[:, :, :-1]
if type(vae).__name__ == "TAEHV":
images = vae.decode_video(latents.permute(0, 2, 1, 3, 4))[0].permute(1, 0, 2, 3)
images = torch.clamp(images, 0.0, 1.0)
images = images.permute(1, 2, 3, 0).cpu().float()
return (images,)
else:
if end_image is not None:
enable_vae_tiling = False
images = vae.decode(latents, device=device, end_=(end_image is not None), tiled=enable_vae_tiling, tile_size=(tile_x//8, tile_y//8), tile_stride=(tile_stride_x//8, tile_stride_y//8))[0]
images = images.cpu().float()
if normalization == "minmax":
images.sub_(images.min()).div_(images.max() - images.min())
else:
images.clamp_(-1.0, 1.0)
images.add_(1.0).div_(2.0)
if is_looped:
temp_latents = torch.cat([latents[:, :, -3:]] + [latents[:, :, :2]], dim=2)
temp_images = vae.decode(temp_latents, device=device, end_=(end_image is not None), tiled=enable_vae_tiling, tile_size=(tile_x//vae.upsampling_factor, tile_y//vae.upsampling_factor), tile_stride=(tile_stride_x//vae.upsampling_factor, tile_stride_y//vae.upsampling_factor))[0]
temp_images = temp_images.cpu().float()
temp_images = (temp_images - temp_images.min()) / (temp_images.max() - temp_images.min())
images = torch.cat([temp_images[:, 9:].to(images), images[:, 5:]], dim=1)
if end_image is not None:
images = images[:, 0:-1]
vae.to(offload_device)
mm.soft_empty_cache()
images.clamp_(0.0, 1.0)
return (images.permute(1, 2, 3, 0),)
#region VideoEncode
class WanVideoEncodeLatentBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"vae": ("WANVAE",),
"images": ("IMAGE",),
"enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}),
"tile_x": ("INT", {"default": 272, "min": 64, "max": 2048, "step": 1, "tooltip": "Tile size in pixels, smaller values use less VRAM, may introduce more seams"}),
"tile_y": ("INT", {"default": 272, "min": 64, "max": 2048, "step": 1, "tooltip": "Tile size in pixels, smaller values use less VRAM, may introduce more seams"}),
"tile_stride_x": ("INT", {"default": 144, "min": 32, "max": 2048, "step": 32, "tooltip": "Tile stride in pixels, smaller values use less VRAM, may introduce more seams"}),
"tile_stride_y": ("INT", {"default": 128, "min": 32, "max": 2048, "step": 32, "tooltip": "Tile stride in pixels, smaller values use less VRAM, may introduce more seams"}),
},
}
RETURN_TYPES = ("LATENT",)
RETURN_NAMES = ("samples",)
FUNCTION = "encode"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Encodes a batch of images individually to create a latent video batch where each video is a single frame, useful for I2V init purposes, for example as multiple context window inits"
def encode(self, vae, images, enable_vae_tiling, tile_x, tile_y, tile_stride_x, tile_stride_y, latent_strength=1.0):
vae.to(device)
images = images.clone()
B, H, W, C = images.shape
if W % 16 != 0 or H % 16 != 0:
new_height = (H // 16) * 16
new_width = (W // 16) * 16
log.warning(f"Image size {W}x{H} is not divisible by 16, resizing to {new_width}x{new_height}")
images = common_upscale(images.movedim(-1, 1), new_width, new_height, "lanczos", "disabled").movedim(1, -1)
if images.shape[-1] == 4:
images = images[..., :3]
images = images.to(vae.dtype).to(device) * 2.0 - 1.0
latent_list = []
for img in images:
if enable_vae_tiling and tile_x is not None:
latent = vae.encode(img.unsqueeze(0).unsqueeze(0).permute(0, 4, 1, 2, 3), device=device, tiled=enable_vae_tiling, tile_size=(tile_x//vae.upsampling_factor, tile_y//vae.upsampling_factor), tile_stride=(tile_stride_x//vae.upsampling_factor, tile_stride_y//vae.upsampling_factor))
else:
latent = vae.encode(img.unsqueeze(0).unsqueeze(0).permute(0, 4, 1, 2, 3), device=device, tiled=enable_vae_tiling)
if latent_strength != 1.0:
latent *= latent_strength
latent_list.append(latent.squeeze(0).cpu())
latents_out = torch.stack(latent_list, dim=0)
log.info(f"WanVideoEncode: Encoded latents shape {latents_out.shape}")
vae.to(offload_device)
mm.soft_empty_cache()
return ({"samples": latents_out},)
class WanVideoEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"vae": ("WANVAE",),
"image": ("IMAGE",),
"enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}),
"tile_x": ("INT", {"default": 272, "min": 64, "max": 2048, "step": 1, "tooltip": "Tile size in pixels, smaller values use less VRAM, may introduce more seams"}),
"tile_y": ("INT", {"default": 272, "min": 64, "max": 2048, "step": 1, "tooltip": "Tile size in pixels, smaller values use less VRAM, may introduce more seams"}),
"tile_stride_x": ("INT", {"default": 144, "min": 32, "max": 2048, "step": 32, "tooltip": "Tile stride in pixels, smaller values use less VRAM, may introduce more seams"}),
"tile_stride_y": ("INT", {"default": 128, "min": 32, "max": 2048, "step": 32, "tooltip": "Tile stride in pixels, smaller values use less VRAM, may introduce more seams"}),
},
"optional": {
"noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Strength of noise augmentation, helpful for leapfusion I2V where some noise can add motion and give sharper results"}),
"latent_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional latent multiplier, helpful for leapfusion I2V where lower values allow for more motion"}),
"mask": ("MASK", ),
}
}
RETURN_TYPES = ("LATENT",)
RETURN_NAMES = ("samples",)
FUNCTION = "encode"
CATEGORY = "WanVideoWrapper"
def encode(self, vae, image, enable_vae_tiling, tile_x, tile_y, tile_stride_x, tile_stride_y, noise_aug_strength=0.0, latent_strength=1.0, mask=None):
vae.to(device)
image = image.clone()
B, H, W, C = image.shape
if W % 16 != 0 or H % 16 != 0:
new_height = (H // 16) * 16
new_width = (W // 16) * 16
log.warning(f"Image size {W}x{H} is not divisible by 16, resizing to {new_width}x{new_height}")
image = common_upscale(image.movedim(-1, 1), new_width, new_height, "lanczos", "disabled").movedim(1, -1)
if image.shape[-1] == 4:
image = image[..., :3]
image = image.to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W
if noise_aug_strength > 0.0:
image = add_noise_to_reference_video(image, ratio=noise_aug_strength)
if isinstance(vae, TAEHV):
latents = vae.encode_video(image.permute(0, 2, 1, 3, 4), parallel=False)# B, T, C, H, W
latents = latents.permute(0, 2, 1, 3, 4)
else:
latents = vae.encode(image * 2.0 - 1.0, device=device, tiled=enable_vae_tiling, tile_size=(tile_x//vae.upsampling_factor, tile_y//vae.upsampling_factor), tile_stride=(tile_stride_x//vae.upsampling_factor, tile_stride_y//vae.upsampling_factor))
vae.to(offload_device)
if latent_strength != 1.0:
latents *= latent_strength
log.info(f"WanVideoEncode: Encoded latents shape {latents.shape}")
mm.soft_empty_cache()
return ({"samples": latents, "noise_mask": mask},)
NODE_CLASS_MAPPINGS = {
"WanVideoDecode": WanVideoDecode,
"WanVideoTextEncode": WanVideoTextEncode,
"WanVideoTextEncodeSingle": WanVideoTextEncodeSingle,
"WanVideoClipVisionEncode": WanVideoClipVisionEncode,
"WanVideoImageToVideoEncode": WanVideoImageToVideoEncode,
"WanVideoEncode": WanVideoEncode,
"WanVideoEncodeLatentBatch": WanVideoEncodeLatentBatch,
"WanVideoEmptyEmbeds": WanVideoEmptyEmbeds,
"WanVideoEnhanceAVideo": WanVideoEnhanceAVideo,
"WanVideoContextOptions": WanVideoContextOptions,
"WanVideoTextEmbedBridge": WanVideoTextEmbedBridge,
"WanVideoFlowEdit": WanVideoFlowEdit,
"WanVideoControlEmbeds": WanVideoControlEmbeds,
"WanVideoSLG": WanVideoSLG,
"WanVideoLoopArgs": WanVideoLoopArgs,
"WanVideoSetBlockSwap": WanVideoSetBlockSwap,
"WanVideoExperimentalArgs": WanVideoExperimentalArgs,
"WanVideoVACEEncode": WanVideoVACEEncode,
"WanVideoPhantomEmbeds": WanVideoPhantomEmbeds,
"WanVideoRealisDanceLatents": WanVideoRealisDanceLatents,
"WanVideoApplyNAG": WanVideoApplyNAG,
"WanVideoMiniMaxRemoverEmbeds": WanVideoMiniMaxRemoverEmbeds,
"WanVideoFreeInitArgs": WanVideoFreeInitArgs,
"WanVideoSetRadialAttention": WanVideoSetRadialAttention,
"WanVideoBlockList": WanVideoBlockList,
"WanVideoTextEncodeCached": WanVideoTextEncodeCached,
"WanVideoAddExtraLatent": WanVideoAddExtraLatent,
"WanVideoScheduler": WanVideoScheduler,
"WanVideoAddStandInLatent": WanVideoAddStandInLatent,
"WanVideoAddControlEmbeds": WanVideoAddControlEmbeds,
"WanVideoAddMTVMotion": WanVideoAddMTVMotion,
"WanVideoRoPEFunction": WanVideoRoPEFunction,
"WanVideoAddPusaNoise": WanVideoAddPusaNoise,
"WanVideoAnimateEmbeds": WanVideoAnimateEmbeds,
"WanVideoAddLucyEditLatents": WanVideoAddLucyEditLatents,
"WanVideoSchedulerSA_ODE": WanVideoSchedulerSA_ODE,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"WanVideoDecode": "WanVideo Decode",
"WanVideoTextEncode": "WanVideo TextEncode",
"WanVideoTextEncodeSingle": "WanVideo TextEncodeSingle",
"WanVideoTextImageEncode": "WanVideo TextImageEncode (IP2V)",
"WanVideoClipVisionEncode": "WanVideo ClipVision Encode",
"WanVideoImageToVideoEncode": "WanVideo ImageToVideo Encode",
"WanVideoEncode": "WanVideo Encode",
"WanVideoEncodeLatentBatch": "WanVideo Encode Latent Batch",
"WanVideoEmptyEmbeds": "WanVideo Empty Embeds",
"WanVideoEnhanceAVideo": "WanVideo Enhance-A-Video",
"WanVideoContextOptions": "WanVideo Context Options",
"WanVideoTextEmbedBridge": "WanVideo TextEmbed Bridge",
"WanVideoFlowEdit": "WanVideo FlowEdit",
"WanVideoControlEmbeds": "WanVideo Control Embeds",
"WanVideoSLG": "WanVideo SLG",
"WanVideoLoopArgs": "WanVideo Loop Args",
"WanVideoSetBlockSwap": "WanVideo Set BlockSwap",
"WanVideoExperimentalArgs": "WanVideo Experimental Args",
"WanVideoVACEEncode": "WanVideo VACE Encode",
"WanVideoPhantomEmbeds": "WanVideo Phantom Embeds",
"WanVideoRealisDanceLatents": "WanVideo RealisDance Latents",
"WanVideoApplyNAG": "WanVideo Apply NAG",
"WanVideoMiniMaxRemoverEmbeds": "WanVideo MiniMax Remover Embeds",
"WanVideoFreeInitArgs": "WanVideo Free Init Args",
"WanVideoSetRadialAttention": "WanVideo Set Radial Attention",
"WanVideoBlockList": "WanVideo Block List",
"WanVideoTextEncodeCached": "WanVideo TextEncode Cached",
"WanVideoAddExtraLatent": "WanVideo Add Extra Latent",
"WanVideoAddStandInLatent": "WanVideo Add StandIn Latent",
"WanVideoAddControlEmbeds": "WanVideo Add Control Embeds",
"WanVideoAddMTVMotion": "WanVideo MTV Crafter Motion",
"WanVideoRoPEFunction": "WanVideo RoPE Function",
"WanVideoAddPusaNoise": "WanVideo Add Pusa Noise",
"WanVideoAnimateEmbeds": "WanVideo Animate Embeds",
"WanVideoAddLucyEditLatents": "WanVideo Add LucyEdit Latents",
"WanVideoSchedulerSA_ODE": "WanVideo Scheduler SA-ODE",
}