| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| LTX-2 Audio-to-Video Pipeline with Video Conditioning Support |
| |
| This is a modified version of the LTX2AudioToVideoPipeline that adds support for |
| video conditioning, enabling avatar/face-swap generation workflows. |
| |
| Usage: |
| pipe = DiffusionPipeline.from_pretrained( |
| "rootonchair/LTX-2-19b-distilled", |
| custom_pipeline="path/to/this/file", |
| torch_dtype=torch.bfloat16 |
| ) |
| |
| # With video conditioning (for avatar/face-swap): |
| video, audio = pipe( |
| image=face_image, # The face/appearance to use |
| video=reference_video, # Video for motion conditioning |
| audio="path/to/audio.wav", # Audio (or extracted from video) |
| prompt="head_swap, a person speaking...", |
| ... |
| ) |
| """ |
|
|
| import copy |
| import inspect |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
|
|
| import numpy as np |
| import torch |
| import torchaudio |
| import torchaudio.transforms as T |
| from PIL import Image |
| from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast |
|
|
| from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback |
| from diffusers.image_processor import PipelineImageInput |
| from diffusers.loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin |
| from diffusers.models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video |
| from diffusers.models.transformers import LTX2VideoTransformer3DModel |
| from diffusers.schedulers import FlowMatchEulerDiscreteScheduler |
| from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring |
| from diffusers.utils.torch_utils import randn_tensor |
| from diffusers.video_processor import VideoProcessor |
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline |
| from diffusers.pipelines.ltx2.connectors import LTX2TextConnectors |
| from diffusers.pipelines.ltx2.pipeline_output import LTX2PipelineOutput |
| from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder |
|
|
|
|
| if is_torch_xla_available(): |
| import torch_xla.core.xla_model as xm |
| XLA_AVAILABLE = True |
| else: |
| XLA_AVAILABLE = False |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| EXAMPLE_DOC_STRING = """ |
| Examples: |
| ```py |
| >>> import torch |
| >>> from diffusers import DiffusionPipeline |
| >>> from diffusers.utils import load_image |
| |
| >>> pipe = DiffusionPipeline.from_pretrained( |
| ... "rootonchair/LTX-2-19b-distilled", |
| ... custom_pipeline="pipeline_ltx2_avatar", |
| ... torch_dtype=torch.bfloat16 |
| ... ) |
| >>> pipe.to("cuda") |
| |
| >>> # Load face swap LoRA |
| >>> pipe.load_lora_weights( |
| ... "Alissonerdx/BFS-Best-Face-Swap-Video", |
| ... weight_name="ltx-2/head_swap_v1_13500_first_frame.safetensors", |
| ... ) |
| >>> pipe.fuse_lora(lora_scale=1.1) |
| |
| >>> face_image = load_image("face.png") |
| >>> video, audio = pipe( |
| ... image=face_image, |
| ... video="reference_video.mp4", # Motion reference |
| ... video_conditioning_strength=1.0, # How strongly to follow motion |
| ... video_conditioning_frame_idx=1, # Frame 0 = face, Frame 1+ = video motion |
| ... audio="reference_video.mp4", # Audio extracted from video |
| ... prompt="head_swap, a person speaking naturally", |
| ... width=512, |
| ... height=768, |
| ... num_frames=121, |
| ... return_dict=False, |
| ... ) |
| ``` |
| """ |
|
|
|
|
| def retrieve_latents( |
| encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" |
| ): |
| if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": |
| return encoder_output.latent_dist.sample(generator) |
| elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": |
| return encoder_output.latent_dist.mode() |
| elif hasattr(encoder_output, "latents"): |
| return encoder_output.latents |
| else: |
| raise AttributeError("Could not access latents of provided encoder_output") |
|
|
|
|
| def calculate_shift( |
| image_seq_len, |
| base_seq_len: int = 256, |
| max_seq_len: int = 4096, |
| base_shift: float = 0.5, |
| max_shift: float = 1.15, |
| ): |
| m = (max_shift - base_shift) / (max_seq_len - base_seq_len) |
| b = base_shift - m * base_seq_len |
| mu = image_seq_len * m + b |
| return mu |
|
|
|
|
| def retrieve_timesteps( |
| scheduler, |
| num_inference_steps: Optional[int] = None, |
| device: Optional[Union[str, torch.device]] = None, |
| timesteps: Optional[List[int]] = None, |
| sigmas: Optional[List[float]] = None, |
| **kwargs, |
| ): |
| if timesteps is not None and sigmas is not None: |
| raise ValueError("Only one of `timesteps` or `sigmas` can be passed.") |
| if timesteps is not None: |
| accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
| if not accepts_timesteps: |
| raise ValueError( |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom timestep schedules." |
| ) |
| scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
| timesteps = scheduler.timesteps |
| num_inference_steps = len(timesteps) |
| elif sigmas is not None: |
| accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
| if not accept_sigmas: |
| raise ValueError( |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom sigmas schedules." |
| ) |
| scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) |
| timesteps = scheduler.timesteps |
| num_inference_steps = len(timesteps) |
| else: |
| scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
| timesteps = scheduler.timesteps |
| return timesteps, num_inference_steps |
|
|
|
|
| def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): |
| std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) |
| std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) |
| noise_pred_rescaled = noise_cfg * (std_text / std_cfg) |
| noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg |
| return noise_cfg |
|
|
|
|
| class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): |
| r""" |
| Pipeline for avatar/face-swap video generation with audio and video conditioning. |
| |
| This pipeline generates video conditioned on: |
| - An input image (the face/appearance to use) |
| - A reference video (for motion/pose conditioning) |
| - Input audio (for lip-sync) |
| |
| This enables avatar generation where the face from the image is animated |
| to match the motion from the reference video and synced to the audio. |
| """ |
|
|
| model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder" |
| _optional_components = [] |
| _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] |
|
|
| def __init__( |
| self, |
| scheduler: FlowMatchEulerDiscreteScheduler, |
| vae: AutoencoderKLLTX2Video, |
| audio_vae: AutoencoderKLLTX2Audio, |
| text_encoder: Gemma3ForConditionalGeneration, |
| tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], |
| connectors: LTX2TextConnectors, |
| transformer: LTX2VideoTransformer3DModel, |
| vocoder: LTX2Vocoder, |
| ): |
| super().__init__() |
|
|
| self.register_modules( |
| vae=vae, |
| audio_vae=audio_vae, |
| text_encoder=text_encoder, |
| tokenizer=tokenizer, |
| connectors=connectors, |
| transformer=transformer, |
| vocoder=vocoder, |
| scheduler=scheduler, |
| ) |
|
|
| self.vae_spatial_compression_ratio = ( |
| self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 |
| ) |
| self.vae_temporal_compression_ratio = ( |
| self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 |
| ) |
| self.audio_vae_mel_compression_ratio = ( |
| self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 |
| ) |
| self.audio_vae_temporal_compression_ratio = ( |
| self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 |
| ) |
| self.transformer_spatial_patch_size = ( |
| self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 |
| ) |
| self.transformer_temporal_patch_size = ( |
| self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 |
| ) |
|
|
| self.audio_sampling_rate = ( |
| self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000 |
| ) |
| self.audio_hop_length = ( |
| self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160 |
| ) |
|
|
| self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio, resample="bilinear") |
| self.tokenizer_max_length = ( |
| self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 |
| ) |
|
|
| |
| |
| def _load_video_frames( |
| self, |
| video: Union[str, List[Image.Image], torch.Tensor], |
| height: int, |
| width: int, |
| num_frames: int, |
| device: torch.device, |
| dtype: torch.dtype, |
| ) -> torch.Tensor: |
| """ |
| Load and preprocess video frames for conditioning. |
| |
| Args: |
| video: Path to video file, list of PIL images, or tensor of frames |
| height: Target height |
| width: Target width |
| num_frames: Number of frames to extract/use |
| device: Target device |
| dtype: Target dtype |
| |
| Returns: |
| Tensor of shape (batch, channels, num_frames, height, width) |
| """ |
| if isinstance(video, str): |
| |
| frames = self._decode_video_file(video, num_frames) |
| elif isinstance(video, list): |
| |
| frames = [np.array(img.convert("RGB")) for img in video] |
| elif isinstance(video, torch.Tensor): |
| |
| if video.ndim == 4: |
| if video.shape[-1] in [1, 3, 4]: |
| frames = [video[i].cpu().numpy() for i in range(video.shape[0])] |
| else: |
| frames = [video[i].permute(1, 2, 0).cpu().numpy() for i in range(video.shape[0])] |
| else: |
| raise ValueError(f"Unexpected video tensor shape: {video.shape}") |
| else: |
| raise TypeError(f"Unsupported video type: {type(video)}") |
| |
| |
| if len(frames) >= num_frames: |
| frames = frames[:num_frames] |
| else: |
| |
| last_frame = frames[-1] |
| while len(frames) < num_frames: |
| frames.append(last_frame) |
| |
| |
| processed_frames = [] |
| for frame in frames: |
| if isinstance(frame, np.ndarray): |
| frame = Image.fromarray(frame.astype(np.uint8)) |
| |
| |
| frame = frame.resize((width, height), Image.LANCZOS) |
| frame = np.array(frame) |
| |
| |
| frame = (frame.astype(np.float32) / 127.5) - 1.0 |
| processed_frames.append(frame) |
| |
| |
| frames_array = np.stack(processed_frames, axis=0) |
| frames_tensor = torch.from_numpy(frames_array).permute(3, 0, 1, 2).unsqueeze(0) |
| |
| return frames_tensor.to(device=device, dtype=dtype) |
| |
| def _decode_video_file(self, video_path: str, max_frames: int) -> List[np.ndarray]: |
| """Decode video file to list of numpy arrays.""" |
| try: |
| import av |
| except ImportError: |
| raise ImportError("Please install av: pip install av") |
| |
| frames = [] |
| container = av.open(video_path) |
| try: |
| video_stream = next(s for s in container.streams if s.type == "video") |
| for frame in container.decode(video_stream): |
| frames.append(frame.to_rgb().to_ndarray()) |
| if len(frames) >= max_frames: |
| break |
| finally: |
| container.close() |
| |
| return frames |
| |
| def _encode_video_conditioning( |
| self, |
| video: torch.Tensor, |
| generator: Optional[torch.Generator] = None, |
| ) -> torch.Tensor: |
| """ |
| Encode video frames through the VAE to get latents. |
| |
| Args: |
| video: Video tensor of shape (batch, channels, frames, height, width) |
| generator: Random generator for sampling |
| |
| Returns: |
| Video latents |
| """ |
| |
| |
| video = video.to(self.vae.dtype) |
| latents = retrieve_latents(self.vae.encode(video), generator, "argmax") |
| return latents |
|
|
| |
| |
| @staticmethod |
| def _pack_text_embeds( |
| text_hidden_states: torch.Tensor, |
| sequence_lengths: torch.Tensor, |
| device: Union[str, torch.device], |
| padding_side: str = "left", |
| scale_factor: int = 8, |
| eps: float = 1e-6, |
| ) -> torch.Tensor: |
| batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape |
| original_dtype = text_hidden_states.dtype |
|
|
| token_indices = torch.arange(seq_len, device=device).unsqueeze(0) |
| if padding_side == "right": |
| mask = token_indices < sequence_lengths[:, None] |
| elif padding_side == "left": |
| start_indices = seq_len - sequence_lengths[:, None] |
| mask = token_indices >= start_indices |
| else: |
| raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") |
| mask = mask[:, :, None, None] |
|
|
| masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) |
| num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) |
| masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) |
|
|
| x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) |
| x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) |
|
|
| normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) |
| normalized_hidden_states = normalized_hidden_states * scale_factor |
|
|
| normalized_hidden_states = normalized_hidden_states.flatten(2) |
| mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) |
| normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) |
| normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) |
| return normalized_hidden_states |
|
|
| def _get_gemma_prompt_embeds( |
| self, |
| prompt: Union[str, List[str]], |
| num_videos_per_prompt: int = 1, |
| max_sequence_length: int = 1024, |
| scale_factor: int = 8, |
| device: Optional[torch.device] = None, |
| dtype: Optional[torch.dtype] = None, |
| ): |
| device = device or self._execution_device |
| dtype = dtype or self.text_encoder.dtype |
|
|
| prompt = [prompt] if isinstance(prompt, str) else prompt |
| batch_size = len(prompt) |
|
|
| if getattr(self, "tokenizer", None) is not None: |
| self.tokenizer.padding_side = "left" |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
| prompt = [p.strip() for p in prompt] |
| text_inputs = self.tokenizer( |
| prompt, |
| padding="max_length", |
| max_length=max_sequence_length, |
| truncation=True, |
| add_special_tokens=True, |
| return_tensors="pt", |
| ) |
| text_input_ids = text_inputs.input_ids |
| prompt_attention_mask = text_inputs.attention_mask |
| text_input_ids = text_input_ids.to(device) |
| prompt_attention_mask = prompt_attention_mask.to(device) |
|
|
| text_encoder_outputs = self.text_encoder( |
| input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True |
| ) |
| text_encoder_hidden_states = text_encoder_outputs.hidden_states |
| text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) |
| sequence_lengths = prompt_attention_mask.sum(dim=-1) |
|
|
| prompt_embeds = self._pack_text_embeds( |
| text_encoder_hidden_states, |
| sequence_lengths, |
| device=device, |
| padding_side=self.tokenizer.padding_side, |
| scale_factor=scale_factor, |
| ) |
| prompt_embeds = prompt_embeds.to(dtype=dtype) |
|
|
| _, seq_len, _ = prompt_embeds.shape |
| prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) |
| prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) |
|
|
| prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) |
| prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) |
|
|
| return prompt_embeds, prompt_attention_mask |
|
|
| def encode_prompt( |
| self, |
| prompt: Union[str, List[str]], |
| negative_prompt: Optional[Union[str, List[str]]] = None, |
| do_classifier_free_guidance: bool = True, |
| num_videos_per_prompt: int = 1, |
| prompt_embeds: Optional[torch.Tensor] = None, |
| negative_prompt_embeds: Optional[torch.Tensor] = None, |
| prompt_attention_mask: Optional[torch.Tensor] = None, |
| negative_prompt_attention_mask: Optional[torch.Tensor] = None, |
| max_sequence_length: int = 1024, |
| scale_factor: int = 8, |
| device: Optional[torch.device] = None, |
| dtype: Optional[torch.dtype] = None, |
| ): |
| device = device or self._execution_device |
|
|
| prompt = [prompt] if isinstance(prompt, str) else prompt |
| if prompt is not None: |
| batch_size = len(prompt) |
| else: |
| batch_size = prompt_embeds.shape[0] |
|
|
| if prompt_embeds is None: |
| prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( |
| prompt=prompt, |
| num_videos_per_prompt=num_videos_per_prompt, |
| max_sequence_length=max_sequence_length, |
| scale_factor=scale_factor, |
| device=device, |
| dtype=dtype, |
| ) |
|
|
| if do_classifier_free_guidance and negative_prompt_embeds is None: |
| negative_prompt = negative_prompt or "" |
| negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt |
|
|
| if prompt is not None and type(prompt) is not type(negative_prompt): |
| raise TypeError( |
| f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} != {type(prompt)}." |
| ) |
| elif batch_size != len(negative_prompt): |
| raise ValueError( |
| f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`: {prompt} has batch size {batch_size}." |
| ) |
|
|
| negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( |
| prompt=negative_prompt, |
| num_videos_per_prompt=num_videos_per_prompt, |
| max_sequence_length=max_sequence_length, |
| scale_factor=scale_factor, |
| device=device, |
| dtype=dtype, |
| ) |
|
|
| return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask |
|
|
| def check_inputs( |
| self, |
| prompt, |
| height, |
| width, |
| callback_on_step_end_tensor_inputs=None, |
| prompt_embeds=None, |
| negative_prompt_embeds=None, |
| prompt_attention_mask=None, |
| negative_prompt_attention_mask=None, |
| ): |
| if height % 32 != 0 or width % 32 != 0: |
| raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") |
|
|
| if callback_on_step_end_tensor_inputs is not None and not all( |
| k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs |
| ): |
| raise ValueError( |
| f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}" |
| ) |
|
|
| if prompt is not None and prompt_embeds is not None: |
| raise ValueError("Cannot forward both `prompt` and `prompt_embeds`.") |
| elif prompt is None and prompt_embeds is None: |
| raise ValueError("Provide either `prompt` or `prompt_embeds`.") |
| elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): |
| raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") |
|
|
| if prompt_embeds is not None and prompt_attention_mask is None: |
| raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") |
|
|
| if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: |
| raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") |
|
|
| |
| |
| @staticmethod |
| def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: |
| batch_size, num_channels, num_frames, height, width = latents.shape |
| post_patch_num_frames = num_frames // patch_size_t |
| post_patch_height = height // patch_size |
| post_patch_width = width // patch_size |
| latents = latents.reshape( |
| batch_size, |
| -1, |
| post_patch_num_frames, |
| patch_size_t, |
| post_patch_height, |
| patch_size, |
| post_patch_width, |
| patch_size, |
| ) |
| latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) |
| return latents |
|
|
| @staticmethod |
| def _unpack_latents( |
| latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 |
| ) -> torch.Tensor: |
| batch_size = latents.size(0) |
| latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) |
| latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) |
| return latents |
|
|
| @staticmethod |
| def _normalize_latents( |
| latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 |
| ) -> torch.Tensor: |
| latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) |
| latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) |
| latents = (latents - latents_mean) * scaling_factor / latents_std |
| return latents |
|
|
| @staticmethod |
| def _denormalize_latents( |
| latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 |
| ) -> torch.Tensor: |
| latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) |
| latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) |
| latents = latents * latents_std / scaling_factor + latents_mean |
| return latents |
|
|
| |
| |
| @staticmethod |
| def _pack_audio_latents( |
| latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None |
| ) -> torch.Tensor: |
| if patch_size is not None and patch_size_t is not None: |
| batch_size, num_channels, latent_length, latent_mel_bins = latents.shape |
| post_patch_latent_length = latent_length / patch_size_t |
| post_patch_mel_bins = latent_mel_bins / patch_size |
| latents = latents.reshape( |
| batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size |
| ) |
| latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) |
| else: |
| latents = latents.transpose(1, 2).flatten(2, 3) |
| return latents |
|
|
| @staticmethod |
| def _unpack_audio_latents( |
| latents: torch.Tensor, |
| latent_length: int, |
| num_mel_bins: int, |
| patch_size: Optional[int] = None, |
| patch_size_t: Optional[int] = None, |
| ) -> torch.Tensor: |
| if patch_size is not None and patch_size_t is not None: |
| batch_size = latents.size(0) |
| latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size) |
| latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) |
| else: |
| latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) |
| return latents |
|
|
| @staticmethod |
| def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): |
| latents_mean = latents_mean.to(latents.device, latents.dtype) |
| latents_std = latents_std.to(latents.device, latents.dtype) |
| return (latents * latents_std) + latents_mean |
|
|
| @staticmethod |
| def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): |
| latents_mean = latents_mean.to(latents.device, latents.dtype) |
| latents_std = latents_std.to(latents.device, latents.dtype) |
| return (latents - latents_mean) / latents_std |
|
|
| @staticmethod |
| def _patchify_audio_latents(latents: torch.Tensor) -> torch.Tensor: |
| batch, channels, time, freq = latents.shape |
| return latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq) |
|
|
| @staticmethod |
| def _unpatchify_audio_latents(latents: torch.Tensor, channels: int, freq: int) -> torch.Tensor: |
| batch, time, _ = latents.shape |
| return latents.reshape(batch, time, channels, freq).permute(0, 2, 1, 3) |
|
|
| def _preprocess_audio(self, audio: Union[str, torch.Tensor], target_sample_rate: int) -> torch.Tensor: |
| """Process audio to mel spectrogram.""" |
| if isinstance(audio, str): |
| waveform, sr = torchaudio.load(audio) |
| else: |
| waveform = audio |
| sr = target_sample_rate |
|
|
| if sr != target_sample_rate: |
| waveform = torchaudio.functional.resample(waveform, sr, target_sample_rate) |
|
|
| if waveform.shape[0] == 1: |
| waveform = waveform.repeat(2, 1) |
| elif waveform.shape[0] > 2: |
| waveform = waveform[:2, :] |
|
|
| waveform = waveform.unsqueeze(0) |
|
|
| n_fft = 1024 |
| mel_transform = T.MelSpectrogram( |
| sample_rate=target_sample_rate, |
| n_fft=n_fft, |
| win_length=n_fft, |
| hop_length=self.audio_hop_length, |
| f_min=0.0, |
| f_max=target_sample_rate / 2.0, |
| n_mels=self.audio_vae.config.mel_bins, |
| window_fn=torch.hann_window, |
| center=True, |
| pad_mode="reflect", |
| power=1.0, |
| mel_scale="slaney", |
| norm="slaney", |
| ) |
|
|
| mel_spec = mel_transform(waveform) |
| mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5)) |
| mel_spec = mel_spec.permute(0, 1, 3, 2).contiguous() |
|
|
| return mel_spec |
|
|
| |
| |
| def prepare_latents( |
| self, |
| image: Optional[torch.Tensor] = None, |
| video: Optional[torch.Tensor] = None, |
| video_conditioning_strength: float = 1.0, |
| video_conditioning_frame_idx: int = 1, |
| batch_size: int = 1, |
| num_channels_latents: int = 128, |
| height: int = 512, |
| width: int = 704, |
| num_frames: int = 161, |
| dtype: Optional[torch.dtype] = None, |
| device: Optional[torch.device] = None, |
| generator: Optional[torch.Generator] = None, |
| latents: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Prepare latents for generation with optional video conditioning. |
| |
| Args: |
| image: Input image for frame 0 conditioning |
| video: Video tensor for motion conditioning |
| video_conditioning_strength: Strength of video conditioning (0-1) |
| video_conditioning_frame_idx: Frame index where video conditioning starts. |
| - 0: Video conditioning replaces all frames including frame 0 |
| - 1: Frame 0 is image-conditioned, frames 1+ are video-conditioned (default for face-swap) |
| - N: Frames 0 to N-1 are image/noise, frames N+ are video-conditioned |
| ... other args ... |
| """ |
| latent_height = height // self.vae_spatial_compression_ratio |
| latent_width = width // self.vae_spatial_compression_ratio |
| latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 |
|
|
| shape = (batch_size, num_channels_latents, latent_num_frames, latent_height, latent_width) |
| mask_shape = (batch_size, 1, latent_num_frames, latent_height, latent_width) |
|
|
| if latents is not None: |
| conditioning_mask = latents.new_zeros(mask_shape) |
| conditioning_mask[:, :, 0] = 1.0 |
| conditioning_mask = self._pack_latents( |
| conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size |
| ).squeeze(-1) |
| return latents.to(device=device, dtype=dtype), conditioning_mask |
|
|
| |
| conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype) |
| |
| |
| init_latents = torch.zeros(shape, device=device, dtype=dtype) |
| |
| |
| if video is not None: |
| |
| video_latents = self._encode_video_conditioning(video, generator) |
| video_latents = self._normalize_latents(video_latents, self.vae.latents_mean, self.vae.latents_std) |
| |
| |
| if video_latents.shape[2] < latent_num_frames: |
| |
| pad_frames = latent_num_frames - video_latents.shape[2] |
| last_frame = video_latents[:, :, -1:, :, :] |
| video_latents = torch.cat([video_latents, last_frame.repeat(1, 1, pad_frames, 1, 1)], dim=2) |
| elif video_latents.shape[2] > latent_num_frames: |
| video_latents = video_latents[:, :, :latent_num_frames, :, :] |
| |
| |
| |
| latent_video_start_idx = video_conditioning_frame_idx // self.vae_temporal_compression_ratio |
| latent_video_start_idx = min(latent_video_start_idx, latent_num_frames - 1) |
| |
| |
| |
| num_video_frames_to_use = latent_num_frames - latent_video_start_idx |
| init_latents[:, :, latent_video_start_idx:, :, :] = video_latents[:, :, :num_video_frames_to_use, :, :] |
| |
| |
| |
| conditioning_mask[:, :, latent_video_start_idx:] = video_conditioning_strength |
| |
| |
| if image is not None: |
| if isinstance(generator, list): |
| image_latents = [ |
| retrieve_latents(self.vae.encode(image[i].unsqueeze(0).unsqueeze(2)), generator[i], "argmax") |
| for i in range(batch_size) |
| ] |
| else: |
| image_latents = [ |
| retrieve_latents(self.vae.encode(img.unsqueeze(0).unsqueeze(2)), generator, "argmax") |
| for img in image |
| ] |
| image_latents = torch.cat(image_latents, dim=0).to(dtype) |
| image_latents = self._normalize_latents(image_latents, self.vae.latents_mean, self.vae.latents_std) |
| |
| |
| init_latents[:, :, 0:1, :, :] = image_latents |
| |
| conditioning_mask[:, :, 0] = 1.0 |
| |
| |
| if video is None: |
| init_latents = image_latents.repeat(1, 1, latent_num_frames, 1, 1) |
|
|
| |
| noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
| |
| |
| latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask) |
|
|
| |
| conditioning_mask = self._pack_latents( |
| conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size |
| ).squeeze(-1) |
| latents = self._pack_latents( |
| latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size |
| ) |
|
|
| return latents, conditioning_mask |
|
|
| def prepare_audio_latents( |
| self, |
| batch_size: int = 1, |
| num_channels_latents: int = 8, |
| num_mel_bins: int = 64, |
| num_frames: int = 121, |
| frame_rate: float = 25.0, |
| sampling_rate: int = 16000, |
| hop_length: int = 160, |
| dtype: Optional[torch.dtype] = None, |
| device: Optional[torch.device] = None, |
| generator: Optional[torch.Generator] = None, |
| audio_input: Optional[Union[str, torch.Tensor]] = None, |
| latents: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]: |
| duration_s = num_frames / frame_rate |
| latents_per_second = ( |
| float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) |
| ) |
| target_length = round(duration_s * latents_per_second) |
|
|
| if latents is not None: |
| return latents.to(device=device, dtype=dtype), target_length, None |
|
|
| latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio |
|
|
| if audio_input is not None: |
| mel_spec = self._preprocess_audio(audio_input, sampling_rate).to(device=device) |
| mel_spec = mel_spec.to(dtype=self.audio_vae.dtype) |
| init_latents = self.audio_vae.encode(mel_spec).latent_dist.sample(generator) |
| init_latents = init_latents.to(dtype=dtype) |
|
|
| latent_channels = init_latents.shape[1] |
| latent_freq = init_latents.shape[3] |
| init_latents_patched = self._patchify_audio_latents(init_latents) |
| init_latents_patched = self._normalize_audio_latents( |
| init_latents_patched, self.audio_vae.latents_mean, self.audio_vae.latents_std |
| ) |
| init_latents = self._unpatchify_audio_latents(init_latents_patched, latent_channels, latent_freq) |
|
|
| current_len = init_latents.shape[2] |
| if current_len < target_length: |
| padding = target_length - current_len |
| init_latents = torch.nn.functional.pad(init_latents, (0, 0, 0, padding)) |
| elif current_len > target_length: |
| init_latents = init_latents[:, :, :target_length, :] |
|
|
| noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype) |
|
|
| if init_latents.shape[0] != batch_size: |
| init_latents = init_latents.repeat(batch_size, 1, 1, 1) |
| noise = noise.repeat(batch_size, 1, 1, 1) |
|
|
| packed_noise = self._pack_audio_latents(noise) |
|
|
| return packed_noise, target_length, init_latents |
|
|
| shape = (batch_size, num_channels_latents, target_length, latent_mel_bins) |
| latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
| latents = self._pack_audio_latents(latents) |
|
|
| return latents, target_length, None |
|
|
| |
| |
| @property |
| def guidance_scale(self): |
| return self._guidance_scale |
|
|
| @property |
| def guidance_rescale(self): |
| return self._guidance_rescale |
|
|
| @property |
| def do_classifier_free_guidance(self): |
| return self._guidance_scale > 1.0 |
|
|
| @property |
| def num_timesteps(self): |
| return self._num_timesteps |
|
|
| @property |
| def current_timestep(self): |
| return self._current_timestep |
|
|
| @property |
| def attention_kwargs(self): |
| return self._attention_kwargs |
|
|
| @property |
| def interrupt(self): |
| return self._interrupt |
|
|
| def _get_audio_duration(self, audio: Union[str, torch.Tensor], sample_rate: int) -> float: |
| if isinstance(audio, str): |
| info = torchaudio.info(audio) |
| return info.num_frames / info.sample_rate |
| else: |
| num_samples = audio.shape[-1] |
| return num_samples / sample_rate |
|
|
| |
| |
| @torch.no_grad() |
| @replace_example_docstring(EXAMPLE_DOC_STRING) |
| def __call__( |
| self, |
| image: PipelineImageInput = None, |
| video: Optional[Union[str, List[Image.Image], torch.Tensor]] = None, |
| video_conditioning_strength: float = 1.0, |
| video_conditioning_frame_idx: int = 1, |
| audio: Optional[Union[str, torch.Tensor]] = None, |
| prompt: Union[str, List[str]] = None, |
| negative_prompt: Optional[Union[str, List[str]]] = None, |
| height: int = 512, |
| width: int = 768, |
| num_frames: Optional[int] = None, |
| max_frames: int = 257, |
| frame_rate: float = 24.0, |
| num_inference_steps: int = 40, |
| timesteps: List[int] = None, |
| sigmas: Optional[List[float]] = None, |
| guidance_scale: float = 4.0, |
| guidance_rescale: float = 0.0, |
| num_videos_per_prompt: Optional[int] = 1, |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| latents: Optional[torch.Tensor] = None, |
| audio_latents: Optional[torch.Tensor] = None, |
| prompt_embeds: Optional[torch.Tensor] = None, |
| prompt_attention_mask: Optional[torch.Tensor] = None, |
| negative_prompt_embeds: Optional[torch.Tensor] = None, |
| negative_prompt_attention_mask: Optional[torch.Tensor] = None, |
| decode_timestep: Union[float, List[float]] = 0.0, |
| decode_noise_scale: Optional[Union[float, List[float]]] = None, |
| output_type: Optional[str] = "pil", |
| return_dict: bool = True, |
| attention_kwargs: Optional[Dict[str, Any]] = None, |
| callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
| callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
| max_sequence_length: int = 1024, |
| ): |
| r""" |
| Generate avatar video with audio and optional video conditioning. |
| |
| Args: |
| image (`PipelineImageInput`): |
| The input image (face/appearance) to condition frame 0. |
| video (`str`, `List[PIL.Image]`, or `torch.Tensor`, *optional*): |
| Reference video for motion conditioning. Can be: |
| - Path to a video file |
| - List of PIL Images |
| - Tensor of shape (F, H, W, C) or (F, C, H, W) |
| video_conditioning_strength (`float`, *optional*, defaults to 1.0): |
| How strongly to condition on the reference video (0.0-1.0). |
| 1.0 = fully conditioned, 0.0 = no conditioning. |
| video_conditioning_frame_idx (`int`, *optional*, defaults to 1): |
| Frame index where video conditioning starts (in pixel/frame space). |
| - 0: Video conditioning replaces all frames including frame 0 |
| - 1: Frame 0 is image-conditioned, frames 1+ are video-conditioned (default for face-swap) |
| - N: Frames 0 to N-1 are image/noise, frames N+ are video-conditioned |
| audio (`str` or `torch.Tensor`, *optional*): |
| Audio for lip-sync. Can be path to audio/video file or waveform tensor. |
| prompt (`str` or `List[str]`, *optional*): |
| Text prompt. For face-swap, include "head_swap" trigger. |
| |
| Examples: |
| |
| Returns: |
| [`LTX2PipelineOutput`] or `tuple`: Generated video and audio. |
| """ |
|
|
| if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): |
| callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs |
|
|
| |
| if num_frames is None: |
| if audio is not None: |
| audio_duration = self._get_audio_duration(audio, self.audio_sampling_rate) |
| calculated_frames = int(audio_duration * frame_rate) + 1 |
| num_frames = min(calculated_frames, max_frames) |
| num_frames = ((num_frames - 1) // self.vae_temporal_compression_ratio) * self.vae_temporal_compression_ratio + 1 |
| num_frames = max(num_frames, 9) |
| logger.info(f"Audio duration: {audio_duration:.2f}s -> num_frames: {num_frames}") |
| else: |
| num_frames = 121 |
|
|
| self.check_inputs( |
| prompt=prompt, |
| height=height, |
| width=width, |
| callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, |
| prompt_embeds=prompt_embeds, |
| negative_prompt_embeds=negative_prompt_embeds, |
| prompt_attention_mask=prompt_attention_mask, |
| negative_prompt_attention_mask=negative_prompt_attention_mask, |
| ) |
|
|
| self._guidance_scale = guidance_scale |
| self._guidance_rescale = guidance_rescale |
| self._attention_kwargs = attention_kwargs |
| self._interrupt = False |
| self._current_timestep = None |
|
|
| if prompt is not None and isinstance(prompt, str): |
| batch_size = 1 |
| elif prompt is not None and isinstance(prompt, list): |
| batch_size = len(prompt) |
| else: |
| batch_size = prompt_embeds.shape[0] |
|
|
| device = self._execution_device |
|
|
| |
| ( |
| prompt_embeds, |
| prompt_attention_mask, |
| negative_prompt_embeds, |
| negative_prompt_attention_mask, |
| ) = self.encode_prompt( |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| do_classifier_free_guidance=self.do_classifier_free_guidance, |
| num_videos_per_prompt=num_videos_per_prompt, |
| prompt_embeds=prompt_embeds, |
| negative_prompt_embeds=negative_prompt_embeds, |
| prompt_attention_mask=prompt_attention_mask, |
| negative_prompt_attention_mask=negative_prompt_attention_mask, |
| max_sequence_length=max_sequence_length, |
| device=device, |
| ) |
| if self.do_classifier_free_guidance: |
| prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) |
| prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) |
|
|
| additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 |
| connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( |
| prompt_embeds, additive_attention_mask, additive_mask=True |
| ) |
|
|
| |
| if latents is None and image is not None: |
| image = self.video_processor.preprocess(image, height=height, width=width) |
| image = image.to(device=device, dtype=prompt_embeds.dtype) |
|
|
| |
| video_tensor = None |
| if video is not None: |
| video_tensor = self._load_video_frames( |
| video=video, |
| height=height, |
| width=width, |
| num_frames=num_frames, |
| device=device, |
| dtype=prompt_embeds.dtype, |
| ) |
|
|
| |
| num_channels_latents = self.transformer.config.in_channels |
| latents, conditioning_mask = self.prepare_latents( |
| image=image, |
| video=video_tensor, |
| video_conditioning_strength=video_conditioning_strength, |
| video_conditioning_frame_idx=video_conditioning_frame_idx, |
| batch_size=batch_size * num_videos_per_prompt, |
| num_channels_latents=num_channels_latents, |
| height=height, |
| width=width, |
| num_frames=num_frames, |
| dtype=torch.float32, |
| device=device, |
| generator=generator, |
| latents=latents, |
| ) |
| if self.do_classifier_free_guidance: |
| conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) |
|
|
| |
| num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 |
| latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio |
|
|
| num_channels_latents_audio = ( |
| self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 |
| ) |
|
|
| audio_latents, audio_num_frames, clean_audio_latents = self.prepare_audio_latents( |
| batch_size * num_videos_per_prompt, |
| num_channels_latents=num_channels_latents_audio, |
| num_mel_bins=num_mel_bins, |
| num_frames=num_frames, |
| frame_rate=frame_rate, |
| sampling_rate=self.audio_sampling_rate, |
| hop_length=self.audio_hop_length, |
| dtype=torch.float32, |
| device=device, |
| generator=generator, |
| latents=audio_latents, |
| audio_input=audio, |
| ) |
|
|
| packed_clean_audio_latents = None |
| if clean_audio_latents is not None: |
| packed_clean_audio_latents = self._pack_audio_latents(clean_audio_latents) |
|
|
| latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 |
| latent_height = height // self.vae_spatial_compression_ratio |
| latent_width = width // self.vae_spatial_compression_ratio |
| video_sequence_length = latent_num_frames * latent_height * latent_width |
|
|
| if sigmas is None: |
| sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) |
|
|
| mu = calculate_shift( |
| video_sequence_length, |
| self.scheduler.config.get("base_image_seq_len", 1024), |
| self.scheduler.config.get("max_image_seq_len", 4096), |
| self.scheduler.config.get("base_shift", 0.95), |
| self.scheduler.config.get("max_shift", 2.05), |
| ) |
|
|
| audio_scheduler = copy.deepcopy(self.scheduler) |
| _, _ = retrieve_timesteps( |
| audio_scheduler, |
| num_inference_steps, |
| device, |
| timesteps, |
| sigmas=sigmas, |
| mu=mu, |
| ) |
| timesteps, num_inference_steps = retrieve_timesteps( |
| self.scheduler, |
| num_inference_steps, |
| device, |
| timesteps, |
| sigmas=sigmas, |
| mu=mu, |
| ) |
| num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) |
| self._num_timesteps = len(timesteps) |
|
|
| rope_interpolation_scale = ( |
| self.vae_temporal_compression_ratio / frame_rate, |
| self.vae_spatial_compression_ratio, |
| self.vae_spatial_compression_ratio, |
| ) |
| video_coords = self.transformer.rope.prepare_video_coords( |
| latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate |
| ) |
| audio_coords = self.transformer.audio_rope.prepare_audio_coords( |
| audio_latents.shape[0], audio_num_frames, audio_latents.device |
| ) |
|
|
| |
| with self.progress_bar(total=num_inference_steps) as progress_bar: |
| for i, t in enumerate(timesteps): |
| if self.interrupt: |
| continue |
|
|
| self._current_timestep = t |
|
|
| if packed_clean_audio_latents is not None: |
| audio_latents_input = packed_clean_audio_latents.to(dtype=prompt_embeds.dtype) |
| else: |
| audio_latents_input = audio_latents.to(dtype=prompt_embeds.dtype) |
|
|
| latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents |
| latent_model_input = latent_model_input.to(prompt_embeds.dtype) |
| audio_latent_model_input = ( |
| torch.cat([audio_latents_input] * 2) if self.do_classifier_free_guidance else audio_latents_input |
| ) |
| audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) |
|
|
| timestep = t.expand(latent_model_input.shape[0]) |
| video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) |
|
|
| if packed_clean_audio_latents is not None: |
| audio_timestep = torch.zeros_like(timestep) |
| else: |
| audio_timestep = timestep |
|
|
| with self.transformer.cache_context("cond_uncond"): |
| noise_pred_video, noise_pred_audio = self.transformer( |
| hidden_states=latent_model_input, |
| audio_hidden_states=audio_latent_model_input, |
| encoder_hidden_states=connector_prompt_embeds, |
| audio_encoder_hidden_states=connector_audio_prompt_embeds, |
| timestep=video_timestep, |
| audio_timestep=audio_timestep, |
| encoder_attention_mask=connector_attention_mask, |
| audio_encoder_attention_mask=connector_attention_mask, |
| num_frames=latent_num_frames, |
| height=latent_height, |
| width=latent_width, |
| fps=frame_rate, |
| audio_num_frames=audio_num_frames, |
| video_coords=video_coords, |
| audio_coords=audio_coords, |
| attention_kwargs=attention_kwargs, |
| return_dict=False, |
| ) |
| noise_pred_video = noise_pred_video.float() |
| noise_pred_audio = noise_pred_audio.float() |
|
|
| if self.do_classifier_free_guidance: |
| noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) |
| noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( |
| noise_pred_video_text - noise_pred_video_uncond |
| ) |
|
|
| noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) |
| noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( |
| noise_pred_audio_text - noise_pred_audio_uncond |
| ) |
|
|
| if self.guidance_rescale > 0: |
| noise_pred_video = rescale_noise_cfg( |
| noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale |
| ) |
| noise_pred_audio = rescale_noise_cfg( |
| noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale |
| ) |
|
|
| noise_pred_video = self._unpack_latents( |
| noise_pred_video, |
| latent_num_frames, |
| latent_height, |
| latent_width, |
| self.transformer_spatial_patch_size, |
| self.transformer_temporal_patch_size, |
| ) |
| latents = self._unpack_latents( |
| latents, |
| latent_num_frames, |
| latent_height, |
| latent_width, |
| self.transformer_spatial_patch_size, |
| self.transformer_temporal_patch_size, |
| ) |
|
|
| noise_pred_video = noise_pred_video[:, :, 1:] |
| noise_latents = latents[:, :, 1:] |
| pred_latents = self.scheduler.step(noise_pred_video, t, noise_latents, return_dict=False)[0] |
|
|
| latents = torch.cat([latents[:, :, :1], pred_latents], dim=2) |
| latents = self._pack_latents( |
| latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size |
| ) |
|
|
| if packed_clean_audio_latents is None: |
| audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] |
|
|
| if callback_on_step_end is not None: |
| callback_kwargs = {} |
| for k in callback_on_step_end_tensor_inputs: |
| callback_kwargs[k] = locals()[k] |
| callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) |
| latents = callback_outputs.pop("latents", latents) |
| prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) |
|
|
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
| progress_bar.update() |
|
|
| if XLA_AVAILABLE: |
| xm.mark_step() |
|
|
| |
| latents = self._unpack_latents( |
| latents, |
| latent_num_frames, |
| latent_height, |
| latent_width, |
| self.transformer_spatial_patch_size, |
| self.transformer_temporal_patch_size, |
| ) |
| latents = self._denormalize_latents( |
| latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor |
| ) |
|
|
| if clean_audio_latents is not None: |
| latent_channels = clean_audio_latents.shape[1] |
| latent_freq = clean_audio_latents.shape[3] |
| audio_patched = self._patchify_audio_latents(clean_audio_latents) |
| audio_patched = self._denormalize_audio_latents( |
| audio_patched, self.audio_vae.latents_mean, self.audio_vae.latents_std |
| ) |
| audio_latents_for_decode = self._unpatchify_audio_latents(audio_patched, latent_channels, latent_freq) |
| else: |
| audio_latents_for_decode = self._denormalize_audio_latents( |
| audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std |
| ) |
| audio_latents_for_decode = self._unpack_audio_latents( |
| audio_latents_for_decode, audio_num_frames, num_mel_bins=latent_mel_bins |
| ) |
|
|
| if output_type == "latent": |
| video = latents |
| audio_output = audio_latents_for_decode |
| else: |
| latents = latents.to(prompt_embeds.dtype) |
|
|
| if not self.vae.config.timestep_conditioning: |
| timestep = None |
| else: |
| noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) |
| if not isinstance(decode_timestep, list): |
| decode_timestep = [decode_timestep] * batch_size |
| if decode_noise_scale is None: |
| decode_noise_scale = decode_timestep |
| elif not isinstance(decode_noise_scale, list): |
| decode_noise_scale = [decode_noise_scale] * batch_size |
|
|
| timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) |
| decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ |
| :, None, None, None, None |
| ] |
| latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise |
|
|
| latents = latents.to(self.vae.dtype) |
| video = self.vae.decode(latents, timestep, return_dict=False)[0] |
| video = self.video_processor.postprocess_video(video, output_type=output_type) |
|
|
| audio_latents_for_decode = audio_latents_for_decode.to(self.audio_vae.dtype) |
| generated_mel_spectrograms = self.audio_vae.decode(audio_latents_for_decode, return_dict=False)[0] |
| audio_output = self.vocoder(generated_mel_spectrograms) |
|
|
| self.maybe_free_model_hooks() |
|
|
| if not return_dict: |
| return (video, audio_output) |
|
|
| return LTX2PipelineOutput(frames=video, audio=audio_output) |