Upload simple_skyreels_nodes.py
Browse files- simple_skyreels_nodes.py +749 -0
simple_skyreels_nodes.py
ADDED
|
@@ -0,0 +1,749 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import gc
|
| 3 |
+
from .utils import log, print_memory, fourier_filter
|
| 4 |
+
import math
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
from .wanvideo.modules.model import rope_params
|
| 8 |
+
from .wanvideo.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 9 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 10 |
+
from .wanvideo.utils.scheduling_flow_match_lcm import FlowMatchLCMScheduler
|
| 11 |
+
from .nodes import optimized_scale
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
|
| 14 |
+
from .enhance_a_video.globals import disable_enhance
|
| 15 |
+
|
| 16 |
+
from .nodes import WanVideoEncode, WanVideoDecode
|
| 17 |
+
|
| 18 |
+
import comfy.model_management as mm
|
| 19 |
+
import comfy.utils
|
| 20 |
+
from comfy.utils import ProgressBar
|
| 21 |
+
from comfy.cli_args import args, LatentPreviewMethod
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def generate_timestep_matrix(
|
| 25 |
+
num_frames,
|
| 26 |
+
step_template,
|
| 27 |
+
base_num_frames,
|
| 28 |
+
ar_step=5,
|
| 29 |
+
num_pre_ready=0,
|
| 30 |
+
casual_block_size=1,
|
| 31 |
+
shrink_interval_with_mask=False,
|
| 32 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
|
| 33 |
+
step_matrix, step_index = [], []
|
| 34 |
+
update_mask, valid_interval = [], []
|
| 35 |
+
num_iterations = len(step_template) + 1
|
| 36 |
+
num_frames_block = num_frames // casual_block_size
|
| 37 |
+
base_num_frames_block = base_num_frames // casual_block_size
|
| 38 |
+
if base_num_frames_block < num_frames_block:
|
| 39 |
+
infer_step_num = len(step_template)
|
| 40 |
+
gen_block = base_num_frames_block
|
| 41 |
+
min_ar_step = infer_step_num / gen_block
|
| 42 |
+
assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting"
|
| 43 |
+
# print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block)
|
| 44 |
+
step_template = torch.cat(
|
| 45 |
+
[
|
| 46 |
+
torch.tensor([999], dtype=torch.int64, device=step_template.device),
|
| 47 |
+
step_template.long(),
|
| 48 |
+
torch.tensor([0], dtype=torch.int64, device=step_template.device),
|
| 49 |
+
]
|
| 50 |
+
) # to handle the counter in row works starting from 1
|
| 51 |
+
pre_row = torch.zeros(num_frames_block, dtype=torch.long)
|
| 52 |
+
if num_pre_ready > 0:
|
| 53 |
+
pre_row[: num_pre_ready // casual_block_size] = num_iterations
|
| 54 |
+
|
| 55 |
+
while torch.all(pre_row >= (num_iterations - 1)) == False:
|
| 56 |
+
new_row = torch.zeros(num_frames_block, dtype=torch.long)
|
| 57 |
+
for i in range(num_frames_block):
|
| 58 |
+
if i == 0 or pre_row[i - 1] >= (
|
| 59 |
+
num_iterations - 1
|
| 60 |
+
): # the first frame or the last frame is completely denoised
|
| 61 |
+
new_row[i] = pre_row[i] + 1
|
| 62 |
+
else:
|
| 63 |
+
new_row[i] = new_row[i - 1] - ar_step
|
| 64 |
+
new_row = new_row.clamp(0, num_iterations)
|
| 65 |
+
|
| 66 |
+
update_mask.append(
|
| 67 |
+
(new_row != pre_row) & (new_row != num_iterations)
|
| 68 |
+
) # False: no need to update, True: need to update
|
| 69 |
+
step_index.append(new_row)
|
| 70 |
+
step_matrix.append(step_template[new_row])
|
| 71 |
+
pre_row = new_row
|
| 72 |
+
|
| 73 |
+
# for long video we split into several sequences, base_num_frames is set to the model max length (for training)
|
| 74 |
+
terminal_flag = base_num_frames_block
|
| 75 |
+
if shrink_interval_with_mask:
|
| 76 |
+
idx_sequence = torch.arange(num_frames_block, dtype=torch.int64)
|
| 77 |
+
update_mask = update_mask[0]
|
| 78 |
+
update_mask_idx = idx_sequence[update_mask]
|
| 79 |
+
last_update_idx = update_mask_idx[-1].item()
|
| 80 |
+
terminal_flag = last_update_idx + 1
|
| 81 |
+
# for i in range(0, len(update_mask)):
|
| 82 |
+
for curr_mask in update_mask:
|
| 83 |
+
if terminal_flag < num_frames_block and curr_mask[terminal_flag]:
|
| 84 |
+
terminal_flag += 1
|
| 85 |
+
valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag))
|
| 86 |
+
|
| 87 |
+
step_update_mask = torch.stack(update_mask, dim=0)
|
| 88 |
+
step_index = torch.stack(step_index, dim=0)
|
| 89 |
+
step_matrix = torch.stack(step_matrix, dim=0)
|
| 90 |
+
|
| 91 |
+
if casual_block_size > 1:
|
| 92 |
+
step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
|
| 93 |
+
step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
|
| 94 |
+
step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
|
| 95 |
+
valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval]
|
| 96 |
+
|
| 97 |
+
return step_matrix, step_index, step_update_mask, valid_interval
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class GetImageRangeFromBatch:
|
| 101 |
+
|
| 102 |
+
RETURN_TYPES = ("IMAGE", "MASK", )
|
| 103 |
+
FUNCTION = "imagesfrombatch"
|
| 104 |
+
CATEGORY = "KJNodes/image"
|
| 105 |
+
DESCRIPTION = """Returns a range of images from a batch."""
|
| 106 |
+
|
| 107 |
+
@classmethod
|
| 108 |
+
def INPUT_TYPES(s):
|
| 109 |
+
return {
|
| 110 |
+
"required": {
|
| 111 |
+
"start_index": ("INT", {"default": 0,"min": -1, "max": 4096, "step": 1}),
|
| 112 |
+
"num_frames": ("INT", {"default": 1,"min": 1, "max": 4096, "step": 1}),
|
| 113 |
+
},
|
| 114 |
+
"optional": {
|
| 115 |
+
"images": ("IMAGE",),
|
| 116 |
+
"masks": ("MASK",),
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
def imagesfrombatch(self, start_index, num_frames, images=None, masks=None):
|
| 121 |
+
chosen_images = None
|
| 122 |
+
chosen_masks = None
|
| 123 |
+
|
| 124 |
+
# Process images if provided
|
| 125 |
+
if images is not None:
|
| 126 |
+
if start_index == -1:
|
| 127 |
+
start_index = max(0, len(images) - num_frames)
|
| 128 |
+
if start_index < 0 or start_index >= len(images):
|
| 129 |
+
raise ValueError("Start index is out of range")
|
| 130 |
+
end_index = min(start_index + num_frames, len(images))
|
| 131 |
+
chosen_images = images[start_index:end_index]
|
| 132 |
+
|
| 133 |
+
# Process masks if provided
|
| 134 |
+
if masks is not None:
|
| 135 |
+
if start_index == -1:
|
| 136 |
+
start_index = max(0, len(masks) - num_frames)
|
| 137 |
+
if start_index < 0 or start_index >= len(masks):
|
| 138 |
+
raise ValueError("Start index is out of range for masks")
|
| 139 |
+
end_index = min(start_index + num_frames, len(masks))
|
| 140 |
+
chosen_masks = masks[start_index:end_index]
|
| 141 |
+
|
| 142 |
+
return (chosen_images, chosen_masks,)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class ImageBatch:
|
| 146 |
+
|
| 147 |
+
@classmethod
|
| 148 |
+
def INPUT_TYPES(s):
|
| 149 |
+
return {"required": { "image1": ("IMAGE",), "image2": ("IMAGE",)}}
|
| 150 |
+
|
| 151 |
+
RETURN_TYPES = ("IMAGE",)
|
| 152 |
+
FUNCTION = "batch"
|
| 153 |
+
|
| 154 |
+
CATEGORY = "image"
|
| 155 |
+
|
| 156 |
+
def batch(self, image1, image2):
|
| 157 |
+
if image1.shape[1:] != image2.shape[1:]:
|
| 158 |
+
image2 = comfy.utils.common_upscale(image2.movedim(-1,1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1,-1)
|
| 159 |
+
s = torch.cat((image1, image2), dim=0)
|
| 160 |
+
return (s,)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
#region Sampler
|
| 164 |
+
class SimpleWanVideoDiffusionForcingSampler:
|
| 165 |
+
@classmethod
|
| 166 |
+
def INPUT_TYPES(s):
|
| 167 |
+
return {
|
| 168 |
+
"required": {
|
| 169 |
+
"model": ("WANVIDEOMODEL",),
|
| 170 |
+
"vae": ("WANVAE",),
|
| 171 |
+
"text_embeds": ("WANVIDEOTEXTEMBEDS", ),
|
| 172 |
+
"image_embeds_list": ("WANVIDIMAGE_EMBEDS", ),
|
| 173 |
+
"addnoise_condition": ("INT", {"default": 10, "min": 0, "max": 1000, "tooltip": "Improves consistency in long video generation"}),
|
| 174 |
+
"fps": ("FLOAT", {"default": 24.0, "min": 1.0, "max": 120.0, "step": 0.01}),
|
| 175 |
+
"steps": ("INT", {"default": 30, "min": 1}),
|
| 176 |
+
"cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
|
| 177 |
+
"shift": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
|
| 178 |
+
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
| 179 |
+
"force_offload": ("BOOLEAN", {"default": True, "tooltip": "Moves the model to the offload device after sampling"}),
|
| 180 |
+
"scheduler": (["unipc", "unipc/beta", "euler", "euler/beta", "lcm", "lcm/beta"],
|
| 181 |
+
{
|
| 182 |
+
"default": 'unipc'
|
| 183 |
+
}),
|
| 184 |
+
},
|
| 185 |
+
"optional": {
|
| 186 |
+
"samples": ("LATENT", {"tooltip": "init Latents to use for video2video process"} ),
|
| 187 |
+
"prefix_samples": ("LATENT", {"tooltip": "prefix latents"} ),
|
| 188 |
+
"denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
| 189 |
+
"teacache_args": ("TEACACHEARGS", ),
|
| 190 |
+
"slg_args": ("SLGARGS", ),
|
| 191 |
+
"rope_function": (["default", "comfy"], {"default": "comfy", "tooltip": "Comfy's RoPE implementation doesn't use complex numbers and can thus be compiled, that should be a lot faster when using torch.compile"}),
|
| 192 |
+
"experimental_args": ("EXPERIMENTALARGS", ),
|
| 193 |
+
"unianimate_poses": ("UNIANIMATE_POSE", ),
|
| 194 |
+
}
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
RETURN_TYPES = ("IMAGE", )
|
| 198 |
+
RETURN_NAMES = ("images",)
|
| 199 |
+
FUNCTION = "process"
|
| 200 |
+
CATEGORY = "WanVideoWrapper"
|
| 201 |
+
|
| 202 |
+
def process(self, model, vae, text_embeds, image_embeds_list, shift, fps, steps, addnoise_condition, cfg, seed, scheduler,
|
| 203 |
+
force_offload=True, samples=None, prefix_samples=None, denoise_strength=1.0, slg_args=None, rope_function="default", teacache_args=None,
|
| 204 |
+
experimental_args=None, unianimate_poses=None):
|
| 205 |
+
|
| 206 |
+
image_range_extractor = GetImageRangeFromBatch()
|
| 207 |
+
decoder = WanVideoDecode()
|
| 208 |
+
encoder = WanVideoEncode()
|
| 209 |
+
|
| 210 |
+
video_chunk_list = []
|
| 211 |
+
for image_embeds in image_embeds_list:
|
| 212 |
+
samples = self.sub_process(model, text_embeds, image_embeds, shift, fps, steps, addnoise_condition, cfg, seed, scheduler,
|
| 213 |
+
force_offload, samples, prefix_samples, denoise_strength, slg_args, rope_function, teacache_args,
|
| 214 |
+
experimental_args, unianimate_poses)
|
| 215 |
+
|
| 216 |
+
video_chunk = decoder.decode(vae, samples, enable_vae_tiling=False, tile_x=272, tile_y=272, tile_stride_x=144, tile_stride_y=128)[0]
|
| 217 |
+
video_chunk_list.append(video_chunk)
|
| 218 |
+
|
| 219 |
+
images = image_range_extractor.imagesfrombatch(start_index=-1, num_frames=17, images=video_chunk, masks=None)[0]
|
| 220 |
+
prefix_samples = encoder.encode(vae, images, enable_vae_tiling=False, tile_x=272, tile_y=272, tile_stride_x=144, tile_stride_y=128, noise_aug_strength=0.0, latent_strength=1.0, mask=None)[0]
|
| 221 |
+
|
| 222 |
+
image_batch_node = ImageBatch()
|
| 223 |
+
combined_video = video_chunk_list[0]
|
| 224 |
+
for i in range(1, len(video_chunk_list)):
|
| 225 |
+
num_frames = image_embeds_list[i].get("num_frames")
|
| 226 |
+
new_video = image_range_extractor.imagesfrombatch(start_index=-1, num_frames=num_frames-17, images=video_chunk_list[i], masks=None)[0]
|
| 227 |
+
|
| 228 |
+
combined_video, = image_batch_node.batch(combined_video, new_video)
|
| 229 |
+
|
| 230 |
+
return (combined_video,)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def sub_process(self, model, text_embeds, image_embeds, shift, fps, steps, addnoise_condition, cfg, seed, scheduler,
|
| 234 |
+
force_offload=True, samples=None, prefix_samples=None, denoise_strength=1.0, slg_args=None, rope_function="default", teacache_args=None,
|
| 235 |
+
experimental_args=None, unianimate_poses=None):
|
| 236 |
+
#assert not (context_options and teacache_args), "Context options cannot currently be used together with teacache."
|
| 237 |
+
patcher = model
|
| 238 |
+
model = model.model
|
| 239 |
+
transformer = model.diffusion_model
|
| 240 |
+
dtype = model["dtype"]
|
| 241 |
+
device = mm.get_torch_device()
|
| 242 |
+
offload_device = mm.unet_offload_device()
|
| 243 |
+
|
| 244 |
+
steps = int(steps/denoise_strength)
|
| 245 |
+
|
| 246 |
+
timesteps = None
|
| 247 |
+
if 'unipc' in scheduler:
|
| 248 |
+
sample_scheduler = FlowUniPCMultistepScheduler(shift=shift)
|
| 249 |
+
sample_scheduler.set_timesteps(steps, device=device, shift=shift, use_beta_sigmas=('beta' in scheduler))
|
| 250 |
+
elif 'euler' in scheduler:
|
| 251 |
+
sample_scheduler = FlowMatchEulerDiscreteScheduler(shift=shift, use_beta_sigmas=(scheduler == 'euler/beta'))
|
| 252 |
+
sample_scheduler.set_timesteps(steps, device=device)
|
| 253 |
+
elif 'lcm' in scheduler:
|
| 254 |
+
sample_scheduler = FlowMatchLCMScheduler(shift=shift, use_beta_sigmas=(scheduler == 'lcm/beta'))
|
| 255 |
+
sample_scheduler.set_timesteps(steps, device=device)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
init_timesteps = sample_scheduler.timesteps
|
| 259 |
+
|
| 260 |
+
if denoise_strength < 1.0:
|
| 261 |
+
steps = int(steps * denoise_strength)
|
| 262 |
+
timesteps = timesteps[-(steps + 1):]
|
| 263 |
+
|
| 264 |
+
seed_g = torch.Generator(device=torch.device("cpu"))
|
| 265 |
+
seed_g.manual_seed(seed)
|
| 266 |
+
|
| 267 |
+
clip_fea, clip_fea_neg = None, None
|
| 268 |
+
vace_data, vace_context, vace_scale = None, None, None
|
| 269 |
+
|
| 270 |
+
image_cond = image_embeds.get("image_embeds", None)
|
| 271 |
+
|
| 272 |
+
target_shape = image_embeds.get("target_shape", None)
|
| 273 |
+
if target_shape is None:
|
| 274 |
+
raise ValueError("Empty image embeds must be provided for T2V (Text to Video")
|
| 275 |
+
|
| 276 |
+
has_ref = image_embeds.get("has_ref", False)
|
| 277 |
+
vace_context = image_embeds.get("vace_context", None)
|
| 278 |
+
vace_scale = image_embeds.get("vace_scale", None)
|
| 279 |
+
vace_start_percent = image_embeds.get("vace_start_percent", 0.0)
|
| 280 |
+
vace_end_percent = image_embeds.get("vace_end_percent", 1.0)
|
| 281 |
+
vace_seqlen = image_embeds.get("vace_seq_len", None)
|
| 282 |
+
|
| 283 |
+
vace_additional_embeds = image_embeds.get("additional_vace_inputs", [])
|
| 284 |
+
if vace_context is not None:
|
| 285 |
+
vace_data = [
|
| 286 |
+
{"context": vace_context,
|
| 287 |
+
"scale": vace_scale,
|
| 288 |
+
"start": vace_start_percent,
|
| 289 |
+
"end": vace_end_percent,
|
| 290 |
+
"seq_len": vace_seqlen
|
| 291 |
+
}
|
| 292 |
+
]
|
| 293 |
+
if len(vace_additional_embeds) > 0:
|
| 294 |
+
for i in range(len(vace_additional_embeds)):
|
| 295 |
+
if vace_additional_embeds[i].get("has_ref", False):
|
| 296 |
+
has_ref = True
|
| 297 |
+
vace_data.append({
|
| 298 |
+
"context": vace_additional_embeds[i]["vace_context"],
|
| 299 |
+
"scale": vace_additional_embeds[i]["vace_scale"],
|
| 300 |
+
"start": vace_additional_embeds[i]["vace_start_percent"],
|
| 301 |
+
"end": vace_additional_embeds[i]["vace_end_percent"],
|
| 302 |
+
"seq_len": vace_additional_embeds[i]["vace_seq_len"]
|
| 303 |
+
})
|
| 304 |
+
|
| 305 |
+
noise = torch.randn(
|
| 306 |
+
target_shape[0],
|
| 307 |
+
target_shape[1] + 1 if has_ref else target_shape[1],
|
| 308 |
+
target_shape[2],
|
| 309 |
+
target_shape[3],
|
| 310 |
+
dtype=torch.float32,
|
| 311 |
+
device=torch.device("cpu"),
|
| 312 |
+
generator=seed_g)
|
| 313 |
+
|
| 314 |
+
latent_video_length = noise.shape[1]
|
| 315 |
+
seq_len = math.ceil((noise.shape[2] * noise.shape[3]) / 4 * noise.shape[1])
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
if samples is not None:
|
| 320 |
+
input_samples = samples["samples"].squeeze(0).to(noise)
|
| 321 |
+
if input_samples.shape[1] != noise.shape[1]:
|
| 322 |
+
input_samples = torch.cat([input_samples[:, :1].repeat(1, noise.shape[1] - input_samples.shape[1], 1, 1), input_samples], dim=1)
|
| 323 |
+
original_image = input_samples.to(device)
|
| 324 |
+
if denoise_strength < 1.0:
|
| 325 |
+
latent_timestep = timesteps[:1].to(noise)
|
| 326 |
+
noise = noise * latent_timestep / 1000 + (1 - latent_timestep / 1000) * input_samples
|
| 327 |
+
|
| 328 |
+
mask = samples.get("mask", None)
|
| 329 |
+
if mask is not None:
|
| 330 |
+
if mask.shape[2] != noise.shape[1]:
|
| 331 |
+
mask = torch.cat([torch.zeros(1, noise.shape[0], noise.shape[1] - mask.shape[2], noise.shape[2], noise.shape[3]), mask], dim=2)
|
| 332 |
+
|
| 333 |
+
latents = noise.to(device)
|
| 334 |
+
|
| 335 |
+
fps_embeds = None
|
| 336 |
+
if hasattr(transformer, "fps_embedding"):
|
| 337 |
+
fps = round(fps, 2)
|
| 338 |
+
log.info(f"Model has fps embedding, using {fps} fps")
|
| 339 |
+
fps_embeds = [fps]
|
| 340 |
+
fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
|
| 341 |
+
|
| 342 |
+
prefix_video = prefix_samples["samples"].to(noise) if prefix_samples is not None else None
|
| 343 |
+
prefix_video_latent_length = prefix_video.shape[2] if prefix_video is not None else 0
|
| 344 |
+
if prefix_video is not None:
|
| 345 |
+
log.info(f"Prefix video of length: {prefix_video_latent_length}")
|
| 346 |
+
latents[:, :prefix_video_latent_length] = prefix_video[0]
|
| 347 |
+
#base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_video_length
|
| 348 |
+
base_num_frames=latent_video_length
|
| 349 |
+
|
| 350 |
+
ar_step = 0
|
| 351 |
+
causal_block_size = 1
|
| 352 |
+
step_matrix, _, step_update_mask, valid_interval = generate_timestep_matrix(
|
| 353 |
+
latent_video_length, init_timesteps, base_num_frames, ar_step, prefix_video_latent_length, causal_block_size
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
sample_schedulers = []
|
| 357 |
+
for _ in range(latent_video_length):
|
| 358 |
+
if 'unipc' in scheduler:
|
| 359 |
+
sample_scheduler = FlowUniPCMultistepScheduler(shift=shift)
|
| 360 |
+
sample_scheduler.set_timesteps(steps, device=device, shift=shift, use_beta_sigmas=('beta' in scheduler))
|
| 361 |
+
elif 'euler' in scheduler:
|
| 362 |
+
sample_scheduler = FlowMatchEulerDiscreteScheduler(shift=shift)
|
| 363 |
+
sample_scheduler.set_timesteps(steps, device=device)
|
| 364 |
+
elif 'lcm' in scheduler:
|
| 365 |
+
sample_scheduler = FlowMatchLCMScheduler(shift=shift, use_beta_sigmas=(scheduler == 'lcm/beta'))
|
| 366 |
+
sample_scheduler.set_timesteps(steps, device=device)
|
| 367 |
+
|
| 368 |
+
sample_schedulers.append(sample_scheduler)
|
| 369 |
+
sample_schedulers_counter = [0] * latent_video_length
|
| 370 |
+
|
| 371 |
+
unianim_data = None
|
| 372 |
+
if unianimate_poses is not None:
|
| 373 |
+
transformer.dwpose_embedding.to(device)
|
| 374 |
+
transformer.randomref_embedding_pose.to(device)
|
| 375 |
+
dwpose_data = unianimate_poses["pose"]
|
| 376 |
+
dwpose_data = transformer.dwpose_embedding(
|
| 377 |
+
(torch.cat([dwpose_data[:,:,:1].repeat(1,1,3,1,1), dwpose_data], dim=2)
|
| 378 |
+
).to(device)).to(model["dtype"])
|
| 379 |
+
log.info(f"UniAnimate pose embed shape: {dwpose_data.shape}")
|
| 380 |
+
if dwpose_data.shape[2] > latent_video_length:
|
| 381 |
+
log.warning(f"UniAnimate pose embed length {dwpose_data.shape[2]} is longer than the video length {latent_video_length}, truncating")
|
| 382 |
+
dwpose_data = dwpose_data[:,:, :latent_video_length]
|
| 383 |
+
elif dwpose_data.shape[2] < latent_video_length:
|
| 384 |
+
log.warning(f"UniAnimate pose embed length {dwpose_data.shape[2]} is shorter than the video length {latent_video_length}, padding with last pose")
|
| 385 |
+
pad_len = latent_video_length - dwpose_data.shape[2]
|
| 386 |
+
pad = dwpose_data[:,:,:1].repeat(1,1,pad_len,1,1)
|
| 387 |
+
dwpose_data = torch.cat([dwpose_data, pad], dim=2)
|
| 388 |
+
dwpose_data_flat = rearrange(dwpose_data, 'b c f h w -> b (f h w) c').contiguous()
|
| 389 |
+
|
| 390 |
+
random_ref_dwpose_data = None
|
| 391 |
+
if image_cond is not None:
|
| 392 |
+
random_ref_dwpose = unianimate_poses.get("ref", None)
|
| 393 |
+
if random_ref_dwpose is not None:
|
| 394 |
+
random_ref_dwpose_data = transformer.randomref_embedding_pose(
|
| 395 |
+
random_ref_dwpose.to(device)
|
| 396 |
+
).unsqueeze(2).to(model["dtype"]) # [1, 20, 104, 60]
|
| 397 |
+
|
| 398 |
+
unianim_data = {
|
| 399 |
+
"dwpose": dwpose_data_flat,
|
| 400 |
+
"random_ref": random_ref_dwpose_data.squeeze(0) if random_ref_dwpose_data is not None else None,
|
| 401 |
+
"strength": unianimate_poses["strength"],
|
| 402 |
+
"start_percent": unianimate_poses["start_percent"],
|
| 403 |
+
"end_percent": unianimate_poses["end_percent"]
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
disable_enhance() #not sure if this can work, disabling for now to avoid errors if it's enabled by another sampler
|
| 407 |
+
|
| 408 |
+
freqs = None
|
| 409 |
+
transformer.rope_embedder.k = None
|
| 410 |
+
transformer.rope_embedder.num_frames = None
|
| 411 |
+
if rope_function=="comfy":
|
| 412 |
+
transformer.rope_embedder.k = 0
|
| 413 |
+
transformer.rope_embedder.num_frames = latent_video_length
|
| 414 |
+
else:
|
| 415 |
+
d = transformer.dim // transformer.num_heads
|
| 416 |
+
freqs = torch.cat([
|
| 417 |
+
rope_params(1024, d - 4 * (d // 6), L_test=latent_video_length, k=0),
|
| 418 |
+
rope_params(1024, 2 * (d // 6)),
|
| 419 |
+
rope_params(1024, 2 * (d // 6))
|
| 420 |
+
],
|
| 421 |
+
dim=1)
|
| 422 |
+
|
| 423 |
+
if not isinstance(cfg, list):
|
| 424 |
+
cfg = [cfg] * (steps +1)
|
| 425 |
+
|
| 426 |
+
log.info(f"Seq len: {seq_len}")
|
| 427 |
+
|
| 428 |
+
pbar = ProgressBar(steps)
|
| 429 |
+
|
| 430 |
+
if args.preview_method in [LatentPreviewMethod.Auto, LatentPreviewMethod.Latent2RGB]: #default for latent2rgb
|
| 431 |
+
from latent_preview import prepare_callback
|
| 432 |
+
else:
|
| 433 |
+
from latent_preview import prepare_callback #custom for tiny VAE previews
|
| 434 |
+
callback = prepare_callback(patcher, steps)
|
| 435 |
+
|
| 436 |
+
#blockswap init
|
| 437 |
+
transformer_options = patcher.model_options.get("transformer_options", None)
|
| 438 |
+
if transformer_options is not None:
|
| 439 |
+
block_swap_args = transformer_options.get("block_swap_args", None)
|
| 440 |
+
|
| 441 |
+
if block_swap_args is not None:
|
| 442 |
+
transformer.use_non_blocking = block_swap_args.get("use_non_blocking", True)
|
| 443 |
+
for name, param in transformer.named_parameters():
|
| 444 |
+
if "block" not in name:
|
| 445 |
+
param.data = param.data.to(device)
|
| 446 |
+
elif block_swap_args["offload_txt_emb"] and "txt_emb" in name:
|
| 447 |
+
param.data = param.data.to(offload_device, non_blocking=transformer.use_non_blocking)
|
| 448 |
+
elif block_swap_args["offload_img_emb"] and "img_emb" in name:
|
| 449 |
+
param.data = param.data.to(offload_device, non_blocking=transformer.use_non_blocking)
|
| 450 |
+
|
| 451 |
+
transformer.block_swap(
|
| 452 |
+
block_swap_args["blocks_to_swap"] - 1 ,
|
| 453 |
+
block_swap_args["offload_txt_emb"],
|
| 454 |
+
block_swap_args["offload_img_emb"],
|
| 455 |
+
vace_blocks_to_swap = block_swap_args.get("vace_blocks_to_swap", None),
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
elif model["auto_cpu_offload"]:
|
| 459 |
+
for module in transformer.modules():
|
| 460 |
+
if hasattr(module, "offload"):
|
| 461 |
+
module.offload()
|
| 462 |
+
if hasattr(module, "onload"):
|
| 463 |
+
module.onload()
|
| 464 |
+
elif model["manual_offloading"]:
|
| 465 |
+
transformer.to(device)
|
| 466 |
+
|
| 467 |
+
# Initialize TeaCache if enabled
|
| 468 |
+
if teacache_args is not None:
|
| 469 |
+
transformer.enable_teacache = True
|
| 470 |
+
transformer.rel_l1_thresh = teacache_args["rel_l1_thresh"]
|
| 471 |
+
transformer.teacache_start_step = teacache_args["start_step"]
|
| 472 |
+
transformer.teacache_cache_device = teacache_args["cache_device"]
|
| 473 |
+
log.info(f"TeaCache: Using cache device: {transformer.teacache_state.cache_device}")
|
| 474 |
+
transformer.teacache_end_step = len(init_timesteps)-1 if teacache_args["end_step"] == -1 else teacache_args["end_step"]
|
| 475 |
+
transformer.teacache_use_coefficients = teacache_args["use_coefficients"]
|
| 476 |
+
transformer.teacache_mode = teacache_args["mode"]
|
| 477 |
+
transformer.teacache_state.clear_all()
|
| 478 |
+
else:
|
| 479 |
+
transformer.enable_teacache = False
|
| 480 |
+
|
| 481 |
+
if slg_args is not None:
|
| 482 |
+
transformer.slg_blocks = slg_args["blocks"]
|
| 483 |
+
transformer.slg_start_percent = slg_args["start_percent"]
|
| 484 |
+
transformer.slg_end_percent = slg_args["end_percent"]
|
| 485 |
+
else:
|
| 486 |
+
transformer.slg_blocks = None
|
| 487 |
+
|
| 488 |
+
self.teacache_state = [None, None]
|
| 489 |
+
self.teacache_state_source = [None, None]
|
| 490 |
+
self.teacache_states_context = []
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
use_cfg_zero_star, use_fresca = False, False
|
| 494 |
+
if experimental_args is not None:
|
| 495 |
+
video_attention_split_steps = experimental_args.get("video_attention_split_steps", [])
|
| 496 |
+
if video_attention_split_steps:
|
| 497 |
+
transformer.video_attention_split_steps = [int(x.strip()) for x in video_attention_split_steps.split(",")]
|
| 498 |
+
else:
|
| 499 |
+
transformer.video_attention_split_steps = []
|
| 500 |
+
use_zero_init = experimental_args.get("use_zero_init", True)
|
| 501 |
+
use_cfg_zero_star = experimental_args.get("cfg_zero_star", False)
|
| 502 |
+
zero_star_steps = experimental_args.get("zero_star_steps", 0)
|
| 503 |
+
|
| 504 |
+
use_fresca = experimental_args.get("use_fresca", False)
|
| 505 |
+
if use_fresca:
|
| 506 |
+
fresca_scale_low = experimental_args.get("fresca_scale_low", 1.0)
|
| 507 |
+
fresca_scale_high = experimental_args.get("fresca_scale_high", 1.25)
|
| 508 |
+
fresca_freq_cutoff = experimental_args.get("fresca_freq_cutoff", 20)
|
| 509 |
+
|
| 510 |
+
#region model pred
|
| 511 |
+
def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, idx, image_cond=None, clip_fea=None,
|
| 512 |
+
vace_data=None, unianim_data=None, teacache_state=None):
|
| 513 |
+
with torch.autocast(device_type=mm.get_autocast_device(device), dtype=dtype, enabled=("fp8" in model["quantization"])):
|
| 514 |
+
|
| 515 |
+
if use_cfg_zero_star and (idx <= zero_star_steps) and use_zero_init:
|
| 516 |
+
return latent_model_input*0, None
|
| 517 |
+
|
| 518 |
+
nonlocal patcher
|
| 519 |
+
current_step_percentage = idx / len(init_timesteps)
|
| 520 |
+
control_lora_enabled = False
|
| 521 |
+
|
| 522 |
+
image_cond_input = image_cond
|
| 523 |
+
|
| 524 |
+
base_params = {
|
| 525 |
+
'seq_len': seq_len,
|
| 526 |
+
'device': device,
|
| 527 |
+
'freqs': freqs,
|
| 528 |
+
't': timestep,
|
| 529 |
+
'current_step': idx,
|
| 530 |
+
'control_lora_enabled': control_lora_enabled,
|
| 531 |
+
'vace_data': vace_data,
|
| 532 |
+
'unianim_data': unianim_data,
|
| 533 |
+
'fps_embeds': fps_embeds,
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
batch_size = 1
|
| 537 |
+
|
| 538 |
+
if not math.isclose(cfg_scale, 1.0) and len(positive_embeds) > 1:
|
| 539 |
+
negative_embeds = negative_embeds * len(positive_embeds)
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
#cond
|
| 543 |
+
noise_pred_cond, teacache_state_cond = transformer(
|
| 544 |
+
[z], context=positive_embeds, y=[image_cond_input] if image_cond_input is not None else None,
|
| 545 |
+
clip_fea=clip_fea, is_uncond=False, current_step_percentage=current_step_percentage,
|
| 546 |
+
pred_id=teacache_state[0] if teacache_state else None,
|
| 547 |
+
**base_params
|
| 548 |
+
)
|
| 549 |
+
noise_pred_cond = noise_pred_cond[0].to(intermediate_device)
|
| 550 |
+
if math.isclose(cfg_scale, 1.0):
|
| 551 |
+
if use_fresca:
|
| 552 |
+
noise_pred_cond = fourier_filter(
|
| 553 |
+
noise_pred_cond,
|
| 554 |
+
scale_low=fresca_scale_low,
|
| 555 |
+
scale_high=fresca_scale_high,
|
| 556 |
+
freq_cutoff=fresca_freq_cutoff,
|
| 557 |
+
)
|
| 558 |
+
return noise_pred_cond, [teacache_state_cond]
|
| 559 |
+
#uncond
|
| 560 |
+
noise_pred_uncond, teacache_state_uncond = transformer(
|
| 561 |
+
[z], context=negative_embeds, clip_fea=clip_fea_neg if clip_fea_neg is not None else clip_fea,
|
| 562 |
+
y=[image_cond_input] if image_cond_input is not None else None,
|
| 563 |
+
is_uncond=True, current_step_percentage=current_step_percentage,
|
| 564 |
+
pred_id=teacache_state[1] if teacache_state else None,
|
| 565 |
+
**base_params
|
| 566 |
+
)
|
| 567 |
+
noise_pred_uncond = noise_pred_uncond[0].to(intermediate_device)
|
| 568 |
+
|
| 569 |
+
#cfg
|
| 570 |
+
|
| 571 |
+
#https://github.com/WeichenFan/CFG-Zero-star/
|
| 572 |
+
if use_cfg_zero_star:
|
| 573 |
+
alpha = optimized_scale(
|
| 574 |
+
noise_pred_cond.view(batch_size, -1),
|
| 575 |
+
noise_pred_uncond.view(batch_size, -1)
|
| 576 |
+
).view(batch_size, 1, 1, 1)
|
| 577 |
+
else:
|
| 578 |
+
alpha = 1.0
|
| 579 |
+
|
| 580 |
+
#https://github.com/WikiChao/FreSca
|
| 581 |
+
if use_fresca:
|
| 582 |
+
filtered_cond = fourier_filter(
|
| 583 |
+
noise_pred_cond - noise_pred_uncond,
|
| 584 |
+
scale_low=fresca_scale_low,
|
| 585 |
+
scale_high=fresca_scale_high,
|
| 586 |
+
freq_cutoff=fresca_freq_cutoff,
|
| 587 |
+
)
|
| 588 |
+
noise_pred = noise_pred_uncond * alpha + cfg_scale * filtered_cond * alpha
|
| 589 |
+
else:
|
| 590 |
+
noise_pred = noise_pred_uncond * alpha + cfg_scale * (noise_pred_cond - noise_pred_uncond * alpha)
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
return noise_pred, [teacache_state_cond, teacache_state_uncond]
|
| 594 |
+
|
| 595 |
+
log.info(f"Sampling {(latent_video_length-1) * 4 + 1} frames at {latents.shape[3]*8}x{latents.shape[2]*8} with {steps} steps")
|
| 596 |
+
|
| 597 |
+
intermediate_device = device
|
| 598 |
+
|
| 599 |
+
#clear memory before sampling
|
| 600 |
+
mm.unload_all_models()
|
| 601 |
+
mm.soft_empty_cache()
|
| 602 |
+
gc.collect()
|
| 603 |
+
try:
|
| 604 |
+
torch.cuda.reset_peak_memory_stats(device)
|
| 605 |
+
except:
|
| 606 |
+
pass
|
| 607 |
+
|
| 608 |
+
#region main loop start
|
| 609 |
+
for i, timestep_i in enumerate(tqdm(step_matrix)):
|
| 610 |
+
update_mask_i = step_update_mask[i]
|
| 611 |
+
valid_interval_i = valid_interval[i]
|
| 612 |
+
valid_interval_start, valid_interval_end = valid_interval_i
|
| 613 |
+
timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
|
| 614 |
+
latent_model_input = latents[:, valid_interval_start:valid_interval_end, :, :].clone()
|
| 615 |
+
if addnoise_condition > 0 and valid_interval_start < prefix_video_latent_length:
|
| 616 |
+
noise_factor = 0.001 * addnoise_condition
|
| 617 |
+
timestep_for_noised_condition = addnoise_condition
|
| 618 |
+
latent_model_input[:, valid_interval_start:prefix_video_latent_length] = (
|
| 619 |
+
latent_model_input[:, valid_interval_start:prefix_video_latent_length] * (1.0 - noise_factor)
|
| 620 |
+
+ torch.randn_like(latent_model_input[:, valid_interval_start:prefix_video_latent_length])
|
| 621 |
+
* noise_factor
|
| 622 |
+
)
|
| 623 |
+
timestep[:, valid_interval_start:prefix_video_latent_length] = timestep_for_noised_condition
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
#print("timestep", timestep)
|
| 627 |
+
noise_pred, self.teacache_state = predict_with_cfg(
|
| 628 |
+
latent_model_input.to(dtype),
|
| 629 |
+
cfg[i],
|
| 630 |
+
text_embeds["prompt_embeds"],
|
| 631 |
+
text_embeds["negative_prompt_embeds"],
|
| 632 |
+
timestep, i, image_cond, clip_fea, unianim_data=unianim_data, vace_data=vace_data,
|
| 633 |
+
teacache_state=self.teacache_state)
|
| 634 |
+
|
| 635 |
+
for idx in range(valid_interval_start, valid_interval_end):
|
| 636 |
+
if update_mask_i[idx].item():
|
| 637 |
+
latents[:, idx] = sample_schedulers[idx].step(
|
| 638 |
+
noise_pred[:, idx - valid_interval_start],
|
| 639 |
+
timestep_i[idx],
|
| 640 |
+
latents[:, idx],
|
| 641 |
+
return_dict=False,
|
| 642 |
+
generator=seed_g,
|
| 643 |
+
)[0]
|
| 644 |
+
sample_schedulers_counter[idx] += 1
|
| 645 |
+
|
| 646 |
+
x0 = latents.unsqueeze(0)
|
| 647 |
+
if callback is not None:
|
| 648 |
+
callback_latent = (latent_model_input - noise_pred.to(timestep_i[idx].device) * timestep_i[idx] / 1000).detach().permute(1,0,2,3)
|
| 649 |
+
callback(i, callback_latent, None, steps)
|
| 650 |
+
else:
|
| 651 |
+
pbar.update(1)
|
| 652 |
+
|
| 653 |
+
if teacache_args is not None:
|
| 654 |
+
states = transformer.teacache_state.states
|
| 655 |
+
state_names = {
|
| 656 |
+
0: "conditional",
|
| 657 |
+
1: "unconditional"
|
| 658 |
+
}
|
| 659 |
+
for pred_id, state in states.items():
|
| 660 |
+
name = state_names.get(pred_id, f"prediction_{pred_id}")
|
| 661 |
+
if 'skipped_steps' in state:
|
| 662 |
+
log.info(f"TeaCache skipped: {len(state['skipped_steps'])} {name} steps: {state['skipped_steps']}")
|
| 663 |
+
transformer.teacache_state.clear_all()
|
| 664 |
+
|
| 665 |
+
if force_offload:
|
| 666 |
+
if model["manual_offloading"]:
|
| 667 |
+
transformer.to(offload_device)
|
| 668 |
+
mm.soft_empty_cache()
|
| 669 |
+
gc.collect()
|
| 670 |
+
|
| 671 |
+
try:
|
| 672 |
+
print_memory(device)
|
| 673 |
+
torch.cuda.reset_peak_memory_stats(device)
|
| 674 |
+
except:
|
| 675 |
+
pass
|
| 676 |
+
|
| 677 |
+
return {"samples": x0.cpu(),}
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
class SimpleWanVideoEmptyEmbeds:
|
| 681 |
+
@classmethod
|
| 682 |
+
def INPUT_TYPES(s):
|
| 683 |
+
return {"required": {
|
| 684 |
+
"width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 8, "tooltip": "Width of the image to encode"}),
|
| 685 |
+
"height": ("INT", {"default": 480, "min": 64, "max": 29048, "step": 8, "tooltip": "Height of the image to encode"}),
|
| 686 |
+
"num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}),
|
| 687 |
+
},
|
| 688 |
+
"optional": {
|
| 689 |
+
"control_embeds": ("WANVIDIMAGE_EMBEDS", {"tooltip": "control signal for the Fun -model"}),
|
| 690 |
+
}
|
| 691 |
+
}
|
| 692 |
+
|
| 693 |
+
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", )
|
| 694 |
+
RETURN_NAMES = ("image_embeds",)
|
| 695 |
+
FUNCTION = "process"
|
| 696 |
+
CATEGORY = "WanVideoWrapper"
|
| 697 |
+
|
| 698 |
+
def get_chunk_num_frame_list(self, num_frames):
|
| 699 |
+
# Define the maximum chunk size as a constant
|
| 700 |
+
# 97 for 540P, 121 for 720P
|
| 701 |
+
# To reduce peak VRAM, just lower the --base_num_frames, e.g., to 77 or 57,
|
| 702 |
+
# while keeping the same generative length --num_frames you want to generate.
|
| 703 |
+
# This may slightly reduce video quality, and it should not be set too small.
|
| 704 |
+
# todo: need to test vram
|
| 705 |
+
MAX_FRAMES_PER_CHUNK = 97
|
| 706 |
+
|
| 707 |
+
# Calculate how many complete chunks we need
|
| 708 |
+
full_chunks = num_frames // MAX_FRAMES_PER_CHUNK
|
| 709 |
+
# Calculate the size of the remainder chunk (if any)
|
| 710 |
+
remainder = num_frames % MAX_FRAMES_PER_CHUNK
|
| 711 |
+
|
| 712 |
+
# Create the list of chunk sizes
|
| 713 |
+
chunk_num_frames_list = [MAX_FRAMES_PER_CHUNK] * full_chunks
|
| 714 |
+
if remainder > 0:
|
| 715 |
+
chunk_num_frames_list.append(remainder)
|
| 716 |
+
|
| 717 |
+
return chunk_num_frames_list
|
| 718 |
+
|
| 719 |
+
def process(self, num_frames, width, height, control_embeds=None):
|
| 720 |
+
embeds_list = []
|
| 721 |
+
chunk_num_frames_list = self.get_chunk_num_frame_list(num_frames)
|
| 722 |
+
|
| 723 |
+
for i in range(len(chunk_num_frames_list)):
|
| 724 |
+
sub_num_frames = chunk_num_frames_list[i]
|
| 725 |
+
vae_stride = (4, 8, 8)
|
| 726 |
+
|
| 727 |
+
target_shape = (16, (sub_num_frames - 1) // vae_stride[0] + 1,
|
| 728 |
+
height // vae_stride[1],
|
| 729 |
+
width // vae_stride[2])
|
| 730 |
+
|
| 731 |
+
embeds = {
|
| 732 |
+
"target_shape": target_shape,
|
| 733 |
+
"num_frames": sub_num_frames,
|
| 734 |
+
"control_embeds": control_embeds["control_embeds"] if control_embeds is not None else None,
|
| 735 |
+
}
|
| 736 |
+
embeds_list.append(embeds)
|
| 737 |
+
|
| 738 |
+
return (embeds_list,)
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
NODE_CLASS_MAPPINGS = {
|
| 742 |
+
"SimpleWanVideoDiffusionForcingSampler": SimpleWanVideoDiffusionForcingSampler,
|
| 743 |
+
"SimpleWanVideoEmptyEmbeds": SimpleWanVideoEmptyEmbeds
|
| 744 |
+
}
|
| 745 |
+
|
| 746 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
| 747 |
+
"SimpleWanVideoDiffusionForcingSampler": "Simple WanVideo Diffusion Forcing Sampler",
|
| 748 |
+
"SimpleWanVideoEmptyEmbeds": "Simple WanVideo Empty Embeds",
|
| 749 |
+
}
|