kaihuac commited on
Commit
c0d9c68
·
verified ·
1 Parent(s): eb1a3d5

Delete unet/models/diffusion_vas/pipeline_diffusion_vas.py

Browse files
unet/models/diffusion_vas/pipeline_diffusion_vas.py DELETED
@@ -1,717 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import inspect
16
- from dataclasses import dataclass
17
- from typing import Callable, Dict, List, Optional, Union
18
-
19
- import numpy as np
20
- import PIL.Image
21
- import torch
22
- from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, CLIPTextModel, CLIPTokenizer
23
-
24
- import diffusers
25
-
26
- from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
27
- from diffusers.models import AutoencoderKLTemporalDecoder
28
-
29
- from diffusers.schedulers import EulerDiscreteScheduler
30
- from diffusers.utils import BaseOutput, logging, replace_example_docstring
31
- from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
32
- from diffusers import DiffusionPipeline
33
-
34
- from .unet_diffusion_vas import UNetSpatioTemporalConditionModel
35
-
36
-
37
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
-
39
- EXAMPLE_DOC_STRING = """
40
- Examples:
41
- ```py
42
- >>> from diffusers import StableVideoDiffusionPipeline
43
- >>> from diffusers.utils import load_image, export_to_video
44
-
45
- >>> pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
46
- >>> pipe.to("cuda")
47
-
48
- >>> image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd-docstring-example.jpeg")
49
- >>> image = image.resize((1024, 576))
50
-
51
- >>> frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
52
- >>> export_to_video(frames, "generated.mp4", fps=7)
53
- ```
54
- """
55
-
56
-
57
- def _append_dims(x, target_dims):
58
- """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
59
- dims_to_append = target_dims - x.ndim
60
- if dims_to_append < 0:
61
- raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
62
- return x[(...,) + (None,) * dims_to_append]
63
-
64
-
65
- # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
66
- def tensor2vid(video: torch.Tensor, processor: VaeImageProcessor, output_type: str = "np"):
67
- batch_size, channels, num_frames, height, width = video.shape
68
- outputs = []
69
- for batch_idx in range(batch_size):
70
- batch_vid = video[batch_idx].permute(1, 0, 2, 3)
71
- batch_output = processor.postprocess(batch_vid, output_type)
72
-
73
- outputs.append(batch_output)
74
-
75
- if output_type == "np":
76
- outputs = np.stack(outputs)
77
-
78
- elif output_type == "pt":
79
- outputs = torch.stack(outputs)
80
-
81
- elif not output_type == "pil":
82
- raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
83
-
84
- return outputs
85
-
86
-
87
- @dataclass
88
- class StableVideoDiffusionPipelineOutput(BaseOutput):
89
- r"""
90
- Output class for Stable Video Diffusion pipeline.
91
-
92
- Args:
93
- frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.FloatTensor`]):
94
- List of denoised PIL images of length `batch_size` or numpy array or torch tensor
95
- of shape `(batch_size, num_frames, height, width, num_channels)`.
96
- """
97
-
98
- frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.FloatTensor]
99
-
100
-
101
- class DiffusionVASPipeline(DiffusionPipeline):
102
- r"""
103
- Pipeline to generate video from an input image using Stable Video Diffusion.
104
-
105
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
106
- implemented for all pipelines (downloading, saving, running on a particular device, etc.).
107
-
108
- Args:
109
- vae ([`AutoencoderKLTemporalDecoder`]):
110
- Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
111
- image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
112
- Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).
113
- unet ([`UNetSpatioTemporalConditionModel`]):
114
- A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents.
115
- scheduler ([`EulerDiscreteScheduler`]):
116
- A scheduler to be used in combination with `unet` to denoise the encoded image latents.
117
- feature_extractor ([`~transformers.CLIPImageProcessor`]):
118
- A `CLIPImageProcessor` to extract features from generated images.
119
- """
120
-
121
- model_cpu_offload_seq = "image_encoder->unet->vae"
122
- _callback_tensor_inputs = ["latents"]
123
-
124
- def __init__(
125
- self,
126
- vae: AutoencoderKLTemporalDecoder,
127
- image_encoder: CLIPVisionModelWithProjection,
128
- unet: UNetSpatioTemporalConditionModel,
129
- scheduler: EulerDiscreteScheduler,
130
- feature_extractor: CLIPImageProcessor,
131
- ):
132
- super().__init__()
133
-
134
- self.register_modules(
135
- vae=vae,
136
- image_encoder=image_encoder,
137
- unet=unet,
138
- scheduler=scheduler,
139
- feature_extractor=feature_extractor,
140
- )
141
-
142
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
143
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
144
-
145
- # def _encode_prompt(
146
- # self,
147
- # prompt,
148
- # device,
149
- # do_classifier_free_guidance
150
- # ):
151
- #
152
- # dtype = next(self.image_encoder.parameters()).dtype
153
- # prompt = [prompt] if isinstance(prompt, str) else prompt
154
- # text_inputs = self.tokenizer(
155
- # prompt, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
156
- # ).input_ids
157
- #
158
- # text_inputs = text_inputs.to(self.text_encoder.device)
159
- #
160
- # text_embeddings = self.text_encoder(text_inputs, return_dict=False)[0].to(device=device,dtype=dtype)
161
- # if do_classifier_free_guidance:
162
- # negative_text_embeddings = torch.zeros_like(text_embeddings)
163
- # text_embeddings = torch.cat([negative_text_embeddings, text_embeddings])
164
- #
165
- # return text_embeddings
166
-
167
-
168
- def _encode_image(
169
- self,
170
- image: PipelineImageInput,
171
- device: Union[str, torch.device],
172
- num_videos_per_prompt: int,
173
- do_classifier_free_guidance: bool,
174
- ) -> torch.FloatTensor:
175
- dtype = next(self.image_encoder.parameters()).dtype
176
-
177
- if not isinstance(image, torch.Tensor):
178
- image = self.image_processor.pil_to_numpy(image)
179
- image = self.image_processor.numpy_to_pt(image)
180
-
181
- # We normalize the image before resizing to match with the original implementation.
182
- # Then we unnormalize it after resizing.
183
- image = image * 2.0 - 1.0
184
- image = _resize_with_antialiasing(image, (224, 224))
185
- image = (image + 1.0) / 2.0
186
-
187
- else:
188
- image = _resize_with_antialiasing(image, (224, 224))
189
- image = (image + 1.0) / 2.0
190
-
191
- # Normalize the image with for CLIP input
192
- image = self.feature_extractor(
193
- images=image,
194
- do_normalize=True,
195
- do_center_crop=False,
196
- do_resize=False,
197
- do_rescale=False,
198
- return_tensors="pt",
199
- ).pixel_values
200
-
201
- image = image.to(device=device, dtype=dtype)
202
- image_embeddings = self.image_encoder(image).image_embeds
203
- image_embeddings = image_embeddings.unsqueeze(1)
204
-
205
- # duplicate image embeddings for each generation per prompt, using mps friendly method
206
- bs_embed, seq_len, _ = image_embeddings.shape
207
- image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
208
- image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
209
-
210
- if do_classifier_free_guidance:
211
- negative_image_embeddings = torch.zeros_like(image_embeddings)
212
-
213
- # For classifier free guidance, we need to do two forward passes.
214
- # Here we concatenate the unconditional and text embeddings into a single batch
215
- # to avoid doing two forward passes
216
- image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
217
-
218
- return image_embeddings
219
-
220
- def _encode_vae_image(
221
- self,
222
- image: torch.Tensor,
223
- device: Union[str, torch.device],
224
- num_videos_per_prompt: int,
225
- do_classifier_free_guidance: bool,
226
- ):
227
- image = image.to(device=device)
228
- image_latents = self.vae.encode(image).latent_dist.mode()
229
-
230
- if do_classifier_free_guidance:
231
- negative_image_latents = torch.zeros_like(image_latents)
232
-
233
- # For classifier free guidance, we need to do two forward passes.
234
- # Here we concatenate the unconditional and text embeddings into a single batch
235
- # to avoid doing two forward passes
236
- image_latents = torch.cat([negative_image_latents, image_latents])
237
-
238
- # duplicate image_latents for each generation per prompt, using mps friendly method
239
- image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
240
-
241
- return image_latents
242
-
243
- def _get_add_time_ids(
244
- self,
245
- fps: int,
246
- motion_bucket_id: int,
247
- noise_aug_strength: float,
248
- dtype: torch.dtype,
249
- batch_size: int,
250
- num_videos_per_prompt: int,
251
- do_classifier_free_guidance: bool,
252
- ):
253
- add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
254
-
255
- passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
256
- expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
257
-
258
- if expected_add_embed_dim != passed_add_embed_dim:
259
- raise ValueError(
260
- f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
261
- )
262
-
263
- add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
264
- add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
265
-
266
- if do_classifier_free_guidance:
267
- add_time_ids = torch.cat([add_time_ids, add_time_ids])
268
-
269
- return add_time_ids
270
-
271
- def decode_latents(self, latents: torch.FloatTensor, num_frames: int, decode_chunk_size: int = 14):
272
- # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
273
- latents = latents.flatten(0, 1)
274
-
275
- latents = 1 / self.vae.config.scaling_factor * latents
276
-
277
- forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
278
- accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
279
-
280
- # decode decode_chunk_size frames at a time to avoid OOM
281
- frames = []
282
- for i in range(0, latents.shape[0], decode_chunk_size):
283
- num_frames_in = latents[i : i + decode_chunk_size].shape[0]
284
- decode_kwargs = {}
285
- if accepts_num_frames:
286
- # we only pass num_frames_in if it's expected
287
- decode_kwargs["num_frames"] = num_frames_in
288
-
289
- frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample
290
- frames.append(frame)
291
- frames = torch.cat(frames, dim=0)
292
-
293
- # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
294
- frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
295
-
296
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
297
- frames = frames.float()
298
- return frames
299
-
300
- def check_inputs(self, image, height, width):
301
- if (
302
- not isinstance(image, torch.Tensor)
303
- and not isinstance(image, PIL.Image.Image)
304
- and not isinstance(image, list)
305
- ):
306
- raise ValueError(
307
- "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
308
- f" {type(image)}"
309
- )
310
-
311
- if height % 8 != 0 or width % 8 != 0:
312
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
313
-
314
- def prepare_latents(
315
- self,
316
- batch_size: int,
317
- num_frames: int,
318
- num_channels_latents: int,
319
- height: int,
320
- width: int,
321
- dtype: torch.dtype,
322
- device: Union[str, torch.device],
323
- generator: torch.Generator,
324
- latents: Optional[torch.FloatTensor] = None,
325
- ):
326
- shape = (
327
- batch_size,
328
- num_frames,
329
- num_channels_latents // 2,
330
- height // self.vae_scale_factor,
331
- width // self.vae_scale_factor,
332
- )
333
- if isinstance(generator, list) and len(generator) != batch_size:
334
- raise ValueError(
335
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
336
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
337
- )
338
-
339
- if latents is None:
340
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
341
- else:
342
- latents = latents.to(device)
343
-
344
- # scale the initial noise by the standard deviation required by the scheduler
345
- latents = latents * self.scheduler.init_noise_sigma
346
- return latents
347
-
348
- @property
349
- def guidance_scale(self):
350
- return self._guidance_scale
351
-
352
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
353
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
354
- # corresponds to doing no classifier free guidance.
355
- @property
356
- def do_classifier_free_guidance(self):
357
- if isinstance(self.guidance_scale, (int, float)):
358
- return self.guidance_scale > 1
359
- return self.guidance_scale.max() > 1
360
-
361
- @property
362
- def num_timesteps(self):
363
- return self._num_timesteps
364
-
365
- @torch.no_grad()
366
- @replace_example_docstring(EXAMPLE_DOC_STRING)
367
- def __call__(
368
- self,
369
- images,
370
- rgb_images,
371
- height: int = 576,
372
- width: int = 1024,
373
- num_frames: Optional[int] = None,
374
- num_inference_steps: int = 25,
375
- min_guidance_scale: float = 1.5,
376
- max_guidance_scale: float = 1.5,
377
- fps: int = 7,
378
- motion_bucket_id: int = 127,
379
- noise_aug_strength: float = 0.02,
380
- decode_chunk_size: Optional[int] = None,
381
- num_videos_per_prompt: Optional[int] = 1,
382
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
383
- latents: Optional[torch.FloatTensor] = None,
384
- output_type: Optional[str] = "pil",
385
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
386
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
387
- return_dict: bool = True,
388
- ):
389
- r"""
390
- The call function to the pipeline for generation.
391
-
392
- Args:
393
- image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
394
- Image(s) to guide image generation. If you provide a tensor, the expected value range is between `[0, 1]`.
395
- height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
396
- The height in pixels of the generated image.
397
- width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
398
- The width in pixels of the generated image.
399
- num_frames (`int`, *optional*):
400
- The number of video frames to generate. Defaults to `self.unet.config.num_frames`
401
- (14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`).
402
- num_inference_steps (`int`, *optional*, defaults to 25):
403
- The number of denoising steps. More denoising steps usually lead to a higher quality video at the
404
- expense of slower inference. This parameter is modulated by `strength`.
405
- min_guidance_scale (`float`, *optional*, defaults to 1.0):
406
- The minimum guidance scale. Used for the classifier free guidance with first frame.
407
- max_guidance_scale (`float`, *optional*, defaults to 3.0):
408
- The maximum guidance scale. Used for the classifier free guidance with last frame.
409
- fps (`int`, *optional*, defaults to 7):
410
- Frames per second. The rate at which the generated images shall be exported to a video after generation.
411
- Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
412
- motion_bucket_id (`int`, *optional*, defaults to 127):
413
- Used for conditioning the amount of motion for the generation. The higher the number the more motion
414
- will be in the video.
415
- noise_aug_strength (`float`, *optional*, defaults to 0.02):
416
- The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
417
- decode_chunk_size (`int`, *optional*):
418
- The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the expense of more memory usage. By default, the decoder decodes all frames at once for maximal
419
- quality. For lower memory usage, reduce `decode_chunk_size`.
420
- num_videos_per_prompt (`int`, *optional*, defaults to 1):
421
- The number of videos to generate per prompt.
422
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
423
- A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
424
- generation deterministic.
425
- latents (`torch.FloatTensor`, *optional*):
426
- Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
427
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
428
- tensor is generated by sampling using the supplied random `generator`.
429
- output_type (`str`, *optional*, defaults to `"pil"`):
430
- The output format of the generated image. Choose between `pil`, `np` or `pt`.
431
- callback_on_step_end (`Callable`, *optional*):
432
- A function that is called at the end of each denoising step during inference. The function is called
433
- with the following arguments:
434
- `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`.
435
- `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
436
- callback_on_step_end_tensor_inputs (`List`, *optional*):
437
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
438
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
439
- `._callback_tensor_inputs` attribute of your pipeline class.
440
- return_dict (`bool`, *optional*, defaults to `True`):
441
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
442
- plain tuple.
443
-
444
- Examples:
445
-
446
- Returns:
447
- [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
448
- If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,
449
- otherwise a `tuple` of (`List[List[PIL.Image.Image]]` or `np.ndarray` or `torch.FloatTensor`) is returned.
450
- """
451
- # 0. Default height and width to unet
452
- height = height or self.unet.config.sample_size * self.vae_scale_factor
453
- width = width or self.unet.config.sample_size * self.vae_scale_factor
454
-
455
- num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
456
- decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
457
-
458
- # 1. Check inputs. Raise error if not correct
459
- self.check_inputs(images[0], height, width)
460
-
461
- # 2. Define call parameters
462
- batch_size = 1
463
- device = self._execution_device
464
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
465
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
466
- # corresponds to doing no classifier free guidance.
467
- self._guidance_scale = max_guidance_scale
468
-
469
- # 3. Encode input image
470
- image_embeddings = [self._encode_image(images[:,i,:,:,:], device, num_videos_per_prompt, self.do_classifier_free_guidance) for i in range(images.shape[1])]
471
- image_embeddings = torch.cat(image_embeddings, dim=0)
472
-
473
- # NOTE: Stable Video Diffusion was conditioned on fps - 1, which is why it is reduced here.
474
- # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
475
- fps = fps - 1
476
-
477
- # 4. Encode input image using VAE
478
-
479
- images = torch.stack([self.image_processor.preprocess(images[:,i,:,:,:], height=height, width=width).to(device) for i in range(images.shape[1])])
480
- noise = randn_tensor(images.shape, generator=generator, device=device, dtype=images.dtype)
481
- images = images + noise_aug_strength * noise
482
-
483
- rgb_images = torch.stack([self.image_processor.preprocess(rgb_images[:,i,:,:,:], height=height, width=width).to(device) for i in range(rgb_images.shape[1])])
484
- noise = randn_tensor(rgb_images.shape, generator=generator, device=device, dtype=rgb_images.dtype)
485
- rgb_images = rgb_images + noise_aug_strength * noise
486
-
487
-
488
- needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
489
- if needs_upcasting:
490
- self.vae.to(dtype=torch.float32)
491
-
492
- image_latents = torch.stack([self._encode_vae_image(
493
- image,
494
- device=device,
495
- num_videos_per_prompt=num_videos_per_prompt,
496
- do_classifier_free_guidance=self.do_classifier_free_guidance,
497
- ).to(image_embeddings.dtype) for image in images], dim=1)
498
-
499
- rgb_image_latents = torch.stack([self._encode_vae_image(
500
- rgb_image,
501
- device=device,
502
- num_videos_per_prompt=num_videos_per_prompt,
503
- do_classifier_free_guidance=self.do_classifier_free_guidance,
504
- ).to(image_embeddings.dtype) for rgb_image in rgb_images], dim=1)
505
-
506
-
507
- # cast back to fp16 if needed
508
- if needs_upcasting:
509
- self.vae.to(dtype=torch.float16)
510
-
511
- # 5. Get Added Time IDs
512
- added_time_ids = self._get_add_time_ids(
513
- fps,
514
- motion_bucket_id,
515
- noise_aug_strength,
516
- image_embeddings.dtype,
517
- batch_size,
518
- num_videos_per_prompt,
519
- self.do_classifier_free_guidance,
520
- )
521
- added_time_ids = added_time_ids.to(device)
522
-
523
- # 6. Prepare timesteps
524
- self.scheduler.set_timesteps(num_inference_steps, device=device)
525
- timesteps = self.scheduler.timesteps
526
-
527
- # 7. Prepare latent variables
528
- num_channels_latents = self.unet.config.in_channels
529
- latents = self.prepare_latents(
530
- batch_size * num_videos_per_prompt,
531
- num_frames,
532
- num_channels_latents,
533
- height,
534
- width,
535
- image_embeddings.dtype,
536
- device,
537
- generator,
538
- latents,
539
- )
540
-
541
-
542
- # 8. Prepare guidance scale
543
- guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
544
- guidance_scale = guidance_scale.to(device, latents.dtype)
545
- guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
546
- guidance_scale = _append_dims(guidance_scale, latents.ndim)
547
-
548
- self._guidance_scale = guidance_scale
549
-
550
- # 9. Denoising loop
551
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
552
- self._num_timesteps = len(timesteps)
553
- with self.progress_bar(total=num_inference_steps) as progress_bar:
554
- for i, t in enumerate(timesteps):
555
- # expand the latents if we are doing classifier free guidance
556
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
557
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
558
-
559
- # Concatenate image_latents over channels dimension
560
- latent_model_input = torch.cat([latent_model_input, image_latents, rgb_image_latents], dim=2)
561
-
562
- # predict the noise residual
563
- noise_pred = self.unet(
564
- latent_model_input,
565
- t,
566
- encoder_hidden_states=image_embeddings,
567
- added_time_ids=added_time_ids,
568
- return_dict=False
569
- )[0]
570
-
571
- # perform guidance
572
- if self.do_classifier_free_guidance:
573
- noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
574
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
575
-
576
- # compute the previous noisy sample x_t -> x_t-1
577
- latents = self.scheduler.step(noise_pred, t, latents).prev_sample
578
-
579
- if callback_on_step_end is not None:
580
- callback_kwargs = {}
581
- for k in callback_on_step_end_tensor_inputs:
582
- callback_kwargs[k] = locals()[k]
583
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
584
-
585
- latents = callback_outputs.pop("latents", latents)
586
-
587
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
588
- progress_bar.update()
589
-
590
- if not output_type == "latent":
591
- # cast back to fp16 if needed
592
-
593
- if needs_upcasting:
594
- self.vae.to(dtype=torch.float16)
595
- frames = self.decode_latents(latents, num_frames, decode_chunk_size)
596
- frames = tensor2vid(frames, self.image_processor, output_type=output_type)
597
- else:
598
- frames = latents
599
-
600
- self.maybe_free_model_hooks()
601
-
602
- if not return_dict:
603
- return frames
604
-
605
- return StableVideoDiffusionPipelineOutput(frames=frames)
606
-
607
-
608
-
609
-
610
-
611
- # resizing utils
612
- # TODO: clean up later
613
- def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
614
- h, w = input.shape[-2:]
615
- factors = (h / size[0], w / size[1])
616
-
617
- # First, we have to determine sigma
618
- # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
619
- sigmas = (
620
- max((factors[0] - 1.0) / 2.0, 0.001),
621
- max((factors[1] - 1.0) / 2.0, 0.001),
622
- )
623
-
624
- # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
625
- # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
626
- # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
627
- ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
628
-
629
- # Make sure it is odd
630
- if (ks[0] % 2) == 0:
631
- ks = ks[0] + 1, ks[1]
632
-
633
- if (ks[1] % 2) == 0:
634
- ks = ks[0], ks[1] + 1
635
-
636
- input = _gaussian_blur2d(input, ks, sigmas)
637
-
638
- output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
639
- return output
640
-
641
-
642
- def _compute_padding(kernel_size):
643
- """Compute padding tuple."""
644
- # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
645
- # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
646
- if len(kernel_size) < 2:
647
- raise AssertionError(kernel_size)
648
- computed = [k - 1 for k in kernel_size]
649
-
650
- # for even kernels we need to do asymmetric padding :(
651
- out_padding = 2 * len(kernel_size) * [0]
652
-
653
- for i in range(len(kernel_size)):
654
- computed_tmp = computed[-(i + 1)]
655
-
656
- pad_front = computed_tmp // 2
657
- pad_rear = computed_tmp - pad_front
658
-
659
- out_padding[2 * i + 0] = pad_front
660
- out_padding[2 * i + 1] = pad_rear
661
-
662
- return out_padding
663
-
664
-
665
- def _filter2d(input, kernel):
666
- # prepare kernel
667
- b, c, h, w = input.shape
668
- tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
669
-
670
- tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
671
-
672
- height, width = tmp_kernel.shape[-2:]
673
-
674
- padding_shape: list[int] = _compute_padding([height, width])
675
- input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
676
-
677
- # kernel and input tensor reshape to align element-wise or batch-wise params
678
- tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
679
- input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
680
-
681
- # convolve the tensor with the kernel.
682
- output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
683
-
684
- out = output.view(b, c, h, w)
685
- return out
686
-
687
-
688
- def _gaussian(window_size: int, sigma):
689
- if isinstance(sigma, float):
690
- sigma = torch.tensor([[sigma]])
691
-
692
- batch_size = sigma.shape[0]
693
-
694
- x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
695
-
696
- if window_size % 2 == 0:
697
- x = x + 0.5
698
-
699
- gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
700
-
701
- return gauss / gauss.sum(-1, keepdim=True)
702
-
703
-
704
- def _gaussian_blur2d(input, kernel_size, sigma):
705
- if isinstance(sigma, tuple):
706
- sigma = torch.tensor([sigma], dtype=input.dtype)
707
- else:
708
- sigma = sigma.to(dtype=input.dtype)
709
-
710
- ky, kx = int(kernel_size[0]), int(kernel_size[1])
711
- bs = sigma.shape[0]
712
- kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
713
- kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
714
- out_x = _filter2d(input, kernel_x[..., None, :])
715
- out = _filter2d(out_x, kernel_y[..., None])
716
-
717
- return out