VideoMaMa / pipeline_svd_mask.py
pizb's picture
missing file udpate
06b9d96
# pipeline_svd_masked.py
import inspect
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import torch
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from diffusers.image_processor import PipelineImageInput
from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
from diffusers.schedulers import EulerDiscreteScheduler
from diffusers.utils import BaseOutput, logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
# Import necessary helpers from the original SVD pipeline
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
_append_dims,
retrieve_timesteps,
_resize_with_antialiasing,
)
import torch.nn.functional as F
from einops import rearrange
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from pipeline_svd_masked import StableVideoDiffusionPipelineWithMask
>>> from diffusers.utils import load_image, export_to_video
>>> # Load your fine-tuned UNet, VAE, etc.
>>> pipe = StableVideoDiffusionPipelineWithMask.from_pretrained(
... "path/to/your/finetuned_model", torch_dtype=torch.float16, variant="fp16"
... )
>>> pipe.to("cuda")
>>> # Load the conditioning image and the mask
>>> image = load_image("path/to/your/conditioning_image.png").resize((1024, 576))
>>> mask = load_image("path/to/your/mask_image.png").resize((1024, 576))
>>> # Generate frames
>>> frames = pipe(
... image=image,
... mask_image=mask,
... num_frames=25,
... decode_chunk_size=8
... ).frames[0]
>>> export_to_video(frames, "generated_video.mp4", fps=7)
```
"""
@dataclass
class StableVideoDiffusionPipelineOutput(BaseOutput):
r"""
Output class for the custom Stable Video Diffusion pipeline.
Args:
frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.Tensor`]):
List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape
`(batch_size, num_frames, height, width, num_channels)`.
"""
frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.Tensor]
class StableVideoDiffusionPipelineWithMask(DiffusionPipeline):
r"""
A custom pipeline based on Stable Video Diffusion that accepts an additional mask for conditioning.
This pipeline is designed to work with a UNet fine-tuned to accept 12 input channels
(4 for noise, 4 for VAE-encoded condition image, 4 for VAE-encoded mask).
"""
model_cpu_offload_seq = "image_encoder->unet->vae"
_callback_tensor_inputs = ["latents"]
def __init__(
self,
vae: AutoencoderKLTemporalDecoder,
image_encoder: CLIPVisionModelWithProjection,
unet: UNetSpatioTemporalConditionModel,
scheduler: EulerDiscreteScheduler,
feature_extractor: CLIPImageProcessor,
):
super().__init__()
self.register_modules(
vae=vae,
image_encoder=image_encoder,
unet=unet,
scheduler=scheduler,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.video_processor = VideoProcessor(do_resize=True, vae_scale_factor=self.vae_scale_factor)
def _encode_image(
self,
image: PipelineImageInput,
device: Union[str, torch.device],
num_videos_per_prompt: int,
) -> torch.Tensor:
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.video_processor.pil_to_numpy(image)
image = self.video_processor.numpy_to_pt(image)
image = image * 2.0 - 1.0
image = _resize_with_antialiasing(image, (224, 224))
image = (image + 1.0) / 2.0
image = self.feature_extractor(
images=image,
do_normalize=True,
do_center_crop=False,
do_resize=False,
do_rescale=False,
return_tensors="pt",
).pixel_values
image = image.to(device=device, dtype=dtype)
image_embeddings = self.image_encoder(image).image_embeds
image_embeddings = image_embeddings.unsqueeze(1)
bs_embed, seq_len, _ = image_embeddings.shape
image_embeddings = image_embeddings
# As per your training script, we zero out the embedding
image_embeddings = torch.zeros_like(image_embeddings)
return image_embeddings
def _encode_vae_image(
self,
image: torch.Tensor,
device: Union[str, torch.device],
num_videos_per_prompt: int,
):
image = image.to(device=device, dtype=torch.float16)
image_latents = self.vae.encode(image).latent_dist.sample()
image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
return image_latents
def _get_add_time_ids(
self,
fps: int,
motion_bucket_id: int,
noise_aug_strength: float,
dtype: torch.dtype,
batch_size: int,
num_videos_per_prompt: int,
):
add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created."
)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
return add_time_ids
def decode_latents(self, latents: torch.Tensor, num_frames: int, decode_chunk_size: int = 14):
latents = latents.flatten(0, 1).to(dtype=torch.float16)
latents = 1 / self.vae.config.scaling_factor * latents
frames = []
for i in range(0, latents.shape[0], decode_chunk_size):
num_frames_in = latents[i: i + decode_chunk_size].shape[0]
frame = self.vae.decode(latents[i: i + decode_chunk_size], num_frames=num_frames_in).sample
frames.append(frame)
frames = torch.cat(frames, dim=0)
frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
frames = frames.float()
return frames
def check_inputs(self, image, height, width):
if (
not isinstance(image, torch.Tensor)
and not isinstance(image, PIL.Image.Image)
and not isinstance(image, list)
):
raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
def prepare_latents(
self,
batch_size: int,
num_frames: int,
height: int,
width: int,
dtype: torch.dtype,
device: Union[str, torch.device],
generator: torch.Generator,
latents: Optional[torch.Tensor] = None,
initial_latents: Optional[torch.Tensor] = None,
denoising_strength: float = 1.0,
timestep: Optional[torch.Tensor] = None,
):
num_channels_latents = self.unet.config.out_channels
shape = (
batch_size,
num_frames,
num_channels_latents,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if initial_latents is not None:
# Noise is added to the initial latents
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# Get the initial latents at the given timestep
latents = self.scheduler.add_noise(initial_latents, noise, timestep)
else:
# Standard pure noise generation
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# Scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def _encode_video_vae(
self,
video_frames: torch.Tensor, # Expects (B, F, C, H, W)
device: Union[str, torch.device],
):
video_frames = video_frames.to(device=device, dtype=self.vae.dtype)
batch_size, num_frames = video_frames.shape[:2]
# Reshape for VAE encoding
video_frames_reshaped = video_frames.reshape(batch_size * num_frames, *video_frames.shape[2:]) # (B*F, C, H, W)
latents = self.vae.encode(video_frames_reshaped).latent_dist.sample() # (B*F, C_latent, H_latent, W_latent)
# Reshape back to video format
latents = latents.reshape(batch_size, num_frames, *latents.shape[1:]) # (B, F, C_latent, H_latent, W_latent)
return latents
@torch.no_grad()
def __call__(
self,
image: Union[List[PIL.Image.Image], torch.Tensor],
mask_image: Union[List[PIL.Image.Image], torch.Tensor],
alpha_matte_image: Optional[Union[List[PIL.Image.Image], torch.Tensor]] = None,
denoising_strength: float = 0.7,
height: int = 576,
width: int = 1024,
num_frames: Optional[int] = None,
num_inference_steps: int = 30,
sigmas: Optional[List[float]] = None,
fps: int = 7,
motion_bucket_id: int = 127,
noise_aug_strength: float = 0.02,
decode_chunk_size: Optional[int] = None,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
mask_noise_strength: float = 0.0,
):
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if num_frames is None:
if isinstance(image, list):
num_frames = len(image)
else:
num_frames = self.unet.config.num_frames
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
self.check_inputs(image, height, width)
self.check_inputs(mask_image, height, width)
if alpha_matte_image:
self.check_inputs(alpha_matte_image, height, width)
batch_size = 1
device = self._execution_device
dtype = self.unet.dtype
image_for_clip = image[0] if isinstance(image, list) else image[0]
image_embeddings = self._encode_image(image_for_clip, device, num_videos_per_prompt)
fps = fps - 1
image_tensor = self.video_processor.preprocess(image, height=height, width=width).to(device).unsqueeze(0)
mask_tensor = self.video_processor.preprocess(mask_image, height=height, width=width).to(device).unsqueeze(0)
noise = randn_tensor(image_tensor.shape, generator=generator, device=device, dtype=dtype)
image_tensor = image_tensor + noise_aug_strength * noise
conditional_latents = self._encode_video_vae(image_tensor, device)
conditional_latents = conditional_latents / self.vae.config.scaling_factor
if self.unet.config.in_channels == 12:
mask_latents = self._encode_video_vae(mask_tensor, device)
mask_latents = mask_latents / self.vae.config.scaling_factor
elif self.unet.config.in_channels == 9:
mask_tensor_gray = mask_tensor.mean(dim=2, keepdim=True)
binarized_mask = (mask_tensor_gray > 0.0).to(dtype)
b, f, c, h, w = binarized_mask.shape
binarized_mask_reshaped = binarized_mask.reshape(b * f, c, h, w)
target_size = (height // self.vae_scale_factor, width // self.vae_scale_factor)
interpolated_mask = F.interpolate(
binarized_mask_reshaped,
size=target_size,
mode='nearest',
)
mask_latents = interpolated_mask.reshape(b, f, *interpolated_mask.shape[1:])
else:
raise ValueError(f"Unsupported number of UNet input channels: {self.unet.config.in_channels}.")
if mask_noise_strength > 0.0:
mask_noise = randn_tensor(mask_latents.shape, generator=generator, device=device, dtype=dtype)
mask_latents = mask_latents + mask_noise_strength * mask_noise
added_time_ids = self._get_add_time_ids(
fps, motion_bucket_id, noise_aug_strength, image_embeddings.dtype, batch_size, num_videos_per_prompt
)
added_time_ids = added_time_ids.to(device)
# --- MODIFIED FOR ALPHA MATTE REFINEMENT ---
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, sigmas)
# self.scheduler.set_timesteps(num_inference_steps, device=device)
# timesteps = self.scheduler.timesteps
initial_latents = None
if alpha_matte_image is not None:
alpha_matte_tensor = self.video_processor.preprocess(alpha_matte_image, height=height, width=width).to(
device).unsqueeze(0)
initial_latents = self._encode_video_vae(alpha_matte_tensor, device)
initial_latents = initial_latents / self.vae.config.scaling_factor
# Adjust the number of steps and the timesteps to start from
t_start = max(num_inference_steps - int(num_inference_steps * denoising_strength), 0)
timesteps = timesteps[t_start:]
# We need the first timestep to add the correct amount of noise
start_timestep = timesteps[0]
else:
start_timestep = timesteps[0] # Not used, but for clarity
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_frames,
height,
width,
dtype,
device,
generator,
latents,
initial_latents=initial_latents,
denoising_strength=denoising_strength,
timestep=start_timestep if initial_latents is not None else None,
)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=len(timesteps)) as progress_bar:
for i, t in enumerate(timesteps):
latent_model_input = self.scheduler.scale_model_input(latents, t)
latent_model_input = torch.cat([latent_model_input, conditional_latents, mask_latents], dim=2)
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=image_embeddings, added_time_ids=added_time_ids,
return_dict=False
)[0]
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
frames = self.decode_latents(latents, num_frames, decode_chunk_size)
frames = self.video_processor.postprocess_video(video=frames, output_type=output_type)
self.maybe_free_model_hooks()
if not return_dict:
return frames
return StableVideoDiffusionPipelineOutput(frames=frames)
class StableVideoDiffusionPipelineOnestepWithMask(DiffusionPipeline):
r"""
A custom pipeline based on Stable Video Diffusion that accepts an additional mask for conditioning.
This pipeline is designed to work with a UNet fine-tuned to accept 12 input channels
(4 for noise, 4 for VAE-encoded condition image, 4 for VAE-encoded mask).
"""
model_cpu_offload_seq = "image_encoder->unet->vae"
_callback_tensor_inputs = ["latents"]
def __init__(
self,
vae: AutoencoderKLTemporalDecoder,
image_encoder: CLIPVisionModelWithProjection,
unet: UNetSpatioTemporalConditionModel,
scheduler: EulerDiscreteScheduler,
feature_extractor: CLIPImageProcessor,
):
super().__init__()
self.register_modules(
vae=vae,
image_encoder=image_encoder,
unet=unet,
scheduler=scheduler,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.video_processor = VideoProcessor(do_resize=True, vae_scale_factor=self.vae_scale_factor)
def _encode_image(
self,
image: PipelineImageInput,
device: Union[str, torch.device],
num_videos_per_prompt: int,
) -> torch.Tensor:
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.video_processor.pil_to_numpy(image)
image = self.video_processor.numpy_to_pt(image)
image = image * 2.0 - 1.0
image = _resize_with_antialiasing(image, (224, 224))
image = (image + 1.0) / 2.0
image = self.feature_extractor(
images=image,
do_normalize=True,
do_center_crop=False,
do_resize=False,
do_rescale=False,
return_tensors="pt",
).pixel_values
image = image.to(device=device, dtype=dtype)
image_embeddings = self.image_encoder(image).image_embeds
image_embeddings = image_embeddings.unsqueeze(1)
bs_embed, seq_len, _ = image_embeddings.shape
image_embeddings = image_embeddings
# As per your training script, we zero out the embedding
image_embeddings = torch.zeros_like(image_embeddings)
return image_embeddings
def _encode_vae_image(
self,
image: torch.Tensor,
device: Union[str, torch.device],
num_videos_per_prompt: int,
):
image = image.to(device=device, dtype=torch.float16)
image_latents = self.vae.encode(image).latent_dist.sample()
image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
return image_latents
def _get_add_time_ids(
self,
fps: int,
motion_bucket_id: int,
noise_aug_strength: float,
dtype: torch.dtype,
batch_size: int,
num_videos_per_prompt: int,
):
add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created."
)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
return add_time_ids
def decode_latents(self, latents: torch.Tensor, num_frames: int, decode_chunk_size: int = 14):
latents = latents.flatten(0, 1).to(dtype=torch.float16)
latents = 1 / self.vae.config.scaling_factor * latents
frames = []
for i in range(0, latents.shape[0], decode_chunk_size):
num_frames_in = latents[i: i + decode_chunk_size].shape[0]
frame = self.vae.decode(latents[i: i + decode_chunk_size], num_frames=num_frames_in).sample
frames.append(frame)
frames = torch.cat(frames, dim=0)
frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
frames = frames.float()
return frames
def check_inputs(self, image, height, width):
if (
not isinstance(image, torch.Tensor)
and not isinstance(image, PIL.Image.Image)
and not isinstance(image, list)
):
raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
def prepare_latents(
self,
batch_size: int,
num_frames: int,
height: int,
width: int,
dtype: torch.dtype,
device: Union[str, torch.device],
generator: torch.Generator,
latents: Optional[torch.Tensor] = None,
):
# The number of channels for the initial noise is based on the UNet's out_channels
num_channels_latents = self.unet.config.out_channels
shape = (
batch_size,
num_frames,
num_channels_latents,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(f"batch size {batch_size} must match the length of the generators {len(generator)}.")
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
latents = latents * self.scheduler.init_noise_sigma
return latents
def _encode_video_vae(
self,
video_frames: torch.Tensor, # Expects (B, F, C, H, W)
device: Union[str, torch.device],
):
video_frames = video_frames.to(device=device, dtype=self.vae.dtype)
batch_size, num_frames = video_frames.shape[:2]
# Reshape for VAE encoding
video_frames_reshaped = video_frames.reshape(batch_size * num_frames, *video_frames.shape[2:]) # (B*F, C, H, W)
latents = self.vae.encode(video_frames_reshaped).latent_dist.sample() # (B*F, C_latent, H_latent, W_latent)
# Reshape back to video format
latents = latents.reshape(batch_size, num_frames, *latents.shape[1:]) # (B, F, C_latent, H_latent, W_latent)
return latents
@torch.no_grad()
def __call__(
self,
image: Union[List[PIL.Image.Image], torch.Tensor],
mask_image: Union[List[PIL.Image.Image], torch.Tensor],
height: int = 576,
width: int = 1024,
num_frames: Optional[int] = None,
fps: int = 7,
motion_bucket_id: int = 127,
noise_aug_strength: float = 0.0,
decode_chunk_size: Optional[int] = None,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
mask_noise_strength: float = 0.0,
):
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if num_frames is None:
if isinstance(image, list):
num_frames = len(image)
else:
num_frames = self.unet.config.num_frames
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
self.check_inputs(image, height, width)
self.check_inputs(mask_image, height, width)
if isinstance(image, list) and isinstance(mask_image, list):
if len(image) != len(mask_image):
raise ValueError("`image` and `mask_image` must have the same number of frames.")
if num_frames != len(image):
logger.warning(
f"Mismatch between `num_frames` ({num_frames}) and number of input images ({len(image)}). Using {len(image)}.")
num_frames = len(image)
batch_size = 1
device = self._execution_device
dtype = self.unet.dtype
image_for_clip = image[0] if isinstance(image, list) else image[0]
image_embeddings = self._encode_image(image_for_clip, device, num_videos_per_prompt)
fps = fps - 1
image_tensor = self.video_processor.preprocess(image, height=height, width=width).to(device).unsqueeze(0)
mask_tensor = self.video_processor.preprocess(mask_image, height=height, width=width).to(
device).unsqueeze(0)
noise = randn_tensor(image_tensor.shape, generator=generator, device=device, dtype=dtype)
image_tensor = image_tensor + noise_aug_strength * noise
conditional_latents = self._encode_video_vae(image_tensor, device)
conditional_latents = conditional_latents / self.vae.config.scaling_factor
if self.unet.config.in_channels == 12:
mask_latents = self._encode_video_vae(mask_tensor, device)
mask_latents = mask_latents / self.vae.config.scaling_factor
elif self.unet.config.in_channels == 9:
mask_tensor_gray = mask_tensor.mean(dim=2, keepdim=True)
binarized_mask = (mask_tensor_gray > 0.0).to(dtype)
b, f, c, h, w = binarized_mask.shape
binarized_mask_reshaped = binarized_mask.reshape(b * f, c, h, w)
target_size = (height // self.vae_scale_factor, width // self.vae_scale_factor)
interpolated_mask = F.interpolate(
binarized_mask_reshaped,
size=target_size,
mode='nearest',
)
mask_latents = interpolated_mask.reshape(b, f, *interpolated_mask.shape[1:])
else:
raise ValueError(
f"Unsupported number of UNet input channels: {self.unet.config.in_channels}. "
"This pipeline only supports 9 (for interpolated mask) or 12 (for VAE mask)."
)
if mask_noise_strength > 0.0:
mask_noise = randn_tensor(mask_latents.shape, generator=generator, device=device, dtype=dtype)
mask_latents = mask_latents + mask_noise_strength * mask_noise
added_time_ids = self._get_add_time_ids(
fps, motion_bucket_id, noise_aug_strength, image_embeddings.dtype, batch_size, num_videos_per_prompt
)
added_time_ids = added_time_ids.to(device)
# **MODIFIED FOR SINGLE-STEP**: Prepare initial noise
num_channels_latents = self.unet.config.out_channels
shape = (
batch_size * num_videos_per_prompt,
num_frames,
num_channels_latents,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# **MODIFIED FOR SINGLE-STEP**: Set a fixed high timestep
timestep = torch.tensor([1.0], dtype=dtype, device=device) # Use a high sigma value
# **MODIFIED FOR SINGLE-STEP**: Single forward pass
latent_model_input = torch.cat([latents, conditional_latents, mask_latents], dim=2)
noise_pred = self.unet(
latent_model_input, timestep, encoder_hidden_states=image_embeddings, added_time_ids=added_time_ids,
return_dict=False
)[0]
# The model's prediction is the final denoised latent
denoised_latents = noise_pred
frames = self.decode_latents(denoised_latents, num_frames, decode_chunk_size)
frames = self.video_processor.postprocess_video(video=frames, output_type=output_type)
self.maybe_free_model_hooks()
if not return_dict:
return frames
return StableVideoDiffusionPipelineOutput(frames=frames)
class StableVideoDiffusionPipelineWithCrossAtnnMask(DiffusionPipeline):
model_cpu_offload_seq = "image_encoder->unet->vae"
_callback_tensor_inputs = ["latents"]
def __init__(
self,
vae: AutoencoderKLTemporalDecoder,
unet: UNetSpatioTemporalConditionModel,
scheduler: EulerDiscreteScheduler,
mask_projector: torch.nn.Module,
# CLIP models are not strictly needed for inference if embeddings are not used
image_encoder: CLIPVisionModelWithProjection = None,
feature_extractor: CLIPImageProcessor = None,
):
super().__init__()
self.register_modules(
vae=vae,
unet=unet,
scheduler=scheduler,
mask_projector=mask_projector,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
def _encode_image_vae(self, image: torch.Tensor, device: Union[str, torch.device]):
image = image.to(device=device, dtype=self.vae.dtype)
latent = self.vae.encode(image).latent_dist.sample()
return latent
def decode_latents(self, latents: torch.Tensor, num_frames: int, decode_chunk_size: int):
latents = latents.flatten(0, 1).to(dtype=torch.float16)
latents = 1 / self.vae.config.scaling_factor * latents
frames = []
for i in range(0, latents.shape[0], decode_chunk_size):
frame = self.vae.decode(latents[i: i + decode_chunk_size], num_frames=decode_chunk_size).sample
frames.append(frame)
frames = torch.cat(frames, dim=0)
frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
frames = frames.float()
return frames
def _encode_video_vae(
self,
video_frames: torch.Tensor, # Expects (B, F, C, H, W)
device: Union[str, torch.device],
):
video_frames = video_frames.to(device=device, dtype=self.vae.dtype)
batch_size, num_frames = video_frames.shape[:2]
# Reshape for VAE encoding
video_frames_reshaped = video_frames.reshape(batch_size * num_frames, *video_frames.shape[2:]) # (B*F, C, H, W)
latents = self.vae.encode(video_frames_reshaped).latent_dist.sample() # (B*F, C_latent, H_latent, W_latent)
# Reshape back to video format
latents = latents.reshape(batch_size, num_frames, *latents.shape[1:]) # (B, F, C_latent, H_latent, W_latent)
return latents
@torch.no_grad()
def __call__(
self,
image: Union[PIL.Image.Image, torch.Tensor], # Static image for appearance
mask_image: List[PIL.Image.Image], # Video mask for motion
height: int = 576,
width: int = 1024,
num_frames: Optional[int] = None,
num_inference_steps: int = 25,
fps: int = 7,
motion_bucket_id: int = 127,
noise_aug_strength: float = 0.0, # Noise is added to latents now
decode_chunk_size: Optional[int] = 8,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
device = self._execution_device
dtype = self.unet.dtype
num_frames = num_frames if num_frames is not None else len(mask_image)
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
# 1. PREPARE STATIC IMAGE CONDITION
image_tensor = self.video_processor.preprocess(image, height, width).to(device).unsqueeze(0)
conditional_latents = self._encode_video_vae(image_tensor, device)
conditional_latents = conditional_latents / self.vae.config.scaling_factor
# 2. PREPARE MASK MOTION CONDITION
mask_tensor = self.video_processor.preprocess(mask_image, height, width)
if mask_tensor.shape[1] > 1:
mask_tensor = mask_tensor.mean(dim=1, keepdim=True)
# Reshape for projector: (T, C, H, W)
mask_for_projection = rearrange(mask_tensor, "f c h w -> f c h w").to(device, dtype)
encoder_hidden_states = self.mask_projector(mask_for_projection)
encoder_hidden_states = encoder_hidden_states.unsqueeze(1) # (T, 1, D)
# Add batch dimension for UNet
encoder_hidden_states = encoder_hidden_states.unsqueeze(0) # (1, T, 1, D)
# The UNet will handle flattening this to (B*T, 1, D) where B=1
# To be safe, we pass it pre-flattened.
encoder_hidden_states = rearrange(encoder_hidden_states, "b f s d -> (b f) s d")
# 3. PREPARE LATENTS
shape = (1, num_frames, self.unet.config.out_channels, height // self.vae_scale_factor,
width // self.vae_scale_factor)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if noise_aug_strength > 0:
latents += noise_aug_strength * randn_tensor(latents.shape, generator=generator, device=device,
dtype=dtype)
latents = latents * self.scheduler.init_noise_sigma
# 4. GET ADDED TIME IDS
# For pipeline, batch size is 1
added_time_ids = [fps - 1, motion_bucket_id, 0.0] # noise_aug_strength for add_time_ids is 0 for inference
added_time_ids = torch.tensor([added_time_ids], dtype=dtype, device=device)
# 5. DENOISING LOOP
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
with self.progress_bar(total=num_inference_steps) as progress_bar:
for t in timesteps:
latent_model_input = self.scheduler.scale_model_input(latents, t)
unet_input = torch.cat([latent_model_input, conditional_latents], dim=2)
noise_pred = self.unet(
unet_input, t, encoder_hidden_states=encoder_hidden_states, added_time_ids=added_time_ids
).sample
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
progress_bar.update()
# 6. DECODE
frames = self.decode_latents(latents, num_frames, decode_chunk_size)
frames = self.video_processor.postprocess_video(video=frames, output_type=output_type)
if not return_dict:
return (frames,)
return StableVideoDiffusionPipelineOutput(frames=frames)
# pipeline.py
import torch
import torch.nn.functional as F
from PIL import Image
from einops import rearrange
from torchvision import transforms
from diffusers import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
class VideoInferencePipeline:
"""
A reusable pipeline for single-step video diffusion inference.
This class encapsulates the models and the core inference logic,
separating it from data loading and saving, which can vary between tasks.
"""
def __init__(self, base_model_path: str, unet_checkpoint_path: str, device: str = "cuda",
weight_dtype: torch.dtype = torch.float16):
"""
Loads all necessary models into memory.
Args:
base_model_path (str): Path to the base Stable Video Diffusion model.
unet_checkpoint_path (str): Path to the fine-tuned UNet checkpoint.
device (str): The device to run models on ('cuda' or 'cpu').
weight_dtype (torch.dtype): The precision for model weights (float16 or bfloat16).
"""
print("--- Initializing Inference Pipeline and Loading Models ---")
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.weight_dtype = weight_dtype
# Load models from pretrained paths
try:
self.feature_extractor = CLIPImageProcessor.from_pretrained(base_model_path, subfolder="feature_extractor")
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_model_path,
subfolder="image_encoder",
variant="fp16")
self.vae = AutoencoderKLTemporalDecoder.from_pretrained(base_model_path, subfolder="vae", variant="fp16")
self.unet = UNetSpatioTemporalConditionModel.from_pretrained(unet_checkpoint_path, subfolder="unet")
except Exception as e:
raise IOError(f"Fatal error loading models: {e}")
# Move models to the specified device and set to evaluation mode
self.image_encoder.to(self.device, dtype=self.weight_dtype).eval()
self.vae.to(self.device, dtype=self.weight_dtype).eval()
self.unet.to(self.device, dtype=self.weight_dtype).eval()
print(f"--- Models Loaded Successfully on {self.device} ---")
def run(self, cond_frames, mask_frames, seed=42, mask_cond_mode="vae", fps=7, motion_bucket_id=127,
noise_aug_strength=0.0):
"""
Runs the core inference process on a sequence of conditioning and mask frames.
Args:
cond_frames (list[Image.Image]): List of PIL images for conditioning.
mask_frames (list[Image.Image]): List of PIL images for the masks.
seed (int): Random seed for generation.
mask_cond_mode (str): How the mask is conditioned ("vae" or "interpolate").
fps (int): Frames per second to condition the model with.
motion_bucket_id (int): Motion bucket ID for conditioning.
noise_aug_strength (float): Noise augmentation strength.
Returns:
list[Image.Image]: A list of the generated video frames as PIL Images.
"""
# --- 1. Prepare Tensors ---
cond_video_tensor = self._pil_to_tensor(cond_frames).to(self.device)
mask_video_tensor = self._pil_to_tensor(mask_frames).to(self.device)
if mask_video_tensor.shape[2] != 3:
mask_video_tensor = mask_video_tensor.repeat(1, 1, 3, 1, 1)
with torch.no_grad():
# --- 2. Get CLIP Image Embeddings ---
first_frame_tensor = cond_video_tensor[:, 0, :, :, :]
pixel_values_for_clip = self._resize_with_antialiasing(first_frame_tensor, (224, 224))
pixel_values_for_clip = ((pixel_values_for_clip + 1.0) / 2.0).clamp(0, 1)
pixel_values = self.feature_extractor(images=pixel_values_for_clip, return_tensors="pt").pixel_values
image_embeddings = self.image_encoder(pixel_values.to(self.device, dtype=self.weight_dtype)).image_embeds
encoder_hidden_states = torch.zeros_like(image_embeddings).unsqueeze(1)
# --- 3. Prepare Latents ---
cond_latents = self._tensor_to_vae_latent(cond_video_tensor.to(self.weight_dtype))
cond_latents = cond_latents / self.vae.config.scaling_factor
if mask_cond_mode == "vae":
mask_latents = self._tensor_to_vae_latent(mask_video_tensor.to(self.weight_dtype))
mask_latents = mask_latents / self.vae.config.scaling_factor
elif mask_cond_mode == "interpolate":
target_shape = cond_latents.shape[-2:]
b, t, c, h, w = mask_video_tensor.shape
mask_video_reshaped = rearrange(mask_video_tensor, "b t c h w -> (b t) c h w")
interpolated_mask = F.interpolate(mask_video_reshaped, size=target_shape, mode='bilinear',
align_corners=False)
mask_latents = rearrange(interpolated_mask, "(b t) c h w -> b t c h w", b=b)
else:
raise ValueError(f"Unknown mask_cond_mode: {mask_cond_mode}")
# --- 4. Run UNet Single-Step Inference ---
generator = torch.Generator(device=self.device).manual_seed(seed)
noisy_latents = torch.randn(cond_latents.shape, generator=generator, device=self.device,
dtype=self.weight_dtype)
timesteps = torch.full((1,), 1.0, device=self.device, dtype=torch.long)
added_time_ids = self._get_add_time_ids(fps, motion_bucket_id, noise_aug_strength, batch_size=1)
unet_input = torch.cat([noisy_latents, cond_latents, mask_latents], dim=2)
pred_latents = self.unet(unet_input, timesteps, encoder_hidden_states, added_time_ids=added_time_ids).sample
# --- 5. Decode Latents to Video Frames ---
pred_latents = (1 / self.vae.config.scaling_factor) * pred_latents.squeeze(0)
frames = []
# Process in chunks to avoid VRAM issues, especially for long videos
for i in range(0, pred_latents.shape[0], 8):
chunk = pred_latents[i: i + 8]
decoded_chunk = self.vae.decode(chunk, num_frames=chunk.shape[0]).sample
frames.append(decoded_chunk)
video_tensor = torch.cat(frames, dim=0)
video_tensor = (video_tensor / 2.0 + 0.5).clamp(0, 1).mean(dim=1, keepdim=True).repeat(1, 3, 1, 1)
# Return a list of PIL images
return [transforms.ToPILImage()(frame) for frame in video_tensor]
def _pil_to_tensor(self, frames: list[Image.Image]):
"""Converts a list of PIL images to a normalized video tensor."""
video_tensor = torch.stack([transforms.ToTensor()(f) for f in frames]).unsqueeze(0)
return video_tensor * 2.0 - 1.0
def _tensor_to_vae_latent(self, t: torch.Tensor):
"""Encodes a video tensor into the VAE's latent space."""
video_length = t.shape[1]
t = rearrange(t, "b f c h w -> (b f) c h w")
latents = self.vae.encode(t).latent_dist.sample()
latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length)
return latents * self.vae.config.scaling_factor
def _get_add_time_ids(self, fps, motion_bucket_id, noise_aug_strength, batch_size):
"""Creates the additional time IDs for conditioning the UNet."""
add_time_ids_list = [fps, motion_bucket_id, noise_aug_strength]
passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids_list)
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created.")
add_time_ids = torch.tensor([add_time_ids_list], dtype=self.weight_dtype, device=self.device)
return add_time_ids.repeat(batch_size, 1)
def _resize_with_antialiasing(self, input_tensor, size, interpolation="bicubic", align_corners=True):
"""
Resizes a tensor with anti-aliasing for CLIP input, mirroring k-diffusion.
This is a direct copy of the helper function from your original scripts.
"""
h, w = input_tensor.shape[-2:]
factors = (h / size[0], w / size[1])
sigmas = (max((factors[0] - 1.0) / 2.0, 0.001), max((factors[1] - 1.0) / 2.0, 0.001))
ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
if (ks[0] % 2) == 0: ks = ks[0] + 1, ks[1]
if (ks[1] % 2) == 0: ks = ks[0], ks[1] + 1
def _compute_padding(kernel_size):
computed = [k - 1 for k in kernel_size]
out_padding = 2 * len(kernel_size) * [0]
for i in range(len(kernel_size)):
computed_tmp = computed[-(i + 1)]
pad_front = computed_tmp // 2
pad_rear = computed_tmp - pad_front
out_padding[2 * i + 0] = pad_front
out_padding[2 * i + 1] = pad_rear
return out_padding
def _filter2d(input_tensor, kernel):
b, c, h, w = input_tensor.shape
tmp_kernel = kernel[:, None, ...].to(device=input_tensor.device, dtype=input_tensor.dtype)
tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
height, width = tmp_kernel.shape[-2:]
padding_shape = _compute_padding([height, width])
input_tensor_padded = F.pad(input_tensor, padding_shape, mode="reflect")
tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
input_tensor_padded = input_tensor_padded.view(-1, tmp_kernel.size(0), input_tensor_padded.size(-2),
input_tensor_padded.size(-1))
output = F.conv2d(input_tensor_padded, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
return output.view(b, c, h, w)
def _gaussian(window_size, sigma):
if isinstance(sigma, float):
sigma = torch.tensor([[sigma]])
x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(
sigma.shape[0], -1)
if window_size % 2 == 0:
x = x + 0.5
gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
return gauss / gauss.sum(-1, keepdim=True)
def _gaussian_blur2d(input_tensor, kernel_size, sigma):
if isinstance(sigma, tuple):
sigma = torch.tensor([sigma], dtype=input_tensor.dtype)
else:
sigma = sigma.to(dtype=input_tensor.dtype)
ky, kx = int(kernel_size[0]), int(kernel_size[1])
bs = sigma.shape[0]
kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
out_x = _filter2d(input_tensor, kernel_x[..., None, :])
return _filter2d(out_x, kernel_y[..., None])
blurred_input = _gaussian_blur2d(input_tensor, ks, sigmas)
return F.interpolate(blurred_input, size=size, mode=interpolation, align_corners=align_corners)