Update pipeline.py
#2
by
linoyts HF Staff - opened
- pipeline.py +513 -532
pipeline.py
CHANGED
|
@@ -12,50 +12,35 @@
|
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
| 14 |
|
| 15 |
-
"""
|
| 16 |
-
LTX-2 Audio-to-Video Pipeline with Video Conditioning Support
|
| 17 |
-
|
| 18 |
-
This is a modified version of the LTX2AudioToVideoPipeline that adds support for
|
| 19 |
-
video conditioning, enabling avatar/face-swap generation workflows.
|
| 20 |
-
|
| 21 |
-
Usage:
|
| 22 |
-
pipe = DiffusionPipeline.from_pretrained(
|
| 23 |
-
"rootonchair/LTX-2-19b-distilled",
|
| 24 |
-
custom_pipeline="path/to/this/file",
|
| 25 |
-
torch_dtype=torch.bfloat16
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
# With video conditioning (for avatar/face-swap):
|
| 29 |
-
video, audio = pipe(
|
| 30 |
-
image=face_image, # The face/appearance to use
|
| 31 |
-
video=reference_video, # Video for motion conditioning
|
| 32 |
-
audio="path/to/audio.wav", # Audio (or extracted from video)
|
| 33 |
-
prompt="head_swap, a person speaking...",
|
| 34 |
-
...
|
| 35 |
-
)
|
| 36 |
-
"""
|
| 37 |
-
|
| 38 |
import copy
|
| 39 |
import inspect
|
|
|
|
| 40 |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 41 |
|
| 42 |
import numpy as np
|
|
|
|
| 43 |
import torch
|
| 44 |
import torchaudio
|
| 45 |
import torchaudio.transforms as T
|
| 46 |
-
from PIL import Image
|
| 47 |
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
|
| 48 |
|
| 49 |
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 50 |
-
from diffusers.image_processor import PipelineImageInput
|
| 51 |
from diffusers.loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
|
| 52 |
-
from diffusers.models.autoencoders import
|
|
|
|
|
|
|
|
|
|
| 53 |
from diffusers.models.transformers import LTX2VideoTransformer3DModel
|
| 54 |
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 55 |
-
from diffusers.utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
from diffusers.utils.torch_utils import randn_tensor
|
| 57 |
from diffusers.video_processor import VideoProcessor
|
| 58 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
|
|
|
| 59 |
from diffusers.pipelines.ltx2.connectors import LTX2TextConnectors
|
| 60 |
from diffusers.pipelines.ltx2.pipeline_output import LTX2PipelineOutput
|
| 61 |
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
|
|
@@ -63,51 +48,86 @@ from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
|
|
| 63 |
|
| 64 |
if is_torch_xla_available():
|
| 65 |
import torch_xla.core.xla_model as xm
|
|
|
|
| 66 |
XLA_AVAILABLE = True
|
| 67 |
else:
|
| 68 |
XLA_AVAILABLE = False
|
| 69 |
|
| 70 |
-
logger = logging.get_logger(__name__)
|
| 71 |
-
|
| 72 |
|
| 73 |
EXAMPLE_DOC_STRING = """
|
| 74 |
Examples:
|
| 75 |
```py
|
| 76 |
>>> import torch
|
| 77 |
-
>>> from diffusers import
|
|
|
|
|
|
|
| 78 |
>>> from diffusers.utils import load_image
|
| 79 |
|
| 80 |
-
>>> pipe =
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
| 84 |
... )
|
| 85 |
-
>>>
|
| 86 |
-
|
| 87 |
-
>>> # Load face swap LoRA
|
| 88 |
-
>>> pipe.load_lora_weights(
|
| 89 |
-
... "Alissonerdx/BFS-Best-Face-Swap-Video",
|
| 90 |
-
... weight_name="ltx-2/head_swap_v1_13500_first_frame.safetensors",
|
| 91 |
... )
|
| 92 |
-
>>>
|
| 93 |
-
|
| 94 |
-
>>>
|
| 95 |
-
>>>
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
...
|
| 101 |
-
... prompt=
|
| 102 |
-
...
|
| 103 |
-
...
|
|
|
|
| 104 |
... num_frames=121,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
... return_dict=False,
|
| 106 |
... )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
```
|
| 108 |
"""
|
| 109 |
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
def retrieve_latents(
|
| 112 |
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 113 |
):
|
|
@@ -121,6 +141,7 @@ def retrieve_latents(
|
|
| 121 |
raise AttributeError("Could not access latents of provided encoder_output")
|
| 122 |
|
| 123 |
|
|
|
|
| 124 |
def calculate_shift(
|
| 125 |
image_seq_len,
|
| 126 |
base_seq_len: int = 256,
|
|
@@ -134,6 +155,7 @@ def calculate_shift(
|
|
| 134 |
return mu
|
| 135 |
|
| 136 |
|
|
|
|
| 137 |
def retrieve_timesteps(
|
| 138 |
scheduler,
|
| 139 |
num_inference_steps: Optional[int] = None,
|
|
@@ -142,13 +164,37 @@ def retrieve_timesteps(
|
|
| 142 |
sigmas: Optional[List[float]] = None,
|
| 143 |
**kwargs,
|
| 144 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
if timesteps is not None and sigmas is not None:
|
| 146 |
-
raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
|
| 147 |
if timesteps is not None:
|
| 148 |
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 149 |
if not accepts_timesteps:
|
| 150 |
raise ValueError(
|
| 151 |
-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom
|
|
|
|
| 152 |
)
|
| 153 |
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 154 |
timesteps = scheduler.timesteps
|
|
@@ -157,7 +203,8 @@ def retrieve_timesteps(
|
|
| 157 |
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 158 |
if not accept_sigmas:
|
| 159 |
raise ValueError(
|
| 160 |
-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom
|
|
|
|
| 161 |
)
|
| 162 |
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 163 |
timesteps = scheduler.timesteps
|
|
@@ -168,7 +215,24 @@ def retrieve_timesteps(
|
|
| 168 |
return timesteps, num_inference_steps
|
| 169 |
|
| 170 |
|
|
|
|
| 171 |
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
| 173 |
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
| 174 |
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
|
@@ -176,17 +240,13 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
|
| 176 |
return noise_cfg
|
| 177 |
|
| 178 |
|
| 179 |
-
class
|
| 180 |
r"""
|
| 181 |
-
Pipeline for
|
| 182 |
|
| 183 |
-
|
| 184 |
-
- An input image (the face/appearance to use)
|
| 185 |
-
- A reference video (for motion/pose conditioning)
|
| 186 |
-
- Input audio (for lip-sync)
|
| 187 |
|
| 188 |
-
|
| 189 |
-
to match the motion from the reference video and synced to the audio.
|
| 190 |
"""
|
| 191 |
|
| 192 |
model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder"
|
|
@@ -223,6 +283,7 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 223 |
self.vae_temporal_compression_ratio = (
|
| 224 |
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
|
| 225 |
)
|
|
|
|
| 226 |
self.audio_vae_mel_compression_ratio = (
|
| 227 |
self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4
|
| 228 |
)
|
|
@@ -248,123 +309,8 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 248 |
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024
|
| 249 |
)
|
| 250 |
|
| 251 |
-
# ==================== Video Conditioning Methods ====================
|
| 252 |
-
|
| 253 |
-
def _load_video_frames(
|
| 254 |
-
self,
|
| 255 |
-
video: Union[str, List[Image.Image], torch.Tensor],
|
| 256 |
-
height: int,
|
| 257 |
-
width: int,
|
| 258 |
-
num_frames: int,
|
| 259 |
-
device: torch.device,
|
| 260 |
-
dtype: torch.dtype,
|
| 261 |
-
) -> torch.Tensor:
|
| 262 |
-
"""
|
| 263 |
-
Load and preprocess video frames for conditioning.
|
| 264 |
-
|
| 265 |
-
Args:
|
| 266 |
-
video: Path to video file, list of PIL images, or tensor of frames
|
| 267 |
-
height: Target height
|
| 268 |
-
width: Target width
|
| 269 |
-
num_frames: Number of frames to extract/use
|
| 270 |
-
device: Target device
|
| 271 |
-
dtype: Target dtype
|
| 272 |
-
|
| 273 |
-
Returns:
|
| 274 |
-
Tensor of shape (batch, channels, num_frames, height, width)
|
| 275 |
-
"""
|
| 276 |
-
if isinstance(video, str):
|
| 277 |
-
# Load from file
|
| 278 |
-
frames = self._decode_video_file(video, num_frames)
|
| 279 |
-
elif isinstance(video, list):
|
| 280 |
-
# List of PIL images
|
| 281 |
-
frames = [np.array(img.convert("RGB")) for img in video]
|
| 282 |
-
elif isinstance(video, torch.Tensor):
|
| 283 |
-
# Already a tensor
|
| 284 |
-
if video.ndim == 4: # (F, H, W, C) or (F, C, H, W)
|
| 285 |
-
if video.shape[-1] in [1, 3, 4]: # (F, H, W, C)
|
| 286 |
-
frames = [video[i].cpu().numpy() for i in range(video.shape[0])]
|
| 287 |
-
else: # (F, C, H, W)
|
| 288 |
-
frames = [video[i].permute(1, 2, 0).cpu().numpy() for i in range(video.shape[0])]
|
| 289 |
-
else:
|
| 290 |
-
raise ValueError(f"Unexpected video tensor shape: {video.shape}")
|
| 291 |
-
else:
|
| 292 |
-
raise TypeError(f"Unsupported video type: {type(video)}")
|
| 293 |
-
|
| 294 |
-
# Handle frame count
|
| 295 |
-
if len(frames) >= num_frames:
|
| 296 |
-
frames = frames[:num_frames]
|
| 297 |
-
else:
|
| 298 |
-
# Pad with last frame
|
| 299 |
-
last_frame = frames[-1]
|
| 300 |
-
while len(frames) < num_frames:
|
| 301 |
-
frames.append(last_frame)
|
| 302 |
-
|
| 303 |
-
# Process each frame
|
| 304 |
-
processed_frames = []
|
| 305 |
-
for frame in frames:
|
| 306 |
-
if isinstance(frame, np.ndarray):
|
| 307 |
-
frame = Image.fromarray(frame.astype(np.uint8))
|
| 308 |
-
|
| 309 |
-
# Resize to target dimensions
|
| 310 |
-
frame = frame.resize((width, height), Image.LANCZOS)
|
| 311 |
-
frame = np.array(frame)
|
| 312 |
-
|
| 313 |
-
# Normalize to [-1, 1]
|
| 314 |
-
frame = (frame.astype(np.float32) / 127.5) - 1.0
|
| 315 |
-
processed_frames.append(frame)
|
| 316 |
-
|
| 317 |
-
# Stack frames: (F, H, W, C) -> (1, C, F, H, W)
|
| 318 |
-
frames_array = np.stack(processed_frames, axis=0) # (F, H, W, C)
|
| 319 |
-
frames_tensor = torch.from_numpy(frames_array).permute(3, 0, 1, 2).unsqueeze(0) # (1, C, F, H, W)
|
| 320 |
-
|
| 321 |
-
return frames_tensor.to(device=device, dtype=dtype)
|
| 322 |
-
|
| 323 |
-
def _decode_video_file(self, video_path: str, max_frames: int) -> List[np.ndarray]:
|
| 324 |
-
"""Decode video file to list of numpy arrays."""
|
| 325 |
-
try:
|
| 326 |
-
import av
|
| 327 |
-
except ImportError:
|
| 328 |
-
raise ImportError("Please install av: pip install av")
|
| 329 |
-
|
| 330 |
-
frames = []
|
| 331 |
-
container = av.open(video_path)
|
| 332 |
-
try:
|
| 333 |
-
video_stream = next(s for s in container.streams if s.type == "video")
|
| 334 |
-
for frame in container.decode(video_stream):
|
| 335 |
-
frames.append(frame.to_rgb().to_ndarray())
|
| 336 |
-
if len(frames) >= max_frames:
|
| 337 |
-
break
|
| 338 |
-
finally:
|
| 339 |
-
container.close()
|
| 340 |
-
|
| 341 |
-
return frames
|
| 342 |
-
|
| 343 |
-
def _encode_video_conditioning(
|
| 344 |
-
self,
|
| 345 |
-
video: torch.Tensor,
|
| 346 |
-
generator: Optional[torch.Generator] = None,
|
| 347 |
-
) -> torch.Tensor:
|
| 348 |
-
"""
|
| 349 |
-
Encode video frames through the VAE to get latents.
|
| 350 |
-
|
| 351 |
-
Args:
|
| 352 |
-
video: Video tensor of shape (batch, channels, frames, height, width)
|
| 353 |
-
generator: Random generator for sampling
|
| 354 |
-
|
| 355 |
-
Returns:
|
| 356 |
-
Video latents
|
| 357 |
-
"""
|
| 358 |
-
# Encode each frame through VAE
|
| 359 |
-
# VAE expects (batch, channels, frames, height, width)
|
| 360 |
-
video = video.to(device=self.vae.device, dtype=self.vae.dtype).contiguous()
|
| 361 |
-
latents = retrieve_latents(self.vae.encode(video), generator, "argmax")
|
| 362 |
-
|
| 363 |
-
return latents
|
| 364 |
-
|
| 365 |
-
# ==================== Text Encoding Methods ====================
|
| 366 |
-
|
| 367 |
@staticmethod
|
|
|
|
| 368 |
def _pack_text_embeds(
|
| 369 |
text_hidden_states: torch.Tensor,
|
| 370 |
sequence_lengths: torch.Tensor,
|
|
@@ -402,6 +348,7 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 402 |
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
|
| 403 |
return normalized_hidden_states
|
| 404 |
|
|
|
|
| 405 |
def _get_gemma_prompt_embeds(
|
| 406 |
self,
|
| 407 |
prompt: Union[str, List[str]],
|
|
@@ -461,6 +408,7 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 461 |
|
| 462 |
return prompt_embeds, prompt_attention_mask
|
| 463 |
|
|
|
|
| 464 |
def encode_prompt(
|
| 465 |
self,
|
| 466 |
prompt: Union[str, List[str]],
|
|
@@ -500,11 +448,14 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 500 |
|
| 501 |
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 502 |
raise TypeError(
|
| 503 |
-
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=
|
|
|
|
| 504 |
)
|
| 505 |
elif batch_size != len(negative_prompt):
|
| 506 |
raise ValueError(
|
| 507 |
-
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:
|
|
|
|
|
|
|
| 508 |
)
|
| 509 |
|
| 510 |
negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
|
|
@@ -536,13 +487,18 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 536 |
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 537 |
):
|
| 538 |
raise ValueError(
|
| 539 |
-
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}"
|
| 540 |
)
|
| 541 |
|
| 542 |
if prompt is not None and prompt_embeds is not None:
|
| 543 |
-
raise ValueError(
|
|
|
|
|
|
|
|
|
|
| 544 |
elif prompt is None and prompt_embeds is None:
|
| 545 |
-
raise ValueError(
|
|
|
|
|
|
|
| 546 |
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 547 |
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 548 |
|
|
@@ -552,9 +508,22 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 552 |
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
| 553 |
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
| 554 |
|
| 555 |
-
|
| 556 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 557 |
@staticmethod
|
|
|
|
| 558 |
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
| 559 |
batch_size, num_channels, num_frames, height, width = latents.shape
|
| 560 |
post_patch_num_frames = num_frames // patch_size_t
|
|
@@ -574,6 +543,7 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 574 |
return latents
|
| 575 |
|
| 576 |
@staticmethod
|
|
|
|
| 577 |
def _unpack_latents(
|
| 578 |
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
|
| 579 |
) -> torch.Tensor:
|
|
@@ -592,6 +562,7 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 592 |
return latents
|
| 593 |
|
| 594 |
@staticmethod
|
|
|
|
| 595 |
def _denormalize_latents(
|
| 596 |
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
| 597 |
) -> torch.Tensor:
|
|
@@ -600,9 +571,17 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 600 |
latents = latents * latents_std / scaling_factor + latents_mean
|
| 601 |
return latents
|
| 602 |
|
| 603 |
-
# ==================== Audio Latent Methods ====================
|
| 604 |
-
|
| 605 |
@staticmethod
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 606 |
def _pack_audio_latents(
|
| 607 |
latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None
|
| 608 |
) -> torch.Tensor:
|
|
@@ -619,6 +598,7 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 619 |
return latents
|
| 620 |
|
| 621 |
@staticmethod
|
|
|
|
| 622 |
def _unpack_audio_latents(
|
| 623 |
latents: torch.Tensor,
|
| 624 |
latent_length: int,
|
|
@@ -635,29 +615,191 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 635 |
return latents
|
| 636 |
|
| 637 |
@staticmethod
|
| 638 |
-
|
|
|
|
| 639 |
latents_mean = latents_mean.to(latents.device, latents.dtype)
|
| 640 |
latents_std = latents_std.to(latents.device, latents.dtype)
|
| 641 |
-
return (latents
|
| 642 |
|
| 643 |
@staticmethod
|
| 644 |
-
|
|
|
|
| 645 |
latents_mean = latents_mean.to(latents.device, latents.dtype)
|
| 646 |
latents_std = latents_std.to(latents.device, latents.dtype)
|
| 647 |
-
return (latents
|
| 648 |
|
| 649 |
-
|
| 650 |
-
def
|
| 651 |
-
|
| 652 |
-
|
|
|
|
|
|
|
| 653 |
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 658 |
|
| 659 |
def _preprocess_audio(self, audio: Union[str, torch.Tensor], target_sample_rate: int) -> torch.Tensor:
|
| 660 |
-
"""Process audio to mel spectrogram."""
|
| 661 |
if isinstance(audio, str):
|
| 662 |
waveform, sr = torchaudio.load(audio)
|
| 663 |
else:
|
|
@@ -667,12 +809,14 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 667 |
if sr != target_sample_rate:
|
| 668 |
waveform = torchaudio.functional.resample(waveform, sr, target_sample_rate)
|
| 669 |
|
|
|
|
|
|
|
| 670 |
if waveform.shape[0] == 1:
|
| 671 |
waveform = waveform.repeat(2, 1)
|
| 672 |
elif waveform.shape[0] > 2:
|
| 673 |
waveform = waveform[:2, :]
|
| 674 |
|
| 675 |
-
waveform = waveform.unsqueeze(0)
|
| 676 |
|
| 677 |
n_fft = 1024
|
| 678 |
mel_transform = T.MelSpectrogram(
|
|
@@ -691,208 +835,86 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 691 |
norm="slaney",
|
| 692 |
)
|
| 693 |
|
| 694 |
-
mel_spec = mel_transform(waveform)
|
| 695 |
mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
|
| 696 |
-
mel_spec = mel_spec.permute(0, 1, 3, 2).contiguous()
|
| 697 |
-
|
| 698 |
return mel_spec
|
| 699 |
|
| 700 |
-
#
|
| 701 |
-
|
| 702 |
-
def prepare_latents(
|
| 703 |
-
self,
|
| 704 |
-
image: Optional[torch.Tensor] = None,
|
| 705 |
-
video: Optional[torch.Tensor] = None,
|
| 706 |
-
video_conditioning_strength: float = 1.0,
|
| 707 |
-
video_conditioning_frame_idx: int = 1,
|
| 708 |
-
batch_size: int = 1,
|
| 709 |
-
num_channels_latents: int = 128,
|
| 710 |
-
height: int = 512,
|
| 711 |
-
width: int = 704,
|
| 712 |
-
num_frames: int = 161,
|
| 713 |
-
dtype: Optional[torch.dtype] = None,
|
| 714 |
-
device: Optional[torch.device] = None,
|
| 715 |
-
generator: Optional[torch.Generator] = None,
|
| 716 |
-
latents: Optional[torch.Tensor] = None,
|
| 717 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 718 |
-
"""
|
| 719 |
-
Prepare latents for generation with optional video conditioning.
|
| 720 |
-
|
| 721 |
-
Args:
|
| 722 |
-
image: Input image for frame 0 conditioning
|
| 723 |
-
video: Video tensor for motion conditioning
|
| 724 |
-
video_conditioning_strength: Strength of video conditioning (0-1)
|
| 725 |
-
video_conditioning_frame_idx: Frame index where video conditioning starts.
|
| 726 |
-
- 0: Video conditioning replaces all frames including frame 0
|
| 727 |
-
- 1: Frame 0 is image-conditioned, frames 1+ are video-conditioned (default for face-swap)
|
| 728 |
-
- N: Frames 0 to N-1 are image/noise, frames N+ are video-conditioned
|
| 729 |
-
... other args ...
|
| 730 |
-
"""
|
| 731 |
-
latent_height = height // self.vae_spatial_compression_ratio
|
| 732 |
-
latent_width = width // self.vae_spatial_compression_ratio
|
| 733 |
-
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
| 734 |
-
|
| 735 |
-
shape = (batch_size, num_channels_latents, latent_num_frames, latent_height, latent_width)
|
| 736 |
-
mask_shape = (batch_size, 1, latent_num_frames, latent_height, latent_width)
|
| 737 |
-
|
| 738 |
-
if latents is not None:
|
| 739 |
-
conditioning_mask = latents.new_zeros(mask_shape)
|
| 740 |
-
conditioning_mask[:, :, 0] = 1.0
|
| 741 |
-
conditioning_mask = self._pack_latents(
|
| 742 |
-
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
| 743 |
-
).squeeze(-1)
|
| 744 |
-
return latents.to(device=device, dtype=dtype), conditioning_mask
|
| 745 |
-
|
| 746 |
-
# Initialize conditioning mask (all zeros = fully denoise)
|
| 747 |
-
conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype)
|
| 748 |
-
|
| 749 |
-
# Initialize latents tensor
|
| 750 |
-
init_latents = torch.zeros(shape, device=device, dtype=dtype)
|
| 751 |
-
|
| 752 |
-
# Case 1: Video conditioning (motion from reference video)
|
| 753 |
-
if video is not None:
|
| 754 |
-
# Encode video through VAE
|
| 755 |
-
video_latents = self._encode_video_conditioning(video, generator)
|
| 756 |
-
video_latents = self._normalize_latents(video_latents, self.vae.latents_mean, self.vae.latents_std)
|
| 757 |
-
|
| 758 |
-
# Ensure video latents match target shape
|
| 759 |
-
if video_latents.shape[2] < latent_num_frames:
|
| 760 |
-
# Pad with last frame
|
| 761 |
-
pad_frames = latent_num_frames - video_latents.shape[2]
|
| 762 |
-
last_frame = video_latents[:, :, -1:, :, :]
|
| 763 |
-
video_latents = torch.cat([video_latents, last_frame.repeat(1, 1, pad_frames, 1, 1)], dim=2)
|
| 764 |
-
elif video_latents.shape[2] > latent_num_frames:
|
| 765 |
-
video_latents = video_latents[:, :, :latent_num_frames, :, :]
|
| 766 |
-
|
| 767 |
-
# Calculate the latent frame index for video conditioning
|
| 768 |
-
# video_conditioning_frame_idx is in pixel space, convert to latent space
|
| 769 |
-
latent_video_start_idx = video_conditioning_frame_idx // self.vae_temporal_compression_ratio
|
| 770 |
-
latent_video_start_idx = min(latent_video_start_idx, latent_num_frames - 1)
|
| 771 |
-
|
| 772 |
-
# Apply video conditioning starting from the specified frame index
|
| 773 |
-
# Video frames are placed starting at latent_video_start_idx
|
| 774 |
-
num_video_frames_to_use = latent_num_frames - latent_video_start_idx
|
| 775 |
-
init_latents[:, :, latent_video_start_idx:, :, :] = video_latents[:, :, :num_video_frames_to_use, :, :]
|
| 776 |
-
|
| 777 |
-
# Set conditioning mask for video frames
|
| 778 |
-
# strength=1.0 means fully conditioned (no denoising), strength=0.0 means fully denoised
|
| 779 |
-
conditioning_mask[:, :, latent_video_start_idx:] = video_conditioning_strength
|
| 780 |
-
|
| 781 |
-
# Handle image conditioning for frame 0
|
| 782 |
-
if image is not None:
|
| 783 |
-
if isinstance(generator, list):
|
| 784 |
-
image_latents = [
|
| 785 |
-
retrieve_latents(
|
| 786 |
-
self.vae.encode(
|
| 787 |
-
image[i].unsqueeze(0).unsqueeze(2)
|
| 788 |
-
.to(device=self.vae.device, dtype=self.vae.dtype)
|
| 789 |
-
.contiguous()
|
| 790 |
-
),
|
| 791 |
-
generator[i],
|
| 792 |
-
"argmax",
|
| 793 |
-
)
|
| 794 |
-
for i in range(batch_size)
|
| 795 |
-
]
|
| 796 |
-
|
| 797 |
-
else:
|
| 798 |
-
image_latents = [
|
| 799 |
-
retrieve_latents(self.vae.encode(img.unsqueeze(0).unsqueeze(2).to(device=self.vae.device, dtype=self.vae.dtype).contiguous()), generator, "argmax")
|
| 800 |
-
|
| 801 |
-
for img in image
|
| 802 |
-
]
|
| 803 |
-
image_latents = torch.cat(image_latents, dim=0).to(dtype)
|
| 804 |
-
image_latents = self._normalize_latents(image_latents, self.vae.latents_mean, self.vae.latents_std)
|
| 805 |
-
|
| 806 |
-
# Replace frame 0 with image latents (face appearance)
|
| 807 |
-
init_latents[:, :, 0:1, :, :] = image_latents
|
| 808 |
-
# Frame 0 is fully conditioned
|
| 809 |
-
conditioning_mask[:, :, 0] = 1.0
|
| 810 |
-
|
| 811 |
-
# If no video conditioning, repeat image for all frames (image-to-video mode)
|
| 812 |
-
if video is None:
|
| 813 |
-
init_latents = image_latents.repeat(1, 1, latent_num_frames, 1, 1)
|
| 814 |
-
|
| 815 |
-
# Generate noise
|
| 816 |
-
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 817 |
-
|
| 818 |
-
# Blend: conditioned regions keep init_latents, unconditioned regions get noise
|
| 819 |
-
latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask)
|
| 820 |
-
|
| 821 |
-
# Pack for transformer
|
| 822 |
-
conditioning_mask = self._pack_latents(
|
| 823 |
-
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
| 824 |
-
).squeeze(-1)
|
| 825 |
-
latents = self._pack_latents(
|
| 826 |
-
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
| 827 |
-
)
|
| 828 |
-
|
| 829 |
-
return latents, conditioning_mask
|
| 830 |
-
|
| 831 |
def prepare_audio_latents(
|
| 832 |
self,
|
| 833 |
batch_size: int = 1,
|
| 834 |
num_channels_latents: int = 8,
|
|
|
|
| 835 |
num_mel_bins: int = 64,
|
| 836 |
-
|
| 837 |
-
frame_rate: float = 25.0,
|
| 838 |
-
sampling_rate: int = 16000,
|
| 839 |
-
hop_length: int = 160,
|
| 840 |
dtype: Optional[torch.dtype] = None,
|
| 841 |
device: Optional[torch.device] = None,
|
| 842 |
generator: Optional[torch.Generator] = None,
|
| 843 |
-
audio_input: Optional[Union[str, torch.Tensor]] = None,
|
| 844 |
latents: Optional[torch.Tensor] = None,
|
|
|
|
| 845 |
) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]:
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
| 852 |
if latents is not None:
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 856 |
|
|
|
|
| 857 |
if audio_input is not None:
|
| 858 |
-
mel_spec = self._preprocess_audio(audio_input,
|
| 859 |
mel_spec = mel_spec.to(dtype=self.audio_vae.dtype)
|
| 860 |
-
init_latents = self.audio_vae.encode(mel_spec).latent_dist.sample(generator)
|
| 861 |
-
init_latents = init_latents.to(dtype=dtype)
|
| 862 |
-
|
| 863 |
-
latent_channels = init_latents.shape[1]
|
| 864 |
-
latent_freq = init_latents.shape[3]
|
| 865 |
-
init_latents_patched = self._patchify_audio_latents(init_latents)
|
| 866 |
-
init_latents_patched = self._normalize_audio_latents(
|
| 867 |
-
init_latents_patched, self.audio_vae.latents_mean, self.audio_vae.latents_std
|
| 868 |
-
)
|
| 869 |
-
init_latents = self._unpatchify_audio_latents(init_latents_patched, latent_channels, latent_freq)
|
| 870 |
|
| 871 |
-
|
| 872 |
-
if current_len < target_length:
|
| 873 |
-
padding = target_length - current_len
|
| 874 |
-
init_latents = torch.nn.functional.pad(init_latents, (0, 0, 0, padding))
|
| 875 |
-
elif current_len > target_length:
|
| 876 |
-
init_latents = init_latents[:, :, :target_length, :]
|
| 877 |
|
| 878 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 879 |
|
| 880 |
-
if
|
| 881 |
-
|
| 882 |
-
noise = noise.repeat(batch_size, 1, 1, 1)
|
| 883 |
|
| 884 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 885 |
|
| 886 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 887 |
|
| 888 |
-
shape = (batch_size, num_channels_latents, target_length, latent_mel_bins)
|
| 889 |
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 890 |
latents = self._pack_audio_latents(latents)
|
|
|
|
| 891 |
|
| 892 |
-
|
| 893 |
|
| 894 |
-
# ==================== Properties ====================
|
| 895 |
-
|
| 896 |
@property
|
| 897 |
def guidance_scale(self):
|
| 898 |
return self._guidance_scale
|
|
@@ -921,37 +943,25 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 921 |
def interrupt(self):
|
| 922 |
return self._interrupt
|
| 923 |
|
| 924 |
-
def _get_audio_duration(self, audio: Union[str, torch.Tensor], sample_rate: int) -> float:
|
| 925 |
-
if isinstance(audio, str):
|
| 926 |
-
info = torchaudio.info(audio)
|
| 927 |
-
return info.num_frames / info.sample_rate
|
| 928 |
-
else:
|
| 929 |
-
num_samples = audio.shape[-1]
|
| 930 |
-
return num_samples / sample_rate
|
| 931 |
-
|
| 932 |
-
# ==================== Main Call ====================
|
| 933 |
-
|
| 934 |
@torch.no_grad()
|
| 935 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 936 |
def __call__(
|
| 937 |
self,
|
| 938 |
-
|
| 939 |
-
video: Optional[Union[str, List[Image.Image], torch.Tensor]] = None,
|
| 940 |
-
video_conditioning_strength: float = 1.0,
|
| 941 |
-
video_conditioning_frame_idx: int = 1,
|
| 942 |
audio: Optional[Union[str, torch.Tensor]] = None,
|
| 943 |
prompt: Union[str, List[str]] = None,
|
| 944 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 945 |
height: int = 512,
|
| 946 |
width: int = 768,
|
| 947 |
-
num_frames: Optional[int] =
|
| 948 |
max_frames: int = 257,
|
| 949 |
frame_rate: float = 24.0,
|
| 950 |
num_inference_steps: int = 40,
|
| 951 |
-
timesteps: List[int] = None,
|
| 952 |
sigmas: Optional[List[float]] = None,
|
|
|
|
| 953 |
guidance_scale: float = 4.0,
|
| 954 |
guidance_rescale: float = 0.0,
|
|
|
|
| 955 |
num_videos_per_prompt: Optional[int] = 1,
|
| 956 |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 957 |
latents: Optional[torch.Tensor] = None,
|
|
@@ -969,50 +979,20 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 969 |
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 970 |
max_sequence_length: int = 1024,
|
| 971 |
):
|
| 972 |
-
r"""
|
| 973 |
-
Generate avatar video with audio and optional video conditioning.
|
| 974 |
-
|
| 975 |
-
Args:
|
| 976 |
-
image (`PipelineImageInput`):
|
| 977 |
-
The input image (face/appearance) to condition frame 0.
|
| 978 |
-
video (`str`, `List[PIL.Image]`, or `torch.Tensor`, *optional*):
|
| 979 |
-
Reference video for motion conditioning. Can be:
|
| 980 |
-
- Path to a video file
|
| 981 |
-
- List of PIL Images
|
| 982 |
-
- Tensor of shape (F, H, W, C) or (F, C, H, W)
|
| 983 |
-
video_conditioning_strength (`float`, *optional*, defaults to 1.0):
|
| 984 |
-
How strongly to condition on the reference video (0.0-1.0).
|
| 985 |
-
1.0 = fully conditioned, 0.0 = no conditioning.
|
| 986 |
-
video_conditioning_frame_idx (`int`, *optional*, defaults to 1):
|
| 987 |
-
Frame index where video conditioning starts (in pixel/frame space).
|
| 988 |
-
- 0: Video conditioning replaces all frames including frame 0
|
| 989 |
-
- 1: Frame 0 is image-conditioned, frames 1+ are video-conditioned (default for face-swap)
|
| 990 |
-
- N: Frames 0 to N-1 are image/noise, frames N+ are video-conditioned
|
| 991 |
-
audio (`str` or `torch.Tensor`, *optional*):
|
| 992 |
-
Audio for lip-sync. Can be path to audio/video file or waveform tensor.
|
| 993 |
-
prompt (`str` or `List[str]`, *optional*):
|
| 994 |
-
Text prompt. For face-swap, include "head_swap" trigger.
|
| 995 |
-
Examples:
|
| 996 |
-
|
| 997 |
-
Returns:
|
| 998 |
-
[`LTX2PipelineOutput`] or `tuple`: Generated video and audio.
|
| 999 |
-
"""
|
| 1000 |
-
|
| 1001 |
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 1002 |
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 1003 |
|
| 1004 |
-
#
|
| 1005 |
-
if num_frames is None:
|
| 1006 |
-
|
| 1007 |
-
|
| 1008 |
-
|
| 1009 |
-
|
| 1010 |
-
|
| 1011 |
-
|
| 1012 |
-
|
| 1013 |
-
|
| 1014 |
-
|
| 1015 |
-
|
| 1016 |
self.check_inputs(
|
| 1017 |
prompt=prompt,
|
| 1018 |
height=height,
|
|
@@ -1030,6 +1010,7 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 1030 |
self._interrupt = False
|
| 1031 |
self._current_timestep = None
|
| 1032 |
|
|
|
|
| 1033 |
if prompt is not None and isinstance(prompt, str):
|
| 1034 |
batch_size = 1
|
| 1035 |
elif prompt is not None and isinstance(prompt, list):
|
|
@@ -1037,9 +1018,16 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 1037 |
else:
|
| 1038 |
batch_size = prompt_embeds.shape[0]
|
| 1039 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1040 |
device = self._execution_device
|
| 1041 |
|
| 1042 |
-
#
|
| 1043 |
(
|
| 1044 |
prompt_embeds,
|
| 1045 |
prompt_attention_mask,
|
|
@@ -1066,48 +1054,67 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 1066 |
prompt_embeds, additive_attention_mask, additive_mask=True
|
| 1067 |
)
|
| 1068 |
|
| 1069 |
-
#
|
| 1070 |
-
|
| 1071 |
-
|
| 1072 |
-
|
| 1073 |
-
|
| 1074 |
-
|
| 1075 |
-
|
| 1076 |
-
|
| 1077 |
-
|
| 1078 |
-
|
| 1079 |
-
|
| 1080 |
-
|
| 1081 |
-
|
| 1082 |
-
|
| 1083 |
-
|
| 1084 |
-
|
| 1085 |
-
|
|
|
|
|
|
|
|
|
|
| 1086 |
|
| 1087 |
-
# Prepare latents with video conditioning
|
| 1088 |
num_channels_latents = self.transformer.config.in_channels
|
| 1089 |
-
latents, conditioning_mask = self.prepare_latents(
|
| 1090 |
-
|
| 1091 |
-
|
| 1092 |
-
|
| 1093 |
-
|
| 1094 |
-
|
| 1095 |
-
|
| 1096 |
-
|
| 1097 |
-
|
| 1098 |
-
|
| 1099 |
-
|
| 1100 |
-
|
| 1101 |
-
generator=generator,
|
| 1102 |
-
latents=latents,
|
| 1103 |
)
|
| 1104 |
if self.do_classifier_free_guidance:
|
| 1105 |
conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
|
| 1106 |
|
| 1107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1108 |
num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
|
| 1109 |
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
|
| 1110 |
-
|
| 1111 |
num_channels_latents_audio = (
|
| 1112 |
self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8
|
| 1113 |
)
|
|
@@ -1115,30 +1122,20 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 1115 |
audio_latents, audio_num_frames, clean_audio_latents = self.prepare_audio_latents(
|
| 1116 |
batch_size * num_videos_per_prompt,
|
| 1117 |
num_channels_latents=num_channels_latents_audio,
|
|
|
|
| 1118 |
num_mel_bins=num_mel_bins,
|
| 1119 |
-
|
| 1120 |
-
frame_rate=frame_rate,
|
| 1121 |
-
sampling_rate=self.audio_sampling_rate,
|
| 1122 |
-
hop_length=self.audio_hop_length,
|
| 1123 |
dtype=torch.float32,
|
| 1124 |
device=device,
|
| 1125 |
generator=generator,
|
| 1126 |
latents=audio_latents,
|
| 1127 |
audio_input=audio,
|
| 1128 |
)
|
|
|
|
|
|
|
| 1129 |
|
| 1130 |
-
|
| 1131 |
-
if
|
| 1132 |
-
packed_clean_audio_latents = self._pack_audio_latents(clean_audio_latents)
|
| 1133 |
-
|
| 1134 |
-
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
| 1135 |
-
latent_height = height // self.vae_spatial_compression_ratio
|
| 1136 |
-
latent_width = width // self.vae_spatial_compression_ratio
|
| 1137 |
-
video_sequence_length = latent_num_frames * latent_height * latent_width
|
| 1138 |
-
|
| 1139 |
-
if sigmas is None:
|
| 1140 |
-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
| 1141 |
-
|
| 1142 |
mu = calculate_shift(
|
| 1143 |
video_sequence_length,
|
| 1144 |
self.scheduler.config.get("base_image_seq_len", 1024),
|
|
@@ -1179,7 +1176,7 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 1179 |
audio_latents.shape[0], audio_num_frames, audio_latents.device
|
| 1180 |
)
|
| 1181 |
|
| 1182 |
-
# Denoising loop
|
| 1183 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 1184 |
for i, t in enumerate(timesteps):
|
| 1185 |
if self.interrupt:
|
|
@@ -1187,6 +1184,7 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 1187 |
|
| 1188 |
self._current_timestep = t
|
| 1189 |
|
|
|
|
| 1190 |
if packed_clean_audio_latents is not None:
|
| 1191 |
audio_latents_input = packed_clean_audio_latents.to(dtype=prompt_embeds.dtype)
|
| 1192 |
else:
|
|
@@ -1200,7 +1198,7 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 1200 |
audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype)
|
| 1201 |
|
| 1202 |
timestep = t.expand(latent_model_input.shape[0])
|
| 1203 |
-
video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
|
| 1204 |
|
| 1205 |
if packed_clean_audio_latents is not None:
|
| 1206 |
audio_timestep = torch.zeros_like(timestep)
|
|
@@ -1224,6 +1222,7 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 1224 |
audio_num_frames=audio_num_frames,
|
| 1225 |
video_coords=video_coords,
|
| 1226 |
audio_coords=audio_coords,
|
|
|
|
| 1227 |
attention_kwargs=attention_kwargs,
|
| 1228 |
return_dict=False,
|
| 1229 |
)
|
|
@@ -1249,32 +1248,17 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 1249 |
noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale
|
| 1250 |
)
|
| 1251 |
|
| 1252 |
-
|
| 1253 |
-
|
| 1254 |
-
|
| 1255 |
-
|
| 1256 |
-
|
| 1257 |
-
|
| 1258 |
-
|
| 1259 |
-
)
|
| 1260 |
-
latents = self._unpack_latents(
|
| 1261 |
-
latents,
|
| 1262 |
-
latent_num_frames,
|
| 1263 |
-
latent_height,
|
| 1264 |
-
latent_width,
|
| 1265 |
-
self.transformer_spatial_patch_size,
|
| 1266 |
-
self.transformer_temporal_patch_size,
|
| 1267 |
-
)
|
| 1268 |
-
|
| 1269 |
-
noise_pred_video = noise_pred_video[:, :, 1:]
|
| 1270 |
-
noise_latents = latents[:, :, 1:]
|
| 1271 |
-
pred_latents = self.scheduler.step(noise_pred_video, t, noise_latents, return_dict=False)[0]
|
| 1272 |
|
| 1273 |
-
latents =
|
| 1274 |
-
latents = self._pack_latents(
|
| 1275 |
-
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
| 1276 |
-
)
|
| 1277 |
|
|
|
|
| 1278 |
if packed_clean_audio_latents is None:
|
| 1279 |
audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0]
|
| 1280 |
|
|
@@ -1283,6 +1267,7 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 1283 |
for k in callback_on_step_end_tensor_inputs:
|
| 1284 |
callback_kwargs[k] = locals()[k]
|
| 1285 |
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
|
|
|
| 1286 |
latents = callback_outputs.pop("latents", latents)
|
| 1287 |
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1288 |
|
|
@@ -1292,7 +1277,6 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 1292 |
if XLA_AVAILABLE:
|
| 1293 |
xm.mark_step()
|
| 1294 |
|
| 1295 |
-
# Decode
|
| 1296 |
latents = self._unpack_latents(
|
| 1297 |
latents,
|
| 1298 |
latent_num_frames,
|
|
@@ -1305,25 +1289,22 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 1305 |
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
| 1306 |
)
|
| 1307 |
|
| 1308 |
-
|
| 1309 |
-
|
| 1310 |
-
|
| 1311 |
-
audio_patched = self._patchify_audio_latents(clean_audio_latents)
|
| 1312 |
-
audio_patched = self._denormalize_audio_latents(
|
| 1313 |
-
audio_patched, self.audio_vae.latents_mean, self.audio_vae.latents_std
|
| 1314 |
-
)
|
| 1315 |
-
audio_latents_for_decode = self._unpatchify_audio_latents(audio_patched, latent_channels, latent_freq)
|
| 1316 |
else:
|
| 1317 |
-
|
| 1318 |
-
|
| 1319 |
-
|
| 1320 |
-
|
| 1321 |
-
|
| 1322 |
-
|
|
|
|
|
|
|
| 1323 |
|
| 1324 |
if output_type == "latent":
|
| 1325 |
video = latents
|
| 1326 |
-
|
| 1327 |
else:
|
| 1328 |
latents = latents.to(prompt_embeds.dtype)
|
| 1329 |
|
|
@@ -1348,13 +1329,13 @@ class LTX2AvatarPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoa
|
|
| 1348 |
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
| 1349 |
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
| 1350 |
|
| 1351 |
-
|
| 1352 |
-
generated_mel_spectrograms = self.audio_vae.decode(
|
| 1353 |
-
|
| 1354 |
|
| 1355 |
self.maybe_free_model_hooks()
|
| 1356 |
|
| 1357 |
if not return_dict:
|
| 1358 |
-
return (video,
|
| 1359 |
|
| 1360 |
-
return LTX2PipelineOutput(frames=video, audio=
|
|
|
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
import copy
|
| 16 |
import inspect
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 19 |
|
| 20 |
import numpy as np
|
| 21 |
+
import PIL.Image
|
| 22 |
import torch
|
| 23 |
import torchaudio
|
| 24 |
import torchaudio.transforms as T
|
|
|
|
| 25 |
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
|
| 26 |
|
| 27 |
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
|
|
|
| 28 |
from diffusers.loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
|
| 29 |
+
from diffusers.models.autoencoders import (
|
| 30 |
+
AutoencoderKLLTX2Audio,
|
| 31 |
+
AutoencoderKLLTX2Video,
|
| 32 |
+
)
|
| 33 |
from diffusers.models.transformers import LTX2VideoTransformer3DModel
|
| 34 |
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 35 |
+
from diffusers.utils import (
|
| 36 |
+
is_torch_xla_available,
|
| 37 |
+
logging,
|
| 38 |
+
replace_example_docstring,
|
| 39 |
+
)
|
| 40 |
from diffusers.utils.torch_utils import randn_tensor
|
| 41 |
from diffusers.video_processor import VideoProcessor
|
| 42 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 43 |
+
|
| 44 |
from diffusers.pipelines.ltx2.connectors import LTX2TextConnectors
|
| 45 |
from diffusers.pipelines.ltx2.pipeline_output import LTX2PipelineOutput
|
| 46 |
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
|
|
|
|
| 48 |
|
| 49 |
if is_torch_xla_available():
|
| 50 |
import torch_xla.core.xla_model as xm
|
| 51 |
+
|
| 52 |
XLA_AVAILABLE = True
|
| 53 |
else:
|
| 54 |
XLA_AVAILABLE = False
|
| 55 |
|
| 56 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
| 57 |
|
| 58 |
EXAMPLE_DOC_STRING = """
|
| 59 |
Examples:
|
| 60 |
```py
|
| 61 |
>>> import torch
|
| 62 |
+
>>> from diffusers import LTX2ConditionPipeline
|
| 63 |
+
>>> from diffusers.pipelines.ltx2.export_utils import encode_video
|
| 64 |
+
>>> from diffusers.pipelines.ltx2.pipeline_ltx2_condition import LTX2VideoCondition
|
| 65 |
>>> from diffusers.utils import load_image
|
| 66 |
|
| 67 |
+
>>> pipe = LTX2ConditionPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16)
|
| 68 |
+
>>> pipe.enable_model_cpu_offload()
|
| 69 |
+
|
| 70 |
+
>>> first_image = load_image(
|
| 71 |
+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png"
|
| 72 |
... )
|
| 73 |
+
>>> last_image = load_image(
|
| 74 |
+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
... )
|
| 76 |
+
>>> first_cond = LTX2VideoCondition(frames=first_image, index=0, strength=1.0)
|
| 77 |
+
>>> last_cond = LTX2VideoCondition(frames=last_image, index=-1, strength=1.0)
|
| 78 |
+
>>> conditions = [first_cond, last_cond]
|
| 79 |
+
>>> prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings."
|
| 80 |
+
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted, static"
|
| 81 |
+
|
| 82 |
+
>>> frame_rate = 24.0
|
| 83 |
+
>>> video = pipe(
|
| 84 |
+
... conditions=conditions,
|
| 85 |
+
... prompt=prompt,
|
| 86 |
+
... negative_prompt=negative_prompt,
|
| 87 |
+
... width=768,
|
| 88 |
+
... height=512,
|
| 89 |
... num_frames=121,
|
| 90 |
+
... frame_rate=frame_rate,
|
| 91 |
+
... num_inference_steps=40,
|
| 92 |
+
... guidance_scale=4.0,
|
| 93 |
+
... output_type="np",
|
| 94 |
... return_dict=False,
|
| 95 |
... )
|
| 96 |
+
>>> video = (video * 255).round().astype("uint8")
|
| 97 |
+
>>> video = torch.from_numpy(video)
|
| 98 |
+
|
| 99 |
+
>>> encode_video(
|
| 100 |
+
... video[0],
|
| 101 |
+
... fps=frame_rate,
|
| 102 |
+
... audio=audio[0].float().cpu(),
|
| 103 |
+
... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000
|
| 104 |
+
... output_path="video.mp4",
|
| 105 |
+
... )
|
| 106 |
```
|
| 107 |
"""
|
| 108 |
|
| 109 |
|
| 110 |
+
@dataclass
|
| 111 |
+
class LTX2VideoCondition:
|
| 112 |
+
"""
|
| 113 |
+
Defines a single frame-conditioning item for LTX-2 Video - a single frame or a sequence of frames.
|
| 114 |
+
|
| 115 |
+
Attributes:
|
| 116 |
+
frames (`PIL.Image.Image` or `List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`):
|
| 117 |
+
The image (or video) to condition the video on. Accepts any type that can be handled by
|
| 118 |
+
VideoProcessor.preprocess_video.
|
| 119 |
+
index (`int`, defaults to `0`):
|
| 120 |
+
The index at which the image or video will conditionally affect the video generation.
|
| 121 |
+
strength (`float`, defaults to `1.0`):
|
| 122 |
+
The strength of the conditioning effect. A value of `1.0` means the conditioning effect is fully applied.
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
frames: Union[PIL.Image.Image, List[PIL.Image.Image], np.ndarray, torch.Tensor]
|
| 126 |
+
index: int = 0
|
| 127 |
+
strength: float = 1.0
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 131 |
def retrieve_latents(
|
| 132 |
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 133 |
):
|
|
|
|
| 141 |
raise AttributeError("Could not access latents of provided encoder_output")
|
| 142 |
|
| 143 |
|
| 144 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
| 145 |
def calculate_shift(
|
| 146 |
image_seq_len,
|
| 147 |
base_seq_len: int = 256,
|
|
|
|
| 155 |
return mu
|
| 156 |
|
| 157 |
|
| 158 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 159 |
def retrieve_timesteps(
|
| 160 |
scheduler,
|
| 161 |
num_inference_steps: Optional[int] = None,
|
|
|
|
| 164 |
sigmas: Optional[List[float]] = None,
|
| 165 |
**kwargs,
|
| 166 |
):
|
| 167 |
+
r"""
|
| 168 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 169 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
scheduler (`SchedulerMixin`):
|
| 173 |
+
The scheduler to get timesteps from.
|
| 174 |
+
num_inference_steps (`int`):
|
| 175 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 176 |
+
must be `None`.
|
| 177 |
+
device (`str` or `torch.device`, *optional*):
|
| 178 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 179 |
+
timesteps (`List[int]`, *optional*):
|
| 180 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 181 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 182 |
+
sigmas (`List[float]`, *optional*):
|
| 183 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 184 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 188 |
+
second element is the number of inference steps.
|
| 189 |
+
"""
|
| 190 |
if timesteps is not None and sigmas is not None:
|
| 191 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 192 |
if timesteps is not None:
|
| 193 |
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 194 |
if not accepts_timesteps:
|
| 195 |
raise ValueError(
|
| 196 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 197 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 198 |
)
|
| 199 |
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 200 |
timesteps = scheduler.timesteps
|
|
|
|
| 203 |
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 204 |
if not accept_sigmas:
|
| 205 |
raise ValueError(
|
| 206 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 207 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 208 |
)
|
| 209 |
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 210 |
timesteps = scheduler.timesteps
|
|
|
|
| 215 |
return timesteps, num_inference_steps
|
| 216 |
|
| 217 |
|
| 218 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
| 219 |
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
| 220 |
+
r"""
|
| 221 |
+
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
| 222 |
+
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 223 |
+
Flawed](https://huggingface.co/papers/2305.08891).
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
noise_cfg (`torch.Tensor`):
|
| 227 |
+
The predicted noise tensor for the guided diffusion process.
|
| 228 |
+
noise_pred_text (`torch.Tensor`):
|
| 229 |
+
The predicted noise tensor for the text-guided diffusion process.
|
| 230 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
| 231 |
+
A rescale factor applied to the noise predictions.
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
| 235 |
+
"""
|
| 236 |
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
| 237 |
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
| 238 |
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
|
|
|
| 240 |
return noise_cfg
|
| 241 |
|
| 242 |
|
| 243 |
+
class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
|
| 244 |
r"""
|
| 245 |
+
Pipeline for video generation which allows image conditions to be inserted at arbitary parts of the video.
|
| 246 |
|
| 247 |
+
Reference: https://github.com/Lightricks/LTX-Video
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
+
TODO
|
|
|
|
| 250 |
"""
|
| 251 |
|
| 252 |
model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder"
|
|
|
|
| 283 |
self.vae_temporal_compression_ratio = (
|
| 284 |
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
|
| 285 |
)
|
| 286 |
+
# TODO: check whether the MEL compression ratio logic here is corrct
|
| 287 |
self.audio_vae_mel_compression_ratio = (
|
| 288 |
self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4
|
| 289 |
)
|
|
|
|
| 309 |
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024
|
| 310 |
)
|
| 311 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
@staticmethod
|
| 313 |
+
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds
|
| 314 |
def _pack_text_embeds(
|
| 315 |
text_hidden_states: torch.Tensor,
|
| 316 |
sequence_lengths: torch.Tensor,
|
|
|
|
| 348 |
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
|
| 349 |
return normalized_hidden_states
|
| 350 |
|
| 351 |
+
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds
|
| 352 |
def _get_gemma_prompt_embeds(
|
| 353 |
self,
|
| 354 |
prompt: Union[str, List[str]],
|
|
|
|
| 408 |
|
| 409 |
return prompt_embeds, prompt_attention_mask
|
| 410 |
|
| 411 |
+
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.encode_prompt
|
| 412 |
def encode_prompt(
|
| 413 |
self,
|
| 414 |
prompt: Union[str, List[str]],
|
|
|
|
| 448 |
|
| 449 |
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 450 |
raise TypeError(
|
| 451 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 452 |
+
f" {type(prompt)}."
|
| 453 |
)
|
| 454 |
elif batch_size != len(negative_prompt):
|
| 455 |
raise ValueError(
|
| 456 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 457 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 458 |
+
" the batch size of `prompt`."
|
| 459 |
)
|
| 460 |
|
| 461 |
negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
|
|
|
|
| 487 |
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 488 |
):
|
| 489 |
raise ValueError(
|
| 490 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 491 |
)
|
| 492 |
|
| 493 |
if prompt is not None and prompt_embeds is not None:
|
| 494 |
+
raise ValueError(
|
| 495 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 496 |
+
" only forward one of the two."
|
| 497 |
+
)
|
| 498 |
elif prompt is None and prompt_embeds is None:
|
| 499 |
+
raise ValueError(
|
| 500 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 501 |
+
)
|
| 502 |
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 503 |
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 504 |
|
|
|
|
| 508 |
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
| 509 |
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
| 510 |
|
| 511 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 512 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 513 |
+
raise ValueError(
|
| 514 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 515 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 516 |
+
f" {negative_prompt_embeds.shape}."
|
| 517 |
+
)
|
| 518 |
+
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
|
| 519 |
+
raise ValueError(
|
| 520 |
+
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
|
| 521 |
+
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
| 522 |
+
f" {negative_prompt_attention_mask.shape}."
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
@staticmethod
|
| 526 |
+
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents
|
| 527 |
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
| 528 |
batch_size, num_channels, num_frames, height, width = latents.shape
|
| 529 |
post_patch_num_frames = num_frames // patch_size_t
|
|
|
|
| 543 |
return latents
|
| 544 |
|
| 545 |
@staticmethod
|
| 546 |
+
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents
|
| 547 |
def _unpack_latents(
|
| 548 |
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
|
| 549 |
) -> torch.Tensor:
|
|
|
|
| 562 |
return latents
|
| 563 |
|
| 564 |
@staticmethod
|
| 565 |
+
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents
|
| 566 |
def _denormalize_latents(
|
| 567 |
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
| 568 |
) -> torch.Tensor:
|
|
|
|
| 571 |
latents = latents * latents_std / scaling_factor + latents_mean
|
| 572 |
return latents
|
| 573 |
|
|
|
|
|
|
|
| 574 |
@staticmethod
|
| 575 |
+
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._create_noised_state
|
| 576 |
+
def _create_noised_state(
|
| 577 |
+
latents: torch.Tensor, noise_scale: Union[float, torch.Tensor], generator: Optional[torch.Generator] = None
|
| 578 |
+
):
|
| 579 |
+
noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype)
|
| 580 |
+
noised_latents = noise_scale * noise + (1 - noise_scale) * latents
|
| 581 |
+
return noised_latents
|
| 582 |
+
|
| 583 |
+
@staticmethod
|
| 584 |
+
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_audio_latents
|
| 585 |
def _pack_audio_latents(
|
| 586 |
latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None
|
| 587 |
) -> torch.Tensor:
|
|
|
|
| 598 |
return latents
|
| 599 |
|
| 600 |
@staticmethod
|
| 601 |
+
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_audio_latents
|
| 602 |
def _unpack_audio_latents(
|
| 603 |
latents: torch.Tensor,
|
| 604 |
latent_length: int,
|
|
|
|
| 615 |
return latents
|
| 616 |
|
| 617 |
@staticmethod
|
| 618 |
+
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._normalize_audio_latents
|
| 619 |
+
def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
|
| 620 |
latents_mean = latents_mean.to(latents.device, latents.dtype)
|
| 621 |
latents_std = latents_std.to(latents.device, latents.dtype)
|
| 622 |
+
return (latents - latents_mean) / latents_std
|
| 623 |
|
| 624 |
@staticmethod
|
| 625 |
+
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_audio_latents
|
| 626 |
+
def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
|
| 627 |
latents_mean = latents_mean.to(latents.device, latents.dtype)
|
| 628 |
latents_std = latents_std.to(latents.device, latents.dtype)
|
| 629 |
+
return (latents * latents_std) + latents_mean
|
| 630 |
|
| 631 |
+
# Copied from diffusers.pipelines.ltx.pipeline_ltx_condition.LTXConditionPipeline.trim_conditioning_sequence
|
| 632 |
+
def trim_conditioning_sequence(self, start_frame: int, sequence_num_frames: int, target_num_frames: int) -> int:
|
| 633 |
+
scale_factor = self.vae_temporal_compression_ratio
|
| 634 |
+
num_frames = min(sequence_num_frames, target_num_frames - start_frame)
|
| 635 |
+
num_frames = (num_frames - 1) // scale_factor * scale_factor + 1
|
| 636 |
+
return num_frames
|
| 637 |
|
| 638 |
+
def preprocess_conditions(
|
| 639 |
+
self,
|
| 640 |
+
conditions: Optional[Union[LTX2VideoCondition, List[LTX2VideoCondition]]] = None,
|
| 641 |
+
height: int = 512,
|
| 642 |
+
width: int = 768,
|
| 643 |
+
num_frames: int = 121,
|
| 644 |
+
device: Optional[torch.device] = None,
|
| 645 |
+
index_type: str = "latent",
|
| 646 |
+
) -> Tuple[List[torch.Tensor], List[float], List[int]]:
|
| 647 |
+
conditioning_frames, conditioning_strengths, conditioning_indices = [], [], []
|
| 648 |
+
|
| 649 |
+
if conditions is None:
|
| 650 |
+
conditions = []
|
| 651 |
+
if isinstance(conditions, LTX2VideoCondition):
|
| 652 |
+
conditions = [conditions]
|
| 653 |
+
|
| 654 |
+
frame_scale_factor = self.vae_temporal_compression_ratio
|
| 655 |
+
latent_num_frames = (num_frames - 1) // frame_scale_factor + 1
|
| 656 |
+
for i, condition in enumerate(conditions):
|
| 657 |
+
if isinstance(condition.frames, PIL.Image.Image):
|
| 658 |
+
video_like_cond = [condition.frames]
|
| 659 |
+
elif isinstance(condition.frames, np.ndarray) and condition.frames.ndim == 3:
|
| 660 |
+
video_like_cond = np.expand_dims(condition.frames, axis=0)
|
| 661 |
+
elif isinstance(condition.frames, torch.Tensor) and condition.frames.ndim == 3:
|
| 662 |
+
video_like_cond = condition.frames.unsqueeze(0)
|
| 663 |
+
else:
|
| 664 |
+
video_like_cond = condition.frames
|
| 665 |
+
|
| 666 |
+
condition_pixels = self.video_processor.preprocess_video(video_like_cond, height, width)
|
| 667 |
+
|
| 668 |
+
latent_start_idx = condition.index
|
| 669 |
+
if latent_start_idx < 0:
|
| 670 |
+
latent_start_idx = latent_start_idx % latent_num_frames
|
| 671 |
+
if latent_start_idx >= latent_num_frames:
|
| 672 |
+
logger.warning(
|
| 673 |
+
f"The starting latent index {latent_start_idx} of condition {i} is too big for the specified number"
|
| 674 |
+
f" of latent frames {latent_num_frames}. This condition will be skipped."
|
| 675 |
+
)
|
| 676 |
+
continue
|
| 677 |
+
|
| 678 |
+
cond_num_frames = condition_pixels.size(2)
|
| 679 |
+
start_idx = max((latent_start_idx - 1) * frame_scale_factor + 1, 0)
|
| 680 |
+
truncated_cond_frames = self.trim_conditioning_sequence(start_idx, cond_num_frames, num_frames)
|
| 681 |
+
condition_pixels = condition_pixels[:, :, :truncated_cond_frames]
|
| 682 |
+
|
| 683 |
+
conditioning_frames.append(condition_pixels.to(dtype=self.vae.dtype, device=device))
|
| 684 |
+
conditioning_strengths.append(condition.strength)
|
| 685 |
+
conditioning_indices.append(latent_start_idx)
|
| 686 |
+
|
| 687 |
+
return conditioning_frames, conditioning_strengths, conditioning_indices
|
| 688 |
+
|
| 689 |
+
def apply_visual_conditioning(
|
| 690 |
+
self,
|
| 691 |
+
latents: torch.Tensor,
|
| 692 |
+
conditioning_mask: torch.Tensor,
|
| 693 |
+
condition_latents: List[torch.Tensor],
|
| 694 |
+
condition_strengths: List[float],
|
| 695 |
+
condition_indices: List[int],
|
| 696 |
+
latent_height: int,
|
| 697 |
+
latent_width: int,
|
| 698 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 699 |
+
clean_latents = torch.zeros_like(latents)
|
| 700 |
+
for cond, strength, latent_idx in zip(condition_latents, condition_strengths, condition_indices):
|
| 701 |
+
num_cond_tokens = cond.size(1)
|
| 702 |
+
start_token_idx = latent_idx * latent_height * latent_width
|
| 703 |
+
end_token_idx = start_token_idx + num_cond_tokens
|
| 704 |
+
|
| 705 |
+
latents[:, start_token_idx:end_token_idx] = cond
|
| 706 |
+
conditioning_mask[:, start_token_idx:end_token_idx] = strength
|
| 707 |
+
clean_latents[:, start_token_idx:end_token_idx] = cond
|
| 708 |
+
|
| 709 |
+
return latents, conditioning_mask, clean_latents
|
| 710 |
+
|
| 711 |
+
def prepare_latents(
|
| 712 |
+
self,
|
| 713 |
+
conditions: Optional[Union[LTX2VideoCondition, List[LTX2VideoCondition]]] = None,
|
| 714 |
+
batch_size: int = 1,
|
| 715 |
+
num_channels_latents: int = 128,
|
| 716 |
+
height: int = 512,
|
| 717 |
+
width: int = 768,
|
| 718 |
+
num_frames: int = 121,
|
| 719 |
+
noise_scale: float = 1.0,
|
| 720 |
+
dtype: Optional[torch.dtype] = None,
|
| 721 |
+
device: Optional[torch.device] = None,
|
| 722 |
+
generator: Optional[torch.Generator] = None,
|
| 723 |
+
latents: Optional[torch.Tensor] = None,
|
| 724 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 725 |
+
latent_height = height // self.vae_spatial_compression_ratio
|
| 726 |
+
latent_width = width // self.vae_spatial_compression_ratio
|
| 727 |
+
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
| 728 |
+
|
| 729 |
+
shape = (batch_size, num_channels_latents, latent_num_frames, latent_height, latent_width)
|
| 730 |
+
mask_shape = (batch_size, 1, latent_num_frames, latent_height, latent_width)
|
| 731 |
+
|
| 732 |
+
if latents is not None:
|
| 733 |
+
if latents.ndim == 5:
|
| 734 |
+
latents = self._normalize_latents(
|
| 735 |
+
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
| 736 |
+
)
|
| 737 |
+
else:
|
| 738 |
+
latents = torch.zeros(shape, device=device, dtype=dtype)
|
| 739 |
+
|
| 740 |
+
conditioning_mask = latents.new_zeros(mask_shape)
|
| 741 |
+
if latents.ndim == 5:
|
| 742 |
+
latents = self._pack_latents(
|
| 743 |
+
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
| 744 |
+
)
|
| 745 |
+
conditioning_mask = self._pack_latents(
|
| 746 |
+
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape[:2]:
|
| 750 |
+
raise ValueError(
|
| 751 |
+
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape[:2] + (num_channels_latents,)}."
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
if isinstance(generator, list):
|
| 755 |
+
logger.warning(
|
| 756 |
+
f"{self.__class__.__name__} does not support using a list of generators. The first generator in the"
|
| 757 |
+
f" list will be used for all (pseudo-)random operations."
|
| 758 |
+
)
|
| 759 |
+
generator = generator[0]
|
| 760 |
+
|
| 761 |
+
condition_frames, condition_strengths, condition_indices = self.preprocess_conditions(
|
| 762 |
+
conditions, height, width, num_frames, device=device
|
| 763 |
+
)
|
| 764 |
+
condition_latents = []
|
| 765 |
+
for condition_tensor in condition_frames:
|
| 766 |
+
condition_latent = retrieve_latents(
|
| 767 |
+
self.vae.encode(condition_tensor), generator=generator, sample_mode="argmax"
|
| 768 |
+
)
|
| 769 |
+
condition_latent = self._normalize_latents(
|
| 770 |
+
condition_latent, self.vae.latents_mean, self.vae.latents_std
|
| 771 |
+
).to(device=device, dtype=dtype)
|
| 772 |
+
condition_latent = self._pack_latents(
|
| 773 |
+
condition_latent, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
| 774 |
+
)
|
| 775 |
+
condition_latents.append(condition_latent)
|
| 776 |
+
|
| 777 |
+
latents, conditioning_mask, clean_latents = self.apply_visual_conditioning(
|
| 778 |
+
latents,
|
| 779 |
+
conditioning_mask,
|
| 780 |
+
condition_latents,
|
| 781 |
+
condition_strengths,
|
| 782 |
+
condition_indices,
|
| 783 |
+
latent_height=latent_height,
|
| 784 |
+
latent_width=latent_width,
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype)
|
| 788 |
+
scaled_mask = (1.0 - conditioning_mask) * noise_scale
|
| 789 |
+
latents = noise * scaled_mask + latents * (1 - scaled_mask)
|
| 790 |
+
|
| 791 |
+
return latents, conditioning_mask, clean_latents
|
| 792 |
+
|
| 793 |
+
# -------------------- Audio conditioning additions (minimal) --------------------
|
| 794 |
+
|
| 795 |
+
def _get_audio_duration(self, audio: Union[str, torch.Tensor], sample_rate: int) -> float:
|
| 796 |
+
if isinstance(audio, str):
|
| 797 |
+
info = torchaudio.info(audio)
|
| 798 |
+
return info.num_frames / info.sample_rate
|
| 799 |
+
num_samples = audio.shape[-1]
|
| 800 |
+
return num_samples / sample_rate
|
| 801 |
|
| 802 |
def _preprocess_audio(self, audio: Union[str, torch.Tensor], target_sample_rate: int) -> torch.Tensor:
|
|
|
|
| 803 |
if isinstance(audio, str):
|
| 804 |
waveform, sr = torchaudio.load(audio)
|
| 805 |
else:
|
|
|
|
| 809 |
if sr != target_sample_rate:
|
| 810 |
waveform = torchaudio.functional.resample(waveform, sr, target_sample_rate)
|
| 811 |
|
| 812 |
+
if waveform.ndim == 1:
|
| 813 |
+
waveform = waveform.unsqueeze(0)
|
| 814 |
if waveform.shape[0] == 1:
|
| 815 |
waveform = waveform.repeat(2, 1)
|
| 816 |
elif waveform.shape[0] > 2:
|
| 817 |
waveform = waveform[:2, :]
|
| 818 |
|
| 819 |
+
waveform = waveform.unsqueeze(0) # [B, 2, samples]
|
| 820 |
|
| 821 |
n_fft = 1024
|
| 822 |
mel_transform = T.MelSpectrogram(
|
|
|
|
| 835 |
norm="slaney",
|
| 836 |
)
|
| 837 |
|
| 838 |
+
mel_spec = mel_transform(waveform) # [B, 2, mel_bins, T]
|
| 839 |
mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
|
| 840 |
+
mel_spec = mel_spec.permute(0, 1, 3, 2).contiguous() # [B, 2, T, mel_bins]
|
|
|
|
| 841 |
return mel_spec
|
| 842 |
|
| 843 |
+
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.prepare_audio_latents (modified minimally)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 844 |
def prepare_audio_latents(
|
| 845 |
self,
|
| 846 |
batch_size: int = 1,
|
| 847 |
num_channels_latents: int = 8,
|
| 848 |
+
audio_latent_length: int = 1, # 1 is just a dummy value
|
| 849 |
num_mel_bins: int = 64,
|
| 850 |
+
noise_scale: float = 0.0,
|
|
|
|
|
|
|
|
|
|
| 851 |
dtype: Optional[torch.dtype] = None,
|
| 852 |
device: Optional[torch.device] = None,
|
| 853 |
generator: Optional[torch.Generator] = None,
|
|
|
|
| 854 |
latents: Optional[torch.Tensor] = None,
|
| 855 |
+
audio_input: Optional[Union[str, torch.Tensor]] = None,
|
| 856 |
) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]:
|
| 857 |
+
"""
|
| 858 |
+
Returns:
|
| 859 |
+
- packed noisy audio latents [B, S, D]
|
| 860 |
+
- audio_latent_length
|
| 861 |
+
- packed clean audio latents [B, S, D] if audio_input is provided, else None
|
| 862 |
+
"""
|
| 863 |
if latents is not None:
|
| 864 |
+
if latents.ndim == 4:
|
| 865 |
+
latents = self._pack_audio_latents(latents)
|
| 866 |
+
if latents.ndim != 3:
|
| 867 |
+
raise ValueError(
|
| 868 |
+
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]."
|
| 869 |
+
)
|
| 870 |
+
latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std)
|
| 871 |
+
latents = self._create_noised_state(latents, noise_scale, generator)
|
| 872 |
+
return latents.to(device=device, dtype=dtype), audio_latent_length, None
|
| 873 |
|
| 874 |
+
# If audio input is provided, encode to clean latents and return both clean and a dummy noisy tensor
|
| 875 |
if audio_input is not None:
|
| 876 |
+
mel_spec = self._preprocess_audio(audio_input, self.audio_sampling_rate).to(device=device)
|
| 877 |
mel_spec = mel_spec.to(dtype=self.audio_vae.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 878 |
|
| 879 |
+
clean_4d = self.audio_vae.encode(mel_spec).latent_dist.sample(generator) # [B, C, L, F]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 880 |
|
| 881 |
+
# pad/trim to audio_latent_length
|
| 882 |
+
cur_len = clean_4d.shape[2]
|
| 883 |
+
if cur_len < audio_latent_length:
|
| 884 |
+
pad = audio_latent_length - cur_len
|
| 885 |
+
clean_4d = torch.nn.functional.pad(clean_4d, (0, 0, 0, pad))
|
| 886 |
+
elif cur_len > audio_latent_length:
|
| 887 |
+
clean_4d = clean_4d[:, :, :audio_latent_length, :]
|
| 888 |
|
| 889 |
+
if clean_4d.shape[0] != batch_size:
|
| 890 |
+
clean_4d = clean_4d.repeat(batch_size, 1, 1, 1)
|
|
|
|
| 891 |
|
| 892 |
+
clean_packed = self._pack_audio_latents(clean_4d) # [B, S, D]
|
| 893 |
+
clean_packed = clean_packed.to(dtype=dtype)
|
| 894 |
+
clean_packed = self._normalize_audio_latents(
|
| 895 |
+
clean_packed, self.audio_vae.latents_mean, self.audio_vae.latents_std
|
| 896 |
+
)
|
| 897 |
+
|
| 898 |
+
noisy = randn_tensor(clean_packed.shape, generator=generator, device=device, dtype=dtype)
|
| 899 |
+
noisy = self._create_noised_state(noisy, noise_scale, generator=None) # keep same scaling semantics
|
| 900 |
|
| 901 |
+
return noisy, audio_latent_length, clean_packed
|
| 902 |
+
|
| 903 |
+
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
|
| 904 |
+
shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins)
|
| 905 |
+
|
| 906 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 907 |
+
raise ValueError(
|
| 908 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 909 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 910 |
+
)
|
| 911 |
|
|
|
|
| 912 |
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 913 |
latents = self._pack_audio_latents(latents)
|
| 914 |
+
return latents, audio_latent_length, None
|
| 915 |
|
| 916 |
+
# ------------------------------------------------------------------
|
| 917 |
|
|
|
|
|
|
|
| 918 |
@property
|
| 919 |
def guidance_scale(self):
|
| 920 |
return self._guidance_scale
|
|
|
|
| 943 |
def interrupt(self):
|
| 944 |
return self._interrupt
|
| 945 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 946 |
@torch.no_grad()
|
| 947 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 948 |
def __call__(
|
| 949 |
self,
|
| 950 |
+
conditions: Union[LTX2VideoCondition, List[LTX2VideoCondition]] = None,
|
|
|
|
|
|
|
|
|
|
| 951 |
audio: Optional[Union[str, torch.Tensor]] = None,
|
| 952 |
prompt: Union[str, List[str]] = None,
|
| 953 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 954 |
height: int = 512,
|
| 955 |
width: int = 768,
|
| 956 |
+
num_frames: Optional[int] = 121,
|
| 957 |
max_frames: int = 257,
|
| 958 |
frame_rate: float = 24.0,
|
| 959 |
num_inference_steps: int = 40,
|
|
|
|
| 960 |
sigmas: Optional[List[float]] = None,
|
| 961 |
+
timesteps: List[int] = None,
|
| 962 |
guidance_scale: float = 4.0,
|
| 963 |
guidance_rescale: float = 0.0,
|
| 964 |
+
noise_scale: Optional[float] = None,
|
| 965 |
num_videos_per_prompt: Optional[int] = 1,
|
| 966 |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 967 |
latents: Optional[torch.Tensor] = None,
|
|
|
|
| 979 |
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 980 |
max_sequence_length: int = 1024,
|
| 981 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 982 |
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 983 |
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 984 |
|
| 985 |
+
# Optional: derive num_frames from audio if user passes num_frames=None
|
| 986 |
+
if num_frames is None and audio is not None:
|
| 987 |
+
audio_duration = self._get_audio_duration(audio, self.audio_sampling_rate)
|
| 988 |
+
calculated_frames = int(audio_duration * frame_rate) + 1
|
| 989 |
+
num_frames = min(calculated_frames, max_frames)
|
| 990 |
+
num_frames = (
|
| 991 |
+
(num_frames - 1) // self.vae_temporal_compression_ratio
|
| 992 |
+
) * self.vae_temporal_compression_ratio + 1
|
| 993 |
+
num_frames = max(num_frames, 9)
|
| 994 |
+
|
| 995 |
+
# 1. Check inputs. Raise error if not correct
|
|
|
|
| 996 |
self.check_inputs(
|
| 997 |
prompt=prompt,
|
| 998 |
height=height,
|
|
|
|
| 1010 |
self._interrupt = False
|
| 1011 |
self._current_timestep = None
|
| 1012 |
|
| 1013 |
+
# 2. Define call parameters
|
| 1014 |
if prompt is not None and isinstance(prompt, str):
|
| 1015 |
batch_size = 1
|
| 1016 |
elif prompt is not None and isinstance(prompt, list):
|
|
|
|
| 1018 |
else:
|
| 1019 |
batch_size = prompt_embeds.shape[0]
|
| 1020 |
|
| 1021 |
+
if conditions is not None and not isinstance(conditions, list):
|
| 1022 |
+
conditions = [conditions]
|
| 1023 |
+
|
| 1024 |
+
# Infer noise scale: first (largest) sigma value if using custom sigmas, else 1.0
|
| 1025 |
+
if noise_scale is None:
|
| 1026 |
+
noise_scale = sigmas[0] if sigmas is not None else 1.0
|
| 1027 |
+
|
| 1028 |
device = self._execution_device
|
| 1029 |
|
| 1030 |
+
# 3. Prepare text embeddings
|
| 1031 |
(
|
| 1032 |
prompt_embeds,
|
| 1033 |
prompt_attention_mask,
|
|
|
|
| 1054 |
prompt_embeds, additive_attention_mask, additive_mask=True
|
| 1055 |
)
|
| 1056 |
|
| 1057 |
+
# 4. Prepare latent variables
|
| 1058 |
+
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
| 1059 |
+
latent_height = height // self.vae_spatial_compression_ratio
|
| 1060 |
+
latent_width = width // self.vae_spatial_compression_ratio
|
| 1061 |
+
if latents is not None:
|
| 1062 |
+
if latents.ndim == 5:
|
| 1063 |
+
logger.info(
|
| 1064 |
+
"Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred."
|
| 1065 |
+
)
|
| 1066 |
+
_, _, latent_num_frames, latent_height, latent_width = latents.shape
|
| 1067 |
+
elif latents.ndim == 3:
|
| 1068 |
+
logger.warning(
|
| 1069 |
+
f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be"
|
| 1070 |
+
f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct."
|
| 1071 |
+
)
|
| 1072 |
+
else:
|
| 1073 |
+
raise ValueError(
|
| 1074 |
+
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]."
|
| 1075 |
+
)
|
| 1076 |
+
video_sequence_length = latent_num_frames * latent_height * latent_width
|
| 1077 |
|
|
|
|
| 1078 |
num_channels_latents = self.transformer.config.in_channels
|
| 1079 |
+
latents, conditioning_mask, clean_latents = self.prepare_latents(
|
| 1080 |
+
conditions,
|
| 1081 |
+
batch_size * num_videos_per_prompt,
|
| 1082 |
+
num_channels_latents,
|
| 1083 |
+
height,
|
| 1084 |
+
width,
|
| 1085 |
+
num_frames,
|
| 1086 |
+
noise_scale,
|
| 1087 |
+
torch.float32,
|
| 1088 |
+
device,
|
| 1089 |
+
generator,
|
| 1090 |
+
latents,
|
|
|
|
|
|
|
| 1091 |
)
|
| 1092 |
if self.do_classifier_free_guidance:
|
| 1093 |
conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
|
| 1094 |
|
| 1095 |
+
duration_s = num_frames / frame_rate
|
| 1096 |
+
audio_latents_per_second = (
|
| 1097 |
+
self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio)
|
| 1098 |
+
)
|
| 1099 |
+
audio_num_frames = round(duration_s * audio_latents_per_second)
|
| 1100 |
+
if audio_latents is not None:
|
| 1101 |
+
if audio_latents.ndim == 4:
|
| 1102 |
+
logger.info(
|
| 1103 |
+
"Got audio_latents of shape [batch_size, num_channels, audio_length, mel_bins], `audio_num_frames` will be inferred."
|
| 1104 |
+
)
|
| 1105 |
+
_, _, audio_num_frames, _ = audio_latents.shape
|
| 1106 |
+
elif audio_latents.ndim == 3:
|
| 1107 |
+
logger.warning(
|
| 1108 |
+
f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims"
|
| 1109 |
+
f" cannot be inferred. Make sure the supplied `num_frames` and `frame_rate` are correct."
|
| 1110 |
+
)
|
| 1111 |
+
else:
|
| 1112 |
+
raise ValueError(
|
| 1113 |
+
f"Provided `audio_latents` tensor has shape {audio_latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, num_channels, audio_length, mel_bins]."
|
| 1114 |
+
)
|
| 1115 |
+
|
| 1116 |
num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
|
| 1117 |
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
|
|
|
|
| 1118 |
num_channels_latents_audio = (
|
| 1119 |
self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8
|
| 1120 |
)
|
|
|
|
| 1122 |
audio_latents, audio_num_frames, clean_audio_latents = self.prepare_audio_latents(
|
| 1123 |
batch_size * num_videos_per_prompt,
|
| 1124 |
num_channels_latents=num_channels_latents_audio,
|
| 1125 |
+
audio_latent_length=audio_num_frames,
|
| 1126 |
num_mel_bins=num_mel_bins,
|
| 1127 |
+
noise_scale=noise_scale,
|
|
|
|
|
|
|
|
|
|
| 1128 |
dtype=torch.float32,
|
| 1129 |
device=device,
|
| 1130 |
generator=generator,
|
| 1131 |
latents=audio_latents,
|
| 1132 |
audio_input=audio,
|
| 1133 |
)
|
| 1134 |
+
# clean_audio_latents is packed [B,S,D] if present
|
| 1135 |
+
packed_clean_audio_latents = clean_audio_latents
|
| 1136 |
|
| 1137 |
+
# 5. Prepare timesteps
|
| 1138 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1139 |
mu = calculate_shift(
|
| 1140 |
video_sequence_length,
|
| 1141 |
self.scheduler.config.get("base_image_seq_len", 1024),
|
|
|
|
| 1176 |
audio_latents.shape[0], audio_num_frames, audio_latents.device
|
| 1177 |
)
|
| 1178 |
|
| 1179 |
+
# 7. Denoising loop
|
| 1180 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 1181 |
for i, t in enumerate(timesteps):
|
| 1182 |
if self.interrupt:
|
|
|
|
| 1184 |
|
| 1185 |
self._current_timestep = t
|
| 1186 |
|
| 1187 |
+
# If audio conditioning provided, use clean audio latents directly (packed), and timestep=0
|
| 1188 |
if packed_clean_audio_latents is not None:
|
| 1189 |
audio_latents_input = packed_clean_audio_latents.to(dtype=prompt_embeds.dtype)
|
| 1190 |
else:
|
|
|
|
| 1198 |
audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype)
|
| 1199 |
|
| 1200 |
timestep = t.expand(latent_model_input.shape[0])
|
| 1201 |
+
video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask.squeeze(-1))
|
| 1202 |
|
| 1203 |
if packed_clean_audio_latents is not None:
|
| 1204 |
audio_timestep = torch.zeros_like(timestep)
|
|
|
|
| 1222 |
audio_num_frames=audio_num_frames,
|
| 1223 |
video_coords=video_coords,
|
| 1224 |
audio_coords=audio_coords,
|
| 1225 |
+
# rope_interpolation_scale=rope_interpolation_scale,
|
| 1226 |
attention_kwargs=attention_kwargs,
|
| 1227 |
return_dict=False,
|
| 1228 |
)
|
|
|
|
| 1248 |
noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale
|
| 1249 |
)
|
| 1250 |
|
| 1251 |
+
bsz = noise_pred_video.size(0)
|
| 1252 |
+
sigma = self.scheduler.sigmas[i]
|
| 1253 |
+
denoised_sample = latents - noise_pred_video * sigma
|
| 1254 |
+
denoised_sample_cond = (
|
| 1255 |
+
denoised_sample * (1 - conditioning_mask[:bsz]) + clean_latents.float() * conditioning_mask[:bsz]
|
| 1256 |
+
).to(noise_pred_video.dtype)
|
| 1257 |
+
denoised_latents_cond = ((latents - denoised_sample_cond) / sigma).to(noise_pred_video.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1258 |
|
| 1259 |
+
latents = self.scheduler.step(denoised_latents_cond, t, latents, return_dict=False)[0]
|
|
|
|
|
|
|
|
|
|
| 1260 |
|
| 1261 |
+
# Only step audio latents if not conditioning on clean audio
|
| 1262 |
if packed_clean_audio_latents is None:
|
| 1263 |
audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0]
|
| 1264 |
|
|
|
|
| 1267 |
for k in callback_on_step_end_tensor_inputs:
|
| 1268 |
callback_kwargs[k] = locals()[k]
|
| 1269 |
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1270 |
+
|
| 1271 |
latents = callback_outputs.pop("latents", latents)
|
| 1272 |
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1273 |
|
|
|
|
| 1277 |
if XLA_AVAILABLE:
|
| 1278 |
xm.mark_step()
|
| 1279 |
|
|
|
|
| 1280 |
latents = self._unpack_latents(
|
| 1281 |
latents,
|
| 1282 |
latent_num_frames,
|
|
|
|
| 1289 |
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
| 1290 |
)
|
| 1291 |
|
| 1292 |
+
# Choose audio latents for decode: clean if provided, else denoised
|
| 1293 |
+
if packed_clean_audio_latents is not None:
|
| 1294 |
+
audio_latents_to_decode = packed_clean_audio_latents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1295 |
else:
|
| 1296 |
+
audio_latents_to_decode = audio_latents
|
| 1297 |
+
|
| 1298 |
+
audio_latents_to_decode = self._denormalize_audio_latents(
|
| 1299 |
+
audio_latents_to_decode, self.audio_vae.latents_mean, self.audio_vae.latents_std
|
| 1300 |
+
)
|
| 1301 |
+
audio_latents_to_decode = self._unpack_audio_latents(
|
| 1302 |
+
audio_latents_to_decode, audio_num_frames, num_mel_bins=latent_mel_bins
|
| 1303 |
+
)
|
| 1304 |
|
| 1305 |
if output_type == "latent":
|
| 1306 |
video = latents
|
| 1307 |
+
audio_out = audio_latents_to_decode
|
| 1308 |
else:
|
| 1309 |
latents = latents.to(prompt_embeds.dtype)
|
| 1310 |
|
|
|
|
| 1329 |
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
| 1330 |
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
| 1331 |
|
| 1332 |
+
audio_latents_to_decode = audio_latents_to_decode.to(self.audio_vae.dtype)
|
| 1333 |
+
generated_mel_spectrograms = self.audio_vae.decode(audio_latents_to_decode, return_dict=False)[0]
|
| 1334 |
+
audio_out = self.vocoder(generated_mel_spectrograms)
|
| 1335 |
|
| 1336 |
self.maybe_free_model_hooks()
|
| 1337 |
|
| 1338 |
if not return_dict:
|
| 1339 |
+
return (video, audio_out)
|
| 1340 |
|
| 1341 |
+
return LTX2PipelineOutput(frames=video, audio=audio_out)
|