Update pipeline.py

#2
by linoyts HF Staff - opened
Files changed (1) hide show
  1. pipeline.py +513 -532
pipeline.py CHANGED
@@ -12,50 +12,35 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
- """
16
- LTX-2 Audio-to-Video Pipeline with Video Conditioning Support
17
-
18
- This is a modified version of the LTX2AudioToVideoPipeline that adds support for
19
- video conditioning, enabling avatar/face-swap generation workflows.
20
-
21
- Usage:
22
- pipe = DiffusionPipeline.from_pretrained(
23
- "rootonchair/LTX-2-19b-distilled",
24
- custom_pipeline="path/to/this/file",
25
- torch_dtype=torch.bfloat16
26
- )
27
-
28
- # With video conditioning (for avatar/face-swap):
29
- video, audio = pipe(
30
- image=face_image, # The face/appearance to use
31
- video=reference_video, # Video for motion conditioning
32
- audio="path/to/audio.wav", # Audio (or extracted from video)
33
- prompt="head_swap, a person speaking...",
34
- ...
35
- )
36
- """
37
-
38
  import copy
39
  import inspect
 
40
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
41
 
42
  import numpy as np
 
43
  import torch
44
  import torchaudio
45
  import torchaudio.transforms as T
46
- from PIL import Image
47
  from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
48
 
49
  from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
50
- from diffusers.image_processor import PipelineImageInput
51
  from diffusers.loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
52
- from diffusers.models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video
 
 
 
53
  from diffusers.models.transformers import LTX2VideoTransformer3DModel
54
  from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
55
- from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
 
 
 
 
56
  from diffusers.utils.torch_utils import randn_tensor
57
  from diffusers.video_processor import VideoProcessor
58
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
 
59
  from diffusers.pipelines.ltx2.connectors import LTX2TextConnectors
60
  from diffusers.pipelines.ltx2.pipeline_output import LTX2PipelineOutput
61
  from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
@@ -63,51 +48,86 @@ from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
63
 
64
  if is_torch_xla_available():
65
  import torch_xla.core.xla_model as xm
 
66
  XLA_AVAILABLE = True
67
  else:
68
  XLA_AVAILABLE = False
69
 
70
- logger = logging.get_logger(__name__)
71
-
72
 
73
  EXAMPLE_DOC_STRING = """
74
  Examples:
75
  ```py
76
  >>> import torch
77
- >>> from diffusers import DiffusionPipeline
 
 
78
  >>> from diffusers.utils import load_image
79
 
80
- >>> pipe = DiffusionPipeline.from_pretrained(
81
- ... "rootonchair/LTX-2-19b-distilled",
82
- ... custom_pipeline="pipeline_ltx2_avatar",
83
- ... torch_dtype=torch.bfloat16
 
84
  ... )
85
- >>> pipe.to("cuda")
86
-
87
- >>> # Load face swap LoRA
88
- >>> pipe.load_lora_weights(
89
- ... "Alissonerdx/BFS-Best-Face-Swap-Video",
90
- ... weight_name="ltx-2/head_swap_v1_13500_first_frame.safetensors",
91
  ... )
92
- >>> pipe.fuse_lora(lora_scale=1.1)
93
-
94
- >>> face_image = load_image("face.png")
95
- >>> video, audio = pipe(
96
- ... image=face_image,
97
- ... video="reference_video.mp4", # Motion reference
98
- ... video_conditioning_strength=1.0, # How strongly to follow motion
99
- ... video_conditioning_frame_idx=1, # Frame 0 = face, Frame 1+ = video motion
100
- ... audio="reference_video.mp4", # Audio extracted from video
101
- ... prompt="head_swap, a person speaking naturally",
102
- ... width=512,
103
- ... height=768,
 
104
  ... num_frames=121,
 
 
 
 
105
  ... return_dict=False,
106
  ... )
 
 
 
 
 
 
 
 
 
 
107
  ```
108
  """
109
 
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  def retrieve_latents(
112
  encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
113
  ):
@@ -121,6 +141,7 @@ def retrieve_latents(
121
  raise AttributeError("Could not access latents of provided encoder_output")
122
 
123
 
 
124
  def calculate_shift(
125
  image_seq_len,
126
  base_seq_len: int = 256,
@@ -134,6 +155,7 @@ def calculate_shift(
134
  return mu
135
 
136
 
 
137
  def retrieve_timesteps(
138
  scheduler,
139
  num_inference_steps: Optional[int] = None,
@@ -142,13 +164,37 @@ def retrieve_timesteps(
142
  sigmas: Optional[List[float]] = None,
143
  **kwargs,
144
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  if timesteps is not None and sigmas is not None:
146
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
147
  if timesteps is not None:
148
  accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
149
  if not accepts_timesteps:
150
  raise ValueError(
151
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom timestep schedules."
 
152
  )
153
  scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
154
  timesteps = scheduler.timesteps
@@ -157,7 +203,8 @@ def retrieve_timesteps(
157
  accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
158
  if not accept_sigmas:
159
  raise ValueError(
160
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom sigmas schedules."
 
161
  )
162
  scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
163
  timesteps = scheduler.timesteps
@@ -168,7 +215,24 @@ def retrieve_timesteps(
168
  return timesteps, num_inference_steps
169
 
170
 
 
171
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
173
  std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
174
  noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
@@ -176,17 +240,13 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
176
  return noise_cfg
177
 
178
 
179
- class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
180
  r"""
181
- Pipeline for avatar/face-swap video generation with audio and video conditioning.
182
 
183
- This pipeline generates video conditioned on:
184
- - An input image (the face/appearance to use)
185
- - A reference video (for motion/pose conditioning)
186
- - Input audio (for lip-sync)
187
 
188
- This enables avatar generation where the face from the image is animated
189
- to match the motion from the reference video and synced to the audio.
190
  """
191
 
192
  model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder"
@@ -223,6 +283,7 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
223
  self.vae_temporal_compression_ratio = (
224
  self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
225
  )
 
226
  self.audio_vae_mel_compression_ratio = (
227
  self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4
228
  )
@@ -248,123 +309,8 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
248
  self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024
249
  )
250
 
251
- # ==================== Video Conditioning Methods ====================
252
-
253
- def _load_video_frames(
254
- self,
255
- video: Union[str, List[Image.Image], torch.Tensor],
256
- height: int,
257
- width: int,
258
- num_frames: int,
259
- device: torch.device,
260
- dtype: torch.dtype,
261
- ) -> torch.Tensor:
262
- """
263
- Load and preprocess video frames for conditioning.
264
-
265
- Args:
266
- video: Path to video file, list of PIL images, or tensor of frames
267
- height: Target height
268
- width: Target width
269
- num_frames: Number of frames to extract/use
270
- device: Target device
271
- dtype: Target dtype
272
-
273
- Returns:
274
- Tensor of shape (batch, channels, num_frames, height, width)
275
- """
276
- if isinstance(video, str):
277
- # Load from file
278
- frames = self._decode_video_file(video, num_frames)
279
- elif isinstance(video, list):
280
- # List of PIL images
281
- frames = [np.array(img.convert("RGB")) for img in video]
282
- elif isinstance(video, torch.Tensor):
283
- # Already a tensor
284
- if video.ndim == 4: # (F, H, W, C) or (F, C, H, W)
285
- if video.shape[-1] in [1, 3, 4]: # (F, H, W, C)
286
- frames = [video[i].cpu().numpy() for i in range(video.shape[0])]
287
- else: # (F, C, H, W)
288
- frames = [video[i].permute(1, 2, 0).cpu().numpy() for i in range(video.shape[0])]
289
- else:
290
- raise ValueError(f"Unexpected video tensor shape: {video.shape}")
291
- else:
292
- raise TypeError(f"Unsupported video type: {type(video)}")
293
-
294
- # Handle frame count
295
- if len(frames) >= num_frames:
296
- frames = frames[:num_frames]
297
- else:
298
- # Pad with last frame
299
- last_frame = frames[-1]
300
- while len(frames) < num_frames:
301
- frames.append(last_frame)
302
-
303
- # Process each frame
304
- processed_frames = []
305
- for frame in frames:
306
- if isinstance(frame, np.ndarray):
307
- frame = Image.fromarray(frame.astype(np.uint8))
308
-
309
- # Resize to target dimensions
310
- frame = frame.resize((width, height), Image.LANCZOS)
311
- frame = np.array(frame)
312
-
313
- # Normalize to [-1, 1]
314
- frame = (frame.astype(np.float32) / 127.5) - 1.0
315
- processed_frames.append(frame)
316
-
317
- # Stack frames: (F, H, W, C) -> (1, C, F, H, W)
318
- frames_array = np.stack(processed_frames, axis=0) # (F, H, W, C)
319
- frames_tensor = torch.from_numpy(frames_array).permute(3, 0, 1, 2).unsqueeze(0) # (1, C, F, H, W)
320
-
321
- return frames_tensor.to(device=device, dtype=dtype)
322
-
323
- def _decode_video_file(self, video_path: str, max_frames: int) -> List[np.ndarray]:
324
- """Decode video file to list of numpy arrays."""
325
- try:
326
- import av
327
- except ImportError:
328
- raise ImportError("Please install av: pip install av")
329
-
330
- frames = []
331
- container = av.open(video_path)
332
- try:
333
- video_stream = next(s for s in container.streams if s.type == "video")
334
- for frame in container.decode(video_stream):
335
- frames.append(frame.to_rgb().to_ndarray())
336
- if len(frames) >= max_frames:
337
- break
338
- finally:
339
- container.close()
340
-
341
- return frames
342
-
343
- def _encode_video_conditioning(
344
- self,
345
- video: torch.Tensor,
346
- generator: Optional[torch.Generator] = None,
347
- ) -> torch.Tensor:
348
- """
349
- Encode video frames through the VAE to get latents.
350
-
351
- Args:
352
- video: Video tensor of shape (batch, channels, frames, height, width)
353
- generator: Random generator for sampling
354
-
355
- Returns:
356
- Video latents
357
- """
358
- # Encode each frame through VAE
359
- # VAE expects (batch, channels, frames, height, width)
360
- video = video.to(device=self.vae.device, dtype=self.vae.dtype).contiguous()
361
- latents = retrieve_latents(self.vae.encode(video), generator, "argmax")
362
-
363
- return latents
364
-
365
- # ==================== Text Encoding Methods ====================
366
-
367
  @staticmethod
 
368
  def _pack_text_embeds(
369
  text_hidden_states: torch.Tensor,
370
  sequence_lengths: torch.Tensor,
@@ -402,6 +348,7 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
402
  normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
403
  return normalized_hidden_states
404
 
 
405
  def _get_gemma_prompt_embeds(
406
  self,
407
  prompt: Union[str, List[str]],
@@ -461,6 +408,7 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
461
 
462
  return prompt_embeds, prompt_attention_mask
463
 
 
464
  def encode_prompt(
465
  self,
466
  prompt: Union[str, List[str]],
@@ -500,11 +448,14 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
500
 
501
  if prompt is not None and type(prompt) is not type(negative_prompt):
502
  raise TypeError(
503
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} != {type(prompt)}."
 
504
  )
505
  elif batch_size != len(negative_prompt):
506
  raise ValueError(
507
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`: {prompt} has batch size {batch_size}."
 
 
508
  )
509
 
510
  negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
@@ -536,13 +487,18 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
536
  k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
537
  ):
538
  raise ValueError(
539
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}"
540
  )
541
 
542
  if prompt is not None and prompt_embeds is not None:
543
- raise ValueError("Cannot forward both `prompt` and `prompt_embeds`.")
 
 
 
544
  elif prompt is None and prompt_embeds is None:
545
- raise ValueError("Provide either `prompt` or `prompt_embeds`.")
 
 
546
  elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
547
  raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
548
 
@@ -552,9 +508,22 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
552
  if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
553
  raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
554
 
555
- # ==================== Latent Packing/Unpacking ====================
556
-
 
 
 
 
 
 
 
 
 
 
 
 
557
  @staticmethod
 
558
  def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
559
  batch_size, num_channels, num_frames, height, width = latents.shape
560
  post_patch_num_frames = num_frames // patch_size_t
@@ -574,6 +543,7 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
574
  return latents
575
 
576
  @staticmethod
 
577
  def _unpack_latents(
578
  latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
579
  ) -> torch.Tensor:
@@ -592,6 +562,7 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
592
  return latents
593
 
594
  @staticmethod
 
595
  def _denormalize_latents(
596
  latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
597
  ) -> torch.Tensor:
@@ -600,9 +571,17 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
600
  latents = latents * latents_std / scaling_factor + latents_mean
601
  return latents
602
 
603
- # ==================== Audio Latent Methods ====================
604
-
605
  @staticmethod
 
 
 
 
 
 
 
 
 
 
606
  def _pack_audio_latents(
607
  latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None
608
  ) -> torch.Tensor:
@@ -619,6 +598,7 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
619
  return latents
620
 
621
  @staticmethod
 
622
  def _unpack_audio_latents(
623
  latents: torch.Tensor,
624
  latent_length: int,
@@ -635,29 +615,191 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
635
  return latents
636
 
637
  @staticmethod
638
- def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
 
639
  latents_mean = latents_mean.to(latents.device, latents.dtype)
640
  latents_std = latents_std.to(latents.device, latents.dtype)
641
- return (latents * latents_std) + latents_mean
642
 
643
  @staticmethod
644
- def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
 
645
  latents_mean = latents_mean.to(latents.device, latents.dtype)
646
  latents_std = latents_std.to(latents.device, latents.dtype)
647
- return (latents - latents_mean) / latents_std
648
 
649
- @staticmethod
650
- def _patchify_audio_latents(latents: torch.Tensor) -> torch.Tensor:
651
- batch, channels, time, freq = latents.shape
652
- return latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq)
 
 
653
 
654
- @staticmethod
655
- def _unpatchify_audio_latents(latents: torch.Tensor, channels: int, freq: int) -> torch.Tensor:
656
- batch, time, _ = latents.shape
657
- return latents.reshape(batch, time, channels, freq).permute(0, 2, 1, 3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
658
 
659
  def _preprocess_audio(self, audio: Union[str, torch.Tensor], target_sample_rate: int) -> torch.Tensor:
660
- """Process audio to mel spectrogram."""
661
  if isinstance(audio, str):
662
  waveform, sr = torchaudio.load(audio)
663
  else:
@@ -667,12 +809,14 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
667
  if sr != target_sample_rate:
668
  waveform = torchaudio.functional.resample(waveform, sr, target_sample_rate)
669
 
 
 
670
  if waveform.shape[0] == 1:
671
  waveform = waveform.repeat(2, 1)
672
  elif waveform.shape[0] > 2:
673
  waveform = waveform[:2, :]
674
 
675
- waveform = waveform.unsqueeze(0)
676
 
677
  n_fft = 1024
678
  mel_transform = T.MelSpectrogram(
@@ -691,208 +835,86 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
691
  norm="slaney",
692
  )
693
 
694
- mel_spec = mel_transform(waveform)
695
  mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
696
- mel_spec = mel_spec.permute(0, 1, 3, 2).contiguous()
697
-
698
  return mel_spec
699
 
700
- # ==================== Latent Preparation ====================
701
-
702
- def prepare_latents(
703
- self,
704
- image: Optional[torch.Tensor] = None,
705
- video: Optional[torch.Tensor] = None,
706
- video_conditioning_strength: float = 1.0,
707
- video_conditioning_frame_idx: int = 1,
708
- batch_size: int = 1,
709
- num_channels_latents: int = 128,
710
- height: int = 512,
711
- width: int = 704,
712
- num_frames: int = 161,
713
- dtype: Optional[torch.dtype] = None,
714
- device: Optional[torch.device] = None,
715
- generator: Optional[torch.Generator] = None,
716
- latents: Optional[torch.Tensor] = None,
717
- ) -> Tuple[torch.Tensor, torch.Tensor]:
718
- """
719
- Prepare latents for generation with optional video conditioning.
720
-
721
- Args:
722
- image: Input image for frame 0 conditioning
723
- video: Video tensor for motion conditioning
724
- video_conditioning_strength: Strength of video conditioning (0-1)
725
- video_conditioning_frame_idx: Frame index where video conditioning starts.
726
- - 0: Video conditioning replaces all frames including frame 0
727
- - 1: Frame 0 is image-conditioned, frames 1+ are video-conditioned (default for face-swap)
728
- - N: Frames 0 to N-1 are image/noise, frames N+ are video-conditioned
729
- ... other args ...
730
- """
731
- latent_height = height // self.vae_spatial_compression_ratio
732
- latent_width = width // self.vae_spatial_compression_ratio
733
- latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
734
-
735
- shape = (batch_size, num_channels_latents, latent_num_frames, latent_height, latent_width)
736
- mask_shape = (batch_size, 1, latent_num_frames, latent_height, latent_width)
737
-
738
- if latents is not None:
739
- conditioning_mask = latents.new_zeros(mask_shape)
740
- conditioning_mask[:, :, 0] = 1.0
741
- conditioning_mask = self._pack_latents(
742
- conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
743
- ).squeeze(-1)
744
- return latents.to(device=device, dtype=dtype), conditioning_mask
745
-
746
- # Initialize conditioning mask (all zeros = fully denoise)
747
- conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype)
748
-
749
- # Initialize latents tensor
750
- init_latents = torch.zeros(shape, device=device, dtype=dtype)
751
-
752
- # Case 1: Video conditioning (motion from reference video)
753
- if video is not None:
754
- # Encode video through VAE
755
- video_latents = self._encode_video_conditioning(video, generator)
756
- video_latents = self._normalize_latents(video_latents, self.vae.latents_mean, self.vae.latents_std)
757
-
758
- # Ensure video latents match target shape
759
- if video_latents.shape[2] < latent_num_frames:
760
- # Pad with last frame
761
- pad_frames = latent_num_frames - video_latents.shape[2]
762
- last_frame = video_latents[:, :, -1:, :, :]
763
- video_latents = torch.cat([video_latents, last_frame.repeat(1, 1, pad_frames, 1, 1)], dim=2)
764
- elif video_latents.shape[2] > latent_num_frames:
765
- video_latents = video_latents[:, :, :latent_num_frames, :, :]
766
-
767
- # Calculate the latent frame index for video conditioning
768
- # video_conditioning_frame_idx is in pixel space, convert to latent space
769
- latent_video_start_idx = video_conditioning_frame_idx // self.vae_temporal_compression_ratio
770
- latent_video_start_idx = min(latent_video_start_idx, latent_num_frames - 1)
771
-
772
- # Apply video conditioning starting from the specified frame index
773
- # Video frames are placed starting at latent_video_start_idx
774
- num_video_frames_to_use = latent_num_frames - latent_video_start_idx
775
- init_latents[:, :, latent_video_start_idx:, :, :] = video_latents[:, :, :num_video_frames_to_use, :, :]
776
-
777
- # Set conditioning mask for video frames
778
- # strength=1.0 means fully conditioned (no denoising), strength=0.0 means fully denoised
779
- conditioning_mask[:, :, latent_video_start_idx:] = video_conditioning_strength
780
-
781
- # Handle image conditioning for frame 0
782
- if image is not None:
783
- if isinstance(generator, list):
784
- image_latents = [
785
- retrieve_latents(
786
- self.vae.encode(
787
- image[i].unsqueeze(0).unsqueeze(2)
788
- .to(device=self.vae.device, dtype=self.vae.dtype)
789
- .contiguous()
790
- ),
791
- generator[i],
792
- "argmax",
793
- )
794
- for i in range(batch_size)
795
- ]
796
-
797
- else:
798
- image_latents = [
799
- retrieve_latents(self.vae.encode(img.unsqueeze(0).unsqueeze(2).to(device=self.vae.device, dtype=self.vae.dtype).contiguous()), generator, "argmax")
800
-
801
- for img in image
802
- ]
803
- image_latents = torch.cat(image_latents, dim=0).to(dtype)
804
- image_latents = self._normalize_latents(image_latents, self.vae.latents_mean, self.vae.latents_std)
805
-
806
- # Replace frame 0 with image latents (face appearance)
807
- init_latents[:, :, 0:1, :, :] = image_latents
808
- # Frame 0 is fully conditioned
809
- conditioning_mask[:, :, 0] = 1.0
810
-
811
- # If no video conditioning, repeat image for all frames (image-to-video mode)
812
- if video is None:
813
- init_latents = image_latents.repeat(1, 1, latent_num_frames, 1, 1)
814
-
815
- # Generate noise
816
- noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
817
-
818
- # Blend: conditioned regions keep init_latents, unconditioned regions get noise
819
- latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask)
820
-
821
- # Pack for transformer
822
- conditioning_mask = self._pack_latents(
823
- conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
824
- ).squeeze(-1)
825
- latents = self._pack_latents(
826
- latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
827
- )
828
-
829
- return latents, conditioning_mask
830
-
831
  def prepare_audio_latents(
832
  self,
833
  batch_size: int = 1,
834
  num_channels_latents: int = 8,
 
835
  num_mel_bins: int = 64,
836
- num_frames: int = 121,
837
- frame_rate: float = 25.0,
838
- sampling_rate: int = 16000,
839
- hop_length: int = 160,
840
  dtype: Optional[torch.dtype] = None,
841
  device: Optional[torch.device] = None,
842
  generator: Optional[torch.Generator] = None,
843
- audio_input: Optional[Union[str, torch.Tensor]] = None,
844
  latents: Optional[torch.Tensor] = None,
 
845
  ) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]:
846
- duration_s = num_frames / frame_rate
847
- latents_per_second = (
848
- float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
849
- )
850
- target_length = round(duration_s * latents_per_second)
851
-
852
  if latents is not None:
853
- return latents.to(device=device, dtype=dtype), target_length, None
854
-
855
- latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
 
 
 
 
 
 
856
 
 
857
  if audio_input is not None:
858
- mel_spec = self._preprocess_audio(audio_input, sampling_rate).to(device=device)
859
  mel_spec = mel_spec.to(dtype=self.audio_vae.dtype)
860
- init_latents = self.audio_vae.encode(mel_spec).latent_dist.sample(generator)
861
- init_latents = init_latents.to(dtype=dtype)
862
-
863
- latent_channels = init_latents.shape[1]
864
- latent_freq = init_latents.shape[3]
865
- init_latents_patched = self._patchify_audio_latents(init_latents)
866
- init_latents_patched = self._normalize_audio_latents(
867
- init_latents_patched, self.audio_vae.latents_mean, self.audio_vae.latents_std
868
- )
869
- init_latents = self._unpatchify_audio_latents(init_latents_patched, latent_channels, latent_freq)
870
 
871
- current_len = init_latents.shape[2]
872
- if current_len < target_length:
873
- padding = target_length - current_len
874
- init_latents = torch.nn.functional.pad(init_latents, (0, 0, 0, padding))
875
- elif current_len > target_length:
876
- init_latents = init_latents[:, :, :target_length, :]
877
 
878
- noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype)
 
 
 
 
 
 
879
 
880
- if init_latents.shape[0] != batch_size:
881
- init_latents = init_latents.repeat(batch_size, 1, 1, 1)
882
- noise = noise.repeat(batch_size, 1, 1, 1)
883
 
884
- packed_noise = self._pack_audio_latents(noise)
 
 
 
 
 
 
 
885
 
886
- return packed_noise, target_length, init_latents
 
 
 
 
 
 
 
 
 
887
 
888
- shape = (batch_size, num_channels_latents, target_length, latent_mel_bins)
889
  latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
890
  latents = self._pack_audio_latents(latents)
 
891
 
892
- return latents, target_length, None
893
 
894
- # ==================== Properties ====================
895
-
896
  @property
897
  def guidance_scale(self):
898
  return self._guidance_scale
@@ -921,37 +943,25 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
921
  def interrupt(self):
922
  return self._interrupt
923
 
924
- def _get_audio_duration(self, audio: Union[str, torch.Tensor], sample_rate: int) -> float:
925
- if isinstance(audio, str):
926
- info = torchaudio.info(audio)
927
- return info.num_frames / info.sample_rate
928
- else:
929
- num_samples = audio.shape[-1]
930
- return num_samples / sample_rate
931
-
932
- # ==================== Main Call ====================
933
-
934
  @torch.no_grad()
935
  @replace_example_docstring(EXAMPLE_DOC_STRING)
936
  def __call__(
937
  self,
938
- image: PipelineImageInput = None,
939
- video: Optional[Union[str, List[Image.Image], torch.Tensor]] = None,
940
- video_conditioning_strength: float = 1.0,
941
- video_conditioning_frame_idx: int = 1,
942
  audio: Optional[Union[str, torch.Tensor]] = None,
943
  prompt: Union[str, List[str]] = None,
944
  negative_prompt: Optional[Union[str, List[str]]] = None,
945
  height: int = 512,
946
  width: int = 768,
947
- num_frames: Optional[int] = None,
948
  max_frames: int = 257,
949
  frame_rate: float = 24.0,
950
  num_inference_steps: int = 40,
951
- timesteps: List[int] = None,
952
  sigmas: Optional[List[float]] = None,
 
953
  guidance_scale: float = 4.0,
954
  guidance_rescale: float = 0.0,
 
955
  num_videos_per_prompt: Optional[int] = 1,
956
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
957
  latents: Optional[torch.Tensor] = None,
@@ -969,50 +979,20 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
969
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
970
  max_sequence_length: int = 1024,
971
  ):
972
- r"""
973
- Generate avatar video with audio and optional video conditioning.
974
-
975
- Args:
976
- image (`PipelineImageInput`):
977
- The input image (face/appearance) to condition frame 0.
978
- video (`str`, `List[PIL.Image]`, or `torch.Tensor`, *optional*):
979
- Reference video for motion conditioning. Can be:
980
- - Path to a video file
981
- - List of PIL Images
982
- - Tensor of shape (F, H, W, C) or (F, C, H, W)
983
- video_conditioning_strength (`float`, *optional*, defaults to 1.0):
984
- How strongly to condition on the reference video (0.0-1.0).
985
- 1.0 = fully conditioned, 0.0 = no conditioning.
986
- video_conditioning_frame_idx (`int`, *optional*, defaults to 1):
987
- Frame index where video conditioning starts (in pixel/frame space).
988
- - 0: Video conditioning replaces all frames including frame 0
989
- - 1: Frame 0 is image-conditioned, frames 1+ are video-conditioned (default for face-swap)
990
- - N: Frames 0 to N-1 are image/noise, frames N+ are video-conditioned
991
- audio (`str` or `torch.Tensor`, *optional*):
992
- Audio for lip-sync. Can be path to audio/video file or waveform tensor.
993
- prompt (`str` or `List[str]`, *optional*):
994
- Text prompt. For face-swap, include "head_swap" trigger.
995
- Examples:
996
-
997
- Returns:
998
- [`LTX2PipelineOutput`] or `tuple`: Generated video and audio.
999
- """
1000
-
1001
  if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1002
  callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1003
 
1004
- # Calculate num_frames from audio duration if not provided
1005
- if num_frames is None:
1006
- if audio is not None:
1007
- audio_duration = self._get_audio_duration(audio, self.audio_sampling_rate)
1008
- calculated_frames = int(audio_duration * frame_rate) + 1
1009
- num_frames = min(calculated_frames, max_frames)
1010
- num_frames = ((num_frames - 1) // self.vae_temporal_compression_ratio) * self.vae_temporal_compression_ratio + 1
1011
- num_frames = max(num_frames, 9)
1012
- logger.info(f"Audio duration: {audio_duration:.2f}s -> num_frames: {num_frames}")
1013
- else:
1014
- num_frames = 121
1015
-
1016
  self.check_inputs(
1017
  prompt=prompt,
1018
  height=height,
@@ -1030,6 +1010,7 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
1030
  self._interrupt = False
1031
  self._current_timestep = None
1032
 
 
1033
  if prompt is not None and isinstance(prompt, str):
1034
  batch_size = 1
1035
  elif prompt is not None and isinstance(prompt, list):
@@ -1037,9 +1018,16 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
1037
  else:
1038
  batch_size = prompt_embeds.shape[0]
1039
 
 
 
 
 
 
 
 
1040
  device = self._execution_device
1041
 
1042
- # Encode prompts
1043
  (
1044
  prompt_embeds,
1045
  prompt_attention_mask,
@@ -1066,48 +1054,67 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
1066
  prompt_embeds, additive_attention_mask, additive_mask=True
1067
  )
1068
 
1069
- # Preprocess image
1070
- if latents is None and image is not None:
1071
- image = self.video_processor.preprocess(image, height=height, width=width)
1072
- image = image.to(device=device, dtype=self.vae.dtype)
1073
-
1074
-
1075
- # Preprocess video conditioning
1076
- video_tensor = None
1077
- if video is not None:
1078
- video_tensor = self._load_video_frames(
1079
- video=video,
1080
- height=height,
1081
- width=width,
1082
- num_frames=num_frames,
1083
- device=device,
1084
- dtype=self.vae.dtype,
1085
- )
 
 
 
1086
 
1087
- # Prepare latents with video conditioning
1088
  num_channels_latents = self.transformer.config.in_channels
1089
- latents, conditioning_mask = self.prepare_latents(
1090
- image=image,
1091
- video=video_tensor,
1092
- video_conditioning_strength=video_conditioning_strength,
1093
- video_conditioning_frame_idx=video_conditioning_frame_idx,
1094
- batch_size=batch_size * num_videos_per_prompt,
1095
- num_channels_latents=num_channels_latents,
1096
- height=height,
1097
- width=width,
1098
- num_frames=num_frames,
1099
- dtype=torch.float32,
1100
- device=device,
1101
- generator=generator,
1102
- latents=latents,
1103
  )
1104
  if self.do_classifier_free_guidance:
1105
  conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
1106
 
1107
- # Prepare audio latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1108
  num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
1109
  latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
1110
-
1111
  num_channels_latents_audio = (
1112
  self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8
1113
  )
@@ -1115,30 +1122,20 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
1115
  audio_latents, audio_num_frames, clean_audio_latents = self.prepare_audio_latents(
1116
  batch_size * num_videos_per_prompt,
1117
  num_channels_latents=num_channels_latents_audio,
 
1118
  num_mel_bins=num_mel_bins,
1119
- num_frames=num_frames,
1120
- frame_rate=frame_rate,
1121
- sampling_rate=self.audio_sampling_rate,
1122
- hop_length=self.audio_hop_length,
1123
  dtype=torch.float32,
1124
  device=device,
1125
  generator=generator,
1126
  latents=audio_latents,
1127
  audio_input=audio,
1128
  )
 
 
1129
 
1130
- packed_clean_audio_latents = None
1131
- if clean_audio_latents is not None:
1132
- packed_clean_audio_latents = self._pack_audio_latents(clean_audio_latents)
1133
-
1134
- latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
1135
- latent_height = height // self.vae_spatial_compression_ratio
1136
- latent_width = width // self.vae_spatial_compression_ratio
1137
- video_sequence_length = latent_num_frames * latent_height * latent_width
1138
-
1139
- if sigmas is None:
1140
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
1141
-
1142
  mu = calculate_shift(
1143
  video_sequence_length,
1144
  self.scheduler.config.get("base_image_seq_len", 1024),
@@ -1179,7 +1176,7 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
1179
  audio_latents.shape[0], audio_num_frames, audio_latents.device
1180
  )
1181
 
1182
- # Denoising loop
1183
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1184
  for i, t in enumerate(timesteps):
1185
  if self.interrupt:
@@ -1187,6 +1184,7 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
1187
 
1188
  self._current_timestep = t
1189
 
 
1190
  if packed_clean_audio_latents is not None:
1191
  audio_latents_input = packed_clean_audio_latents.to(dtype=prompt_embeds.dtype)
1192
  else:
@@ -1200,7 +1198,7 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
1200
  audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype)
1201
 
1202
  timestep = t.expand(latent_model_input.shape[0])
1203
- video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
1204
 
1205
  if packed_clean_audio_latents is not None:
1206
  audio_timestep = torch.zeros_like(timestep)
@@ -1224,6 +1222,7 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
1224
  audio_num_frames=audio_num_frames,
1225
  video_coords=video_coords,
1226
  audio_coords=audio_coords,
 
1227
  attention_kwargs=attention_kwargs,
1228
  return_dict=False,
1229
  )
@@ -1249,32 +1248,17 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
1249
  noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale
1250
  )
1251
 
1252
- noise_pred_video = self._unpack_latents(
1253
- noise_pred_video,
1254
- latent_num_frames,
1255
- latent_height,
1256
- latent_width,
1257
- self.transformer_spatial_patch_size,
1258
- self.transformer_temporal_patch_size,
1259
- )
1260
- latents = self._unpack_latents(
1261
- latents,
1262
- latent_num_frames,
1263
- latent_height,
1264
- latent_width,
1265
- self.transformer_spatial_patch_size,
1266
- self.transformer_temporal_patch_size,
1267
- )
1268
-
1269
- noise_pred_video = noise_pred_video[:, :, 1:]
1270
- noise_latents = latents[:, :, 1:]
1271
- pred_latents = self.scheduler.step(noise_pred_video, t, noise_latents, return_dict=False)[0]
1272
 
1273
- latents = torch.cat([latents[:, :, :1], pred_latents], dim=2)
1274
- latents = self._pack_latents(
1275
- latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
1276
- )
1277
 
 
1278
  if packed_clean_audio_latents is None:
1279
  audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0]
1280
 
@@ -1283,6 +1267,7 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
1283
  for k in callback_on_step_end_tensor_inputs:
1284
  callback_kwargs[k] = locals()[k]
1285
  callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
 
1286
  latents = callback_outputs.pop("latents", latents)
1287
  prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1288
 
@@ -1292,7 +1277,6 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
1292
  if XLA_AVAILABLE:
1293
  xm.mark_step()
1294
 
1295
- # Decode
1296
  latents = self._unpack_latents(
1297
  latents,
1298
  latent_num_frames,
@@ -1305,25 +1289,22 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
1305
  latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
1306
  )
1307
 
1308
- if clean_audio_latents is not None:
1309
- latent_channels = clean_audio_latents.shape[1]
1310
- latent_freq = clean_audio_latents.shape[3]
1311
- audio_patched = self._patchify_audio_latents(clean_audio_latents)
1312
- audio_patched = self._denormalize_audio_latents(
1313
- audio_patched, self.audio_vae.latents_mean, self.audio_vae.latents_std
1314
- )
1315
- audio_latents_for_decode = self._unpatchify_audio_latents(audio_patched, latent_channels, latent_freq)
1316
  else:
1317
- audio_latents_for_decode = self._denormalize_audio_latents(
1318
- audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std
1319
- )
1320
- audio_latents_for_decode = self._unpack_audio_latents(
1321
- audio_latents_for_decode, audio_num_frames, num_mel_bins=latent_mel_bins
1322
- )
 
 
1323
 
1324
  if output_type == "latent":
1325
  video = latents
1326
- audio_output = audio_latents_for_decode
1327
  else:
1328
  latents = latents.to(prompt_embeds.dtype)
1329
 
@@ -1348,13 +1329,13 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
1348
  video = self.vae.decode(latents, timestep, return_dict=False)[0]
1349
  video = self.video_processor.postprocess_video(video, output_type=output_type)
1350
 
1351
- audio_latents_for_decode = audio_latents_for_decode.to(self.audio_vae.dtype)
1352
- generated_mel_spectrograms = self.audio_vae.decode(audio_latents_for_decode, return_dict=False)[0]
1353
- audio_output = self.vocoder(generated_mel_spectrograms)
1354
 
1355
  self.maybe_free_model_hooks()
1356
 
1357
  if not return_dict:
1358
- return (video, audio_output)
1359
 
1360
- return LTX2PipelineOutput(frames=video, audio=audio_output)
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  import copy
16
  import inspect
17
+ from dataclasses import dataclass
18
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19
 
20
  import numpy as np
21
+ import PIL.Image
22
  import torch
23
  import torchaudio
24
  import torchaudio.transforms as T
 
25
  from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
26
 
27
  from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
 
28
  from diffusers.loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
29
+ from diffusers.models.autoencoders import (
30
+ AutoencoderKLLTX2Audio,
31
+ AutoencoderKLLTX2Video,
32
+ )
33
  from diffusers.models.transformers import LTX2VideoTransformer3DModel
34
  from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
35
+ from diffusers.utils import (
36
+ is_torch_xla_available,
37
+ logging,
38
+ replace_example_docstring,
39
+ )
40
  from diffusers.utils.torch_utils import randn_tensor
41
  from diffusers.video_processor import VideoProcessor
42
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
43
+
44
  from diffusers.pipelines.ltx2.connectors import LTX2TextConnectors
45
  from diffusers.pipelines.ltx2.pipeline_output import LTX2PipelineOutput
46
  from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
 
48
 
49
  if is_torch_xla_available():
50
  import torch_xla.core.xla_model as xm
51
+
52
  XLA_AVAILABLE = True
53
  else:
54
  XLA_AVAILABLE = False
55
 
56
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
 
57
 
58
  EXAMPLE_DOC_STRING = """
59
  Examples:
60
  ```py
61
  >>> import torch
62
+ >>> from diffusers import LTX2ConditionPipeline
63
+ >>> from diffusers.pipelines.ltx2.export_utils import encode_video
64
+ >>> from diffusers.pipelines.ltx2.pipeline_ltx2_condition import LTX2VideoCondition
65
  >>> from diffusers.utils import load_image
66
 
67
+ >>> pipe = LTX2ConditionPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16)
68
+ >>> pipe.enable_model_cpu_offload()
69
+
70
+ >>> first_image = load_image(
71
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png"
72
  ... )
73
+ >>> last_image = load_image(
74
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png"
 
 
 
 
75
  ... )
76
+ >>> first_cond = LTX2VideoCondition(frames=first_image, index=0, strength=1.0)
77
+ >>> last_cond = LTX2VideoCondition(frames=last_image, index=-1, strength=1.0)
78
+ >>> conditions = [first_cond, last_cond]
79
+ >>> prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings."
80
+ >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted, static"
81
+
82
+ >>> frame_rate = 24.0
83
+ >>> video = pipe(
84
+ ... conditions=conditions,
85
+ ... prompt=prompt,
86
+ ... negative_prompt=negative_prompt,
87
+ ... width=768,
88
+ ... height=512,
89
  ... num_frames=121,
90
+ ... frame_rate=frame_rate,
91
+ ... num_inference_steps=40,
92
+ ... guidance_scale=4.0,
93
+ ... output_type="np",
94
  ... return_dict=False,
95
  ... )
96
+ >>> video = (video * 255).round().astype("uint8")
97
+ >>> video = torch.from_numpy(video)
98
+
99
+ >>> encode_video(
100
+ ... video[0],
101
+ ... fps=frame_rate,
102
+ ... audio=audio[0].float().cpu(),
103
+ ... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000
104
+ ... output_path="video.mp4",
105
+ ... )
106
  ```
107
  """
108
 
109
 
110
+ @dataclass
111
+ class LTX2VideoCondition:
112
+ """
113
+ Defines a single frame-conditioning item for LTX-2 Video - a single frame or a sequence of frames.
114
+
115
+ Attributes:
116
+ frames (`PIL.Image.Image` or `List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`):
117
+ The image (or video) to condition the video on. Accepts any type that can be handled by
118
+ VideoProcessor.preprocess_video.
119
+ index (`int`, defaults to `0`):
120
+ The index at which the image or video will conditionally affect the video generation.
121
+ strength (`float`, defaults to `1.0`):
122
+ The strength of the conditioning effect. A value of `1.0` means the conditioning effect is fully applied.
123
+ """
124
+
125
+ frames: Union[PIL.Image.Image, List[PIL.Image.Image], np.ndarray, torch.Tensor]
126
+ index: int = 0
127
+ strength: float = 1.0
128
+
129
+
130
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
131
  def retrieve_latents(
132
  encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
133
  ):
 
141
  raise AttributeError("Could not access latents of provided encoder_output")
142
 
143
 
144
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
145
  def calculate_shift(
146
  image_seq_len,
147
  base_seq_len: int = 256,
 
155
  return mu
156
 
157
 
158
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
159
  def retrieve_timesteps(
160
  scheduler,
161
  num_inference_steps: Optional[int] = None,
 
164
  sigmas: Optional[List[float]] = None,
165
  **kwargs,
166
  ):
167
+ r"""
168
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
169
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
170
+
171
+ Args:
172
+ scheduler (`SchedulerMixin`):
173
+ The scheduler to get timesteps from.
174
+ num_inference_steps (`int`):
175
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
176
+ must be `None`.
177
+ device (`str` or `torch.device`, *optional*):
178
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
179
+ timesteps (`List[int]`, *optional*):
180
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
181
+ `num_inference_steps` and `sigmas` must be `None`.
182
+ sigmas (`List[float]`, *optional*):
183
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
184
+ `num_inference_steps` and `timesteps` must be `None`.
185
+
186
+ Returns:
187
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
188
+ second element is the number of inference steps.
189
+ """
190
  if timesteps is not None and sigmas is not None:
191
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
192
  if timesteps is not None:
193
  accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
194
  if not accepts_timesteps:
195
  raise ValueError(
196
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
197
+ f" timestep schedules. Please check whether you are using the correct scheduler."
198
  )
199
  scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
200
  timesteps = scheduler.timesteps
 
203
  accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
204
  if not accept_sigmas:
205
  raise ValueError(
206
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
207
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
208
  )
209
  scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
210
  timesteps = scheduler.timesteps
 
215
  return timesteps, num_inference_steps
216
 
217
 
218
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
219
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
220
+ r"""
221
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
222
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
223
+ Flawed](https://huggingface.co/papers/2305.08891).
224
+
225
+ Args:
226
+ noise_cfg (`torch.Tensor`):
227
+ The predicted noise tensor for the guided diffusion process.
228
+ noise_pred_text (`torch.Tensor`):
229
+ The predicted noise tensor for the text-guided diffusion process.
230
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
231
+ A rescale factor applied to the noise predictions.
232
+
233
+ Returns:
234
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
235
+ """
236
  std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
237
  std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
238
  noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
 
240
  return noise_cfg
241
 
242
 
243
+ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
244
  r"""
245
+ Pipeline for video generation which allows image conditions to be inserted at arbitary parts of the video.
246
 
247
+ Reference: https://github.com/Lightricks/LTX-Video
 
 
 
248
 
249
+ TODO
 
250
  """
251
 
252
  model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder"
 
283
  self.vae_temporal_compression_ratio = (
284
  self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
285
  )
286
+ # TODO: check whether the MEL compression ratio logic here is corrct
287
  self.audio_vae_mel_compression_ratio = (
288
  self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4
289
  )
 
309
  self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024
310
  )
311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  @staticmethod
313
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds
314
  def _pack_text_embeds(
315
  text_hidden_states: torch.Tensor,
316
  sequence_lengths: torch.Tensor,
 
348
  normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
349
  return normalized_hidden_states
350
 
351
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds
352
  def _get_gemma_prompt_embeds(
353
  self,
354
  prompt: Union[str, List[str]],
 
408
 
409
  return prompt_embeds, prompt_attention_mask
410
 
411
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.encode_prompt
412
  def encode_prompt(
413
  self,
414
  prompt: Union[str, List[str]],
 
448
 
449
  if prompt is not None and type(prompt) is not type(negative_prompt):
450
  raise TypeError(
451
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
452
+ f" {type(prompt)}."
453
  )
454
  elif batch_size != len(negative_prompt):
455
  raise ValueError(
456
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
457
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
458
+ " the batch size of `prompt`."
459
  )
460
 
461
  negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
 
487
  k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
488
  ):
489
  raise ValueError(
490
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
491
  )
492
 
493
  if prompt is not None and prompt_embeds is not None:
494
+ raise ValueError(
495
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
496
+ " only forward one of the two."
497
+ )
498
  elif prompt is None and prompt_embeds is None:
499
+ raise ValueError(
500
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
501
+ )
502
  elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
503
  raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
504
 
 
508
  if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
509
  raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
510
 
511
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
512
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
513
+ raise ValueError(
514
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
515
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
516
+ f" {negative_prompt_embeds.shape}."
517
+ )
518
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
519
+ raise ValueError(
520
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
521
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
522
+ f" {negative_prompt_attention_mask.shape}."
523
+ )
524
+
525
  @staticmethod
526
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents
527
  def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
528
  batch_size, num_channels, num_frames, height, width = latents.shape
529
  post_patch_num_frames = num_frames // patch_size_t
 
543
  return latents
544
 
545
  @staticmethod
546
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents
547
  def _unpack_latents(
548
  latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
549
  ) -> torch.Tensor:
 
562
  return latents
563
 
564
  @staticmethod
565
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents
566
  def _denormalize_latents(
567
  latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
568
  ) -> torch.Tensor:
 
571
  latents = latents * latents_std / scaling_factor + latents_mean
572
  return latents
573
 
 
 
574
  @staticmethod
575
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._create_noised_state
576
+ def _create_noised_state(
577
+ latents: torch.Tensor, noise_scale: Union[float, torch.Tensor], generator: Optional[torch.Generator] = None
578
+ ):
579
+ noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype)
580
+ noised_latents = noise_scale * noise + (1 - noise_scale) * latents
581
+ return noised_latents
582
+
583
+ @staticmethod
584
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_audio_latents
585
  def _pack_audio_latents(
586
  latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None
587
  ) -> torch.Tensor:
 
598
  return latents
599
 
600
  @staticmethod
601
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_audio_latents
602
  def _unpack_audio_latents(
603
  latents: torch.Tensor,
604
  latent_length: int,
 
615
  return latents
616
 
617
  @staticmethod
618
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._normalize_audio_latents
619
+ def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
620
  latents_mean = latents_mean.to(latents.device, latents.dtype)
621
  latents_std = latents_std.to(latents.device, latents.dtype)
622
+ return (latents - latents_mean) / latents_std
623
 
624
  @staticmethod
625
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_audio_latents
626
+ def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
627
  latents_mean = latents_mean.to(latents.device, latents.dtype)
628
  latents_std = latents_std.to(latents.device, latents.dtype)
629
+ return (latents * latents_std) + latents_mean
630
 
631
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx_condition.LTXConditionPipeline.trim_conditioning_sequence
632
+ def trim_conditioning_sequence(self, start_frame: int, sequence_num_frames: int, target_num_frames: int) -> int:
633
+ scale_factor = self.vae_temporal_compression_ratio
634
+ num_frames = min(sequence_num_frames, target_num_frames - start_frame)
635
+ num_frames = (num_frames - 1) // scale_factor * scale_factor + 1
636
+ return num_frames
637
 
638
+ def preprocess_conditions(
639
+ self,
640
+ conditions: Optional[Union[LTX2VideoCondition, List[LTX2VideoCondition]]] = None,
641
+ height: int = 512,
642
+ width: int = 768,
643
+ num_frames: int = 121,
644
+ device: Optional[torch.device] = None,
645
+ index_type: str = "latent",
646
+ ) -> Tuple[List[torch.Tensor], List[float], List[int]]:
647
+ conditioning_frames, conditioning_strengths, conditioning_indices = [], [], []
648
+
649
+ if conditions is None:
650
+ conditions = []
651
+ if isinstance(conditions, LTX2VideoCondition):
652
+ conditions = [conditions]
653
+
654
+ frame_scale_factor = self.vae_temporal_compression_ratio
655
+ latent_num_frames = (num_frames - 1) // frame_scale_factor + 1
656
+ for i, condition in enumerate(conditions):
657
+ if isinstance(condition.frames, PIL.Image.Image):
658
+ video_like_cond = [condition.frames]
659
+ elif isinstance(condition.frames, np.ndarray) and condition.frames.ndim == 3:
660
+ video_like_cond = np.expand_dims(condition.frames, axis=0)
661
+ elif isinstance(condition.frames, torch.Tensor) and condition.frames.ndim == 3:
662
+ video_like_cond = condition.frames.unsqueeze(0)
663
+ else:
664
+ video_like_cond = condition.frames
665
+
666
+ condition_pixels = self.video_processor.preprocess_video(video_like_cond, height, width)
667
+
668
+ latent_start_idx = condition.index
669
+ if latent_start_idx < 0:
670
+ latent_start_idx = latent_start_idx % latent_num_frames
671
+ if latent_start_idx >= latent_num_frames:
672
+ logger.warning(
673
+ f"The starting latent index {latent_start_idx} of condition {i} is too big for the specified number"
674
+ f" of latent frames {latent_num_frames}. This condition will be skipped."
675
+ )
676
+ continue
677
+
678
+ cond_num_frames = condition_pixels.size(2)
679
+ start_idx = max((latent_start_idx - 1) * frame_scale_factor + 1, 0)
680
+ truncated_cond_frames = self.trim_conditioning_sequence(start_idx, cond_num_frames, num_frames)
681
+ condition_pixels = condition_pixels[:, :, :truncated_cond_frames]
682
+
683
+ conditioning_frames.append(condition_pixels.to(dtype=self.vae.dtype, device=device))
684
+ conditioning_strengths.append(condition.strength)
685
+ conditioning_indices.append(latent_start_idx)
686
+
687
+ return conditioning_frames, conditioning_strengths, conditioning_indices
688
+
689
+ def apply_visual_conditioning(
690
+ self,
691
+ latents: torch.Tensor,
692
+ conditioning_mask: torch.Tensor,
693
+ condition_latents: List[torch.Tensor],
694
+ condition_strengths: List[float],
695
+ condition_indices: List[int],
696
+ latent_height: int,
697
+ latent_width: int,
698
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
699
+ clean_latents = torch.zeros_like(latents)
700
+ for cond, strength, latent_idx in zip(condition_latents, condition_strengths, condition_indices):
701
+ num_cond_tokens = cond.size(1)
702
+ start_token_idx = latent_idx * latent_height * latent_width
703
+ end_token_idx = start_token_idx + num_cond_tokens
704
+
705
+ latents[:, start_token_idx:end_token_idx] = cond
706
+ conditioning_mask[:, start_token_idx:end_token_idx] = strength
707
+ clean_latents[:, start_token_idx:end_token_idx] = cond
708
+
709
+ return latents, conditioning_mask, clean_latents
710
+
711
+ def prepare_latents(
712
+ self,
713
+ conditions: Optional[Union[LTX2VideoCondition, List[LTX2VideoCondition]]] = None,
714
+ batch_size: int = 1,
715
+ num_channels_latents: int = 128,
716
+ height: int = 512,
717
+ width: int = 768,
718
+ num_frames: int = 121,
719
+ noise_scale: float = 1.0,
720
+ dtype: Optional[torch.dtype] = None,
721
+ device: Optional[torch.device] = None,
722
+ generator: Optional[torch.Generator] = None,
723
+ latents: Optional[torch.Tensor] = None,
724
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
725
+ latent_height = height // self.vae_spatial_compression_ratio
726
+ latent_width = width // self.vae_spatial_compression_ratio
727
+ latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
728
+
729
+ shape = (batch_size, num_channels_latents, latent_num_frames, latent_height, latent_width)
730
+ mask_shape = (batch_size, 1, latent_num_frames, latent_height, latent_width)
731
+
732
+ if latents is not None:
733
+ if latents.ndim == 5:
734
+ latents = self._normalize_latents(
735
+ latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
736
+ )
737
+ else:
738
+ latents = torch.zeros(shape, device=device, dtype=dtype)
739
+
740
+ conditioning_mask = latents.new_zeros(mask_shape)
741
+ if latents.ndim == 5:
742
+ latents = self._pack_latents(
743
+ latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
744
+ )
745
+ conditioning_mask = self._pack_latents(
746
+ conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
747
+ )
748
+
749
+ if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape[:2]:
750
+ raise ValueError(
751
+ f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape[:2] + (num_channels_latents,)}."
752
+ )
753
+
754
+ if isinstance(generator, list):
755
+ logger.warning(
756
+ f"{self.__class__.__name__} does not support using a list of generators. The first generator in the"
757
+ f" list will be used for all (pseudo-)random operations."
758
+ )
759
+ generator = generator[0]
760
+
761
+ condition_frames, condition_strengths, condition_indices = self.preprocess_conditions(
762
+ conditions, height, width, num_frames, device=device
763
+ )
764
+ condition_latents = []
765
+ for condition_tensor in condition_frames:
766
+ condition_latent = retrieve_latents(
767
+ self.vae.encode(condition_tensor), generator=generator, sample_mode="argmax"
768
+ )
769
+ condition_latent = self._normalize_latents(
770
+ condition_latent, self.vae.latents_mean, self.vae.latents_std
771
+ ).to(device=device, dtype=dtype)
772
+ condition_latent = self._pack_latents(
773
+ condition_latent, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
774
+ )
775
+ condition_latents.append(condition_latent)
776
+
777
+ latents, conditioning_mask, clean_latents = self.apply_visual_conditioning(
778
+ latents,
779
+ conditioning_mask,
780
+ condition_latents,
781
+ condition_strengths,
782
+ condition_indices,
783
+ latent_height=latent_height,
784
+ latent_width=latent_width,
785
+ )
786
+
787
+ noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype)
788
+ scaled_mask = (1.0 - conditioning_mask) * noise_scale
789
+ latents = noise * scaled_mask + latents * (1 - scaled_mask)
790
+
791
+ return latents, conditioning_mask, clean_latents
792
+
793
+ # -------------------- Audio conditioning additions (minimal) --------------------
794
+
795
+ def _get_audio_duration(self, audio: Union[str, torch.Tensor], sample_rate: int) -> float:
796
+ if isinstance(audio, str):
797
+ info = torchaudio.info(audio)
798
+ return info.num_frames / info.sample_rate
799
+ num_samples = audio.shape[-1]
800
+ return num_samples / sample_rate
801
 
802
  def _preprocess_audio(self, audio: Union[str, torch.Tensor], target_sample_rate: int) -> torch.Tensor:
 
803
  if isinstance(audio, str):
804
  waveform, sr = torchaudio.load(audio)
805
  else:
 
809
  if sr != target_sample_rate:
810
  waveform = torchaudio.functional.resample(waveform, sr, target_sample_rate)
811
 
812
+ if waveform.ndim == 1:
813
+ waveform = waveform.unsqueeze(0)
814
  if waveform.shape[0] == 1:
815
  waveform = waveform.repeat(2, 1)
816
  elif waveform.shape[0] > 2:
817
  waveform = waveform[:2, :]
818
 
819
+ waveform = waveform.unsqueeze(0) # [B, 2, samples]
820
 
821
  n_fft = 1024
822
  mel_transform = T.MelSpectrogram(
 
835
  norm="slaney",
836
  )
837
 
838
+ mel_spec = mel_transform(waveform) # [B, 2, mel_bins, T]
839
  mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
840
+ mel_spec = mel_spec.permute(0, 1, 3, 2).contiguous() # [B, 2, T, mel_bins]
 
841
  return mel_spec
842
 
843
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.prepare_audio_latents (modified minimally)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
844
  def prepare_audio_latents(
845
  self,
846
  batch_size: int = 1,
847
  num_channels_latents: int = 8,
848
+ audio_latent_length: int = 1, # 1 is just a dummy value
849
  num_mel_bins: int = 64,
850
+ noise_scale: float = 0.0,
 
 
 
851
  dtype: Optional[torch.dtype] = None,
852
  device: Optional[torch.device] = None,
853
  generator: Optional[torch.Generator] = None,
 
854
  latents: Optional[torch.Tensor] = None,
855
+ audio_input: Optional[Union[str, torch.Tensor]] = None,
856
  ) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]:
857
+ """
858
+ Returns:
859
+ - packed noisy audio latents [B, S, D]
860
+ - audio_latent_length
861
+ - packed clean audio latents [B, S, D] if audio_input is provided, else None
862
+ """
863
  if latents is not None:
864
+ if latents.ndim == 4:
865
+ latents = self._pack_audio_latents(latents)
866
+ if latents.ndim != 3:
867
+ raise ValueError(
868
+ f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]."
869
+ )
870
+ latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std)
871
+ latents = self._create_noised_state(latents, noise_scale, generator)
872
+ return latents.to(device=device, dtype=dtype), audio_latent_length, None
873
 
874
+ # If audio input is provided, encode to clean latents and return both clean and a dummy noisy tensor
875
  if audio_input is not None:
876
+ mel_spec = self._preprocess_audio(audio_input, self.audio_sampling_rate).to(device=device)
877
  mel_spec = mel_spec.to(dtype=self.audio_vae.dtype)
 
 
 
 
 
 
 
 
 
 
878
 
879
+ clean_4d = self.audio_vae.encode(mel_spec).latent_dist.sample(generator) # [B, C, L, F]
 
 
 
 
 
880
 
881
+ # pad/trim to audio_latent_length
882
+ cur_len = clean_4d.shape[2]
883
+ if cur_len < audio_latent_length:
884
+ pad = audio_latent_length - cur_len
885
+ clean_4d = torch.nn.functional.pad(clean_4d, (0, 0, 0, pad))
886
+ elif cur_len > audio_latent_length:
887
+ clean_4d = clean_4d[:, :, :audio_latent_length, :]
888
 
889
+ if clean_4d.shape[0] != batch_size:
890
+ clean_4d = clean_4d.repeat(batch_size, 1, 1, 1)
 
891
 
892
+ clean_packed = self._pack_audio_latents(clean_4d) # [B, S, D]
893
+ clean_packed = clean_packed.to(dtype=dtype)
894
+ clean_packed = self._normalize_audio_latents(
895
+ clean_packed, self.audio_vae.latents_mean, self.audio_vae.latents_std
896
+ )
897
+
898
+ noisy = randn_tensor(clean_packed.shape, generator=generator, device=device, dtype=dtype)
899
+ noisy = self._create_noised_state(noisy, noise_scale, generator=None) # keep same scaling semantics
900
 
901
+ return noisy, audio_latent_length, clean_packed
902
+
903
+ latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
904
+ shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins)
905
+
906
+ if isinstance(generator, list) and len(generator) != batch_size:
907
+ raise ValueError(
908
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
909
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
910
+ )
911
 
 
912
  latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
913
  latents = self._pack_audio_latents(latents)
914
+ return latents, audio_latent_length, None
915
 
916
+ # ------------------------------------------------------------------
917
 
 
 
918
  @property
919
  def guidance_scale(self):
920
  return self._guidance_scale
 
943
  def interrupt(self):
944
  return self._interrupt
945
 
 
 
 
 
 
 
 
 
 
 
946
  @torch.no_grad()
947
  @replace_example_docstring(EXAMPLE_DOC_STRING)
948
  def __call__(
949
  self,
950
+ conditions: Union[LTX2VideoCondition, List[LTX2VideoCondition]] = None,
 
 
 
951
  audio: Optional[Union[str, torch.Tensor]] = None,
952
  prompt: Union[str, List[str]] = None,
953
  negative_prompt: Optional[Union[str, List[str]]] = None,
954
  height: int = 512,
955
  width: int = 768,
956
+ num_frames: Optional[int] = 121,
957
  max_frames: int = 257,
958
  frame_rate: float = 24.0,
959
  num_inference_steps: int = 40,
 
960
  sigmas: Optional[List[float]] = None,
961
+ timesteps: List[int] = None,
962
  guidance_scale: float = 4.0,
963
  guidance_rescale: float = 0.0,
964
+ noise_scale: Optional[float] = None,
965
  num_videos_per_prompt: Optional[int] = 1,
966
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
967
  latents: Optional[torch.Tensor] = None,
 
979
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
980
  max_sequence_length: int = 1024,
981
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
982
  if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
983
  callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
984
 
985
+ # Optional: derive num_frames from audio if user passes num_frames=None
986
+ if num_frames is None and audio is not None:
987
+ audio_duration = self._get_audio_duration(audio, self.audio_sampling_rate)
988
+ calculated_frames = int(audio_duration * frame_rate) + 1
989
+ num_frames = min(calculated_frames, max_frames)
990
+ num_frames = (
991
+ (num_frames - 1) // self.vae_temporal_compression_ratio
992
+ ) * self.vae_temporal_compression_ratio + 1
993
+ num_frames = max(num_frames, 9)
994
+
995
+ # 1. Check inputs. Raise error if not correct
 
996
  self.check_inputs(
997
  prompt=prompt,
998
  height=height,
 
1010
  self._interrupt = False
1011
  self._current_timestep = None
1012
 
1013
+ # 2. Define call parameters
1014
  if prompt is not None and isinstance(prompt, str):
1015
  batch_size = 1
1016
  elif prompt is not None and isinstance(prompt, list):
 
1018
  else:
1019
  batch_size = prompt_embeds.shape[0]
1020
 
1021
+ if conditions is not None and not isinstance(conditions, list):
1022
+ conditions = [conditions]
1023
+
1024
+ # Infer noise scale: first (largest) sigma value if using custom sigmas, else 1.0
1025
+ if noise_scale is None:
1026
+ noise_scale = sigmas[0] if sigmas is not None else 1.0
1027
+
1028
  device = self._execution_device
1029
 
1030
+ # 3. Prepare text embeddings
1031
  (
1032
  prompt_embeds,
1033
  prompt_attention_mask,
 
1054
  prompt_embeds, additive_attention_mask, additive_mask=True
1055
  )
1056
 
1057
+ # 4. Prepare latent variables
1058
+ latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
1059
+ latent_height = height // self.vae_spatial_compression_ratio
1060
+ latent_width = width // self.vae_spatial_compression_ratio
1061
+ if latents is not None:
1062
+ if latents.ndim == 5:
1063
+ logger.info(
1064
+ "Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred."
1065
+ )
1066
+ _, _, latent_num_frames, latent_height, latent_width = latents.shape
1067
+ elif latents.ndim == 3:
1068
+ logger.warning(
1069
+ f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be"
1070
+ f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct."
1071
+ )
1072
+ else:
1073
+ raise ValueError(
1074
+ f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]."
1075
+ )
1076
+ video_sequence_length = latent_num_frames * latent_height * latent_width
1077
 
 
1078
  num_channels_latents = self.transformer.config.in_channels
1079
+ latents, conditioning_mask, clean_latents = self.prepare_latents(
1080
+ conditions,
1081
+ batch_size * num_videos_per_prompt,
1082
+ num_channels_latents,
1083
+ height,
1084
+ width,
1085
+ num_frames,
1086
+ noise_scale,
1087
+ torch.float32,
1088
+ device,
1089
+ generator,
1090
+ latents,
 
 
1091
  )
1092
  if self.do_classifier_free_guidance:
1093
  conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
1094
 
1095
+ duration_s = num_frames / frame_rate
1096
+ audio_latents_per_second = (
1097
+ self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio)
1098
+ )
1099
+ audio_num_frames = round(duration_s * audio_latents_per_second)
1100
+ if audio_latents is not None:
1101
+ if audio_latents.ndim == 4:
1102
+ logger.info(
1103
+ "Got audio_latents of shape [batch_size, num_channels, audio_length, mel_bins], `audio_num_frames` will be inferred."
1104
+ )
1105
+ _, _, audio_num_frames, _ = audio_latents.shape
1106
+ elif audio_latents.ndim == 3:
1107
+ logger.warning(
1108
+ f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims"
1109
+ f" cannot be inferred. Make sure the supplied `num_frames` and `frame_rate` are correct."
1110
+ )
1111
+ else:
1112
+ raise ValueError(
1113
+ f"Provided `audio_latents` tensor has shape {audio_latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, num_channels, audio_length, mel_bins]."
1114
+ )
1115
+
1116
  num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
1117
  latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
 
1118
  num_channels_latents_audio = (
1119
  self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8
1120
  )
 
1122
  audio_latents, audio_num_frames, clean_audio_latents = self.prepare_audio_latents(
1123
  batch_size * num_videos_per_prompt,
1124
  num_channels_latents=num_channels_latents_audio,
1125
+ audio_latent_length=audio_num_frames,
1126
  num_mel_bins=num_mel_bins,
1127
+ noise_scale=noise_scale,
 
 
 
1128
  dtype=torch.float32,
1129
  device=device,
1130
  generator=generator,
1131
  latents=audio_latents,
1132
  audio_input=audio,
1133
  )
1134
+ # clean_audio_latents is packed [B,S,D] if present
1135
+ packed_clean_audio_latents = clean_audio_latents
1136
 
1137
+ # 5. Prepare timesteps
1138
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
 
 
 
 
 
 
 
 
 
 
1139
  mu = calculate_shift(
1140
  video_sequence_length,
1141
  self.scheduler.config.get("base_image_seq_len", 1024),
 
1176
  audio_latents.shape[0], audio_num_frames, audio_latents.device
1177
  )
1178
 
1179
+ # 7. Denoising loop
1180
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1181
  for i, t in enumerate(timesteps):
1182
  if self.interrupt:
 
1184
 
1185
  self._current_timestep = t
1186
 
1187
+ # If audio conditioning provided, use clean audio latents directly (packed), and timestep=0
1188
  if packed_clean_audio_latents is not None:
1189
  audio_latents_input = packed_clean_audio_latents.to(dtype=prompt_embeds.dtype)
1190
  else:
 
1198
  audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype)
1199
 
1200
  timestep = t.expand(latent_model_input.shape[0])
1201
+ video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask.squeeze(-1))
1202
 
1203
  if packed_clean_audio_latents is not None:
1204
  audio_timestep = torch.zeros_like(timestep)
 
1222
  audio_num_frames=audio_num_frames,
1223
  video_coords=video_coords,
1224
  audio_coords=audio_coords,
1225
+ # rope_interpolation_scale=rope_interpolation_scale,
1226
  attention_kwargs=attention_kwargs,
1227
  return_dict=False,
1228
  )
 
1248
  noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale
1249
  )
1250
 
1251
+ bsz = noise_pred_video.size(0)
1252
+ sigma = self.scheduler.sigmas[i]
1253
+ denoised_sample = latents - noise_pred_video * sigma
1254
+ denoised_sample_cond = (
1255
+ denoised_sample * (1 - conditioning_mask[:bsz]) + clean_latents.float() * conditioning_mask[:bsz]
1256
+ ).to(noise_pred_video.dtype)
1257
+ denoised_latents_cond = ((latents - denoised_sample_cond) / sigma).to(noise_pred_video.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
1258
 
1259
+ latents = self.scheduler.step(denoised_latents_cond, t, latents, return_dict=False)[0]
 
 
 
1260
 
1261
+ # Only step audio latents if not conditioning on clean audio
1262
  if packed_clean_audio_latents is None:
1263
  audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0]
1264
 
 
1267
  for k in callback_on_step_end_tensor_inputs:
1268
  callback_kwargs[k] = locals()[k]
1269
  callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1270
+
1271
  latents = callback_outputs.pop("latents", latents)
1272
  prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1273
 
 
1277
  if XLA_AVAILABLE:
1278
  xm.mark_step()
1279
 
 
1280
  latents = self._unpack_latents(
1281
  latents,
1282
  latent_num_frames,
 
1289
  latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
1290
  )
1291
 
1292
+ # Choose audio latents for decode: clean if provided, else denoised
1293
+ if packed_clean_audio_latents is not None:
1294
+ audio_latents_to_decode = packed_clean_audio_latents
 
 
 
 
 
1295
  else:
1296
+ audio_latents_to_decode = audio_latents
1297
+
1298
+ audio_latents_to_decode = self._denormalize_audio_latents(
1299
+ audio_latents_to_decode, self.audio_vae.latents_mean, self.audio_vae.latents_std
1300
+ )
1301
+ audio_latents_to_decode = self._unpack_audio_latents(
1302
+ audio_latents_to_decode, audio_num_frames, num_mel_bins=latent_mel_bins
1303
+ )
1304
 
1305
  if output_type == "latent":
1306
  video = latents
1307
+ audio_out = audio_latents_to_decode
1308
  else:
1309
  latents = latents.to(prompt_embeds.dtype)
1310
 
 
1329
  video = self.vae.decode(latents, timestep, return_dict=False)[0]
1330
  video = self.video_processor.postprocess_video(video, output_type=output_type)
1331
 
1332
+ audio_latents_to_decode = audio_latents_to_decode.to(self.audio_vae.dtype)
1333
+ generated_mel_spectrograms = self.audio_vae.decode(audio_latents_to_decode, return_dict=False)[0]
1334
+ audio_out = self.vocoder(generated_mel_spectrograms)
1335
 
1336
  self.maybe_free_model_hooks()
1337
 
1338
  if not return_dict:
1339
+ return (video, audio_out)
1340
 
1341
+ return LTX2PipelineOutput(frames=video, audio=audio_out)