aliensmn's picture
Mirror from https://github.com/kijai/ComfyUI-WanVideoWrapper
cf812a0 verified
import os
import torch
import gc
from ..utils import log, print_memory, fourier_filter, optimized_scale, setup_radial_attention, compile_model
import math
from tqdm import tqdm
from ..wanvideo.modules.model import rope_params
from ..wanvideo.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from ..custom_linear import remove_lora_from_module, set_lora_params
from ..wanvideo.schedulers.scheduling_flow_match_lcm import FlowMatchLCMScheduler
from ..gguf.gguf import set_lora_params_gguf
from einops import rearrange
from ..enhance_a_video.globals import disable_enhance
import comfy.model_management as mm
from comfy.utils import load_torch_file, ProgressBar, common_upscale
from comfy.clip_vision import clip_preprocess, ClipVisionModel
from comfy.cli_args import args, LatentPreviewMethod
from ..nodes_model_loading import load_weights
from ..nodes_sampler import offload_transformer
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
script_directory = os.path.dirname(os.path.abspath(__file__))
def generate_timestep_matrix(
num_frames,
step_template,
base_num_frames,
ar_step=5,
num_pre_ready=0,
casual_block_size=1,
shrink_interval_with_mask=False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
step_matrix, step_index = [], []
update_mask, valid_interval = [], []
num_iterations = len(step_template) + 1
num_frames_block = num_frames // casual_block_size
base_num_frames_block = base_num_frames // casual_block_size
if base_num_frames_block < num_frames_block:
infer_step_num = len(step_template)
gen_block = base_num_frames_block
min_ar_step = infer_step_num / gen_block
assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting"
# print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block)
step_template = torch.cat(
[
torch.tensor([999], dtype=torch.int64, device=step_template.device),
step_template.long(),
torch.tensor([0], dtype=torch.int64, device=step_template.device),
]
) # to handle the counter in row works starting from 1
pre_row = torch.zeros(num_frames_block, dtype=torch.long)
if num_pre_ready > 0:
pre_row[: num_pre_ready // casual_block_size] = num_iterations
while torch.all(pre_row >= (num_iterations - 1)) == False:
new_row = torch.zeros(num_frames_block, dtype=torch.long)
for i in range(num_frames_block):
if i == 0 or pre_row[i - 1] >= (
num_iterations - 1
): # the first frame or the last frame is completely denoised
new_row[i] = pre_row[i] + 1
else:
new_row[i] = new_row[i - 1] - ar_step
new_row = new_row.clamp(0, num_iterations)
update_mask.append(
(new_row != pre_row) & (new_row != num_iterations)
) # False: no need to update, True: need to update
step_index.append(new_row)
step_matrix.append(step_template[new_row])
pre_row = new_row
# for long video we split into several sequences, base_num_frames is set to the model max length (for training)
terminal_flag = base_num_frames_block
if shrink_interval_with_mask:
idx_sequence = torch.arange(num_frames_block, dtype=torch.int64)
update_mask = update_mask[0]
update_mask_idx = idx_sequence[update_mask]
last_update_idx = update_mask_idx[-1].item()
terminal_flag = last_update_idx + 1
# for i in range(0, len(update_mask)):
for curr_mask in update_mask:
if terminal_flag < num_frames_block and curr_mask[terminal_flag]:
terminal_flag += 1
valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag))
step_update_mask = torch.stack(update_mask, dim=0)
step_index = torch.stack(step_index, dim=0)
step_matrix = torch.stack(step_matrix, dim=0)
if casual_block_size > 1:
step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval]
return step_matrix, step_index, step_update_mask, valid_interval
#region Sampler
class WanVideoDiffusionForcingSampler:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("WANVIDEOMODEL",),
"text_embeds": ("WANVIDEOTEXTEMBEDS", ),
"image_embeds": ("WANVIDIMAGE_EMBEDS", ),
"addnoise_condition": ("INT", {"default": 10, "min": 0, "max": 1000, "tooltip": "Improves consistency in long video generation"}),
"fps": ("FLOAT", {"default": 24.0, "min": 1.0, "max": 120.0, "step": 0.01}),
"steps": ("INT", {"default": 30, "min": 1}),
"cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
"shift": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"force_offload": ("BOOLEAN", {"default": True, "tooltip": "Moves the model to the offload device after sampling"}),
"scheduler": (["unipc", "unipc/beta", "euler", "euler/beta", "lcm", "lcm/beta"],
{
"default": 'unipc'
}),
},
"optional": {
"samples": ("LATENT", {"tooltip": "init Latents to use for video2video process"} ),
"prefix_samples": ("LATENT", {"tooltip": "prefix latents"} ),
"denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"cache_args": ("CACHEARGS", ),
"slg_args": ("SLGARGS", ),
"rope_function": (["default", "comfy"], {"default": "comfy", "tooltip": "Comfy's RoPE implementation doesn't use complex numbers and can thus be compiled, that should be a lot faster when using torch.compile"}),
"experimental_args": ("EXPERIMENTALARGS", ),
"unianimate_poses": ("UNIANIMATE_POSE", ),
}
}
RETURN_TYPES = ("LATENT", )
RETURN_NAMES = ("samples",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
def process(self, model, text_embeds, image_embeds, shift, fps, steps, addnoise_condition, cfg, seed, scheduler,
force_offload=True, samples=None, prefix_samples=None, denoise_strength=1.0, slg_args=None, rope_function="default", cache_args=None, teacache_args=None,
experimental_args=None, unianimate_poses=None):
#assert not (context_options and teacache_args), "Context options cannot currently be used together with teacache."
patcher = model
model = model.model
transformer = model.diffusion_model
dtype = model["base_dtype"]
weight_dtype = model["weight_dtype"]
fp8_matmul = model["fp8_matmul"]
gguf_reader = model["gguf_reader"]
control_lora = model["control_lora"]
transformer_options = patcher.model_options.get("transformer_options", None)
merge_loras = transformer_options["merge_loras"]
block_swap_args = transformer_options.get("block_swap_args", None)
if block_swap_args is not None:
transformer.use_non_blocking = block_swap_args.get("use_non_blocking", False)
transformer.blocks_to_swap = block_swap_args.get("blocks_to_swap", 0)
transformer.vace_blocks_to_swap = block_swap_args.get("vace_blocks_to_swap", 0)
transformer.prefetch_blocks = block_swap_args.get("prefetch_blocks", 0)
transformer.block_swap_debug = block_swap_args.get("block_swap_debug", False)
transformer.offload_img_emb = block_swap_args.get("offload_img_emb", False)
transformer.offload_txt_emb = block_swap_args.get("offload_txt_emb", False)
is_5b = transformer.out_dim == 48
vae_upscale_factor = 16 if is_5b else 8
# Load weights
if transformer.patched_linear and gguf_reader is None:
load_weights(patcher.model.diffusion_model, patcher.model["sd"], weight_dtype, base_dtype=dtype, transformer_load_device=device, block_swap_args=block_swap_args)
if gguf_reader is not None: #handle GGUF
load_weights(transformer, patcher.model["sd"], base_dtype=dtype, transformer_load_device=device, patcher=patcher, gguf=True, reader=gguf_reader, block_swap_args=block_swap_args)
set_lora_params_gguf(transformer, patcher.patches)
transformer.patched_linear = True
elif len(patcher.patches) != 0 and transformer.patched_linear: #handle patched linear layers (unmerged loras, fp8 scaled)
log.info(f"Using {len(patcher.patches)} LoRA weight patches for WanVideo model")
if not merge_loras and fp8_matmul:
raise NotImplementedError("FP8 matmul with unmerged LoRAs is not supported")
set_lora_params(transformer, patcher.patches)
else:
remove_lora_from_module(transformer) #clear possible unmerged lora weights
transformer.lora_scheduling_enabled = transformer_options.get("lora_scheduling_enabled", False)
#torch.compile
if model["auto_cpu_offload"] is False:
transformer = compile_model(transformer, model["compile_args"])
steps = int(steps/denoise_strength)
timesteps = None
if 'unipc' in scheduler:
sample_scheduler = FlowUniPCMultistepScheduler(shift=shift)
sample_scheduler.set_timesteps(steps, device=device, shift=shift, use_beta_sigmas=('beta' in scheduler))
elif 'euler' in scheduler:
sample_scheduler = FlowMatchEulerDiscreteScheduler(shift=shift, use_beta_sigmas=(scheduler == 'euler/beta'))
sample_scheduler.set_timesteps(steps, device=device)
elif 'lcm' in scheduler:
sample_scheduler = FlowMatchLCMScheduler(shift=shift, use_beta_sigmas=(scheduler == 'lcm/beta'))
sample_scheduler.set_timesteps(steps, device=device)
init_timesteps = sample_scheduler.timesteps
if denoise_strength < 1.0:
steps = int(steps * denoise_strength)
timesteps = timesteps[-(steps + 1):]
seed_g = torch.Generator(device=torch.device("cpu"))
seed_g.manual_seed(seed)
clip_fea, clip_fea_neg = None, None
vace_data, vace_context, vace_scale = None, None, None
image_cond = image_embeds.get("image_embeds", None)
target_shape = image_embeds.get("target_shape", None)
if target_shape is None:
raise ValueError("Empty image embeds must be provided for T2V (Text to Video")
has_ref = image_embeds.get("has_ref", False)
vace_context = image_embeds.get("vace_context", None)
vace_scale = image_embeds.get("vace_scale", None)
if not isinstance(vace_scale, list):
vace_scale = [vace_scale] * (steps+1)
vace_start_percent = image_embeds.get("vace_start_percent", 0.0)
vace_end_percent = image_embeds.get("vace_end_percent", 1.0)
vace_seqlen = image_embeds.get("vace_seq_len", None)
vace_additional_embeds = image_embeds.get("additional_vace_inputs", [])
if vace_context is not None:
vace_data = [
{"context": vace_context,
"scale": vace_scale,
"start": vace_start_percent,
"end": vace_end_percent,
"seq_len": vace_seqlen
}
]
if len(vace_additional_embeds) > 0:
for i in range(len(vace_additional_embeds)):
if vace_additional_embeds[i].get("has_ref", False):
has_ref = True
vace_scale = vace_additional_embeds[i]["vace_scale"]
if not isinstance(vace_scale, list):
vace_scale = [vace_scale] * (steps+1)
vace_data.append({
"context": vace_additional_embeds[i]["vace_context"],
"scale": vace_scale,
"start": vace_additional_embeds[i]["vace_start_percent"],
"end": vace_additional_embeds[i]["vace_end_percent"],
"seq_len": vace_additional_embeds[i]["vace_seq_len"]
})
noise = torch.randn(
target_shape[0],
target_shape[1] + 1 if has_ref else target_shape[1],
target_shape[2],
target_shape[3],
dtype=torch.float32,
device=torch.device("cpu"),
generator=seed_g)
latent_video_length = noise.shape[1]
seq_len = math.ceil((noise.shape[2] * noise.shape[3]) / 4 * noise.shape[1])
if samples is not None:
input_samples = samples["samples"].squeeze(0).to(noise)
if input_samples.shape[1] != noise.shape[1]:
input_samples = torch.cat([input_samples[:, :1].repeat(1, noise.shape[1] - input_samples.shape[1], 1, 1), input_samples], dim=1)
original_image = input_samples.to(device)
if denoise_strength < 1.0:
latent_timestep = timesteps[:1].to(noise)
noise = noise * latent_timestep / 1000 + (1 - latent_timestep / 1000) * input_samples
mask = samples.get("mask", None)
if mask is not None:
if mask.shape[2] != noise.shape[1]:
mask = torch.cat([torch.zeros(1, noise.shape[0], noise.shape[1] - mask.shape[2], noise.shape[2], noise.shape[3]), mask], dim=2)
latents = noise.to(device)
fps_embeds = None
if hasattr(transformer, "fps_embedding"):
fps = round(fps, 2)
log.info(f"Model has fps embedding, using {fps} fps")
fps_embeds = [fps]
fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
prefix_video = prefix_samples["samples"].to(noise) if prefix_samples is not None else None
prefix_video_latent_length = prefix_video.shape[2] if prefix_video is not None else 0
if prefix_video is not None:
log.info(f"Prefix video of length: {prefix_video_latent_length}")
latents[:, :prefix_video_latent_length] = prefix_video[0]
#base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_video_length
base_num_frames=latent_video_length
ar_step = 0
causal_block_size = 1
step_matrix, _, step_update_mask, valid_interval = generate_timestep_matrix(
latent_video_length, init_timesteps, base_num_frames, ar_step, prefix_video_latent_length, causal_block_size
)
sample_schedulers = []
for _ in range(latent_video_length):
if 'unipc' in scheduler:
sample_scheduler = FlowUniPCMultistepScheduler(shift=shift)
sample_scheduler.set_timesteps(steps, device=device, shift=shift, use_beta_sigmas=('beta' in scheduler))
elif 'euler' in scheduler:
sample_scheduler = FlowMatchEulerDiscreteScheduler(shift=shift)
sample_scheduler.set_timesteps(steps, device=device)
elif 'lcm' in scheduler:
sample_scheduler = FlowMatchLCMScheduler(shift=shift, use_beta_sigmas=(scheduler == 'lcm/beta'))
sample_scheduler.set_timesteps(steps, device=device)
sample_schedulers.append(sample_scheduler)
sample_schedulers_counter = [0] * latent_video_length
unianim_data = None
if unianimate_poses is not None:
transformer.dwpose_embedding.to(device)
transformer.randomref_embedding_pose.to(device)
dwpose_data = unianimate_poses["pose"]
dwpose_data = transformer.dwpose_embedding(
(torch.cat([dwpose_data[:,:,:1].repeat(1,1,3,1,1), dwpose_data], dim=2)
).to(device)).to(model["dtype"])
log.info(f"UniAnimate pose embed shape: {dwpose_data.shape}")
if dwpose_data.shape[2] > latent_video_length:
log.warning(f"UniAnimate pose embed length {dwpose_data.shape[2]} is longer than the video length {latent_video_length}, truncating")
dwpose_data = dwpose_data[:,:, :latent_video_length]
elif dwpose_data.shape[2] < latent_video_length:
log.warning(f"UniAnimate pose embed length {dwpose_data.shape[2]} is shorter than the video length {latent_video_length}, padding with last pose")
pad_len = latent_video_length - dwpose_data.shape[2]
pad = dwpose_data[:,:,:1].repeat(1,1,pad_len,1,1)
dwpose_data = torch.cat([dwpose_data, pad], dim=2)
dwpose_data_flat = rearrange(dwpose_data, 'b c f h w -> b (f h w) c').contiguous()
random_ref_dwpose_data = None
if image_cond is not None:
random_ref_dwpose = unianimate_poses.get("ref", None)
if random_ref_dwpose is not None:
random_ref_dwpose_data = transformer.randomref_embedding_pose(
random_ref_dwpose.to(device)
).unsqueeze(2).to(model["dtype"]) # [1, 20, 104, 60]
unianim_data = {
"dwpose": dwpose_data_flat,
"random_ref": random_ref_dwpose_data.squeeze(0) if random_ref_dwpose_data is not None else None,
"strength": unianimate_poses["strength"],
"start_percent": unianimate_poses["start_percent"],
"end_percent": unianimate_poses["end_percent"]
}
disable_enhance() #not sure if this can work, disabling for now to avoid errors if it's enabled by another sampler
freqs = None
transformer.rope_embedder.k = None
transformer.rope_embedder.num_frames = None
if rope_function=="comfy":
transformer.rope_embedder.k = 0
transformer.rope_embedder.num_frames = latent_video_length
else:
d = transformer.dim // transformer.num_heads
freqs = torch.cat([
rope_params(1024, d - 4 * (d // 6), L_test=latent_video_length, k=0),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6))
],
dim=1)
if not isinstance(cfg, list):
cfg = [cfg] * (steps +1)
log.info(f"Seq len: {seq_len}")
pbar = ProgressBar(steps)
if args.preview_method in [LatentPreviewMethod.Auto, LatentPreviewMethod.Latent2RGB]: #default for latent2rgb
from latent_preview import prepare_callback
else:
from ..latent_preview import prepare_callback #custom for tiny VAE previews
callback = prepare_callback(patcher, steps)
#blockswap init
#blockswap init
if not transformer.patched_linear:
if block_swap_args is not None:
transformer.use_non_blocking = block_swap_args.get("use_non_blocking", False)
for name, param in transformer.named_parameters():
if "block" not in name:
param.data = param.data.to(device)
if "control_adapter" in name:
param.data = param.data.to(device)
elif block_swap_args["offload_txt_emb"] and "txt_emb" in name:
param.data = param.data.to(offload_device)
elif block_swap_args["offload_img_emb"] and "img_emb" in name:
param.data = param.data.to(offload_device)
transformer.block_swap(
block_swap_args["blocks_to_swap"] - 1 ,
block_swap_args["offload_txt_emb"],
block_swap_args["offload_img_emb"],
vace_blocks_to_swap = block_swap_args.get("vace_blocks_to_swap", None),
prefetch_blocks = block_swap_args.get("prefetch_blocks", 0),
block_swap_debug = block_swap_args.get("block_swap_debug", False),
)
elif model["auto_cpu_offload"]:
for module in transformer.modules():
if hasattr(module, "offload"):
module.offload()
if hasattr(module, "onload"):
module.onload()
for block in transformer.blocks:
block.modulation = torch.nn.Parameter(block.modulation.to(device))
transformer.head.modulation = torch.nn.Parameter(transformer.head.modulation.to(device))
else:
transformer.to(device)
# Initialize Cache if enabled
transformer.enable_teacache = transformer.enable_magcache = False
if teacache_args is not None: #for backward compatibility on old workflows
cache_args = teacache_args
if cache_args is not None:
transformer.cache_device = cache_args["cache_device"]
if cache_args["cache_type"] == "TeaCache":
log.info(f"TeaCache: Using cache device: {transformer.cache_device}")
transformer.teacache_state.clear_all()
transformer.enable_teacache = True
transformer.rel_l1_thresh = cache_args["rel_l1_thresh"]
transformer.teacache_start_step = cache_args["start_step"]
transformer.teacache_end_step = len(init_timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"]
transformer.teacache_use_coefficients = cache_args["use_coefficients"]
transformer.teacache_mode = cache_args["mode"]
elif cache_args["cache_type"] == "MagCache":
log.info(f"MagCache: Using cache device: {transformer.cache_device}")
transformer.magcache_state.clear_all()
transformer.enable_magcache = True
transformer.magcache_start_step = cache_args["start_step"]
transformer.magcache_end_step = len(init_timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"]
transformer.magcache_thresh = cache_args["magcache_thresh"]
transformer.magcache_K = cache_args["magcache_K"]
if slg_args is not None:
transformer.slg_blocks = slg_args["blocks"]
transformer.slg_start_percent = slg_args["start_percent"]
transformer.slg_end_percent = slg_args["end_percent"]
else:
transformer.slg_blocks = None
self.teacache_state = [None, None]
self.teacache_state_source = [None, None]
self.teacache_states_context = []
if transformer.attention_mode == "radial_sage_attention":
setup_radial_attention(transformer, transformer_options, latents, seq_len, latent_video_length)
use_cfg_zero_star, use_fresca = False, False
if experimental_args is not None:
video_attention_split_steps = experimental_args.get("video_attention_split_steps", [])
if video_attention_split_steps:
transformer.video_attention_split_steps = [int(x.strip()) for x in video_attention_split_steps.split(",")]
else:
transformer.video_attention_split_steps = []
use_zero_init = experimental_args.get("use_zero_init", True)
use_cfg_zero_star = experimental_args.get("cfg_zero_star", False)
zero_star_steps = experimental_args.get("zero_star_steps", 0)
use_fresca = experimental_args.get("use_fresca", False)
if use_fresca:
fresca_scale_low = experimental_args.get("fresca_scale_low", 1.0)
fresca_scale_high = experimental_args.get("fresca_scale_high", 1.25)
fresca_freq_cutoff = experimental_args.get("fresca_freq_cutoff", 20)
#region model pred
def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, idx, image_cond=None, clip_fea=None,
vace_data=None, unianim_data=None, teacache_state=None):
with torch.autocast(device_type=mm.get_autocast_device(device), dtype=dtype, enabled=("fp8" in model["quantization"])):
if use_cfg_zero_star and (idx <= zero_star_steps) and use_zero_init:
return latent_model_input*0, None
nonlocal patcher
current_step_percentage = idx / len(init_timesteps)
control_lora_enabled = False
image_cond_input = image_cond
base_params = {
'seq_len': seq_len,
'device': device,
'freqs': freqs,
't': timestep,
'current_step': idx,
'control_lora_enabled': control_lora_enabled,
'vace_data': vace_data,
'unianim_data': unianim_data,
'fps_embeds': fps_embeds,
"nag_params": text_embeds.get("nag_params", {}),
"nag_context": text_embeds.get("nag_prompt_embeds", None),
}
batch_size = 1
if not math.isclose(cfg_scale, 1.0) and len(positive_embeds) > 1:
negative_embeds = negative_embeds * len(positive_embeds)
#cond
noise_pred_cond, teacache_state_cond = transformer(
[z], context=positive_embeds, y=[image_cond_input] if image_cond_input is not None else None,
clip_fea=clip_fea, is_uncond=False, current_step_percentage=current_step_percentage,
pred_id=teacache_state[0] if teacache_state else None,
**base_params
)
noise_pred_cond = noise_pred_cond[0].to(intermediate_device)
if math.isclose(cfg_scale, 1.0):
if use_fresca:
noise_pred_cond = fourier_filter(
noise_pred_cond,
scale_low=fresca_scale_low,
scale_high=fresca_scale_high,
freq_cutoff=fresca_freq_cutoff,
)
return noise_pred_cond, [teacache_state_cond]
#uncond
noise_pred_uncond, teacache_state_uncond = transformer(
[z], context=negative_embeds, clip_fea=clip_fea_neg if clip_fea_neg is not None else clip_fea,
y=[image_cond_input] if image_cond_input is not None else None,
is_uncond=True, current_step_percentage=current_step_percentage,
pred_id=teacache_state[1] if teacache_state else None,
**base_params
)
noise_pred_uncond = noise_pred_uncond[0].to(intermediate_device)
#cfg
#https://github.com/WeichenFan/CFG-Zero-star/
if use_cfg_zero_star:
alpha = optimized_scale(
noise_pred_cond.view(batch_size, -1),
noise_pred_uncond.view(batch_size, -1)
).view(batch_size, 1, 1, 1)
else:
alpha = 1.0
#https://github.com/WikiChao/FreSca
if use_fresca:
filtered_cond = fourier_filter(
noise_pred_cond - noise_pred_uncond,
scale_low=fresca_scale_low,
scale_high=fresca_scale_high,
freq_cutoff=fresca_freq_cutoff,
)
noise_pred = noise_pred_uncond * alpha + cfg_scale * filtered_cond * alpha
else:
noise_pred = noise_pred_uncond * alpha + cfg_scale * (noise_pred_cond - noise_pred_uncond * alpha)
return noise_pred, [teacache_state_cond, teacache_state_uncond]
log.info(f"Sampling {(latent_video_length-1) * 4 + 1} frames at {latents.shape[3]*8}x{latents.shape[2]*8} with {steps} steps")
intermediate_device = device
#clear memory before sampling
mm.unload_all_models()
mm.soft_empty_cache()
gc.collect()
try:
torch.cuda.reset_peak_memory_stats(device)
except:
pass
#region main loop start
for i, timestep_i in enumerate(tqdm(step_matrix)):
update_mask_i = step_update_mask[i]
valid_interval_i = valid_interval[i]
valid_interval_start, valid_interval_end = valid_interval_i
timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
latent_model_input = latents[:, valid_interval_start:valid_interval_end, :, :].clone()
if addnoise_condition > 0 and valid_interval_start < prefix_video_latent_length:
noise_factor = 0.001 * addnoise_condition
timestep_for_noised_condition = addnoise_condition
latent_model_input[:, valid_interval_start:prefix_video_latent_length] = (
latent_model_input[:, valid_interval_start:prefix_video_latent_length] * (1.0 - noise_factor)
+ torch.randn_like(latent_model_input[:, valid_interval_start:prefix_video_latent_length])
* noise_factor
)
timestep[:, valid_interval_start:prefix_video_latent_length] = timestep_for_noised_condition
#print("timestep", timestep)
noise_pred, self.teacache_state = predict_with_cfg(
latent_model_input.to(dtype),
cfg[i],
text_embeds["prompt_embeds"],
text_embeds["negative_prompt_embeds"],
timestep, i, image_cond, clip_fea, unianim_data=unianim_data, vace_data=vace_data,
teacache_state=self.teacache_state)
for idx in range(valid_interval_start, valid_interval_end):
if update_mask_i[idx].item():
latents[:, idx] = sample_schedulers[idx].step(
noise_pred[:, idx - valid_interval_start],
timestep_i[idx],
latents[:, idx],
return_dict=False,
generator=seed_g,
)[0]
sample_schedulers_counter[idx] += 1
x0 = latents.unsqueeze(0)
if callback is not None:
callback_latent = (latent_model_input - noise_pred.to(timestep_i[idx].device) * timestep_i[idx] / 1000).detach().permute(1,0,2,3)
callback(i, callback_latent, None, steps)
else:
pbar.update(1)
if teacache_args is not None:
states = transformer.teacache_state.states
state_names = {
0: "conditional",
1: "unconditional"
}
for pred_id, state in states.items():
name = state_names.get(pred_id, f"prediction_{pred_id}")
if 'skipped_steps' in state:
log.info(f"TeaCache skipped: {len(state['skipped_steps'])} {name} steps: {state['skipped_steps']}")
transformer.teacache_state.clear_all()
if force_offload:
if not model["auto_cpu_offload"]:
offload_transformer(transformer)
try:
print_memory(device)
torch.cuda.reset_peak_memory_stats(device)
except:
pass
return ({
"samples": x0.cpu(),
}, )
NODE_CLASS_MAPPINGS = {
"WanVideoDiffusionForcingSampler": WanVideoDiffusionForcingSampler,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"WanVideoDiffusionForcingSampler": "WanVideo Diffusion Forcing Sampler",
}