| |
|
|
| import inspect |
| from dataclasses import dataclass |
| from typing import Callable, Dict, List, Optional, Union |
|
|
| import numpy as np |
| import PIL.Image |
| import torch |
| from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection |
|
|
| from diffusers.image_processor import PipelineImageInput |
| from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel |
| from diffusers.schedulers import EulerDiscreteScheduler |
| from diffusers.utils import BaseOutput, logging, replace_example_docstring |
| from diffusers.utils.torch_utils import randn_tensor |
| from diffusers.video_processor import VideoProcessor |
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline |
|
|
| |
| from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import ( |
| _append_dims, |
| retrieve_timesteps, |
| _resize_with_antialiasing, |
| ) |
| import torch.nn.functional as F |
| from einops import rearrange |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| EXAMPLE_DOC_STRING = """ |
| Examples: |
| ```py |
| >>> from pipeline_svd_masked import StableVideoDiffusionPipelineWithMask |
| >>> from diffusers.utils import load_image, export_to_video |
| |
| >>> # Load your fine-tuned UNet, VAE, etc. |
| >>> pipe = StableVideoDiffusionPipelineWithMask.from_pretrained( |
| ... "path/to/your/finetuned_model", torch_dtype=torch.float16, variant="fp16" |
| ... ) |
| >>> pipe.to("cuda") |
| |
| >>> # Load the conditioning image and the mask |
| >>> image = load_image("path/to/your/conditioning_image.png").resize((1024, 576)) |
| >>> mask = load_image("path/to/your/mask_image.png").resize((1024, 576)) |
| |
| >>> # Generate frames |
| >>> frames = pipe( |
| ... image=image, |
| ... mask_image=mask, |
| ... num_frames=25, |
| ... decode_chunk_size=8 |
| ... ).frames[0] |
| |
| >>> export_to_video(frames, "generated_video.mp4", fps=7) |
| ``` |
| """ |
|
|
|
|
| @dataclass |
| class StableVideoDiffusionPipelineOutput(BaseOutput): |
| r""" |
| Output class for the custom Stable Video Diffusion pipeline. |
| Args: |
| frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.Tensor`]): |
| List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape |
| `(batch_size, num_frames, height, width, num_channels)`. |
| """ |
| frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.Tensor] |
|
|
|
|
| class StableVideoDiffusionPipelineWithMask(DiffusionPipeline): |
| r""" |
| A custom pipeline based on Stable Video Diffusion that accepts an additional mask for conditioning. |
| This pipeline is designed to work with a UNet fine-tuned to accept 12 input channels |
| (4 for noise, 4 for VAE-encoded condition image, 4 for VAE-encoded mask). |
| """ |
|
|
| model_cpu_offload_seq = "image_encoder->unet->vae" |
| _callback_tensor_inputs = ["latents"] |
|
|
| def __init__( |
| self, |
| vae: AutoencoderKLTemporalDecoder, |
| image_encoder: CLIPVisionModelWithProjection, |
| unet: UNetSpatioTemporalConditionModel, |
| scheduler: EulerDiscreteScheduler, |
| feature_extractor: CLIPImageProcessor, |
| ): |
| super().__init__() |
|
|
| self.register_modules( |
| vae=vae, |
| image_encoder=image_encoder, |
| unet=unet, |
| scheduler=scheduler, |
| feature_extractor=feature_extractor, |
| ) |
| self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
| self.video_processor = VideoProcessor(do_resize=True, vae_scale_factor=self.vae_scale_factor) |
|
|
| def _encode_image( |
| self, |
| image: PipelineImageInput, |
| device: Union[str, torch.device], |
| num_videos_per_prompt: int, |
| ) -> torch.Tensor: |
| dtype = next(self.image_encoder.parameters()).dtype |
|
|
| if not isinstance(image, torch.Tensor): |
| image = self.video_processor.pil_to_numpy(image) |
| image = self.video_processor.numpy_to_pt(image) |
|
|
| image = image * 2.0 - 1.0 |
| image = _resize_with_antialiasing(image, (224, 224)) |
| image = (image + 1.0) / 2.0 |
|
|
| image = self.feature_extractor( |
| images=image, |
| do_normalize=True, |
| do_center_crop=False, |
| do_resize=False, |
| do_rescale=False, |
| return_tensors="pt", |
| ).pixel_values |
|
|
| image = image.to(device=device, dtype=dtype) |
| image_embeddings = self.image_encoder(image).image_embeds |
| image_embeddings = image_embeddings.unsqueeze(1) |
|
|
| bs_embed, seq_len, _ = image_embeddings.shape |
| image_embeddings = image_embeddings |
| |
| image_embeddings = torch.zeros_like(image_embeddings) |
|
|
| return image_embeddings |
|
|
| def _encode_vae_image( |
| self, |
| image: torch.Tensor, |
| device: Union[str, torch.device], |
| num_videos_per_prompt: int, |
| ): |
| image = image.to(device=device, dtype=torch.float16) |
| image_latents = self.vae.encode(image).latent_dist.sample() |
| image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) |
| return image_latents |
|
|
| def _get_add_time_ids( |
| self, |
| fps: int, |
| motion_bucket_id: int, |
| noise_aug_strength: float, |
| dtype: torch.dtype, |
| batch_size: int, |
| num_videos_per_prompt: int, |
| ): |
| add_time_ids = [fps, motion_bucket_id, noise_aug_strength] |
| passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids) |
| expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features |
| if expected_add_embed_dim != passed_add_embed_dim: |
| raise ValueError( |
| f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created." |
| ) |
| add_time_ids = torch.tensor([add_time_ids], dtype=dtype) |
| add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) |
| return add_time_ids |
|
|
| def decode_latents(self, latents: torch.Tensor, num_frames: int, decode_chunk_size: int = 14): |
| latents = latents.flatten(0, 1).to(dtype=torch.float16) |
| latents = 1 / self.vae.config.scaling_factor * latents |
| frames = [] |
| for i in range(0, latents.shape[0], decode_chunk_size): |
| num_frames_in = latents[i: i + decode_chunk_size].shape[0] |
| frame = self.vae.decode(latents[i: i + decode_chunk_size], num_frames=num_frames_in).sample |
| frames.append(frame) |
| frames = torch.cat(frames, dim=0) |
| frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4) |
| frames = frames.float() |
| return frames |
|
|
| def check_inputs(self, image, height, width): |
| if ( |
| not isinstance(image, torch.Tensor) |
| and not isinstance(image, PIL.Image.Image) |
| and not isinstance(image, list) |
| ): |
| raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") |
| if height % 8 != 0 or width % 8 != 0: |
| raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
|
|
| def prepare_latents( |
| self, |
| batch_size: int, |
| num_frames: int, |
| height: int, |
| width: int, |
| dtype: torch.dtype, |
| device: Union[str, torch.device], |
| generator: torch.Generator, |
| latents: Optional[torch.Tensor] = None, |
| initial_latents: Optional[torch.Tensor] = None, |
| denoising_strength: float = 1.0, |
| timestep: Optional[torch.Tensor] = None, |
| ): |
| num_channels_latents = self.unet.config.out_channels |
| shape = ( |
| batch_size, |
| num_frames, |
| num_channels_latents, |
| height // self.vae_scale_factor, |
| width // self.vae_scale_factor, |
| ) |
|
|
| if initial_latents is not None: |
| |
| noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
| |
| latents = self.scheduler.add_noise(initial_latents, noise, timestep) |
| else: |
| |
| if latents is None: |
| latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
| else: |
| latents = latents.to(device) |
| |
| latents = latents * self.scheduler.init_noise_sigma |
|
|
| return latents |
|
|
| def _encode_video_vae( |
| self, |
| video_frames: torch.Tensor, |
| device: Union[str, torch.device], |
| ): |
| video_frames = video_frames.to(device=device, dtype=self.vae.dtype) |
| batch_size, num_frames = video_frames.shape[:2] |
|
|
| |
| video_frames_reshaped = video_frames.reshape(batch_size * num_frames, *video_frames.shape[2:]) |
| latents = self.vae.encode(video_frames_reshaped).latent_dist.sample() |
|
|
| |
| latents = latents.reshape(batch_size, num_frames, *latents.shape[1:]) |
|
|
| return latents |
|
|
| @torch.no_grad() |
| def __call__( |
| self, |
| image: Union[List[PIL.Image.Image], torch.Tensor], |
| mask_image: Union[List[PIL.Image.Image], torch.Tensor], |
| alpha_matte_image: Optional[Union[List[PIL.Image.Image], torch.Tensor]] = None, |
| denoising_strength: float = 0.7, |
| height: int = 576, |
| width: int = 1024, |
| num_frames: Optional[int] = None, |
| num_inference_steps: int = 30, |
| sigmas: Optional[List[float]] = None, |
| fps: int = 7, |
| motion_bucket_id: int = 127, |
| noise_aug_strength: float = 0.02, |
| decode_chunk_size: Optional[int] = None, |
| num_videos_per_prompt: Optional[int] = 1, |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| latents: Optional[torch.Tensor] = None, |
| output_type: Optional[str] = "pil", |
| return_dict: bool = True, |
| mask_noise_strength: float = 0.0, |
| ): |
| height = height or self.unet.config.sample_size * self.vae_scale_factor |
| width = width or self.unet.config.sample_size * self.vae_scale_factor |
|
|
| if num_frames is None: |
| if isinstance(image, list): |
| num_frames = len(image) |
| else: |
| num_frames = self.unet.config.num_frames |
|
|
| decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames |
|
|
| self.check_inputs(image, height, width) |
| self.check_inputs(mask_image, height, width) |
| if alpha_matte_image: |
| self.check_inputs(alpha_matte_image, height, width) |
|
|
| batch_size = 1 |
| device = self._execution_device |
| dtype = self.unet.dtype |
|
|
| image_for_clip = image[0] if isinstance(image, list) else image[0] |
| image_embeddings = self._encode_image(image_for_clip, device, num_videos_per_prompt) |
|
|
| fps = fps - 1 |
|
|
| image_tensor = self.video_processor.preprocess(image, height=height, width=width).to(device).unsqueeze(0) |
| mask_tensor = self.video_processor.preprocess(mask_image, height=height, width=width).to(device).unsqueeze(0) |
|
|
| noise = randn_tensor(image_tensor.shape, generator=generator, device=device, dtype=dtype) |
| image_tensor = image_tensor + noise_aug_strength * noise |
|
|
| conditional_latents = self._encode_video_vae(image_tensor, device) |
| conditional_latents = conditional_latents / self.vae.config.scaling_factor |
|
|
| if self.unet.config.in_channels == 12: |
| mask_latents = self._encode_video_vae(mask_tensor, device) |
| mask_latents = mask_latents / self.vae.config.scaling_factor |
| elif self.unet.config.in_channels == 9: |
| mask_tensor_gray = mask_tensor.mean(dim=2, keepdim=True) |
| binarized_mask = (mask_tensor_gray > 0.0).to(dtype) |
| b, f, c, h, w = binarized_mask.shape |
| binarized_mask_reshaped = binarized_mask.reshape(b * f, c, h, w) |
| target_size = (height // self.vae_scale_factor, width // self.vae_scale_factor) |
| interpolated_mask = F.interpolate( |
| binarized_mask_reshaped, |
| size=target_size, |
| mode='nearest', |
| ) |
| mask_latents = interpolated_mask.reshape(b, f, *interpolated_mask.shape[1:]) |
| else: |
| raise ValueError(f"Unsupported number of UNet input channels: {self.unet.config.in_channels}.") |
|
|
| if mask_noise_strength > 0.0: |
| mask_noise = randn_tensor(mask_latents.shape, generator=generator, device=device, dtype=dtype) |
| mask_latents = mask_latents + mask_noise_strength * mask_noise |
|
|
| added_time_ids = self._get_add_time_ids( |
| fps, motion_bucket_id, noise_aug_strength, image_embeddings.dtype, batch_size, num_videos_per_prompt |
| ) |
| added_time_ids = added_time_ids.to(device) |
|
|
| |
| timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, sigmas) |
|
|
| |
| |
| initial_latents = None |
|
|
| if alpha_matte_image is not None: |
| alpha_matte_tensor = self.video_processor.preprocess(alpha_matte_image, height=height, width=width).to( |
| device).unsqueeze(0) |
| initial_latents = self._encode_video_vae(alpha_matte_tensor, device) |
| initial_latents = initial_latents / self.vae.config.scaling_factor |
|
|
| |
| t_start = max(num_inference_steps - int(num_inference_steps * denoising_strength), 0) |
| timesteps = timesteps[t_start:] |
| |
| start_timestep = timesteps[0] |
| else: |
| start_timestep = timesteps[0] |
|
|
| latents = self.prepare_latents( |
| batch_size * num_videos_per_prompt, |
| num_frames, |
| height, |
| width, |
| dtype, |
| device, |
| generator, |
| latents, |
| initial_latents=initial_latents, |
| denoising_strength=denoising_strength, |
| timestep=start_timestep if initial_latents is not None else None, |
| ) |
|
|
| num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
| self._num_timesteps = len(timesteps) |
|
|
| with self.progress_bar(total=len(timesteps)) as progress_bar: |
| for i, t in enumerate(timesteps): |
| latent_model_input = self.scheduler.scale_model_input(latents, t) |
| latent_model_input = torch.cat([latent_model_input, conditional_latents, mask_latents], dim=2) |
|
|
| noise_pred = self.unet( |
| latent_model_input, t, encoder_hidden_states=image_embeddings, added_time_ids=added_time_ids, |
| return_dict=False |
| )[0] |
|
|
| latents = self.scheduler.step(noise_pred, t, latents).prev_sample |
|
|
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
| progress_bar.update() |
|
|
| frames = self.decode_latents(latents, num_frames, decode_chunk_size) |
| frames = self.video_processor.postprocess_video(video=frames, output_type=output_type) |
|
|
| self.maybe_free_model_hooks() |
|
|
| if not return_dict: |
| return frames |
| return StableVideoDiffusionPipelineOutput(frames=frames) |
|
|
|
|
| class StableVideoDiffusionPipelineOnestepWithMask(DiffusionPipeline): |
| r""" |
| A custom pipeline based on Stable Video Diffusion that accepts an additional mask for conditioning. |
| This pipeline is designed to work with a UNet fine-tuned to accept 12 input channels |
| (4 for noise, 4 for VAE-encoded condition image, 4 for VAE-encoded mask). |
| """ |
|
|
| model_cpu_offload_seq = "image_encoder->unet->vae" |
| _callback_tensor_inputs = ["latents"] |
|
|
| def __init__( |
| self, |
| vae: AutoencoderKLTemporalDecoder, |
| image_encoder: CLIPVisionModelWithProjection, |
| unet: UNetSpatioTemporalConditionModel, |
| scheduler: EulerDiscreteScheduler, |
| feature_extractor: CLIPImageProcessor, |
| ): |
| super().__init__() |
|
|
| self.register_modules( |
| vae=vae, |
| image_encoder=image_encoder, |
| unet=unet, |
| scheduler=scheduler, |
| feature_extractor=feature_extractor, |
| ) |
| self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
| self.video_processor = VideoProcessor(do_resize=True, vae_scale_factor=self.vae_scale_factor) |
|
|
| def _encode_image( |
| self, |
| image: PipelineImageInput, |
| device: Union[str, torch.device], |
| num_videos_per_prompt: int, |
| ) -> torch.Tensor: |
| dtype = next(self.image_encoder.parameters()).dtype |
|
|
| if not isinstance(image, torch.Tensor): |
| image = self.video_processor.pil_to_numpy(image) |
| image = self.video_processor.numpy_to_pt(image) |
|
|
| image = image * 2.0 - 1.0 |
| image = _resize_with_antialiasing(image, (224, 224)) |
| image = (image + 1.0) / 2.0 |
|
|
| image = self.feature_extractor( |
| images=image, |
| do_normalize=True, |
| do_center_crop=False, |
| do_resize=False, |
| do_rescale=False, |
| return_tensors="pt", |
| ).pixel_values |
|
|
| image = image.to(device=device, dtype=dtype) |
| image_embeddings = self.image_encoder(image).image_embeds |
| image_embeddings = image_embeddings.unsqueeze(1) |
|
|
| bs_embed, seq_len, _ = image_embeddings.shape |
| image_embeddings = image_embeddings |
| |
| image_embeddings = torch.zeros_like(image_embeddings) |
|
|
| return image_embeddings |
|
|
| def _encode_vae_image( |
| self, |
| image: torch.Tensor, |
| device: Union[str, torch.device], |
| num_videos_per_prompt: int, |
| ): |
| image = image.to(device=device, dtype=torch.float16) |
| image_latents = self.vae.encode(image).latent_dist.sample() |
| image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) |
| return image_latents |
|
|
| def _get_add_time_ids( |
| self, |
| fps: int, |
| motion_bucket_id: int, |
| noise_aug_strength: float, |
| dtype: torch.dtype, |
| batch_size: int, |
| num_videos_per_prompt: int, |
| ): |
| add_time_ids = [fps, motion_bucket_id, noise_aug_strength] |
| passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids) |
| expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features |
| if expected_add_embed_dim != passed_add_embed_dim: |
| raise ValueError( |
| f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created." |
| ) |
| add_time_ids = torch.tensor([add_time_ids], dtype=dtype) |
| add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) |
| return add_time_ids |
|
|
| def decode_latents(self, latents: torch.Tensor, num_frames: int, decode_chunk_size: int = 14): |
| latents = latents.flatten(0, 1).to(dtype=torch.float16) |
| latents = 1 / self.vae.config.scaling_factor * latents |
| frames = [] |
| for i in range(0, latents.shape[0], decode_chunk_size): |
| num_frames_in = latents[i: i + decode_chunk_size].shape[0] |
| frame = self.vae.decode(latents[i: i + decode_chunk_size], num_frames=num_frames_in).sample |
| frames.append(frame) |
| frames = torch.cat(frames, dim=0) |
| frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4) |
| frames = frames.float() |
| return frames |
|
|
| def check_inputs(self, image, height, width): |
| if ( |
| not isinstance(image, torch.Tensor) |
| and not isinstance(image, PIL.Image.Image) |
| and not isinstance(image, list) |
| ): |
| raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") |
| if height % 8 != 0 or width % 8 != 0: |
| raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
|
|
| def prepare_latents( |
| self, |
| batch_size: int, |
| num_frames: int, |
| height: int, |
| width: int, |
| dtype: torch.dtype, |
| device: Union[str, torch.device], |
| generator: torch.Generator, |
| latents: Optional[torch.Tensor] = None, |
| ): |
| |
| num_channels_latents = self.unet.config.out_channels |
| shape = ( |
| batch_size, |
| num_frames, |
| num_channels_latents, |
| height // self.vae_scale_factor, |
| width // self.vae_scale_factor, |
| ) |
| if isinstance(generator, list) and len(generator) != batch_size: |
| raise ValueError(f"batch size {batch_size} must match the length of the generators {len(generator)}.") |
|
|
| if latents is None: |
| latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
| else: |
| latents = latents.to(device) |
|
|
| latents = latents * self.scheduler.init_noise_sigma |
| return latents |
|
|
| def _encode_video_vae( |
| self, |
| video_frames: torch.Tensor, |
| device: Union[str, torch.device], |
| ): |
| video_frames = video_frames.to(device=device, dtype=self.vae.dtype) |
| batch_size, num_frames = video_frames.shape[:2] |
|
|
| |
| video_frames_reshaped = video_frames.reshape(batch_size * num_frames, *video_frames.shape[2:]) |
| latents = self.vae.encode(video_frames_reshaped).latent_dist.sample() |
|
|
| |
| latents = latents.reshape(batch_size, num_frames, *latents.shape[1:]) |
|
|
| return latents |
|
|
| @torch.no_grad() |
| def __call__( |
| self, |
| image: Union[List[PIL.Image.Image], torch.Tensor], |
| mask_image: Union[List[PIL.Image.Image], torch.Tensor], |
| height: int = 576, |
| width: int = 1024, |
| num_frames: Optional[int] = None, |
| fps: int = 7, |
| motion_bucket_id: int = 127, |
| noise_aug_strength: float = 0.0, |
| decode_chunk_size: Optional[int] = None, |
| num_videos_per_prompt: Optional[int] = 1, |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| latents: Optional[torch.Tensor] = None, |
| output_type: Optional[str] = "pil", |
| return_dict: bool = True, |
| mask_noise_strength: float = 0.0, |
| ): |
| height = height or self.unet.config.sample_size * self.vae_scale_factor |
| width = width or self.unet.config.sample_size * self.vae_scale_factor |
|
|
| if num_frames is None: |
| if isinstance(image, list): |
| num_frames = len(image) |
| else: |
| num_frames = self.unet.config.num_frames |
|
|
| decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames |
|
|
| self.check_inputs(image, height, width) |
| self.check_inputs(mask_image, height, width) |
| if isinstance(image, list) and isinstance(mask_image, list): |
| if len(image) != len(mask_image): |
| raise ValueError("`image` and `mask_image` must have the same number of frames.") |
| if num_frames != len(image): |
| logger.warning( |
| f"Mismatch between `num_frames` ({num_frames}) and number of input images ({len(image)}). Using {len(image)}.") |
| num_frames = len(image) |
|
|
| batch_size = 1 |
| device = self._execution_device |
| dtype = self.unet.dtype |
|
|
| image_for_clip = image[0] if isinstance(image, list) else image[0] |
| image_embeddings = self._encode_image(image_for_clip, device, num_videos_per_prompt) |
|
|
| fps = fps - 1 |
|
|
| image_tensor = self.video_processor.preprocess(image, height=height, width=width).to(device).unsqueeze(0) |
| mask_tensor = self.video_processor.preprocess(mask_image, height=height, width=width).to( |
| device).unsqueeze(0) |
|
|
| noise = randn_tensor(image_tensor.shape, generator=generator, device=device, dtype=dtype) |
| image_tensor = image_tensor + noise_aug_strength * noise |
|
|
| conditional_latents = self._encode_video_vae(image_tensor, device) |
| conditional_latents = conditional_latents / self.vae.config.scaling_factor |
|
|
| if self.unet.config.in_channels == 12: |
| mask_latents = self._encode_video_vae(mask_tensor, device) |
| mask_latents = mask_latents / self.vae.config.scaling_factor |
| elif self.unet.config.in_channels == 9: |
| mask_tensor_gray = mask_tensor.mean(dim=2, keepdim=True) |
| binarized_mask = (mask_tensor_gray > 0.0).to(dtype) |
| b, f, c, h, w = binarized_mask.shape |
| binarized_mask_reshaped = binarized_mask.reshape(b * f, c, h, w) |
| target_size = (height // self.vae_scale_factor, width // self.vae_scale_factor) |
| interpolated_mask = F.interpolate( |
| binarized_mask_reshaped, |
| size=target_size, |
| mode='nearest', |
| ) |
| mask_latents = interpolated_mask.reshape(b, f, *interpolated_mask.shape[1:]) |
| else: |
| raise ValueError( |
| f"Unsupported number of UNet input channels: {self.unet.config.in_channels}. " |
| "This pipeline only supports 9 (for interpolated mask) or 12 (for VAE mask)." |
| ) |
|
|
| if mask_noise_strength > 0.0: |
| mask_noise = randn_tensor(mask_latents.shape, generator=generator, device=device, dtype=dtype) |
| mask_latents = mask_latents + mask_noise_strength * mask_noise |
|
|
| added_time_ids = self._get_add_time_ids( |
| fps, motion_bucket_id, noise_aug_strength, image_embeddings.dtype, batch_size, num_videos_per_prompt |
| ) |
| added_time_ids = added_time_ids.to(device) |
|
|
| |
| num_channels_latents = self.unet.config.out_channels |
| shape = ( |
| batch_size * num_videos_per_prompt, |
| num_frames, |
| num_channels_latents, |
| height // self.vae_scale_factor, |
| width // self.vae_scale_factor, |
| ) |
| if latents is None: |
| latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
|
| |
| timestep = torch.tensor([1.0], dtype=dtype, device=device) |
|
|
| |
| latent_model_input = torch.cat([latents, conditional_latents, mask_latents], dim=2) |
|
|
| noise_pred = self.unet( |
| latent_model_input, timestep, encoder_hidden_states=image_embeddings, added_time_ids=added_time_ids, |
| return_dict=False |
| )[0] |
|
|
| |
| denoised_latents = noise_pred |
|
|
| frames = self.decode_latents(denoised_latents, num_frames, decode_chunk_size) |
| frames = self.video_processor.postprocess_video(video=frames, output_type=output_type) |
|
|
| self.maybe_free_model_hooks() |
|
|
| if not return_dict: |
| return frames |
| return StableVideoDiffusionPipelineOutput(frames=frames) |
|
|
|
|
| class StableVideoDiffusionPipelineWithCrossAtnnMask(DiffusionPipeline): |
| model_cpu_offload_seq = "image_encoder->unet->vae" |
| _callback_tensor_inputs = ["latents"] |
|
|
| def __init__( |
| self, |
| vae: AutoencoderKLTemporalDecoder, |
| unet: UNetSpatioTemporalConditionModel, |
| scheduler: EulerDiscreteScheduler, |
| mask_projector: torch.nn.Module, |
| |
| image_encoder: CLIPVisionModelWithProjection = None, |
| feature_extractor: CLIPImageProcessor = None, |
| ): |
| super().__init__() |
| self.register_modules( |
| vae=vae, |
| unet=unet, |
| scheduler=scheduler, |
| mask_projector=mask_projector, |
| image_encoder=image_encoder, |
| feature_extractor=feature_extractor, |
| ) |
| self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
| self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) |
|
|
| def _encode_image_vae(self, image: torch.Tensor, device: Union[str, torch.device]): |
| image = image.to(device=device, dtype=self.vae.dtype) |
| latent = self.vae.encode(image).latent_dist.sample() |
| return latent |
|
|
| def decode_latents(self, latents: torch.Tensor, num_frames: int, decode_chunk_size: int): |
| latents = latents.flatten(0, 1).to(dtype=torch.float16) |
| latents = 1 / self.vae.config.scaling_factor * latents |
| frames = [] |
| for i in range(0, latents.shape[0], decode_chunk_size): |
| frame = self.vae.decode(latents[i: i + decode_chunk_size], num_frames=decode_chunk_size).sample |
| frames.append(frame) |
|
|
| frames = torch.cat(frames, dim=0) |
| frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4) |
| frames = frames.float() |
| return frames |
|
|
| def _encode_video_vae( |
| self, |
| video_frames: torch.Tensor, |
| device: Union[str, torch.device], |
| ): |
| video_frames = video_frames.to(device=device, dtype=self.vae.dtype) |
| batch_size, num_frames = video_frames.shape[:2] |
|
|
| |
| video_frames_reshaped = video_frames.reshape(batch_size * num_frames, *video_frames.shape[2:]) |
| latents = self.vae.encode(video_frames_reshaped).latent_dist.sample() |
|
|
| |
| latents = latents.reshape(batch_size, num_frames, *latents.shape[1:]) |
|
|
| return latents |
|
|
| @torch.no_grad() |
| def __call__( |
| self, |
| image: Union[PIL.Image.Image, torch.Tensor], |
| mask_image: List[PIL.Image.Image], |
| height: int = 576, |
| width: int = 1024, |
| num_frames: Optional[int] = None, |
| num_inference_steps: int = 25, |
| fps: int = 7, |
| motion_bucket_id: int = 127, |
| noise_aug_strength: float = 0.0, |
| decode_chunk_size: Optional[int] = 8, |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| output_type: Optional[str] = "pil", |
| return_dict: bool = True, |
| ): |
| device = self._execution_device |
| dtype = self.unet.dtype |
| num_frames = num_frames if num_frames is not None else len(mask_image) |
| decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames |
|
|
| |
| image_tensor = self.video_processor.preprocess(image, height, width).to(device).unsqueeze(0) |
| conditional_latents = self._encode_video_vae(image_tensor, device) |
| conditional_latents = conditional_latents / self.vae.config.scaling_factor |
|
|
| |
| mask_tensor = self.video_processor.preprocess(mask_image, height, width) |
| if mask_tensor.shape[1] > 1: |
| mask_tensor = mask_tensor.mean(dim=1, keepdim=True) |
|
|
| |
| mask_for_projection = rearrange(mask_tensor, "f c h w -> f c h w").to(device, dtype) |
| encoder_hidden_states = self.mask_projector(mask_for_projection) |
| encoder_hidden_states = encoder_hidden_states.unsqueeze(1) |
| |
| encoder_hidden_states = encoder_hidden_states.unsqueeze(0) |
| |
| |
| encoder_hidden_states = rearrange(encoder_hidden_states, "b f s d -> (b f) s d") |
|
|
| |
| shape = (1, num_frames, self.unet.config.out_channels, height // self.vae_scale_factor, |
| width // self.vae_scale_factor) |
| latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
| if noise_aug_strength > 0: |
| latents += noise_aug_strength * randn_tensor(latents.shape, generator=generator, device=device, |
| dtype=dtype) |
| latents = latents * self.scheduler.init_noise_sigma |
|
|
| |
| |
| added_time_ids = [fps - 1, motion_bucket_id, 0.0] |
| added_time_ids = torch.tensor([added_time_ids], dtype=dtype, device=device) |
|
|
| |
| self.scheduler.set_timesteps(num_inference_steps, device=device) |
| timesteps = self.scheduler.timesteps |
|
|
| with self.progress_bar(total=num_inference_steps) as progress_bar: |
| for t in timesteps: |
| latent_model_input = self.scheduler.scale_model_input(latents, t) |
| unet_input = torch.cat([latent_model_input, conditional_latents], dim=2) |
|
|
| noise_pred = self.unet( |
| unet_input, t, encoder_hidden_states=encoder_hidden_states, added_time_ids=added_time_ids |
| ).sample |
|
|
| latents = self.scheduler.step(noise_pred, t, latents).prev_sample |
| progress_bar.update() |
|
|
| |
| frames = self.decode_latents(latents, num_frames, decode_chunk_size) |
| frames = self.video_processor.postprocess_video(video=frames, output_type=output_type) |
|
|
| if not return_dict: |
| return (frames,) |
| return StableVideoDiffusionPipelineOutput(frames=frames) |
|
|
|
|
| |
|
|
| import torch |
| import torch.nn.functional as F |
| from PIL import Image |
| from einops import rearrange |
| from torchvision import transforms |
| from diffusers import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel |
| from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection |
|
|
|
|
| class VideoInferencePipeline: |
| """ |
| A reusable pipeline for single-step video diffusion inference. |
| |
| This class encapsulates the models and the core inference logic, |
| separating it from data loading and saving, which can vary between tasks. |
| """ |
|
|
| def __init__(self, base_model_path: str, unet_checkpoint_path: str, device: str = "cuda", |
| weight_dtype: torch.dtype = torch.float16): |
| """ |
| Loads all necessary models into memory. |
| |
| Args: |
| base_model_path (str): Path to the base Stable Video Diffusion model. |
| unet_checkpoint_path (str): Path to the fine-tuned UNet checkpoint. |
| device (str): The device to run models on ('cuda' or 'cpu'). |
| weight_dtype (torch.dtype): The precision for model weights (float16 or bfloat16). |
| """ |
| print("--- Initializing Inference Pipeline and Loading Models ---") |
| self.device = torch.device(device if torch.cuda.is_available() else "cpu") |
| self.weight_dtype = weight_dtype |
|
|
| |
| try: |
| self.feature_extractor = CLIPImageProcessor.from_pretrained(base_model_path, subfolder="feature_extractor") |
| self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_model_path, |
| subfolder="image_encoder", |
| variant="fp16") |
| self.vae = AutoencoderKLTemporalDecoder.from_pretrained(base_model_path, subfolder="vae", variant="fp16") |
| self.unet = UNetSpatioTemporalConditionModel.from_pretrained(unet_checkpoint_path, subfolder="unet") |
| except Exception as e: |
| raise IOError(f"Fatal error loading models: {e}") |
|
|
| |
| self.image_encoder.to(self.device, dtype=self.weight_dtype).eval() |
| self.vae.to(self.device, dtype=self.weight_dtype).eval() |
| self.unet.to(self.device, dtype=self.weight_dtype).eval() |
|
|
| print(f"--- Models Loaded Successfully on {self.device} ---") |
|
|
| def run(self, cond_frames, mask_frames, seed=42, mask_cond_mode="vae", fps=7, motion_bucket_id=127, |
| noise_aug_strength=0.0): |
| """ |
| Runs the core inference process on a sequence of conditioning and mask frames. |
| |
| Args: |
| cond_frames (list[Image.Image]): List of PIL images for conditioning. |
| mask_frames (list[Image.Image]): List of PIL images for the masks. |
| seed (int): Random seed for generation. |
| mask_cond_mode (str): How the mask is conditioned ("vae" or "interpolate"). |
| fps (int): Frames per second to condition the model with. |
| motion_bucket_id (int): Motion bucket ID for conditioning. |
| noise_aug_strength (float): Noise augmentation strength. |
| |
| Returns: |
| list[Image.Image]: A list of the generated video frames as PIL Images. |
| """ |
| |
| cond_video_tensor = self._pil_to_tensor(cond_frames).to(self.device) |
| mask_video_tensor = self._pil_to_tensor(mask_frames).to(self.device) |
|
|
| if mask_video_tensor.shape[2] != 3: |
| mask_video_tensor = mask_video_tensor.repeat(1, 1, 3, 1, 1) |
|
|
| with torch.no_grad(): |
| |
| first_frame_tensor = cond_video_tensor[:, 0, :, :, :] |
| pixel_values_for_clip = self._resize_with_antialiasing(first_frame_tensor, (224, 224)) |
| pixel_values_for_clip = ((pixel_values_for_clip + 1.0) / 2.0).clamp(0, 1) |
| pixel_values = self.feature_extractor(images=pixel_values_for_clip, return_tensors="pt").pixel_values |
| image_embeddings = self.image_encoder(pixel_values.to(self.device, dtype=self.weight_dtype)).image_embeds |
| encoder_hidden_states = torch.zeros_like(image_embeddings).unsqueeze(1) |
|
|
| |
| cond_latents = self._tensor_to_vae_latent(cond_video_tensor.to(self.weight_dtype)) |
| cond_latents = cond_latents / self.vae.config.scaling_factor |
|
|
| if mask_cond_mode == "vae": |
| mask_latents = self._tensor_to_vae_latent(mask_video_tensor.to(self.weight_dtype)) |
| mask_latents = mask_latents / self.vae.config.scaling_factor |
| elif mask_cond_mode == "interpolate": |
| target_shape = cond_latents.shape[-2:] |
| b, t, c, h, w = mask_video_tensor.shape |
| mask_video_reshaped = rearrange(mask_video_tensor, "b t c h w -> (b t) c h w") |
| interpolated_mask = F.interpolate(mask_video_reshaped, size=target_shape, mode='bilinear', |
| align_corners=False) |
| mask_latents = rearrange(interpolated_mask, "(b t) c h w -> b t c h w", b=b) |
| else: |
| raise ValueError(f"Unknown mask_cond_mode: {mask_cond_mode}") |
|
|
| |
| generator = torch.Generator(device=self.device).manual_seed(seed) |
| noisy_latents = torch.randn(cond_latents.shape, generator=generator, device=self.device, |
| dtype=self.weight_dtype) |
| timesteps = torch.full((1,), 1.0, device=self.device, dtype=torch.long) |
| added_time_ids = self._get_add_time_ids(fps, motion_bucket_id, noise_aug_strength, batch_size=1) |
|
|
| unet_input = torch.cat([noisy_latents, cond_latents, mask_latents], dim=2) |
| pred_latents = self.unet(unet_input, timesteps, encoder_hidden_states, added_time_ids=added_time_ids).sample |
|
|
| |
| pred_latents = (1 / self.vae.config.scaling_factor) * pred_latents.squeeze(0) |
|
|
| frames = [] |
| |
| for i in range(0, pred_latents.shape[0], 8): |
| chunk = pred_latents[i: i + 8] |
| decoded_chunk = self.vae.decode(chunk, num_frames=chunk.shape[0]).sample |
| frames.append(decoded_chunk) |
|
|
| video_tensor = torch.cat(frames, dim=0) |
| video_tensor = (video_tensor / 2.0 + 0.5).clamp(0, 1).mean(dim=1, keepdim=True).repeat(1, 3, 1, 1) |
|
|
| |
| return [transforms.ToPILImage()(frame) for frame in video_tensor] |
|
|
| def _pil_to_tensor(self, frames: list[Image.Image]): |
| """Converts a list of PIL images to a normalized video tensor.""" |
| video_tensor = torch.stack([transforms.ToTensor()(f) for f in frames]).unsqueeze(0) |
| return video_tensor * 2.0 - 1.0 |
|
|
| def _tensor_to_vae_latent(self, t: torch.Tensor): |
| """Encodes a video tensor into the VAE's latent space.""" |
| video_length = t.shape[1] |
| t = rearrange(t, "b f c h w -> (b f) c h w") |
| latents = self.vae.encode(t).latent_dist.sample() |
| latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length) |
| return latents * self.vae.config.scaling_factor |
|
|
| def _get_add_time_ids(self, fps, motion_bucket_id, noise_aug_strength, batch_size): |
| """Creates the additional time IDs for conditioning the UNet.""" |
| add_time_ids_list = [fps, motion_bucket_id, noise_aug_strength] |
| passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids_list) |
| expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features |
| if expected_add_embed_dim != passed_add_embed_dim: |
| raise ValueError( |
| f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created.") |
| add_time_ids = torch.tensor([add_time_ids_list], dtype=self.weight_dtype, device=self.device) |
| return add_time_ids.repeat(batch_size, 1) |
|
|
| def _resize_with_antialiasing(self, input_tensor, size, interpolation="bicubic", align_corners=True): |
| """ |
| Resizes a tensor with anti-aliasing for CLIP input, mirroring k-diffusion. |
| This is a direct copy of the helper function from your original scripts. |
| """ |
| h, w = input_tensor.shape[-2:] |
| factors = (h / size[0], w / size[1]) |
| sigmas = (max((factors[0] - 1.0) / 2.0, 0.001), max((factors[1] - 1.0) / 2.0, 0.001)) |
| ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) |
| if (ks[0] % 2) == 0: ks = ks[0] + 1, ks[1] |
| if (ks[1] % 2) == 0: ks = ks[0], ks[1] + 1 |
|
|
| def _compute_padding(kernel_size): |
| computed = [k - 1 for k in kernel_size] |
| out_padding = 2 * len(kernel_size) * [0] |
| for i in range(len(kernel_size)): |
| computed_tmp = computed[-(i + 1)] |
| pad_front = computed_tmp // 2 |
| pad_rear = computed_tmp - pad_front |
| out_padding[2 * i + 0] = pad_front |
| out_padding[2 * i + 1] = pad_rear |
| return out_padding |
|
|
| def _filter2d(input_tensor, kernel): |
| b, c, h, w = input_tensor.shape |
| tmp_kernel = kernel[:, None, ...].to(device=input_tensor.device, dtype=input_tensor.dtype) |
| tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) |
| height, width = tmp_kernel.shape[-2:] |
| padding_shape = _compute_padding([height, width]) |
| input_tensor_padded = F.pad(input_tensor, padding_shape, mode="reflect") |
| tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) |
| input_tensor_padded = input_tensor_padded.view(-1, tmp_kernel.size(0), input_tensor_padded.size(-2), |
| input_tensor_padded.size(-1)) |
| output = F.conv2d(input_tensor_padded, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) |
| return output.view(b, c, h, w) |
|
|
| def _gaussian(window_size, sigma): |
| if isinstance(sigma, float): |
| sigma = torch.tensor([[sigma]]) |
| x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand( |
| sigma.shape[0], -1) |
| if window_size % 2 == 0: |
| x = x + 0.5 |
| gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) |
| return gauss / gauss.sum(-1, keepdim=True) |
|
|
| def _gaussian_blur2d(input_tensor, kernel_size, sigma): |
| if isinstance(sigma, tuple): |
| sigma = torch.tensor([sigma], dtype=input_tensor.dtype) |
| else: |
| sigma = sigma.to(dtype=input_tensor.dtype) |
| ky, kx = int(kernel_size[0]), int(kernel_size[1]) |
| bs = sigma.shape[0] |
| kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) |
| kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) |
| out_x = _filter2d(input_tensor, kernel_x[..., None, :]) |
| return _filter2d(out_x, kernel_y[..., None]) |
|
|
| blurred_input = _gaussian_blur2d(input_tensor, ks, sigmas) |
| return F.interpolate(blurred_input, size=size, mode=interpolation, align_corners=align_corners) |