sin30 commited on
Commit
6c858a3
·
verified ·
1 Parent(s): 991ded8

Upload simple_skyreels_nodes.py

Browse files
Files changed (1) hide show
  1. simple_skyreels_nodes.py +749 -0
simple_skyreels_nodes.py ADDED
@@ -0,0 +1,749 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gc
3
+ from .utils import log, print_memory, fourier_filter
4
+ import math
5
+ from tqdm import tqdm
6
+
7
+ from .wanvideo.modules.model import rope_params
8
+ from .wanvideo.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
9
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
10
+ from .wanvideo.utils.scheduling_flow_match_lcm import FlowMatchLCMScheduler
11
+ from .nodes import optimized_scale
12
+ from einops import rearrange
13
+
14
+ from .enhance_a_video.globals import disable_enhance
15
+
16
+ from .nodes import WanVideoEncode, WanVideoDecode
17
+
18
+ import comfy.model_management as mm
19
+ import comfy.utils
20
+ from comfy.utils import ProgressBar
21
+ from comfy.cli_args import args, LatentPreviewMethod
22
+
23
+
24
+ def generate_timestep_matrix(
25
+ num_frames,
26
+ step_template,
27
+ base_num_frames,
28
+ ar_step=5,
29
+ num_pre_ready=0,
30
+ casual_block_size=1,
31
+ shrink_interval_with_mask=False,
32
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
33
+ step_matrix, step_index = [], []
34
+ update_mask, valid_interval = [], []
35
+ num_iterations = len(step_template) + 1
36
+ num_frames_block = num_frames // casual_block_size
37
+ base_num_frames_block = base_num_frames // casual_block_size
38
+ if base_num_frames_block < num_frames_block:
39
+ infer_step_num = len(step_template)
40
+ gen_block = base_num_frames_block
41
+ min_ar_step = infer_step_num / gen_block
42
+ assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting"
43
+ # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block)
44
+ step_template = torch.cat(
45
+ [
46
+ torch.tensor([999], dtype=torch.int64, device=step_template.device),
47
+ step_template.long(),
48
+ torch.tensor([0], dtype=torch.int64, device=step_template.device),
49
+ ]
50
+ ) # to handle the counter in row works starting from 1
51
+ pre_row = torch.zeros(num_frames_block, dtype=torch.long)
52
+ if num_pre_ready > 0:
53
+ pre_row[: num_pre_ready // casual_block_size] = num_iterations
54
+
55
+ while torch.all(pre_row >= (num_iterations - 1)) == False:
56
+ new_row = torch.zeros(num_frames_block, dtype=torch.long)
57
+ for i in range(num_frames_block):
58
+ if i == 0 or pre_row[i - 1] >= (
59
+ num_iterations - 1
60
+ ): # the first frame or the last frame is completely denoised
61
+ new_row[i] = pre_row[i] + 1
62
+ else:
63
+ new_row[i] = new_row[i - 1] - ar_step
64
+ new_row = new_row.clamp(0, num_iterations)
65
+
66
+ update_mask.append(
67
+ (new_row != pre_row) & (new_row != num_iterations)
68
+ ) # False: no need to update, True: need to update
69
+ step_index.append(new_row)
70
+ step_matrix.append(step_template[new_row])
71
+ pre_row = new_row
72
+
73
+ # for long video we split into several sequences, base_num_frames is set to the model max length (for training)
74
+ terminal_flag = base_num_frames_block
75
+ if shrink_interval_with_mask:
76
+ idx_sequence = torch.arange(num_frames_block, dtype=torch.int64)
77
+ update_mask = update_mask[0]
78
+ update_mask_idx = idx_sequence[update_mask]
79
+ last_update_idx = update_mask_idx[-1].item()
80
+ terminal_flag = last_update_idx + 1
81
+ # for i in range(0, len(update_mask)):
82
+ for curr_mask in update_mask:
83
+ if terminal_flag < num_frames_block and curr_mask[terminal_flag]:
84
+ terminal_flag += 1
85
+ valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag))
86
+
87
+ step_update_mask = torch.stack(update_mask, dim=0)
88
+ step_index = torch.stack(step_index, dim=0)
89
+ step_matrix = torch.stack(step_matrix, dim=0)
90
+
91
+ if casual_block_size > 1:
92
+ step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
93
+ step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
94
+ step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
95
+ valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval]
96
+
97
+ return step_matrix, step_index, step_update_mask, valid_interval
98
+
99
+
100
+ class GetImageRangeFromBatch:
101
+
102
+ RETURN_TYPES = ("IMAGE", "MASK", )
103
+ FUNCTION = "imagesfrombatch"
104
+ CATEGORY = "KJNodes/image"
105
+ DESCRIPTION = """Returns a range of images from a batch."""
106
+
107
+ @classmethod
108
+ def INPUT_TYPES(s):
109
+ return {
110
+ "required": {
111
+ "start_index": ("INT", {"default": 0,"min": -1, "max": 4096, "step": 1}),
112
+ "num_frames": ("INT", {"default": 1,"min": 1, "max": 4096, "step": 1}),
113
+ },
114
+ "optional": {
115
+ "images": ("IMAGE",),
116
+ "masks": ("MASK",),
117
+ }
118
+ }
119
+
120
+ def imagesfrombatch(self, start_index, num_frames, images=None, masks=None):
121
+ chosen_images = None
122
+ chosen_masks = None
123
+
124
+ # Process images if provided
125
+ if images is not None:
126
+ if start_index == -1:
127
+ start_index = max(0, len(images) - num_frames)
128
+ if start_index < 0 or start_index >= len(images):
129
+ raise ValueError("Start index is out of range")
130
+ end_index = min(start_index + num_frames, len(images))
131
+ chosen_images = images[start_index:end_index]
132
+
133
+ # Process masks if provided
134
+ if masks is not None:
135
+ if start_index == -1:
136
+ start_index = max(0, len(masks) - num_frames)
137
+ if start_index < 0 or start_index >= len(masks):
138
+ raise ValueError("Start index is out of range for masks")
139
+ end_index = min(start_index + num_frames, len(masks))
140
+ chosen_masks = masks[start_index:end_index]
141
+
142
+ return (chosen_images, chosen_masks,)
143
+
144
+
145
+ class ImageBatch:
146
+
147
+ @classmethod
148
+ def INPUT_TYPES(s):
149
+ return {"required": { "image1": ("IMAGE",), "image2": ("IMAGE",)}}
150
+
151
+ RETURN_TYPES = ("IMAGE",)
152
+ FUNCTION = "batch"
153
+
154
+ CATEGORY = "image"
155
+
156
+ def batch(self, image1, image2):
157
+ if image1.shape[1:] != image2.shape[1:]:
158
+ image2 = comfy.utils.common_upscale(image2.movedim(-1,1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1,-1)
159
+ s = torch.cat((image1, image2), dim=0)
160
+ return (s,)
161
+
162
+
163
+ #region Sampler
164
+ class SimpleWanVideoDiffusionForcingSampler:
165
+ @classmethod
166
+ def INPUT_TYPES(s):
167
+ return {
168
+ "required": {
169
+ "model": ("WANVIDEOMODEL",),
170
+ "vae": ("WANVAE",),
171
+ "text_embeds": ("WANVIDEOTEXTEMBEDS", ),
172
+ "image_embeds_list": ("WANVIDIMAGE_EMBEDS", ),
173
+ "addnoise_condition": ("INT", {"default": 10, "min": 0, "max": 1000, "tooltip": "Improves consistency in long video generation"}),
174
+ "fps": ("FLOAT", {"default": 24.0, "min": 1.0, "max": 120.0, "step": 0.01}),
175
+ "steps": ("INT", {"default": 30, "min": 1}),
176
+ "cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
177
+ "shift": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
178
+ "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
179
+ "force_offload": ("BOOLEAN", {"default": True, "tooltip": "Moves the model to the offload device after sampling"}),
180
+ "scheduler": (["unipc", "unipc/beta", "euler", "euler/beta", "lcm", "lcm/beta"],
181
+ {
182
+ "default": 'unipc'
183
+ }),
184
+ },
185
+ "optional": {
186
+ "samples": ("LATENT", {"tooltip": "init Latents to use for video2video process"} ),
187
+ "prefix_samples": ("LATENT", {"tooltip": "prefix latents"} ),
188
+ "denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
189
+ "teacache_args": ("TEACACHEARGS", ),
190
+ "slg_args": ("SLGARGS", ),
191
+ "rope_function": (["default", "comfy"], {"default": "comfy", "tooltip": "Comfy's RoPE implementation doesn't use complex numbers and can thus be compiled, that should be a lot faster when using torch.compile"}),
192
+ "experimental_args": ("EXPERIMENTALARGS", ),
193
+ "unianimate_poses": ("UNIANIMATE_POSE", ),
194
+ }
195
+ }
196
+
197
+ RETURN_TYPES = ("IMAGE", )
198
+ RETURN_NAMES = ("images",)
199
+ FUNCTION = "process"
200
+ CATEGORY = "WanVideoWrapper"
201
+
202
+ def process(self, model, vae, text_embeds, image_embeds_list, shift, fps, steps, addnoise_condition, cfg, seed, scheduler,
203
+ force_offload=True, samples=None, prefix_samples=None, denoise_strength=1.0, slg_args=None, rope_function="default", teacache_args=None,
204
+ experimental_args=None, unianimate_poses=None):
205
+
206
+ image_range_extractor = GetImageRangeFromBatch()
207
+ decoder = WanVideoDecode()
208
+ encoder = WanVideoEncode()
209
+
210
+ video_chunk_list = []
211
+ for image_embeds in image_embeds_list:
212
+ samples = self.sub_process(model, text_embeds, image_embeds, shift, fps, steps, addnoise_condition, cfg, seed, scheduler,
213
+ force_offload, samples, prefix_samples, denoise_strength, slg_args, rope_function, teacache_args,
214
+ experimental_args, unianimate_poses)
215
+
216
+ video_chunk = decoder.decode(vae, samples, enable_vae_tiling=False, tile_x=272, tile_y=272, tile_stride_x=144, tile_stride_y=128)[0]
217
+ video_chunk_list.append(video_chunk)
218
+
219
+ images = image_range_extractor.imagesfrombatch(start_index=-1, num_frames=17, images=video_chunk, masks=None)[0]
220
+ prefix_samples = encoder.encode(vae, images, enable_vae_tiling=False, tile_x=272, tile_y=272, tile_stride_x=144, tile_stride_y=128, noise_aug_strength=0.0, latent_strength=1.0, mask=None)[0]
221
+
222
+ image_batch_node = ImageBatch()
223
+ combined_video = video_chunk_list[0]
224
+ for i in range(1, len(video_chunk_list)):
225
+ num_frames = image_embeds_list[i].get("num_frames")
226
+ new_video = image_range_extractor.imagesfrombatch(start_index=-1, num_frames=num_frames-17, images=video_chunk_list[i], masks=None)[0]
227
+
228
+ combined_video, = image_batch_node.batch(combined_video, new_video)
229
+
230
+ return (combined_video,)
231
+
232
+
233
+ def sub_process(self, model, text_embeds, image_embeds, shift, fps, steps, addnoise_condition, cfg, seed, scheduler,
234
+ force_offload=True, samples=None, prefix_samples=None, denoise_strength=1.0, slg_args=None, rope_function="default", teacache_args=None,
235
+ experimental_args=None, unianimate_poses=None):
236
+ #assert not (context_options and teacache_args), "Context options cannot currently be used together with teacache."
237
+ patcher = model
238
+ model = model.model
239
+ transformer = model.diffusion_model
240
+ dtype = model["dtype"]
241
+ device = mm.get_torch_device()
242
+ offload_device = mm.unet_offload_device()
243
+
244
+ steps = int(steps/denoise_strength)
245
+
246
+ timesteps = None
247
+ if 'unipc' in scheduler:
248
+ sample_scheduler = FlowUniPCMultistepScheduler(shift=shift)
249
+ sample_scheduler.set_timesteps(steps, device=device, shift=shift, use_beta_sigmas=('beta' in scheduler))
250
+ elif 'euler' in scheduler:
251
+ sample_scheduler = FlowMatchEulerDiscreteScheduler(shift=shift, use_beta_sigmas=(scheduler == 'euler/beta'))
252
+ sample_scheduler.set_timesteps(steps, device=device)
253
+ elif 'lcm' in scheduler:
254
+ sample_scheduler = FlowMatchLCMScheduler(shift=shift, use_beta_sigmas=(scheduler == 'lcm/beta'))
255
+ sample_scheduler.set_timesteps(steps, device=device)
256
+
257
+
258
+ init_timesteps = sample_scheduler.timesteps
259
+
260
+ if denoise_strength < 1.0:
261
+ steps = int(steps * denoise_strength)
262
+ timesteps = timesteps[-(steps + 1):]
263
+
264
+ seed_g = torch.Generator(device=torch.device("cpu"))
265
+ seed_g.manual_seed(seed)
266
+
267
+ clip_fea, clip_fea_neg = None, None
268
+ vace_data, vace_context, vace_scale = None, None, None
269
+
270
+ image_cond = image_embeds.get("image_embeds", None)
271
+
272
+ target_shape = image_embeds.get("target_shape", None)
273
+ if target_shape is None:
274
+ raise ValueError("Empty image embeds must be provided for T2V (Text to Video")
275
+
276
+ has_ref = image_embeds.get("has_ref", False)
277
+ vace_context = image_embeds.get("vace_context", None)
278
+ vace_scale = image_embeds.get("vace_scale", None)
279
+ vace_start_percent = image_embeds.get("vace_start_percent", 0.0)
280
+ vace_end_percent = image_embeds.get("vace_end_percent", 1.0)
281
+ vace_seqlen = image_embeds.get("vace_seq_len", None)
282
+
283
+ vace_additional_embeds = image_embeds.get("additional_vace_inputs", [])
284
+ if vace_context is not None:
285
+ vace_data = [
286
+ {"context": vace_context,
287
+ "scale": vace_scale,
288
+ "start": vace_start_percent,
289
+ "end": vace_end_percent,
290
+ "seq_len": vace_seqlen
291
+ }
292
+ ]
293
+ if len(vace_additional_embeds) > 0:
294
+ for i in range(len(vace_additional_embeds)):
295
+ if vace_additional_embeds[i].get("has_ref", False):
296
+ has_ref = True
297
+ vace_data.append({
298
+ "context": vace_additional_embeds[i]["vace_context"],
299
+ "scale": vace_additional_embeds[i]["vace_scale"],
300
+ "start": vace_additional_embeds[i]["vace_start_percent"],
301
+ "end": vace_additional_embeds[i]["vace_end_percent"],
302
+ "seq_len": vace_additional_embeds[i]["vace_seq_len"]
303
+ })
304
+
305
+ noise = torch.randn(
306
+ target_shape[0],
307
+ target_shape[1] + 1 if has_ref else target_shape[1],
308
+ target_shape[2],
309
+ target_shape[3],
310
+ dtype=torch.float32,
311
+ device=torch.device("cpu"),
312
+ generator=seed_g)
313
+
314
+ latent_video_length = noise.shape[1]
315
+ seq_len = math.ceil((noise.shape[2] * noise.shape[3]) / 4 * noise.shape[1])
316
+
317
+
318
+
319
+ if samples is not None:
320
+ input_samples = samples["samples"].squeeze(0).to(noise)
321
+ if input_samples.shape[1] != noise.shape[1]:
322
+ input_samples = torch.cat([input_samples[:, :1].repeat(1, noise.shape[1] - input_samples.shape[1], 1, 1), input_samples], dim=1)
323
+ original_image = input_samples.to(device)
324
+ if denoise_strength < 1.0:
325
+ latent_timestep = timesteps[:1].to(noise)
326
+ noise = noise * latent_timestep / 1000 + (1 - latent_timestep / 1000) * input_samples
327
+
328
+ mask = samples.get("mask", None)
329
+ if mask is not None:
330
+ if mask.shape[2] != noise.shape[1]:
331
+ mask = torch.cat([torch.zeros(1, noise.shape[0], noise.shape[1] - mask.shape[2], noise.shape[2], noise.shape[3]), mask], dim=2)
332
+
333
+ latents = noise.to(device)
334
+
335
+ fps_embeds = None
336
+ if hasattr(transformer, "fps_embedding"):
337
+ fps = round(fps, 2)
338
+ log.info(f"Model has fps embedding, using {fps} fps")
339
+ fps_embeds = [fps]
340
+ fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
341
+
342
+ prefix_video = prefix_samples["samples"].to(noise) if prefix_samples is not None else None
343
+ prefix_video_latent_length = prefix_video.shape[2] if prefix_video is not None else 0
344
+ if prefix_video is not None:
345
+ log.info(f"Prefix video of length: {prefix_video_latent_length}")
346
+ latents[:, :prefix_video_latent_length] = prefix_video[0]
347
+ #base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_video_length
348
+ base_num_frames=latent_video_length
349
+
350
+ ar_step = 0
351
+ causal_block_size = 1
352
+ step_matrix, _, step_update_mask, valid_interval = generate_timestep_matrix(
353
+ latent_video_length, init_timesteps, base_num_frames, ar_step, prefix_video_latent_length, causal_block_size
354
+ )
355
+
356
+ sample_schedulers = []
357
+ for _ in range(latent_video_length):
358
+ if 'unipc' in scheduler:
359
+ sample_scheduler = FlowUniPCMultistepScheduler(shift=shift)
360
+ sample_scheduler.set_timesteps(steps, device=device, shift=shift, use_beta_sigmas=('beta' in scheduler))
361
+ elif 'euler' in scheduler:
362
+ sample_scheduler = FlowMatchEulerDiscreteScheduler(shift=shift)
363
+ sample_scheduler.set_timesteps(steps, device=device)
364
+ elif 'lcm' in scheduler:
365
+ sample_scheduler = FlowMatchLCMScheduler(shift=shift, use_beta_sigmas=(scheduler == 'lcm/beta'))
366
+ sample_scheduler.set_timesteps(steps, device=device)
367
+
368
+ sample_schedulers.append(sample_scheduler)
369
+ sample_schedulers_counter = [0] * latent_video_length
370
+
371
+ unianim_data = None
372
+ if unianimate_poses is not None:
373
+ transformer.dwpose_embedding.to(device)
374
+ transformer.randomref_embedding_pose.to(device)
375
+ dwpose_data = unianimate_poses["pose"]
376
+ dwpose_data = transformer.dwpose_embedding(
377
+ (torch.cat([dwpose_data[:,:,:1].repeat(1,1,3,1,1), dwpose_data], dim=2)
378
+ ).to(device)).to(model["dtype"])
379
+ log.info(f"UniAnimate pose embed shape: {dwpose_data.shape}")
380
+ if dwpose_data.shape[2] > latent_video_length:
381
+ log.warning(f"UniAnimate pose embed length {dwpose_data.shape[2]} is longer than the video length {latent_video_length}, truncating")
382
+ dwpose_data = dwpose_data[:,:, :latent_video_length]
383
+ elif dwpose_data.shape[2] < latent_video_length:
384
+ log.warning(f"UniAnimate pose embed length {dwpose_data.shape[2]} is shorter than the video length {latent_video_length}, padding with last pose")
385
+ pad_len = latent_video_length - dwpose_data.shape[2]
386
+ pad = dwpose_data[:,:,:1].repeat(1,1,pad_len,1,1)
387
+ dwpose_data = torch.cat([dwpose_data, pad], dim=2)
388
+ dwpose_data_flat = rearrange(dwpose_data, 'b c f h w -> b (f h w) c').contiguous()
389
+
390
+ random_ref_dwpose_data = None
391
+ if image_cond is not None:
392
+ random_ref_dwpose = unianimate_poses.get("ref", None)
393
+ if random_ref_dwpose is not None:
394
+ random_ref_dwpose_data = transformer.randomref_embedding_pose(
395
+ random_ref_dwpose.to(device)
396
+ ).unsqueeze(2).to(model["dtype"]) # [1, 20, 104, 60]
397
+
398
+ unianim_data = {
399
+ "dwpose": dwpose_data_flat,
400
+ "random_ref": random_ref_dwpose_data.squeeze(0) if random_ref_dwpose_data is not None else None,
401
+ "strength": unianimate_poses["strength"],
402
+ "start_percent": unianimate_poses["start_percent"],
403
+ "end_percent": unianimate_poses["end_percent"]
404
+ }
405
+
406
+ disable_enhance() #not sure if this can work, disabling for now to avoid errors if it's enabled by another sampler
407
+
408
+ freqs = None
409
+ transformer.rope_embedder.k = None
410
+ transformer.rope_embedder.num_frames = None
411
+ if rope_function=="comfy":
412
+ transformer.rope_embedder.k = 0
413
+ transformer.rope_embedder.num_frames = latent_video_length
414
+ else:
415
+ d = transformer.dim // transformer.num_heads
416
+ freqs = torch.cat([
417
+ rope_params(1024, d - 4 * (d // 6), L_test=latent_video_length, k=0),
418
+ rope_params(1024, 2 * (d // 6)),
419
+ rope_params(1024, 2 * (d // 6))
420
+ ],
421
+ dim=1)
422
+
423
+ if not isinstance(cfg, list):
424
+ cfg = [cfg] * (steps +1)
425
+
426
+ log.info(f"Seq len: {seq_len}")
427
+
428
+ pbar = ProgressBar(steps)
429
+
430
+ if args.preview_method in [LatentPreviewMethod.Auto, LatentPreviewMethod.Latent2RGB]: #default for latent2rgb
431
+ from latent_preview import prepare_callback
432
+ else:
433
+ from latent_preview import prepare_callback #custom for tiny VAE previews
434
+ callback = prepare_callback(patcher, steps)
435
+
436
+ #blockswap init
437
+ transformer_options = patcher.model_options.get("transformer_options", None)
438
+ if transformer_options is not None:
439
+ block_swap_args = transformer_options.get("block_swap_args", None)
440
+
441
+ if block_swap_args is not None:
442
+ transformer.use_non_blocking = block_swap_args.get("use_non_blocking", True)
443
+ for name, param in transformer.named_parameters():
444
+ if "block" not in name:
445
+ param.data = param.data.to(device)
446
+ elif block_swap_args["offload_txt_emb"] and "txt_emb" in name:
447
+ param.data = param.data.to(offload_device, non_blocking=transformer.use_non_blocking)
448
+ elif block_swap_args["offload_img_emb"] and "img_emb" in name:
449
+ param.data = param.data.to(offload_device, non_blocking=transformer.use_non_blocking)
450
+
451
+ transformer.block_swap(
452
+ block_swap_args["blocks_to_swap"] - 1 ,
453
+ block_swap_args["offload_txt_emb"],
454
+ block_swap_args["offload_img_emb"],
455
+ vace_blocks_to_swap = block_swap_args.get("vace_blocks_to_swap", None),
456
+ )
457
+
458
+ elif model["auto_cpu_offload"]:
459
+ for module in transformer.modules():
460
+ if hasattr(module, "offload"):
461
+ module.offload()
462
+ if hasattr(module, "onload"):
463
+ module.onload()
464
+ elif model["manual_offloading"]:
465
+ transformer.to(device)
466
+
467
+ # Initialize TeaCache if enabled
468
+ if teacache_args is not None:
469
+ transformer.enable_teacache = True
470
+ transformer.rel_l1_thresh = teacache_args["rel_l1_thresh"]
471
+ transformer.teacache_start_step = teacache_args["start_step"]
472
+ transformer.teacache_cache_device = teacache_args["cache_device"]
473
+ log.info(f"TeaCache: Using cache device: {transformer.teacache_state.cache_device}")
474
+ transformer.teacache_end_step = len(init_timesteps)-1 if teacache_args["end_step"] == -1 else teacache_args["end_step"]
475
+ transformer.teacache_use_coefficients = teacache_args["use_coefficients"]
476
+ transformer.teacache_mode = teacache_args["mode"]
477
+ transformer.teacache_state.clear_all()
478
+ else:
479
+ transformer.enable_teacache = False
480
+
481
+ if slg_args is not None:
482
+ transformer.slg_blocks = slg_args["blocks"]
483
+ transformer.slg_start_percent = slg_args["start_percent"]
484
+ transformer.slg_end_percent = slg_args["end_percent"]
485
+ else:
486
+ transformer.slg_blocks = None
487
+
488
+ self.teacache_state = [None, None]
489
+ self.teacache_state_source = [None, None]
490
+ self.teacache_states_context = []
491
+
492
+
493
+ use_cfg_zero_star, use_fresca = False, False
494
+ if experimental_args is not None:
495
+ video_attention_split_steps = experimental_args.get("video_attention_split_steps", [])
496
+ if video_attention_split_steps:
497
+ transformer.video_attention_split_steps = [int(x.strip()) for x in video_attention_split_steps.split(",")]
498
+ else:
499
+ transformer.video_attention_split_steps = []
500
+ use_zero_init = experimental_args.get("use_zero_init", True)
501
+ use_cfg_zero_star = experimental_args.get("cfg_zero_star", False)
502
+ zero_star_steps = experimental_args.get("zero_star_steps", 0)
503
+
504
+ use_fresca = experimental_args.get("use_fresca", False)
505
+ if use_fresca:
506
+ fresca_scale_low = experimental_args.get("fresca_scale_low", 1.0)
507
+ fresca_scale_high = experimental_args.get("fresca_scale_high", 1.25)
508
+ fresca_freq_cutoff = experimental_args.get("fresca_freq_cutoff", 20)
509
+
510
+ #region model pred
511
+ def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, idx, image_cond=None, clip_fea=None,
512
+ vace_data=None, unianim_data=None, teacache_state=None):
513
+ with torch.autocast(device_type=mm.get_autocast_device(device), dtype=dtype, enabled=("fp8" in model["quantization"])):
514
+
515
+ if use_cfg_zero_star and (idx <= zero_star_steps) and use_zero_init:
516
+ return latent_model_input*0, None
517
+
518
+ nonlocal patcher
519
+ current_step_percentage = idx / len(init_timesteps)
520
+ control_lora_enabled = False
521
+
522
+ image_cond_input = image_cond
523
+
524
+ base_params = {
525
+ 'seq_len': seq_len,
526
+ 'device': device,
527
+ 'freqs': freqs,
528
+ 't': timestep,
529
+ 'current_step': idx,
530
+ 'control_lora_enabled': control_lora_enabled,
531
+ 'vace_data': vace_data,
532
+ 'unianim_data': unianim_data,
533
+ 'fps_embeds': fps_embeds,
534
+ }
535
+
536
+ batch_size = 1
537
+
538
+ if not math.isclose(cfg_scale, 1.0) and len(positive_embeds) > 1:
539
+ negative_embeds = negative_embeds * len(positive_embeds)
540
+
541
+
542
+ #cond
543
+ noise_pred_cond, teacache_state_cond = transformer(
544
+ [z], context=positive_embeds, y=[image_cond_input] if image_cond_input is not None else None,
545
+ clip_fea=clip_fea, is_uncond=False, current_step_percentage=current_step_percentage,
546
+ pred_id=teacache_state[0] if teacache_state else None,
547
+ **base_params
548
+ )
549
+ noise_pred_cond = noise_pred_cond[0].to(intermediate_device)
550
+ if math.isclose(cfg_scale, 1.0):
551
+ if use_fresca:
552
+ noise_pred_cond = fourier_filter(
553
+ noise_pred_cond,
554
+ scale_low=fresca_scale_low,
555
+ scale_high=fresca_scale_high,
556
+ freq_cutoff=fresca_freq_cutoff,
557
+ )
558
+ return noise_pred_cond, [teacache_state_cond]
559
+ #uncond
560
+ noise_pred_uncond, teacache_state_uncond = transformer(
561
+ [z], context=negative_embeds, clip_fea=clip_fea_neg if clip_fea_neg is not None else clip_fea,
562
+ y=[image_cond_input] if image_cond_input is not None else None,
563
+ is_uncond=True, current_step_percentage=current_step_percentage,
564
+ pred_id=teacache_state[1] if teacache_state else None,
565
+ **base_params
566
+ )
567
+ noise_pred_uncond = noise_pred_uncond[0].to(intermediate_device)
568
+
569
+ #cfg
570
+
571
+ #https://github.com/WeichenFan/CFG-Zero-star/
572
+ if use_cfg_zero_star:
573
+ alpha = optimized_scale(
574
+ noise_pred_cond.view(batch_size, -1),
575
+ noise_pred_uncond.view(batch_size, -1)
576
+ ).view(batch_size, 1, 1, 1)
577
+ else:
578
+ alpha = 1.0
579
+
580
+ #https://github.com/WikiChao/FreSca
581
+ if use_fresca:
582
+ filtered_cond = fourier_filter(
583
+ noise_pred_cond - noise_pred_uncond,
584
+ scale_low=fresca_scale_low,
585
+ scale_high=fresca_scale_high,
586
+ freq_cutoff=fresca_freq_cutoff,
587
+ )
588
+ noise_pred = noise_pred_uncond * alpha + cfg_scale * filtered_cond * alpha
589
+ else:
590
+ noise_pred = noise_pred_uncond * alpha + cfg_scale * (noise_pred_cond - noise_pred_uncond * alpha)
591
+
592
+
593
+ return noise_pred, [teacache_state_cond, teacache_state_uncond]
594
+
595
+ log.info(f"Sampling {(latent_video_length-1) * 4 + 1} frames at {latents.shape[3]*8}x{latents.shape[2]*8} with {steps} steps")
596
+
597
+ intermediate_device = device
598
+
599
+ #clear memory before sampling
600
+ mm.unload_all_models()
601
+ mm.soft_empty_cache()
602
+ gc.collect()
603
+ try:
604
+ torch.cuda.reset_peak_memory_stats(device)
605
+ except:
606
+ pass
607
+
608
+ #region main loop start
609
+ for i, timestep_i in enumerate(tqdm(step_matrix)):
610
+ update_mask_i = step_update_mask[i]
611
+ valid_interval_i = valid_interval[i]
612
+ valid_interval_start, valid_interval_end = valid_interval_i
613
+ timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
614
+ latent_model_input = latents[:, valid_interval_start:valid_interval_end, :, :].clone()
615
+ if addnoise_condition > 0 and valid_interval_start < prefix_video_latent_length:
616
+ noise_factor = 0.001 * addnoise_condition
617
+ timestep_for_noised_condition = addnoise_condition
618
+ latent_model_input[:, valid_interval_start:prefix_video_latent_length] = (
619
+ latent_model_input[:, valid_interval_start:prefix_video_latent_length] * (1.0 - noise_factor)
620
+ + torch.randn_like(latent_model_input[:, valid_interval_start:prefix_video_latent_length])
621
+ * noise_factor
622
+ )
623
+ timestep[:, valid_interval_start:prefix_video_latent_length] = timestep_for_noised_condition
624
+
625
+
626
+ #print("timestep", timestep)
627
+ noise_pred, self.teacache_state = predict_with_cfg(
628
+ latent_model_input.to(dtype),
629
+ cfg[i],
630
+ text_embeds["prompt_embeds"],
631
+ text_embeds["negative_prompt_embeds"],
632
+ timestep, i, image_cond, clip_fea, unianim_data=unianim_data, vace_data=vace_data,
633
+ teacache_state=self.teacache_state)
634
+
635
+ for idx in range(valid_interval_start, valid_interval_end):
636
+ if update_mask_i[idx].item():
637
+ latents[:, idx] = sample_schedulers[idx].step(
638
+ noise_pred[:, idx - valid_interval_start],
639
+ timestep_i[idx],
640
+ latents[:, idx],
641
+ return_dict=False,
642
+ generator=seed_g,
643
+ )[0]
644
+ sample_schedulers_counter[idx] += 1
645
+
646
+ x0 = latents.unsqueeze(0)
647
+ if callback is not None:
648
+ callback_latent = (latent_model_input - noise_pred.to(timestep_i[idx].device) * timestep_i[idx] / 1000).detach().permute(1,0,2,3)
649
+ callback(i, callback_latent, None, steps)
650
+ else:
651
+ pbar.update(1)
652
+
653
+ if teacache_args is not None:
654
+ states = transformer.teacache_state.states
655
+ state_names = {
656
+ 0: "conditional",
657
+ 1: "unconditional"
658
+ }
659
+ for pred_id, state in states.items():
660
+ name = state_names.get(pred_id, f"prediction_{pred_id}")
661
+ if 'skipped_steps' in state:
662
+ log.info(f"TeaCache skipped: {len(state['skipped_steps'])} {name} steps: {state['skipped_steps']}")
663
+ transformer.teacache_state.clear_all()
664
+
665
+ if force_offload:
666
+ if model["manual_offloading"]:
667
+ transformer.to(offload_device)
668
+ mm.soft_empty_cache()
669
+ gc.collect()
670
+
671
+ try:
672
+ print_memory(device)
673
+ torch.cuda.reset_peak_memory_stats(device)
674
+ except:
675
+ pass
676
+
677
+ return {"samples": x0.cpu(),}
678
+
679
+
680
+ class SimpleWanVideoEmptyEmbeds:
681
+ @classmethod
682
+ def INPUT_TYPES(s):
683
+ return {"required": {
684
+ "width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 8, "tooltip": "Width of the image to encode"}),
685
+ "height": ("INT", {"default": 480, "min": 64, "max": 29048, "step": 8, "tooltip": "Height of the image to encode"}),
686
+ "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}),
687
+ },
688
+ "optional": {
689
+ "control_embeds": ("WANVIDIMAGE_EMBEDS", {"tooltip": "control signal for the Fun -model"}),
690
+ }
691
+ }
692
+
693
+ RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", )
694
+ RETURN_NAMES = ("image_embeds",)
695
+ FUNCTION = "process"
696
+ CATEGORY = "WanVideoWrapper"
697
+
698
+ def get_chunk_num_frame_list(self, num_frames):
699
+ # Define the maximum chunk size as a constant
700
+ # 97 for 540P, 121 for 720P
701
+ # To reduce peak VRAM, just lower the --base_num_frames, e.g., to 77 or 57,
702
+ # while keeping the same generative length --num_frames you want to generate.
703
+ # This may slightly reduce video quality, and it should not be set too small.
704
+ # todo: need to test vram
705
+ MAX_FRAMES_PER_CHUNK = 97
706
+
707
+ # Calculate how many complete chunks we need
708
+ full_chunks = num_frames // MAX_FRAMES_PER_CHUNK
709
+ # Calculate the size of the remainder chunk (if any)
710
+ remainder = num_frames % MAX_FRAMES_PER_CHUNK
711
+
712
+ # Create the list of chunk sizes
713
+ chunk_num_frames_list = [MAX_FRAMES_PER_CHUNK] * full_chunks
714
+ if remainder > 0:
715
+ chunk_num_frames_list.append(remainder)
716
+
717
+ return chunk_num_frames_list
718
+
719
+ def process(self, num_frames, width, height, control_embeds=None):
720
+ embeds_list = []
721
+ chunk_num_frames_list = self.get_chunk_num_frame_list(num_frames)
722
+
723
+ for i in range(len(chunk_num_frames_list)):
724
+ sub_num_frames = chunk_num_frames_list[i]
725
+ vae_stride = (4, 8, 8)
726
+
727
+ target_shape = (16, (sub_num_frames - 1) // vae_stride[0] + 1,
728
+ height // vae_stride[1],
729
+ width // vae_stride[2])
730
+
731
+ embeds = {
732
+ "target_shape": target_shape,
733
+ "num_frames": sub_num_frames,
734
+ "control_embeds": control_embeds["control_embeds"] if control_embeds is not None else None,
735
+ }
736
+ embeds_list.append(embeds)
737
+
738
+ return (embeds_list,)
739
+
740
+
741
+ NODE_CLASS_MAPPINGS = {
742
+ "SimpleWanVideoDiffusionForcingSampler": SimpleWanVideoDiffusionForcingSampler,
743
+ "SimpleWanVideoEmptyEmbeds": SimpleWanVideoEmptyEmbeds
744
+ }
745
+
746
+ NODE_DISPLAY_NAME_MAPPINGS = {
747
+ "SimpleWanVideoDiffusionForcingSampler": "Simple WanVideo Diffusion Forcing Sampler",
748
+ "SimpleWanVideoEmptyEmbeds": "Simple WanVideo Empty Embeds",
749
+ }