|
|
import os |
|
|
import torch |
|
|
import gc |
|
|
from ..utils import log, print_memory, fourier_filter, optimized_scale, setup_radial_attention, compile_model |
|
|
import math |
|
|
from tqdm import tqdm |
|
|
|
|
|
from ..wanvideo.modules.model import rope_params |
|
|
from ..wanvideo.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler |
|
|
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler |
|
|
from ..custom_linear import remove_lora_from_module, set_lora_params |
|
|
from ..wanvideo.schedulers.scheduling_flow_match_lcm import FlowMatchLCMScheduler |
|
|
from ..gguf.gguf import set_lora_params_gguf |
|
|
from einops import rearrange |
|
|
|
|
|
from ..enhance_a_video.globals import disable_enhance |
|
|
|
|
|
import comfy.model_management as mm |
|
|
from comfy.utils import load_torch_file, ProgressBar, common_upscale |
|
|
from comfy.clip_vision import clip_preprocess, ClipVisionModel |
|
|
from comfy.cli_args import args, LatentPreviewMethod |
|
|
from ..nodes_model_loading import load_weights |
|
|
from ..nodes_sampler import offload_transformer |
|
|
|
|
|
device = mm.get_torch_device() |
|
|
offload_device = mm.unet_offload_device() |
|
|
|
|
|
script_directory = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
def generate_timestep_matrix( |
|
|
num_frames, |
|
|
step_template, |
|
|
base_num_frames, |
|
|
ar_step=5, |
|
|
num_pre_ready=0, |
|
|
casual_block_size=1, |
|
|
shrink_interval_with_mask=False, |
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]: |
|
|
step_matrix, step_index = [], [] |
|
|
update_mask, valid_interval = [], [] |
|
|
num_iterations = len(step_template) + 1 |
|
|
num_frames_block = num_frames // casual_block_size |
|
|
base_num_frames_block = base_num_frames // casual_block_size |
|
|
if base_num_frames_block < num_frames_block: |
|
|
infer_step_num = len(step_template) |
|
|
gen_block = base_num_frames_block |
|
|
min_ar_step = infer_step_num / gen_block |
|
|
assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting" |
|
|
|
|
|
step_template = torch.cat( |
|
|
[ |
|
|
torch.tensor([999], dtype=torch.int64, device=step_template.device), |
|
|
step_template.long(), |
|
|
torch.tensor([0], dtype=torch.int64, device=step_template.device), |
|
|
] |
|
|
) |
|
|
pre_row = torch.zeros(num_frames_block, dtype=torch.long) |
|
|
if num_pre_ready > 0: |
|
|
pre_row[: num_pre_ready // casual_block_size] = num_iterations |
|
|
|
|
|
while torch.all(pre_row >= (num_iterations - 1)) == False: |
|
|
new_row = torch.zeros(num_frames_block, dtype=torch.long) |
|
|
for i in range(num_frames_block): |
|
|
if i == 0 or pre_row[i - 1] >= ( |
|
|
num_iterations - 1 |
|
|
): |
|
|
new_row[i] = pre_row[i] + 1 |
|
|
else: |
|
|
new_row[i] = new_row[i - 1] - ar_step |
|
|
new_row = new_row.clamp(0, num_iterations) |
|
|
|
|
|
update_mask.append( |
|
|
(new_row != pre_row) & (new_row != num_iterations) |
|
|
) |
|
|
step_index.append(new_row) |
|
|
step_matrix.append(step_template[new_row]) |
|
|
pre_row = new_row |
|
|
|
|
|
|
|
|
terminal_flag = base_num_frames_block |
|
|
if shrink_interval_with_mask: |
|
|
idx_sequence = torch.arange(num_frames_block, dtype=torch.int64) |
|
|
update_mask = update_mask[0] |
|
|
update_mask_idx = idx_sequence[update_mask] |
|
|
last_update_idx = update_mask_idx[-1].item() |
|
|
terminal_flag = last_update_idx + 1 |
|
|
|
|
|
for curr_mask in update_mask: |
|
|
if terminal_flag < num_frames_block and curr_mask[terminal_flag]: |
|
|
terminal_flag += 1 |
|
|
valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag)) |
|
|
|
|
|
step_update_mask = torch.stack(update_mask, dim=0) |
|
|
step_index = torch.stack(step_index, dim=0) |
|
|
step_matrix = torch.stack(step_matrix, dim=0) |
|
|
|
|
|
if casual_block_size > 1: |
|
|
step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() |
|
|
step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() |
|
|
step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() |
|
|
valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval] |
|
|
|
|
|
return step_matrix, step_index, step_update_mask, valid_interval |
|
|
|
|
|
|
|
|
class WanVideoDiffusionForcingSampler: |
|
|
@classmethod |
|
|
def INPUT_TYPES(s): |
|
|
return { |
|
|
"required": { |
|
|
"model": ("WANVIDEOMODEL",), |
|
|
"text_embeds": ("WANVIDEOTEXTEMBEDS", ), |
|
|
"image_embeds": ("WANVIDIMAGE_EMBEDS", ), |
|
|
"addnoise_condition": ("INT", {"default": 10, "min": 0, "max": 1000, "tooltip": "Improves consistency in long video generation"}), |
|
|
"fps": ("FLOAT", {"default": 24.0, "min": 1.0, "max": 120.0, "step": 0.01}), |
|
|
"steps": ("INT", {"default": 30, "min": 1}), |
|
|
"cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), |
|
|
"shift": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 1000.0, "step": 0.01}), |
|
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), |
|
|
"force_offload": ("BOOLEAN", {"default": True, "tooltip": "Moves the model to the offload device after sampling"}), |
|
|
"scheduler": (["unipc", "unipc/beta", "euler", "euler/beta", "lcm", "lcm/beta"], |
|
|
{ |
|
|
"default": 'unipc' |
|
|
}), |
|
|
}, |
|
|
"optional": { |
|
|
"samples": ("LATENT", {"tooltip": "init Latents to use for video2video process"} ), |
|
|
"prefix_samples": ("LATENT", {"tooltip": "prefix latents"} ), |
|
|
"denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), |
|
|
"cache_args": ("CACHEARGS", ), |
|
|
"slg_args": ("SLGARGS", ), |
|
|
"rope_function": (["default", "comfy"], {"default": "comfy", "tooltip": "Comfy's RoPE implementation doesn't use complex numbers and can thus be compiled, that should be a lot faster when using torch.compile"}), |
|
|
"experimental_args": ("EXPERIMENTALARGS", ), |
|
|
"unianimate_poses": ("UNIANIMATE_POSE", ), |
|
|
} |
|
|
} |
|
|
|
|
|
RETURN_TYPES = ("LATENT", ) |
|
|
RETURN_NAMES = ("samples",) |
|
|
FUNCTION = "process" |
|
|
CATEGORY = "WanVideoWrapper" |
|
|
|
|
|
def process(self, model, text_embeds, image_embeds, shift, fps, steps, addnoise_condition, cfg, seed, scheduler, |
|
|
force_offload=True, samples=None, prefix_samples=None, denoise_strength=1.0, slg_args=None, rope_function="default", cache_args=None, teacache_args=None, |
|
|
experimental_args=None, unianimate_poses=None): |
|
|
|
|
|
patcher = model |
|
|
model = model.model |
|
|
transformer = model.diffusion_model |
|
|
|
|
|
dtype = model["base_dtype"] |
|
|
weight_dtype = model["weight_dtype"] |
|
|
fp8_matmul = model["fp8_matmul"] |
|
|
gguf_reader = model["gguf_reader"] |
|
|
control_lora = model["control_lora"] |
|
|
|
|
|
transformer_options = patcher.model_options.get("transformer_options", None) |
|
|
merge_loras = transformer_options["merge_loras"] |
|
|
|
|
|
block_swap_args = transformer_options.get("block_swap_args", None) |
|
|
if block_swap_args is not None: |
|
|
transformer.use_non_blocking = block_swap_args.get("use_non_blocking", False) |
|
|
transformer.blocks_to_swap = block_swap_args.get("blocks_to_swap", 0) |
|
|
transformer.vace_blocks_to_swap = block_swap_args.get("vace_blocks_to_swap", 0) |
|
|
transformer.prefetch_blocks = block_swap_args.get("prefetch_blocks", 0) |
|
|
transformer.block_swap_debug = block_swap_args.get("block_swap_debug", False) |
|
|
transformer.offload_img_emb = block_swap_args.get("offload_img_emb", False) |
|
|
transformer.offload_txt_emb = block_swap_args.get("offload_txt_emb", False) |
|
|
|
|
|
is_5b = transformer.out_dim == 48 |
|
|
vae_upscale_factor = 16 if is_5b else 8 |
|
|
|
|
|
|
|
|
if transformer.patched_linear and gguf_reader is None: |
|
|
load_weights(patcher.model.diffusion_model, patcher.model["sd"], weight_dtype, base_dtype=dtype, transformer_load_device=device, block_swap_args=block_swap_args) |
|
|
|
|
|
if gguf_reader is not None: |
|
|
load_weights(transformer, patcher.model["sd"], base_dtype=dtype, transformer_load_device=device, patcher=patcher, gguf=True, reader=gguf_reader, block_swap_args=block_swap_args) |
|
|
set_lora_params_gguf(transformer, patcher.patches) |
|
|
transformer.patched_linear = True |
|
|
elif len(patcher.patches) != 0 and transformer.patched_linear: |
|
|
log.info(f"Using {len(patcher.patches)} LoRA weight patches for WanVideo model") |
|
|
if not merge_loras and fp8_matmul: |
|
|
raise NotImplementedError("FP8 matmul with unmerged LoRAs is not supported") |
|
|
set_lora_params(transformer, patcher.patches) |
|
|
else: |
|
|
remove_lora_from_module(transformer) |
|
|
|
|
|
transformer.lora_scheduling_enabled = transformer_options.get("lora_scheduling_enabled", False) |
|
|
|
|
|
|
|
|
if model["auto_cpu_offload"] is False: |
|
|
transformer = compile_model(transformer, model["compile_args"]) |
|
|
|
|
|
steps = int(steps/denoise_strength) |
|
|
|
|
|
timesteps = None |
|
|
if 'unipc' in scheduler: |
|
|
sample_scheduler = FlowUniPCMultistepScheduler(shift=shift) |
|
|
sample_scheduler.set_timesteps(steps, device=device, shift=shift, use_beta_sigmas=('beta' in scheduler)) |
|
|
elif 'euler' in scheduler: |
|
|
sample_scheduler = FlowMatchEulerDiscreteScheduler(shift=shift, use_beta_sigmas=(scheduler == 'euler/beta')) |
|
|
sample_scheduler.set_timesteps(steps, device=device) |
|
|
elif 'lcm' in scheduler: |
|
|
sample_scheduler = FlowMatchLCMScheduler(shift=shift, use_beta_sigmas=(scheduler == 'lcm/beta')) |
|
|
sample_scheduler.set_timesteps(steps, device=device) |
|
|
|
|
|
|
|
|
init_timesteps = sample_scheduler.timesteps |
|
|
|
|
|
if denoise_strength < 1.0: |
|
|
steps = int(steps * denoise_strength) |
|
|
timesteps = timesteps[-(steps + 1):] |
|
|
|
|
|
seed_g = torch.Generator(device=torch.device("cpu")) |
|
|
seed_g.manual_seed(seed) |
|
|
|
|
|
clip_fea, clip_fea_neg = None, None |
|
|
vace_data, vace_context, vace_scale = None, None, None |
|
|
|
|
|
image_cond = image_embeds.get("image_embeds", None) |
|
|
|
|
|
target_shape = image_embeds.get("target_shape", None) |
|
|
if target_shape is None: |
|
|
raise ValueError("Empty image embeds must be provided for T2V (Text to Video") |
|
|
|
|
|
has_ref = image_embeds.get("has_ref", False) |
|
|
vace_context = image_embeds.get("vace_context", None) |
|
|
vace_scale = image_embeds.get("vace_scale", None) |
|
|
if not isinstance(vace_scale, list): |
|
|
vace_scale = [vace_scale] * (steps+1) |
|
|
vace_start_percent = image_embeds.get("vace_start_percent", 0.0) |
|
|
vace_end_percent = image_embeds.get("vace_end_percent", 1.0) |
|
|
vace_seqlen = image_embeds.get("vace_seq_len", None) |
|
|
|
|
|
vace_additional_embeds = image_embeds.get("additional_vace_inputs", []) |
|
|
if vace_context is not None: |
|
|
vace_data = [ |
|
|
{"context": vace_context, |
|
|
"scale": vace_scale, |
|
|
"start": vace_start_percent, |
|
|
"end": vace_end_percent, |
|
|
"seq_len": vace_seqlen |
|
|
} |
|
|
] |
|
|
if len(vace_additional_embeds) > 0: |
|
|
for i in range(len(vace_additional_embeds)): |
|
|
if vace_additional_embeds[i].get("has_ref", False): |
|
|
has_ref = True |
|
|
vace_scale = vace_additional_embeds[i]["vace_scale"] |
|
|
if not isinstance(vace_scale, list): |
|
|
vace_scale = [vace_scale] * (steps+1) |
|
|
vace_data.append({ |
|
|
"context": vace_additional_embeds[i]["vace_context"], |
|
|
"scale": vace_scale, |
|
|
"start": vace_additional_embeds[i]["vace_start_percent"], |
|
|
"end": vace_additional_embeds[i]["vace_end_percent"], |
|
|
"seq_len": vace_additional_embeds[i]["vace_seq_len"] |
|
|
}) |
|
|
|
|
|
noise = torch.randn( |
|
|
target_shape[0], |
|
|
target_shape[1] + 1 if has_ref else target_shape[1], |
|
|
target_shape[2], |
|
|
target_shape[3], |
|
|
dtype=torch.float32, |
|
|
device=torch.device("cpu"), |
|
|
generator=seed_g) |
|
|
|
|
|
latent_video_length = noise.shape[1] |
|
|
seq_len = math.ceil((noise.shape[2] * noise.shape[3]) / 4 * noise.shape[1]) |
|
|
|
|
|
|
|
|
|
|
|
if samples is not None: |
|
|
input_samples = samples["samples"].squeeze(0).to(noise) |
|
|
if input_samples.shape[1] != noise.shape[1]: |
|
|
input_samples = torch.cat([input_samples[:, :1].repeat(1, noise.shape[1] - input_samples.shape[1], 1, 1), input_samples], dim=1) |
|
|
original_image = input_samples.to(device) |
|
|
if denoise_strength < 1.0: |
|
|
latent_timestep = timesteps[:1].to(noise) |
|
|
noise = noise * latent_timestep / 1000 + (1 - latent_timestep / 1000) * input_samples |
|
|
|
|
|
mask = samples.get("mask", None) |
|
|
if mask is not None: |
|
|
if mask.shape[2] != noise.shape[1]: |
|
|
mask = torch.cat([torch.zeros(1, noise.shape[0], noise.shape[1] - mask.shape[2], noise.shape[2], noise.shape[3]), mask], dim=2) |
|
|
|
|
|
latents = noise.to(device) |
|
|
|
|
|
fps_embeds = None |
|
|
if hasattr(transformer, "fps_embedding"): |
|
|
fps = round(fps, 2) |
|
|
log.info(f"Model has fps embedding, using {fps} fps") |
|
|
fps_embeds = [fps] |
|
|
fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] |
|
|
|
|
|
prefix_video = prefix_samples["samples"].to(noise) if prefix_samples is not None else None |
|
|
prefix_video_latent_length = prefix_video.shape[2] if prefix_video is not None else 0 |
|
|
if prefix_video is not None: |
|
|
log.info(f"Prefix video of length: {prefix_video_latent_length}") |
|
|
latents[:, :prefix_video_latent_length] = prefix_video[0] |
|
|
|
|
|
base_num_frames=latent_video_length |
|
|
|
|
|
ar_step = 0 |
|
|
causal_block_size = 1 |
|
|
step_matrix, _, step_update_mask, valid_interval = generate_timestep_matrix( |
|
|
latent_video_length, init_timesteps, base_num_frames, ar_step, prefix_video_latent_length, causal_block_size |
|
|
) |
|
|
|
|
|
sample_schedulers = [] |
|
|
for _ in range(latent_video_length): |
|
|
if 'unipc' in scheduler: |
|
|
sample_scheduler = FlowUniPCMultistepScheduler(shift=shift) |
|
|
sample_scheduler.set_timesteps(steps, device=device, shift=shift, use_beta_sigmas=('beta' in scheduler)) |
|
|
elif 'euler' in scheduler: |
|
|
sample_scheduler = FlowMatchEulerDiscreteScheduler(shift=shift) |
|
|
sample_scheduler.set_timesteps(steps, device=device) |
|
|
elif 'lcm' in scheduler: |
|
|
sample_scheduler = FlowMatchLCMScheduler(shift=shift, use_beta_sigmas=(scheduler == 'lcm/beta')) |
|
|
sample_scheduler.set_timesteps(steps, device=device) |
|
|
|
|
|
sample_schedulers.append(sample_scheduler) |
|
|
sample_schedulers_counter = [0] * latent_video_length |
|
|
|
|
|
unianim_data = None |
|
|
if unianimate_poses is not None: |
|
|
transformer.dwpose_embedding.to(device) |
|
|
transformer.randomref_embedding_pose.to(device) |
|
|
dwpose_data = unianimate_poses["pose"] |
|
|
dwpose_data = transformer.dwpose_embedding( |
|
|
(torch.cat([dwpose_data[:,:,:1].repeat(1,1,3,1,1), dwpose_data], dim=2) |
|
|
).to(device)).to(model["dtype"]) |
|
|
log.info(f"UniAnimate pose embed shape: {dwpose_data.shape}") |
|
|
if dwpose_data.shape[2] > latent_video_length: |
|
|
log.warning(f"UniAnimate pose embed length {dwpose_data.shape[2]} is longer than the video length {latent_video_length}, truncating") |
|
|
dwpose_data = dwpose_data[:,:, :latent_video_length] |
|
|
elif dwpose_data.shape[2] < latent_video_length: |
|
|
log.warning(f"UniAnimate pose embed length {dwpose_data.shape[2]} is shorter than the video length {latent_video_length}, padding with last pose") |
|
|
pad_len = latent_video_length - dwpose_data.shape[2] |
|
|
pad = dwpose_data[:,:,:1].repeat(1,1,pad_len,1,1) |
|
|
dwpose_data = torch.cat([dwpose_data, pad], dim=2) |
|
|
dwpose_data_flat = rearrange(dwpose_data, 'b c f h w -> b (f h w) c').contiguous() |
|
|
|
|
|
random_ref_dwpose_data = None |
|
|
if image_cond is not None: |
|
|
random_ref_dwpose = unianimate_poses.get("ref", None) |
|
|
if random_ref_dwpose is not None: |
|
|
random_ref_dwpose_data = transformer.randomref_embedding_pose( |
|
|
random_ref_dwpose.to(device) |
|
|
).unsqueeze(2).to(model["dtype"]) |
|
|
|
|
|
unianim_data = { |
|
|
"dwpose": dwpose_data_flat, |
|
|
"random_ref": random_ref_dwpose_data.squeeze(0) if random_ref_dwpose_data is not None else None, |
|
|
"strength": unianimate_poses["strength"], |
|
|
"start_percent": unianimate_poses["start_percent"], |
|
|
"end_percent": unianimate_poses["end_percent"] |
|
|
} |
|
|
|
|
|
disable_enhance() |
|
|
|
|
|
freqs = None |
|
|
transformer.rope_embedder.k = None |
|
|
transformer.rope_embedder.num_frames = None |
|
|
if rope_function=="comfy": |
|
|
transformer.rope_embedder.k = 0 |
|
|
transformer.rope_embedder.num_frames = latent_video_length |
|
|
else: |
|
|
d = transformer.dim // transformer.num_heads |
|
|
freqs = torch.cat([ |
|
|
rope_params(1024, d - 4 * (d // 6), L_test=latent_video_length, k=0), |
|
|
rope_params(1024, 2 * (d // 6)), |
|
|
rope_params(1024, 2 * (d // 6)) |
|
|
], |
|
|
dim=1) |
|
|
|
|
|
if not isinstance(cfg, list): |
|
|
cfg = [cfg] * (steps +1) |
|
|
|
|
|
log.info(f"Seq len: {seq_len}") |
|
|
|
|
|
pbar = ProgressBar(steps) |
|
|
|
|
|
if args.preview_method in [LatentPreviewMethod.Auto, LatentPreviewMethod.Latent2RGB]: |
|
|
from latent_preview import prepare_callback |
|
|
else: |
|
|
from ..latent_preview import prepare_callback |
|
|
callback = prepare_callback(patcher, steps) |
|
|
|
|
|
|
|
|
|
|
|
if not transformer.patched_linear: |
|
|
if block_swap_args is not None: |
|
|
transformer.use_non_blocking = block_swap_args.get("use_non_blocking", False) |
|
|
for name, param in transformer.named_parameters(): |
|
|
if "block" not in name: |
|
|
param.data = param.data.to(device) |
|
|
if "control_adapter" in name: |
|
|
param.data = param.data.to(device) |
|
|
elif block_swap_args["offload_txt_emb"] and "txt_emb" in name: |
|
|
param.data = param.data.to(offload_device) |
|
|
elif block_swap_args["offload_img_emb"] and "img_emb" in name: |
|
|
param.data = param.data.to(offload_device) |
|
|
|
|
|
transformer.block_swap( |
|
|
block_swap_args["blocks_to_swap"] - 1 , |
|
|
block_swap_args["offload_txt_emb"], |
|
|
block_swap_args["offload_img_emb"], |
|
|
vace_blocks_to_swap = block_swap_args.get("vace_blocks_to_swap", None), |
|
|
prefetch_blocks = block_swap_args.get("prefetch_blocks", 0), |
|
|
block_swap_debug = block_swap_args.get("block_swap_debug", False), |
|
|
) |
|
|
elif model["auto_cpu_offload"]: |
|
|
for module in transformer.modules(): |
|
|
if hasattr(module, "offload"): |
|
|
module.offload() |
|
|
if hasattr(module, "onload"): |
|
|
module.onload() |
|
|
for block in transformer.blocks: |
|
|
block.modulation = torch.nn.Parameter(block.modulation.to(device)) |
|
|
transformer.head.modulation = torch.nn.Parameter(transformer.head.modulation.to(device)) |
|
|
else: |
|
|
transformer.to(device) |
|
|
|
|
|
|
|
|
transformer.enable_teacache = transformer.enable_magcache = False |
|
|
if teacache_args is not None: |
|
|
cache_args = teacache_args |
|
|
if cache_args is not None: |
|
|
transformer.cache_device = cache_args["cache_device"] |
|
|
if cache_args["cache_type"] == "TeaCache": |
|
|
log.info(f"TeaCache: Using cache device: {transformer.cache_device}") |
|
|
transformer.teacache_state.clear_all() |
|
|
transformer.enable_teacache = True |
|
|
transformer.rel_l1_thresh = cache_args["rel_l1_thresh"] |
|
|
transformer.teacache_start_step = cache_args["start_step"] |
|
|
transformer.teacache_end_step = len(init_timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"] |
|
|
transformer.teacache_use_coefficients = cache_args["use_coefficients"] |
|
|
transformer.teacache_mode = cache_args["mode"] |
|
|
elif cache_args["cache_type"] == "MagCache": |
|
|
log.info(f"MagCache: Using cache device: {transformer.cache_device}") |
|
|
transformer.magcache_state.clear_all() |
|
|
transformer.enable_magcache = True |
|
|
transformer.magcache_start_step = cache_args["start_step"] |
|
|
transformer.magcache_end_step = len(init_timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"] |
|
|
transformer.magcache_thresh = cache_args["magcache_thresh"] |
|
|
transformer.magcache_K = cache_args["magcache_K"] |
|
|
|
|
|
if slg_args is not None: |
|
|
transformer.slg_blocks = slg_args["blocks"] |
|
|
transformer.slg_start_percent = slg_args["start_percent"] |
|
|
transformer.slg_end_percent = slg_args["end_percent"] |
|
|
else: |
|
|
transformer.slg_blocks = None |
|
|
|
|
|
self.teacache_state = [None, None] |
|
|
self.teacache_state_source = [None, None] |
|
|
self.teacache_states_context = [] |
|
|
|
|
|
if transformer.attention_mode == "radial_sage_attention": |
|
|
setup_radial_attention(transformer, transformer_options, latents, seq_len, latent_video_length) |
|
|
|
|
|
|
|
|
use_cfg_zero_star, use_fresca = False, False |
|
|
if experimental_args is not None: |
|
|
video_attention_split_steps = experimental_args.get("video_attention_split_steps", []) |
|
|
if video_attention_split_steps: |
|
|
transformer.video_attention_split_steps = [int(x.strip()) for x in video_attention_split_steps.split(",")] |
|
|
else: |
|
|
transformer.video_attention_split_steps = [] |
|
|
use_zero_init = experimental_args.get("use_zero_init", True) |
|
|
use_cfg_zero_star = experimental_args.get("cfg_zero_star", False) |
|
|
zero_star_steps = experimental_args.get("zero_star_steps", 0) |
|
|
|
|
|
use_fresca = experimental_args.get("use_fresca", False) |
|
|
if use_fresca: |
|
|
fresca_scale_low = experimental_args.get("fresca_scale_low", 1.0) |
|
|
fresca_scale_high = experimental_args.get("fresca_scale_high", 1.25) |
|
|
fresca_freq_cutoff = experimental_args.get("fresca_freq_cutoff", 20) |
|
|
|
|
|
|
|
|
def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, idx, image_cond=None, clip_fea=None, |
|
|
vace_data=None, unianim_data=None, teacache_state=None): |
|
|
with torch.autocast(device_type=mm.get_autocast_device(device), dtype=dtype, enabled=("fp8" in model["quantization"])): |
|
|
|
|
|
if use_cfg_zero_star and (idx <= zero_star_steps) and use_zero_init: |
|
|
return latent_model_input*0, None |
|
|
|
|
|
nonlocal patcher |
|
|
current_step_percentage = idx / len(init_timesteps) |
|
|
control_lora_enabled = False |
|
|
|
|
|
image_cond_input = image_cond |
|
|
|
|
|
base_params = { |
|
|
'seq_len': seq_len, |
|
|
'device': device, |
|
|
'freqs': freqs, |
|
|
't': timestep, |
|
|
'current_step': idx, |
|
|
'control_lora_enabled': control_lora_enabled, |
|
|
'vace_data': vace_data, |
|
|
'unianim_data': unianim_data, |
|
|
'fps_embeds': fps_embeds, |
|
|
"nag_params": text_embeds.get("nag_params", {}), |
|
|
"nag_context": text_embeds.get("nag_prompt_embeds", None), |
|
|
} |
|
|
|
|
|
batch_size = 1 |
|
|
|
|
|
if not math.isclose(cfg_scale, 1.0) and len(positive_embeds) > 1: |
|
|
negative_embeds = negative_embeds * len(positive_embeds) |
|
|
|
|
|
|
|
|
|
|
|
noise_pred_cond, teacache_state_cond = transformer( |
|
|
[z], context=positive_embeds, y=[image_cond_input] if image_cond_input is not None else None, |
|
|
clip_fea=clip_fea, is_uncond=False, current_step_percentage=current_step_percentage, |
|
|
pred_id=teacache_state[0] if teacache_state else None, |
|
|
**base_params |
|
|
) |
|
|
noise_pred_cond = noise_pred_cond[0].to(intermediate_device) |
|
|
if math.isclose(cfg_scale, 1.0): |
|
|
if use_fresca: |
|
|
noise_pred_cond = fourier_filter( |
|
|
noise_pred_cond, |
|
|
scale_low=fresca_scale_low, |
|
|
scale_high=fresca_scale_high, |
|
|
freq_cutoff=fresca_freq_cutoff, |
|
|
) |
|
|
return noise_pred_cond, [teacache_state_cond] |
|
|
|
|
|
noise_pred_uncond, teacache_state_uncond = transformer( |
|
|
[z], context=negative_embeds, clip_fea=clip_fea_neg if clip_fea_neg is not None else clip_fea, |
|
|
y=[image_cond_input] if image_cond_input is not None else None, |
|
|
is_uncond=True, current_step_percentage=current_step_percentage, |
|
|
pred_id=teacache_state[1] if teacache_state else None, |
|
|
**base_params |
|
|
) |
|
|
noise_pred_uncond = noise_pred_uncond[0].to(intermediate_device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if use_cfg_zero_star: |
|
|
alpha = optimized_scale( |
|
|
noise_pred_cond.view(batch_size, -1), |
|
|
noise_pred_uncond.view(batch_size, -1) |
|
|
).view(batch_size, 1, 1, 1) |
|
|
else: |
|
|
alpha = 1.0 |
|
|
|
|
|
|
|
|
if use_fresca: |
|
|
filtered_cond = fourier_filter( |
|
|
noise_pred_cond - noise_pred_uncond, |
|
|
scale_low=fresca_scale_low, |
|
|
scale_high=fresca_scale_high, |
|
|
freq_cutoff=fresca_freq_cutoff, |
|
|
) |
|
|
noise_pred = noise_pred_uncond * alpha + cfg_scale * filtered_cond * alpha |
|
|
else: |
|
|
noise_pred = noise_pred_uncond * alpha + cfg_scale * (noise_pred_cond - noise_pred_uncond * alpha) |
|
|
|
|
|
|
|
|
return noise_pred, [teacache_state_cond, teacache_state_uncond] |
|
|
|
|
|
log.info(f"Sampling {(latent_video_length-1) * 4 + 1} frames at {latents.shape[3]*8}x{latents.shape[2]*8} with {steps} steps") |
|
|
|
|
|
intermediate_device = device |
|
|
|
|
|
|
|
|
mm.unload_all_models() |
|
|
mm.soft_empty_cache() |
|
|
gc.collect() |
|
|
try: |
|
|
torch.cuda.reset_peak_memory_stats(device) |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
for i, timestep_i in enumerate(tqdm(step_matrix)): |
|
|
update_mask_i = step_update_mask[i] |
|
|
valid_interval_i = valid_interval[i] |
|
|
valid_interval_start, valid_interval_end = valid_interval_i |
|
|
timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() |
|
|
latent_model_input = latents[:, valid_interval_start:valid_interval_end, :, :].clone() |
|
|
if addnoise_condition > 0 and valid_interval_start < prefix_video_latent_length: |
|
|
noise_factor = 0.001 * addnoise_condition |
|
|
timestep_for_noised_condition = addnoise_condition |
|
|
latent_model_input[:, valid_interval_start:prefix_video_latent_length] = ( |
|
|
latent_model_input[:, valid_interval_start:prefix_video_latent_length] * (1.0 - noise_factor) |
|
|
+ torch.randn_like(latent_model_input[:, valid_interval_start:prefix_video_latent_length]) |
|
|
* noise_factor |
|
|
) |
|
|
timestep[:, valid_interval_start:prefix_video_latent_length] = timestep_for_noised_condition |
|
|
|
|
|
|
|
|
|
|
|
noise_pred, self.teacache_state = predict_with_cfg( |
|
|
latent_model_input.to(dtype), |
|
|
cfg[i], |
|
|
text_embeds["prompt_embeds"], |
|
|
text_embeds["negative_prompt_embeds"], |
|
|
timestep, i, image_cond, clip_fea, unianim_data=unianim_data, vace_data=vace_data, |
|
|
teacache_state=self.teacache_state) |
|
|
|
|
|
for idx in range(valid_interval_start, valid_interval_end): |
|
|
if update_mask_i[idx].item(): |
|
|
latents[:, idx] = sample_schedulers[idx].step( |
|
|
noise_pred[:, idx - valid_interval_start], |
|
|
timestep_i[idx], |
|
|
latents[:, idx], |
|
|
return_dict=False, |
|
|
generator=seed_g, |
|
|
)[0] |
|
|
sample_schedulers_counter[idx] += 1 |
|
|
|
|
|
x0 = latents.unsqueeze(0) |
|
|
if callback is not None: |
|
|
callback_latent = (latent_model_input - noise_pred.to(timestep_i[idx].device) * timestep_i[idx] / 1000).detach().permute(1,0,2,3) |
|
|
callback(i, callback_latent, None, steps) |
|
|
else: |
|
|
pbar.update(1) |
|
|
|
|
|
if teacache_args is not None: |
|
|
states = transformer.teacache_state.states |
|
|
state_names = { |
|
|
0: "conditional", |
|
|
1: "unconditional" |
|
|
} |
|
|
for pred_id, state in states.items(): |
|
|
name = state_names.get(pred_id, f"prediction_{pred_id}") |
|
|
if 'skipped_steps' in state: |
|
|
log.info(f"TeaCache skipped: {len(state['skipped_steps'])} {name} steps: {state['skipped_steps']}") |
|
|
transformer.teacache_state.clear_all() |
|
|
|
|
|
if force_offload: |
|
|
if not model["auto_cpu_offload"]: |
|
|
offload_transformer(transformer) |
|
|
|
|
|
try: |
|
|
print_memory(device) |
|
|
torch.cuda.reset_peak_memory_stats(device) |
|
|
except: |
|
|
pass |
|
|
|
|
|
return ({ |
|
|
"samples": x0.cpu(), |
|
|
}, ) |
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
|
"WanVideoDiffusionForcingSampler": WanVideoDiffusionForcingSampler, |
|
|
} |
|
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
|
"WanVideoDiffusionForcingSampler": "WanVideo Diffusion Forcing Sampler", |
|
|
} |
|
|
|