pizb commited on
Commit
06b9d96
·
1 Parent(s): d33e75e

missing file udpate

Browse files
Files changed (1) hide show
  1. pipeline_svd_mask.py +1042 -0
pipeline_svd_mask.py ADDED
@@ -0,0 +1,1042 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pipeline_svd_masked.py
2
+
3
+ import inspect
4
+ from dataclasses import dataclass
5
+ from typing import Callable, Dict, List, Optional, Union
6
+
7
+ import numpy as np
8
+ import PIL.Image
9
+ import torch
10
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
11
+
12
+ from diffusers.image_processor import PipelineImageInput
13
+ from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
14
+ from diffusers.schedulers import EulerDiscreteScheduler
15
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
16
+ from diffusers.utils.torch_utils import randn_tensor
17
+ from diffusers.video_processor import VideoProcessor
18
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
19
+
20
+ # Import necessary helpers from the original SVD pipeline
21
+ from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
22
+ _append_dims,
23
+ retrieve_timesteps,
24
+ _resize_with_antialiasing,
25
+ )
26
+ import torch.nn.functional as F
27
+ from einops import rearrange
28
+
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+ EXAMPLE_DOC_STRING = """
33
+ Examples:
34
+ ```py
35
+ >>> from pipeline_svd_masked import StableVideoDiffusionPipelineWithMask
36
+ >>> from diffusers.utils import load_image, export_to_video
37
+
38
+ >>> # Load your fine-tuned UNet, VAE, etc.
39
+ >>> pipe = StableVideoDiffusionPipelineWithMask.from_pretrained(
40
+ ... "path/to/your/finetuned_model", torch_dtype=torch.float16, variant="fp16"
41
+ ... )
42
+ >>> pipe.to("cuda")
43
+
44
+ >>> # Load the conditioning image and the mask
45
+ >>> image = load_image("path/to/your/conditioning_image.png").resize((1024, 576))
46
+ >>> mask = load_image("path/to/your/mask_image.png").resize((1024, 576))
47
+
48
+ >>> # Generate frames
49
+ >>> frames = pipe(
50
+ ... image=image,
51
+ ... mask_image=mask,
52
+ ... num_frames=25,
53
+ ... decode_chunk_size=8
54
+ ... ).frames[0]
55
+
56
+ >>> export_to_video(frames, "generated_video.mp4", fps=7)
57
+ ```
58
+ """
59
+
60
+
61
+ @dataclass
62
+ class StableVideoDiffusionPipelineOutput(BaseOutput):
63
+ r"""
64
+ Output class for the custom Stable Video Diffusion pipeline.
65
+ Args:
66
+ frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.Tensor`]):
67
+ List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape
68
+ `(batch_size, num_frames, height, width, num_channels)`.
69
+ """
70
+ frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.Tensor]
71
+
72
+
73
+ class StableVideoDiffusionPipelineWithMask(DiffusionPipeline):
74
+ r"""
75
+ A custom pipeline based on Stable Video Diffusion that accepts an additional mask for conditioning.
76
+ This pipeline is designed to work with a UNet fine-tuned to accept 12 input channels
77
+ (4 for noise, 4 for VAE-encoded condition image, 4 for VAE-encoded mask).
78
+ """
79
+
80
+ model_cpu_offload_seq = "image_encoder->unet->vae"
81
+ _callback_tensor_inputs = ["latents"]
82
+
83
+ def __init__(
84
+ self,
85
+ vae: AutoencoderKLTemporalDecoder,
86
+ image_encoder: CLIPVisionModelWithProjection,
87
+ unet: UNetSpatioTemporalConditionModel,
88
+ scheduler: EulerDiscreteScheduler,
89
+ feature_extractor: CLIPImageProcessor,
90
+ ):
91
+ super().__init__()
92
+
93
+ self.register_modules(
94
+ vae=vae,
95
+ image_encoder=image_encoder,
96
+ unet=unet,
97
+ scheduler=scheduler,
98
+ feature_extractor=feature_extractor,
99
+ )
100
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
101
+ self.video_processor = VideoProcessor(do_resize=True, vae_scale_factor=self.vae_scale_factor)
102
+
103
+ def _encode_image(
104
+ self,
105
+ image: PipelineImageInput,
106
+ device: Union[str, torch.device],
107
+ num_videos_per_prompt: int,
108
+ ) -> torch.Tensor:
109
+ dtype = next(self.image_encoder.parameters()).dtype
110
+
111
+ if not isinstance(image, torch.Tensor):
112
+ image = self.video_processor.pil_to_numpy(image)
113
+ image = self.video_processor.numpy_to_pt(image)
114
+
115
+ image = image * 2.0 - 1.0
116
+ image = _resize_with_antialiasing(image, (224, 224))
117
+ image = (image + 1.0) / 2.0
118
+
119
+ image = self.feature_extractor(
120
+ images=image,
121
+ do_normalize=True,
122
+ do_center_crop=False,
123
+ do_resize=False,
124
+ do_rescale=False,
125
+ return_tensors="pt",
126
+ ).pixel_values
127
+
128
+ image = image.to(device=device, dtype=dtype)
129
+ image_embeddings = self.image_encoder(image).image_embeds
130
+ image_embeddings = image_embeddings.unsqueeze(1)
131
+
132
+ bs_embed, seq_len, _ = image_embeddings.shape
133
+ image_embeddings = image_embeddings
134
+ # As per your training script, we zero out the embedding
135
+ image_embeddings = torch.zeros_like(image_embeddings)
136
+
137
+ return image_embeddings
138
+
139
+ def _encode_vae_image(
140
+ self,
141
+ image: torch.Tensor,
142
+ device: Union[str, torch.device],
143
+ num_videos_per_prompt: int,
144
+ ):
145
+ image = image.to(device=device, dtype=torch.float16)
146
+ image_latents = self.vae.encode(image).latent_dist.sample()
147
+ image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
148
+ return image_latents
149
+
150
+ def _get_add_time_ids(
151
+ self,
152
+ fps: int,
153
+ motion_bucket_id: int,
154
+ noise_aug_strength: float,
155
+ dtype: torch.dtype,
156
+ batch_size: int,
157
+ num_videos_per_prompt: int,
158
+ ):
159
+ add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
160
+ passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
161
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
162
+ if expected_add_embed_dim != passed_add_embed_dim:
163
+ raise ValueError(
164
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created."
165
+ )
166
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
167
+ add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
168
+ return add_time_ids
169
+
170
+ def decode_latents(self, latents: torch.Tensor, num_frames: int, decode_chunk_size: int = 14):
171
+ latents = latents.flatten(0, 1).to(dtype=torch.float16)
172
+ latents = 1 / self.vae.config.scaling_factor * latents
173
+ frames = []
174
+ for i in range(0, latents.shape[0], decode_chunk_size):
175
+ num_frames_in = latents[i: i + decode_chunk_size].shape[0]
176
+ frame = self.vae.decode(latents[i: i + decode_chunk_size], num_frames=num_frames_in).sample
177
+ frames.append(frame)
178
+ frames = torch.cat(frames, dim=0)
179
+ frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
180
+ frames = frames.float()
181
+ return frames
182
+
183
+ def check_inputs(self, image, height, width):
184
+ if (
185
+ not isinstance(image, torch.Tensor)
186
+ and not isinstance(image, PIL.Image.Image)
187
+ and not isinstance(image, list)
188
+ ):
189
+ raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
190
+ if height % 8 != 0 or width % 8 != 0:
191
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
192
+
193
+ def prepare_latents(
194
+ self,
195
+ batch_size: int,
196
+ num_frames: int,
197
+ height: int,
198
+ width: int,
199
+ dtype: torch.dtype,
200
+ device: Union[str, torch.device],
201
+ generator: torch.Generator,
202
+ latents: Optional[torch.Tensor] = None,
203
+ initial_latents: Optional[torch.Tensor] = None,
204
+ denoising_strength: float = 1.0,
205
+ timestep: Optional[torch.Tensor] = None,
206
+ ):
207
+ num_channels_latents = self.unet.config.out_channels
208
+ shape = (
209
+ batch_size,
210
+ num_frames,
211
+ num_channels_latents,
212
+ height // self.vae_scale_factor,
213
+ width // self.vae_scale_factor,
214
+ )
215
+
216
+ if initial_latents is not None:
217
+ # Noise is added to the initial latents
218
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
219
+ # Get the initial latents at the given timestep
220
+ latents = self.scheduler.add_noise(initial_latents, noise, timestep)
221
+ else:
222
+ # Standard pure noise generation
223
+ if latents is None:
224
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
225
+ else:
226
+ latents = latents.to(device)
227
+ # Scale the initial noise by the standard deviation required by the scheduler
228
+ latents = latents * self.scheduler.init_noise_sigma
229
+
230
+ return latents
231
+
232
+ def _encode_video_vae(
233
+ self,
234
+ video_frames: torch.Tensor, # Expects (B, F, C, H, W)
235
+ device: Union[str, torch.device],
236
+ ):
237
+ video_frames = video_frames.to(device=device, dtype=self.vae.dtype)
238
+ batch_size, num_frames = video_frames.shape[:2]
239
+
240
+ # Reshape for VAE encoding
241
+ video_frames_reshaped = video_frames.reshape(batch_size * num_frames, *video_frames.shape[2:]) # (B*F, C, H, W)
242
+ latents = self.vae.encode(video_frames_reshaped).latent_dist.sample() # (B*F, C_latent, H_latent, W_latent)
243
+
244
+ # Reshape back to video format
245
+ latents = latents.reshape(batch_size, num_frames, *latents.shape[1:]) # (B, F, C_latent, H_latent, W_latent)
246
+
247
+ return latents
248
+
249
+ @torch.no_grad()
250
+ def __call__(
251
+ self,
252
+ image: Union[List[PIL.Image.Image], torch.Tensor],
253
+ mask_image: Union[List[PIL.Image.Image], torch.Tensor],
254
+ alpha_matte_image: Optional[Union[List[PIL.Image.Image], torch.Tensor]] = None,
255
+ denoising_strength: float = 0.7,
256
+ height: int = 576,
257
+ width: int = 1024,
258
+ num_frames: Optional[int] = None,
259
+ num_inference_steps: int = 30,
260
+ sigmas: Optional[List[float]] = None,
261
+ fps: int = 7,
262
+ motion_bucket_id: int = 127,
263
+ noise_aug_strength: float = 0.02,
264
+ decode_chunk_size: Optional[int] = None,
265
+ num_videos_per_prompt: Optional[int] = 1,
266
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
267
+ latents: Optional[torch.Tensor] = None,
268
+ output_type: Optional[str] = "pil",
269
+ return_dict: bool = True,
270
+ mask_noise_strength: float = 0.0,
271
+ ):
272
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
273
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
274
+
275
+ if num_frames is None:
276
+ if isinstance(image, list):
277
+ num_frames = len(image)
278
+ else:
279
+ num_frames = self.unet.config.num_frames
280
+
281
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
282
+
283
+ self.check_inputs(image, height, width)
284
+ self.check_inputs(mask_image, height, width)
285
+ if alpha_matte_image:
286
+ self.check_inputs(alpha_matte_image, height, width)
287
+
288
+ batch_size = 1
289
+ device = self._execution_device
290
+ dtype = self.unet.dtype
291
+
292
+ image_for_clip = image[0] if isinstance(image, list) else image[0]
293
+ image_embeddings = self._encode_image(image_for_clip, device, num_videos_per_prompt)
294
+
295
+ fps = fps - 1
296
+
297
+ image_tensor = self.video_processor.preprocess(image, height=height, width=width).to(device).unsqueeze(0)
298
+ mask_tensor = self.video_processor.preprocess(mask_image, height=height, width=width).to(device).unsqueeze(0)
299
+
300
+ noise = randn_tensor(image_tensor.shape, generator=generator, device=device, dtype=dtype)
301
+ image_tensor = image_tensor + noise_aug_strength * noise
302
+
303
+ conditional_latents = self._encode_video_vae(image_tensor, device)
304
+ conditional_latents = conditional_latents / self.vae.config.scaling_factor
305
+
306
+ if self.unet.config.in_channels == 12:
307
+ mask_latents = self._encode_video_vae(mask_tensor, device)
308
+ mask_latents = mask_latents / self.vae.config.scaling_factor
309
+ elif self.unet.config.in_channels == 9:
310
+ mask_tensor_gray = mask_tensor.mean(dim=2, keepdim=True)
311
+ binarized_mask = (mask_tensor_gray > 0.0).to(dtype)
312
+ b, f, c, h, w = binarized_mask.shape
313
+ binarized_mask_reshaped = binarized_mask.reshape(b * f, c, h, w)
314
+ target_size = (height // self.vae_scale_factor, width // self.vae_scale_factor)
315
+ interpolated_mask = F.interpolate(
316
+ binarized_mask_reshaped,
317
+ size=target_size,
318
+ mode='nearest',
319
+ )
320
+ mask_latents = interpolated_mask.reshape(b, f, *interpolated_mask.shape[1:])
321
+ else:
322
+ raise ValueError(f"Unsupported number of UNet input channels: {self.unet.config.in_channels}.")
323
+
324
+ if mask_noise_strength > 0.0:
325
+ mask_noise = randn_tensor(mask_latents.shape, generator=generator, device=device, dtype=dtype)
326
+ mask_latents = mask_latents + mask_noise_strength * mask_noise
327
+
328
+ added_time_ids = self._get_add_time_ids(
329
+ fps, motion_bucket_id, noise_aug_strength, image_embeddings.dtype, batch_size, num_videos_per_prompt
330
+ )
331
+ added_time_ids = added_time_ids.to(device)
332
+
333
+ # --- MODIFIED FOR ALPHA MATTE REFINEMENT ---
334
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, sigmas)
335
+
336
+ # self.scheduler.set_timesteps(num_inference_steps, device=device)
337
+ # timesteps = self.scheduler.timesteps
338
+ initial_latents = None
339
+
340
+ if alpha_matte_image is not None:
341
+ alpha_matte_tensor = self.video_processor.preprocess(alpha_matte_image, height=height, width=width).to(
342
+ device).unsqueeze(0)
343
+ initial_latents = self._encode_video_vae(alpha_matte_tensor, device)
344
+ initial_latents = initial_latents / self.vae.config.scaling_factor
345
+
346
+ # Adjust the number of steps and the timesteps to start from
347
+ t_start = max(num_inference_steps - int(num_inference_steps * denoising_strength), 0)
348
+ timesteps = timesteps[t_start:]
349
+ # We need the first timestep to add the correct amount of noise
350
+ start_timestep = timesteps[0]
351
+ else:
352
+ start_timestep = timesteps[0] # Not used, but for clarity
353
+
354
+ latents = self.prepare_latents(
355
+ batch_size * num_videos_per_prompt,
356
+ num_frames,
357
+ height,
358
+ width,
359
+ dtype,
360
+ device,
361
+ generator,
362
+ latents,
363
+ initial_latents=initial_latents,
364
+ denoising_strength=denoising_strength,
365
+ timestep=start_timestep if initial_latents is not None else None,
366
+ )
367
+
368
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
369
+ self._num_timesteps = len(timesteps)
370
+
371
+ with self.progress_bar(total=len(timesteps)) as progress_bar:
372
+ for i, t in enumerate(timesteps):
373
+ latent_model_input = self.scheduler.scale_model_input(latents, t)
374
+ latent_model_input = torch.cat([latent_model_input, conditional_latents, mask_latents], dim=2)
375
+
376
+ noise_pred = self.unet(
377
+ latent_model_input, t, encoder_hidden_states=image_embeddings, added_time_ids=added_time_ids,
378
+ return_dict=False
379
+ )[0]
380
+
381
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
382
+
383
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
384
+ progress_bar.update()
385
+
386
+ frames = self.decode_latents(latents, num_frames, decode_chunk_size)
387
+ frames = self.video_processor.postprocess_video(video=frames, output_type=output_type)
388
+
389
+ self.maybe_free_model_hooks()
390
+
391
+ if not return_dict:
392
+ return frames
393
+ return StableVideoDiffusionPipelineOutput(frames=frames)
394
+
395
+
396
+ class StableVideoDiffusionPipelineOnestepWithMask(DiffusionPipeline):
397
+ r"""
398
+ A custom pipeline based on Stable Video Diffusion that accepts an additional mask for conditioning.
399
+ This pipeline is designed to work with a UNet fine-tuned to accept 12 input channels
400
+ (4 for noise, 4 for VAE-encoded condition image, 4 for VAE-encoded mask).
401
+ """
402
+
403
+ model_cpu_offload_seq = "image_encoder->unet->vae"
404
+ _callback_tensor_inputs = ["latents"]
405
+
406
+ def __init__(
407
+ self,
408
+ vae: AutoencoderKLTemporalDecoder,
409
+ image_encoder: CLIPVisionModelWithProjection,
410
+ unet: UNetSpatioTemporalConditionModel,
411
+ scheduler: EulerDiscreteScheduler,
412
+ feature_extractor: CLIPImageProcessor,
413
+ ):
414
+ super().__init__()
415
+
416
+ self.register_modules(
417
+ vae=vae,
418
+ image_encoder=image_encoder,
419
+ unet=unet,
420
+ scheduler=scheduler,
421
+ feature_extractor=feature_extractor,
422
+ )
423
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
424
+ self.video_processor = VideoProcessor(do_resize=True, vae_scale_factor=self.vae_scale_factor)
425
+
426
+ def _encode_image(
427
+ self,
428
+ image: PipelineImageInput,
429
+ device: Union[str, torch.device],
430
+ num_videos_per_prompt: int,
431
+ ) -> torch.Tensor:
432
+ dtype = next(self.image_encoder.parameters()).dtype
433
+
434
+ if not isinstance(image, torch.Tensor):
435
+ image = self.video_processor.pil_to_numpy(image)
436
+ image = self.video_processor.numpy_to_pt(image)
437
+
438
+ image = image * 2.0 - 1.0
439
+ image = _resize_with_antialiasing(image, (224, 224))
440
+ image = (image + 1.0) / 2.0
441
+
442
+ image = self.feature_extractor(
443
+ images=image,
444
+ do_normalize=True,
445
+ do_center_crop=False,
446
+ do_resize=False,
447
+ do_rescale=False,
448
+ return_tensors="pt",
449
+ ).pixel_values
450
+
451
+ image = image.to(device=device, dtype=dtype)
452
+ image_embeddings = self.image_encoder(image).image_embeds
453
+ image_embeddings = image_embeddings.unsqueeze(1)
454
+
455
+ bs_embed, seq_len, _ = image_embeddings.shape
456
+ image_embeddings = image_embeddings
457
+ # As per your training script, we zero out the embedding
458
+ image_embeddings = torch.zeros_like(image_embeddings)
459
+
460
+ return image_embeddings
461
+
462
+ def _encode_vae_image(
463
+ self,
464
+ image: torch.Tensor,
465
+ device: Union[str, torch.device],
466
+ num_videos_per_prompt: int,
467
+ ):
468
+ image = image.to(device=device, dtype=torch.float16)
469
+ image_latents = self.vae.encode(image).latent_dist.sample()
470
+ image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
471
+ return image_latents
472
+
473
+ def _get_add_time_ids(
474
+ self,
475
+ fps: int,
476
+ motion_bucket_id: int,
477
+ noise_aug_strength: float,
478
+ dtype: torch.dtype,
479
+ batch_size: int,
480
+ num_videos_per_prompt: int,
481
+ ):
482
+ add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
483
+ passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
484
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
485
+ if expected_add_embed_dim != passed_add_embed_dim:
486
+ raise ValueError(
487
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created."
488
+ )
489
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
490
+ add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
491
+ return add_time_ids
492
+
493
+ def decode_latents(self, latents: torch.Tensor, num_frames: int, decode_chunk_size: int = 14):
494
+ latents = latents.flatten(0, 1).to(dtype=torch.float16)
495
+ latents = 1 / self.vae.config.scaling_factor * latents
496
+ frames = []
497
+ for i in range(0, latents.shape[0], decode_chunk_size):
498
+ num_frames_in = latents[i: i + decode_chunk_size].shape[0]
499
+ frame = self.vae.decode(latents[i: i + decode_chunk_size], num_frames=num_frames_in).sample
500
+ frames.append(frame)
501
+ frames = torch.cat(frames, dim=0)
502
+ frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
503
+ frames = frames.float()
504
+ return frames
505
+
506
+ def check_inputs(self, image, height, width):
507
+ if (
508
+ not isinstance(image, torch.Tensor)
509
+ and not isinstance(image, PIL.Image.Image)
510
+ and not isinstance(image, list)
511
+ ):
512
+ raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
513
+ if height % 8 != 0 or width % 8 != 0:
514
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
515
+
516
+ def prepare_latents(
517
+ self,
518
+ batch_size: int,
519
+ num_frames: int,
520
+ height: int,
521
+ width: int,
522
+ dtype: torch.dtype,
523
+ device: Union[str, torch.device],
524
+ generator: torch.Generator,
525
+ latents: Optional[torch.Tensor] = None,
526
+ ):
527
+ # The number of channels for the initial noise is based on the UNet's out_channels
528
+ num_channels_latents = self.unet.config.out_channels
529
+ shape = (
530
+ batch_size,
531
+ num_frames,
532
+ num_channels_latents,
533
+ height // self.vae_scale_factor,
534
+ width // self.vae_scale_factor,
535
+ )
536
+ if isinstance(generator, list) and len(generator) != batch_size:
537
+ raise ValueError(f"batch size {batch_size} must match the length of the generators {len(generator)}.")
538
+
539
+ if latents is None:
540
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
541
+ else:
542
+ latents = latents.to(device)
543
+
544
+ latents = latents * self.scheduler.init_noise_sigma
545
+ return latents
546
+
547
+ def _encode_video_vae(
548
+ self,
549
+ video_frames: torch.Tensor, # Expects (B, F, C, H, W)
550
+ device: Union[str, torch.device],
551
+ ):
552
+ video_frames = video_frames.to(device=device, dtype=self.vae.dtype)
553
+ batch_size, num_frames = video_frames.shape[:2]
554
+
555
+ # Reshape for VAE encoding
556
+ video_frames_reshaped = video_frames.reshape(batch_size * num_frames, *video_frames.shape[2:]) # (B*F, C, H, W)
557
+ latents = self.vae.encode(video_frames_reshaped).latent_dist.sample() # (B*F, C_latent, H_latent, W_latent)
558
+
559
+ # Reshape back to video format
560
+ latents = latents.reshape(batch_size, num_frames, *latents.shape[1:]) # (B, F, C_latent, H_latent, W_latent)
561
+
562
+ return latents
563
+
564
+ @torch.no_grad()
565
+ def __call__(
566
+ self,
567
+ image: Union[List[PIL.Image.Image], torch.Tensor],
568
+ mask_image: Union[List[PIL.Image.Image], torch.Tensor],
569
+ height: int = 576,
570
+ width: int = 1024,
571
+ num_frames: Optional[int] = None,
572
+ fps: int = 7,
573
+ motion_bucket_id: int = 127,
574
+ noise_aug_strength: float = 0.0,
575
+ decode_chunk_size: Optional[int] = None,
576
+ num_videos_per_prompt: Optional[int] = 1,
577
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
578
+ latents: Optional[torch.Tensor] = None,
579
+ output_type: Optional[str] = "pil",
580
+ return_dict: bool = True,
581
+ mask_noise_strength: float = 0.0,
582
+ ):
583
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
584
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
585
+
586
+ if num_frames is None:
587
+ if isinstance(image, list):
588
+ num_frames = len(image)
589
+ else:
590
+ num_frames = self.unet.config.num_frames
591
+
592
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
593
+
594
+ self.check_inputs(image, height, width)
595
+ self.check_inputs(mask_image, height, width)
596
+ if isinstance(image, list) and isinstance(mask_image, list):
597
+ if len(image) != len(mask_image):
598
+ raise ValueError("`image` and `mask_image` must have the same number of frames.")
599
+ if num_frames != len(image):
600
+ logger.warning(
601
+ f"Mismatch between `num_frames` ({num_frames}) and number of input images ({len(image)}). Using {len(image)}.")
602
+ num_frames = len(image)
603
+
604
+ batch_size = 1
605
+ device = self._execution_device
606
+ dtype = self.unet.dtype
607
+
608
+ image_for_clip = image[0] if isinstance(image, list) else image[0]
609
+ image_embeddings = self._encode_image(image_for_clip, device, num_videos_per_prompt)
610
+
611
+ fps = fps - 1
612
+
613
+ image_tensor = self.video_processor.preprocess(image, height=height, width=width).to(device).unsqueeze(0)
614
+ mask_tensor = self.video_processor.preprocess(mask_image, height=height, width=width).to(
615
+ device).unsqueeze(0)
616
+
617
+ noise = randn_tensor(image_tensor.shape, generator=generator, device=device, dtype=dtype)
618
+ image_tensor = image_tensor + noise_aug_strength * noise
619
+
620
+ conditional_latents = self._encode_video_vae(image_tensor, device)
621
+ conditional_latents = conditional_latents / self.vae.config.scaling_factor
622
+
623
+ if self.unet.config.in_channels == 12:
624
+ mask_latents = self._encode_video_vae(mask_tensor, device)
625
+ mask_latents = mask_latents / self.vae.config.scaling_factor
626
+ elif self.unet.config.in_channels == 9:
627
+ mask_tensor_gray = mask_tensor.mean(dim=2, keepdim=True)
628
+ binarized_mask = (mask_tensor_gray > 0.0).to(dtype)
629
+ b, f, c, h, w = binarized_mask.shape
630
+ binarized_mask_reshaped = binarized_mask.reshape(b * f, c, h, w)
631
+ target_size = (height // self.vae_scale_factor, width // self.vae_scale_factor)
632
+ interpolated_mask = F.interpolate(
633
+ binarized_mask_reshaped,
634
+ size=target_size,
635
+ mode='nearest',
636
+ )
637
+ mask_latents = interpolated_mask.reshape(b, f, *interpolated_mask.shape[1:])
638
+ else:
639
+ raise ValueError(
640
+ f"Unsupported number of UNet input channels: {self.unet.config.in_channels}. "
641
+ "This pipeline only supports 9 (for interpolated mask) or 12 (for VAE mask)."
642
+ )
643
+
644
+ if mask_noise_strength > 0.0:
645
+ mask_noise = randn_tensor(mask_latents.shape, generator=generator, device=device, dtype=dtype)
646
+ mask_latents = mask_latents + mask_noise_strength * mask_noise
647
+
648
+ added_time_ids = self._get_add_time_ids(
649
+ fps, motion_bucket_id, noise_aug_strength, image_embeddings.dtype, batch_size, num_videos_per_prompt
650
+ )
651
+ added_time_ids = added_time_ids.to(device)
652
+
653
+ # **MODIFIED FOR SINGLE-STEP**: Prepare initial noise
654
+ num_channels_latents = self.unet.config.out_channels
655
+ shape = (
656
+ batch_size * num_videos_per_prompt,
657
+ num_frames,
658
+ num_channels_latents,
659
+ height // self.vae_scale_factor,
660
+ width // self.vae_scale_factor,
661
+ )
662
+ if latents is None:
663
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
664
+
665
+ # **MODIFIED FOR SINGLE-STEP**: Set a fixed high timestep
666
+ timestep = torch.tensor([1.0], dtype=dtype, device=device) # Use a high sigma value
667
+
668
+ # **MODIFIED FOR SINGLE-STEP**: Single forward pass
669
+ latent_model_input = torch.cat([latents, conditional_latents, mask_latents], dim=2)
670
+
671
+ noise_pred = self.unet(
672
+ latent_model_input, timestep, encoder_hidden_states=image_embeddings, added_time_ids=added_time_ids,
673
+ return_dict=False
674
+ )[0]
675
+
676
+ # The model's prediction is the final denoised latent
677
+ denoised_latents = noise_pred
678
+
679
+ frames = self.decode_latents(denoised_latents, num_frames, decode_chunk_size)
680
+ frames = self.video_processor.postprocess_video(video=frames, output_type=output_type)
681
+
682
+ self.maybe_free_model_hooks()
683
+
684
+ if not return_dict:
685
+ return frames
686
+ return StableVideoDiffusionPipelineOutput(frames=frames)
687
+
688
+
689
+ class StableVideoDiffusionPipelineWithCrossAtnnMask(DiffusionPipeline):
690
+ model_cpu_offload_seq = "image_encoder->unet->vae"
691
+ _callback_tensor_inputs = ["latents"]
692
+
693
+ def __init__(
694
+ self,
695
+ vae: AutoencoderKLTemporalDecoder,
696
+ unet: UNetSpatioTemporalConditionModel,
697
+ scheduler: EulerDiscreteScheduler,
698
+ mask_projector: torch.nn.Module,
699
+ # CLIP models are not strictly needed for inference if embeddings are not used
700
+ image_encoder: CLIPVisionModelWithProjection = None,
701
+ feature_extractor: CLIPImageProcessor = None,
702
+ ):
703
+ super().__init__()
704
+ self.register_modules(
705
+ vae=vae,
706
+ unet=unet,
707
+ scheduler=scheduler,
708
+ mask_projector=mask_projector,
709
+ image_encoder=image_encoder,
710
+ feature_extractor=feature_extractor,
711
+ )
712
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
713
+ self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
714
+
715
+ def _encode_image_vae(self, image: torch.Tensor, device: Union[str, torch.device]):
716
+ image = image.to(device=device, dtype=self.vae.dtype)
717
+ latent = self.vae.encode(image).latent_dist.sample()
718
+ return latent
719
+
720
+ def decode_latents(self, latents: torch.Tensor, num_frames: int, decode_chunk_size: int):
721
+ latents = latents.flatten(0, 1).to(dtype=torch.float16)
722
+ latents = 1 / self.vae.config.scaling_factor * latents
723
+ frames = []
724
+ for i in range(0, latents.shape[0], decode_chunk_size):
725
+ frame = self.vae.decode(latents[i: i + decode_chunk_size], num_frames=decode_chunk_size).sample
726
+ frames.append(frame)
727
+
728
+ frames = torch.cat(frames, dim=0)
729
+ frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
730
+ frames = frames.float()
731
+ return frames
732
+
733
+ def _encode_video_vae(
734
+ self,
735
+ video_frames: torch.Tensor, # Expects (B, F, C, H, W)
736
+ device: Union[str, torch.device],
737
+ ):
738
+ video_frames = video_frames.to(device=device, dtype=self.vae.dtype)
739
+ batch_size, num_frames = video_frames.shape[:2]
740
+
741
+ # Reshape for VAE encoding
742
+ video_frames_reshaped = video_frames.reshape(batch_size * num_frames, *video_frames.shape[2:]) # (B*F, C, H, W)
743
+ latents = self.vae.encode(video_frames_reshaped).latent_dist.sample() # (B*F, C_latent, H_latent, W_latent)
744
+
745
+ # Reshape back to video format
746
+ latents = latents.reshape(batch_size, num_frames, *latents.shape[1:]) # (B, F, C_latent, H_latent, W_latent)
747
+
748
+ return latents
749
+
750
+ @torch.no_grad()
751
+ def __call__(
752
+ self,
753
+ image: Union[PIL.Image.Image, torch.Tensor], # Static image for appearance
754
+ mask_image: List[PIL.Image.Image], # Video mask for motion
755
+ height: int = 576,
756
+ width: int = 1024,
757
+ num_frames: Optional[int] = None,
758
+ num_inference_steps: int = 25,
759
+ fps: int = 7,
760
+ motion_bucket_id: int = 127,
761
+ noise_aug_strength: float = 0.0, # Noise is added to latents now
762
+ decode_chunk_size: Optional[int] = 8,
763
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
764
+ output_type: Optional[str] = "pil",
765
+ return_dict: bool = True,
766
+ ):
767
+ device = self._execution_device
768
+ dtype = self.unet.dtype
769
+ num_frames = num_frames if num_frames is not None else len(mask_image)
770
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
771
+
772
+ # 1. PREPARE STATIC IMAGE CONDITION
773
+ image_tensor = self.video_processor.preprocess(image, height, width).to(device).unsqueeze(0)
774
+ conditional_latents = self._encode_video_vae(image_tensor, device)
775
+ conditional_latents = conditional_latents / self.vae.config.scaling_factor
776
+
777
+ # 2. PREPARE MASK MOTION CONDITION
778
+ mask_tensor = self.video_processor.preprocess(mask_image, height, width)
779
+ if mask_tensor.shape[1] > 1:
780
+ mask_tensor = mask_tensor.mean(dim=1, keepdim=True)
781
+
782
+ # Reshape for projector: (T, C, H, W)
783
+ mask_for_projection = rearrange(mask_tensor, "f c h w -> f c h w").to(device, dtype)
784
+ encoder_hidden_states = self.mask_projector(mask_for_projection)
785
+ encoder_hidden_states = encoder_hidden_states.unsqueeze(1) # (T, 1, D)
786
+ # Add batch dimension for UNet
787
+ encoder_hidden_states = encoder_hidden_states.unsqueeze(0) # (1, T, 1, D)
788
+ # The UNet will handle flattening this to (B*T, 1, D) where B=1
789
+ # To be safe, we pass it pre-flattened.
790
+ encoder_hidden_states = rearrange(encoder_hidden_states, "b f s d -> (b f) s d")
791
+
792
+ # 3. PREPARE LATENTS
793
+ shape = (1, num_frames, self.unet.config.out_channels, height // self.vae_scale_factor,
794
+ width // self.vae_scale_factor)
795
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
796
+ if noise_aug_strength > 0:
797
+ latents += noise_aug_strength * randn_tensor(latents.shape, generator=generator, device=device,
798
+ dtype=dtype)
799
+ latents = latents * self.scheduler.init_noise_sigma
800
+
801
+ # 4. GET ADDED TIME IDS
802
+ # For pipeline, batch size is 1
803
+ added_time_ids = [fps - 1, motion_bucket_id, 0.0] # noise_aug_strength for add_time_ids is 0 for inference
804
+ added_time_ids = torch.tensor([added_time_ids], dtype=dtype, device=device)
805
+
806
+ # 5. DENOISING LOOP
807
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
808
+ timesteps = self.scheduler.timesteps
809
+
810
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
811
+ for t in timesteps:
812
+ latent_model_input = self.scheduler.scale_model_input(latents, t)
813
+ unet_input = torch.cat([latent_model_input, conditional_latents], dim=2)
814
+
815
+ noise_pred = self.unet(
816
+ unet_input, t, encoder_hidden_states=encoder_hidden_states, added_time_ids=added_time_ids
817
+ ).sample
818
+
819
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
820
+ progress_bar.update()
821
+
822
+ # 6. DECODE
823
+ frames = self.decode_latents(latents, num_frames, decode_chunk_size)
824
+ frames = self.video_processor.postprocess_video(video=frames, output_type=output_type)
825
+
826
+ if not return_dict:
827
+ return (frames,)
828
+ return StableVideoDiffusionPipelineOutput(frames=frames)
829
+
830
+
831
+ # pipeline.py
832
+
833
+ import torch
834
+ import torch.nn.functional as F
835
+ from PIL import Image
836
+ from einops import rearrange
837
+ from torchvision import transforms
838
+ from diffusers import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
839
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
840
+
841
+
842
+ class VideoInferencePipeline:
843
+ """
844
+ A reusable pipeline for single-step video diffusion inference.
845
+
846
+ This class encapsulates the models and the core inference logic,
847
+ separating it from data loading and saving, which can vary between tasks.
848
+ """
849
+
850
+ def __init__(self, base_model_path: str, unet_checkpoint_path: str, device: str = "cuda",
851
+ weight_dtype: torch.dtype = torch.float16):
852
+ """
853
+ Loads all necessary models into memory.
854
+
855
+ Args:
856
+ base_model_path (str): Path to the base Stable Video Diffusion model.
857
+ unet_checkpoint_path (str): Path to the fine-tuned UNet checkpoint.
858
+ device (str): The device to run models on ('cuda' or 'cpu').
859
+ weight_dtype (torch.dtype): The precision for model weights (float16 or bfloat16).
860
+ """
861
+ print("--- Initializing Inference Pipeline and Loading Models ---")
862
+ self.device = torch.device(device if torch.cuda.is_available() else "cpu")
863
+ self.weight_dtype = weight_dtype
864
+
865
+ # Load models from pretrained paths
866
+ try:
867
+ self.feature_extractor = CLIPImageProcessor.from_pretrained(base_model_path, subfolder="feature_extractor")
868
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_model_path,
869
+ subfolder="image_encoder",
870
+ variant="fp16")
871
+ self.vae = AutoencoderKLTemporalDecoder.from_pretrained(base_model_path, subfolder="vae", variant="fp16")
872
+ self.unet = UNetSpatioTemporalConditionModel.from_pretrained(unet_checkpoint_path, subfolder="unet")
873
+ except Exception as e:
874
+ raise IOError(f"Fatal error loading models: {e}")
875
+
876
+ # Move models to the specified device and set to evaluation mode
877
+ self.image_encoder.to(self.device, dtype=self.weight_dtype).eval()
878
+ self.vae.to(self.device, dtype=self.weight_dtype).eval()
879
+ self.unet.to(self.device, dtype=self.weight_dtype).eval()
880
+
881
+ print(f"--- Models Loaded Successfully on {self.device} ---")
882
+
883
+ def run(self, cond_frames, mask_frames, seed=42, mask_cond_mode="vae", fps=7, motion_bucket_id=127,
884
+ noise_aug_strength=0.0):
885
+ """
886
+ Runs the core inference process on a sequence of conditioning and mask frames.
887
+
888
+ Args:
889
+ cond_frames (list[Image.Image]): List of PIL images for conditioning.
890
+ mask_frames (list[Image.Image]): List of PIL images for the masks.
891
+ seed (int): Random seed for generation.
892
+ mask_cond_mode (str): How the mask is conditioned ("vae" or "interpolate").
893
+ fps (int): Frames per second to condition the model with.
894
+ motion_bucket_id (int): Motion bucket ID for conditioning.
895
+ noise_aug_strength (float): Noise augmentation strength.
896
+
897
+ Returns:
898
+ list[Image.Image]: A list of the generated video frames as PIL Images.
899
+ """
900
+ # --- 1. Prepare Tensors ---
901
+ cond_video_tensor = self._pil_to_tensor(cond_frames).to(self.device)
902
+ mask_video_tensor = self._pil_to_tensor(mask_frames).to(self.device)
903
+
904
+ if mask_video_tensor.shape[2] != 3:
905
+ mask_video_tensor = mask_video_tensor.repeat(1, 1, 3, 1, 1)
906
+
907
+ with torch.no_grad():
908
+ # --- 2. Get CLIP Image Embeddings ---
909
+ first_frame_tensor = cond_video_tensor[:, 0, :, :, :]
910
+ pixel_values_for_clip = self._resize_with_antialiasing(first_frame_tensor, (224, 224))
911
+ pixel_values_for_clip = ((pixel_values_for_clip + 1.0) / 2.0).clamp(0, 1)
912
+ pixel_values = self.feature_extractor(images=pixel_values_for_clip, return_tensors="pt").pixel_values
913
+ image_embeddings = self.image_encoder(pixel_values.to(self.device, dtype=self.weight_dtype)).image_embeds
914
+ encoder_hidden_states = torch.zeros_like(image_embeddings).unsqueeze(1)
915
+
916
+ # --- 3. Prepare Latents ---
917
+ cond_latents = self._tensor_to_vae_latent(cond_video_tensor.to(self.weight_dtype))
918
+ cond_latents = cond_latents / self.vae.config.scaling_factor
919
+
920
+ if mask_cond_mode == "vae":
921
+ mask_latents = self._tensor_to_vae_latent(mask_video_tensor.to(self.weight_dtype))
922
+ mask_latents = mask_latents / self.vae.config.scaling_factor
923
+ elif mask_cond_mode == "interpolate":
924
+ target_shape = cond_latents.shape[-2:]
925
+ b, t, c, h, w = mask_video_tensor.shape
926
+ mask_video_reshaped = rearrange(mask_video_tensor, "b t c h w -> (b t) c h w")
927
+ interpolated_mask = F.interpolate(mask_video_reshaped, size=target_shape, mode='bilinear',
928
+ align_corners=False)
929
+ mask_latents = rearrange(interpolated_mask, "(b t) c h w -> b t c h w", b=b)
930
+ else:
931
+ raise ValueError(f"Unknown mask_cond_mode: {mask_cond_mode}")
932
+
933
+ # --- 4. Run UNet Single-Step Inference ---
934
+ generator = torch.Generator(device=self.device).manual_seed(seed)
935
+ noisy_latents = torch.randn(cond_latents.shape, generator=generator, device=self.device,
936
+ dtype=self.weight_dtype)
937
+ timesteps = torch.full((1,), 1.0, device=self.device, dtype=torch.long)
938
+ added_time_ids = self._get_add_time_ids(fps, motion_bucket_id, noise_aug_strength, batch_size=1)
939
+
940
+ unet_input = torch.cat([noisy_latents, cond_latents, mask_latents], dim=2)
941
+ pred_latents = self.unet(unet_input, timesteps, encoder_hidden_states, added_time_ids=added_time_ids).sample
942
+
943
+ # --- 5. Decode Latents to Video Frames ---
944
+ pred_latents = (1 / self.vae.config.scaling_factor) * pred_latents.squeeze(0)
945
+
946
+ frames = []
947
+ # Process in chunks to avoid VRAM issues, especially for long videos
948
+ for i in range(0, pred_latents.shape[0], 8):
949
+ chunk = pred_latents[i: i + 8]
950
+ decoded_chunk = self.vae.decode(chunk, num_frames=chunk.shape[0]).sample
951
+ frames.append(decoded_chunk)
952
+
953
+ video_tensor = torch.cat(frames, dim=0)
954
+ video_tensor = (video_tensor / 2.0 + 0.5).clamp(0, 1).mean(dim=1, keepdim=True).repeat(1, 3, 1, 1)
955
+
956
+ # Return a list of PIL images
957
+ return [transforms.ToPILImage()(frame) for frame in video_tensor]
958
+
959
+ def _pil_to_tensor(self, frames: list[Image.Image]):
960
+ """Converts a list of PIL images to a normalized video tensor."""
961
+ video_tensor = torch.stack([transforms.ToTensor()(f) for f in frames]).unsqueeze(0)
962
+ return video_tensor * 2.0 - 1.0
963
+
964
+ def _tensor_to_vae_latent(self, t: torch.Tensor):
965
+ """Encodes a video tensor into the VAE's latent space."""
966
+ video_length = t.shape[1]
967
+ t = rearrange(t, "b f c h w -> (b f) c h w")
968
+ latents = self.vae.encode(t).latent_dist.sample()
969
+ latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length)
970
+ return latents * self.vae.config.scaling_factor
971
+
972
+ def _get_add_time_ids(self, fps, motion_bucket_id, noise_aug_strength, batch_size):
973
+ """Creates the additional time IDs for conditioning the UNet."""
974
+ add_time_ids_list = [fps, motion_bucket_id, noise_aug_strength]
975
+ passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids_list)
976
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
977
+ if expected_add_embed_dim != passed_add_embed_dim:
978
+ raise ValueError(
979
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created.")
980
+ add_time_ids = torch.tensor([add_time_ids_list], dtype=self.weight_dtype, device=self.device)
981
+ return add_time_ids.repeat(batch_size, 1)
982
+
983
+ def _resize_with_antialiasing(self, input_tensor, size, interpolation="bicubic", align_corners=True):
984
+ """
985
+ Resizes a tensor with anti-aliasing for CLIP input, mirroring k-diffusion.
986
+ This is a direct copy of the helper function from your original scripts.
987
+ """
988
+ h, w = input_tensor.shape[-2:]
989
+ factors = (h / size[0], w / size[1])
990
+ sigmas = (max((factors[0] - 1.0) / 2.0, 0.001), max((factors[1] - 1.0) / 2.0, 0.001))
991
+ ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
992
+ if (ks[0] % 2) == 0: ks = ks[0] + 1, ks[1]
993
+ if (ks[1] % 2) == 0: ks = ks[0], ks[1] + 1
994
+
995
+ def _compute_padding(kernel_size):
996
+ computed = [k - 1 for k in kernel_size]
997
+ out_padding = 2 * len(kernel_size) * [0]
998
+ for i in range(len(kernel_size)):
999
+ computed_tmp = computed[-(i + 1)]
1000
+ pad_front = computed_tmp // 2
1001
+ pad_rear = computed_tmp - pad_front
1002
+ out_padding[2 * i + 0] = pad_front
1003
+ out_padding[2 * i + 1] = pad_rear
1004
+ return out_padding
1005
+
1006
+ def _filter2d(input_tensor, kernel):
1007
+ b, c, h, w = input_tensor.shape
1008
+ tmp_kernel = kernel[:, None, ...].to(device=input_tensor.device, dtype=input_tensor.dtype)
1009
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
1010
+ height, width = tmp_kernel.shape[-2:]
1011
+ padding_shape = _compute_padding([height, width])
1012
+ input_tensor_padded = F.pad(input_tensor, padding_shape, mode="reflect")
1013
+ tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
1014
+ input_tensor_padded = input_tensor_padded.view(-1, tmp_kernel.size(0), input_tensor_padded.size(-2),
1015
+ input_tensor_padded.size(-1))
1016
+ output = F.conv2d(input_tensor_padded, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
1017
+ return output.view(b, c, h, w)
1018
+
1019
+ def _gaussian(window_size, sigma):
1020
+ if isinstance(sigma, float):
1021
+ sigma = torch.tensor([[sigma]])
1022
+ x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(
1023
+ sigma.shape[0], -1)
1024
+ if window_size % 2 == 0:
1025
+ x = x + 0.5
1026
+ gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
1027
+ return gauss / gauss.sum(-1, keepdim=True)
1028
+
1029
+ def _gaussian_blur2d(input_tensor, kernel_size, sigma):
1030
+ if isinstance(sigma, tuple):
1031
+ sigma = torch.tensor([sigma], dtype=input_tensor.dtype)
1032
+ else:
1033
+ sigma = sigma.to(dtype=input_tensor.dtype)
1034
+ ky, kx = int(kernel_size[0]), int(kernel_size[1])
1035
+ bs = sigma.shape[0]
1036
+ kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
1037
+ kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
1038
+ out_x = _filter2d(input_tensor, kernel_x[..., None, :])
1039
+ return _filter2d(out_x, kernel_y[..., None])
1040
+
1041
+ blurred_input = _gaussian_blur2d(input_tensor, ks, sigmas)
1042
+ return F.interpolate(blurred_input, size=size, mode=interpolation, align_corners=align_corners)