build-tools / diffusers /models /transformers /transformer_ltx2.py
salmankhanpm's picture
Add files using upload-large-folder tool
69e1a8d verified
# Copyright 2025 The Lightricks team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from dataclasses import dataclass
from typing import Any
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import BaseOutput, apply_lora_scale, is_torch_version, logging
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings, PixArtAlphaTextProjection
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def apply_interleaved_rotary_emb(x: torch.Tensor, freqs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
cos, sin = freqs
x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
def apply_split_rotary_emb(x: torch.Tensor, freqs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
cos, sin = freqs
x_dtype = x.dtype
needs_reshape = False
if x.ndim != 4 and cos.ndim == 4:
# cos is (b, h, t, r) -> reshape x to (b, h, t, dim_per_head)
b, h, t, _ = cos.shape
x = x.reshape(b, t, h, -1).swapaxes(1, 2)
needs_reshape = True
# Split last dim (2*r) into (d=2, r)
last = x.shape[-1]
if last % 2 != 0:
raise ValueError(f"Expected x.shape[-1] to be even for split rotary, got {last}.")
r = last // 2
# (..., 2, r)
split_x = x.reshape(*x.shape[:-1], 2, r).float() # Explicitly upcast to float
first_x = split_x[..., :1, :] # (..., 1, r)
second_x = split_x[..., 1:, :] # (..., 1, r)
cos_u = cos.unsqueeze(-2) # broadcast to (..., 1, r) against (..., 2, r)
sin_u = sin.unsqueeze(-2)
out = split_x * cos_u
first_out = out[..., :1, :]
second_out = out[..., 1:, :]
first_out.addcmul_(-sin_u, second_x)
second_out.addcmul_(sin_u, first_x)
out = out.reshape(*out.shape[:-2], last)
if needs_reshape:
out = out.swapaxes(1, 2).reshape(b, t, -1)
out = out.to(dtype=x_dtype)
return out
@dataclass
class AudioVisualModelOutput(BaseOutput):
r"""
Holds the output of an audiovisual model which produces both visual (e.g. video) and audio outputs.
Args:
sample (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`):
The hidden states output conditioned on the `encoder_hidden_states` input, representing the visual output
of the model. This is typically a video (spatiotemporal) output.
audio_sample (`torch.Tensor` of shape `(batch_size, TODO)`):
The audio output of the audiovisual model.
"""
sample: "torch.Tensor" # noqa: F821
audio_sample: "torch.Tensor" # noqa: F821
class LTX2AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
As proposed in PixArt-Alpha (see: https://huggingface.co/papers/2310.00426; Section 2.3) and adapted by the LTX-2.0
model. In particular, the number of modulation parameters to be calculated is now configurable.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_mod_params (`int`, *optional*, defaults to `6`):
The number of modulation parameters which will be calculated in the first return argument. The default of 6
is standard, but sometimes we may want to have a different (usually smaller) number of modulation
parameters.
use_additional_conditions (`bool`, *optional*, defaults to `False`):
Whether to use additional conditions for normalization or not.
"""
def __init__(self, embedding_dim: int, num_mod_params: int = 6, use_additional_conditions: bool = False):
super().__init__()
self.num_mod_params = num_mod_params
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, self.num_mod_params * embedding_dim, bias=True)
def forward(
self,
timestep: torch.Tensor,
added_cond_kwargs: dict[str, torch.Tensor] | None = None,
batch_size: int | None = None,
hidden_dtype: torch.dtype | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# No modulation happening here.
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
return self.linear(self.silu(embedded_timestep)), embedded_timestep
class LTX2AudioVideoAttnProcessor:
r"""
Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0) for the LTX-2.0 model.
Compared to the LTX-1.0 model, we allow the RoPE embeddings for the queries and keys to be separate so that we can
support audio-to-video (a2v) and video-to-audio (v2a) cross attention.
"""
_attention_backend = None
_parallel_config = None
def __init__(self):
if is_torch_version("<", "2.0"):
raise ValueError(
"LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
)
def __call__(
self,
attn: "LTX2Attention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
query_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
key_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.norm_q(query)
key = attn.norm_k(key)
if query_rotary_emb is not None:
if attn.rope_type == "interleaved":
query = apply_interleaved_rotary_emb(query, query_rotary_emb)
key = apply_interleaved_rotary_emb(
key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
)
elif attn.rope_type == "split":
query = apply_split_rotary_emb(query, query_rotary_emb)
key = apply_split_rotary_emb(key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
r"""
Attention class for all LTX-2.0 attention layers. Compared to LTX-1.0, this supports specifying the query and key
RoPE embeddings separately for audio-to-video (a2v) and video-to-audio (v2a) cross-attention.
"""
_default_processor_cls = LTX2AudioVideoAttnProcessor
_available_processors = [LTX2AudioVideoAttnProcessor]
def __init__(
self,
query_dim: int,
heads: int = 8,
kv_heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = True,
cross_attention_dim: int | None = None,
out_bias: bool = True,
qk_norm: str = "rms_norm_across_heads",
norm_eps: float = 1e-6,
norm_elementwise_affine: bool = True,
rope_type: str = "interleaved",
processor=None,
):
super().__init__()
if qk_norm != "rms_norm_across_heads":
raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.")
self.head_dim = dim_head
self.inner_dim = dim_head * heads
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
self.query_dim = query_dim
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.use_bias = bias
self.dropout = dropout
self.out_dim = query_dim
self.heads = heads
self.rope_type = rope_type
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
self.to_out = torch.nn.ModuleList([])
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(torch.nn.Dropout(dropout))
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
query_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
key_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs,
) -> torch.Tensor:
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
if len(unused_kwargs) > 0:
logger.warning(
f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
hidden_states = self.processor(
self, hidden_states, encoder_hidden_states, attention_mask, query_rotary_emb, key_rotary_emb, **kwargs
)
return hidden_states
class LTX2VideoTransformerBlock(nn.Module):
r"""
Transformer block used in [LTX-2.0](https://huggingface.co/Lightricks/LTX-Video).
Args:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`):
The number of channels in each head.
qk_norm (`str`, defaults to `"rms_norm"`):
The normalization layer to use.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to use in feed-forward.
eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
cross_attention_dim: int,
audio_dim: int,
audio_num_attention_heads: int,
audio_attention_head_dim,
audio_cross_attention_dim: int,
qk_norm: str = "rms_norm_across_heads",
activation_fn: str = "gelu-approximate",
attention_bias: bool = True,
attention_out_bias: bool = True,
eps: float = 1e-6,
elementwise_affine: bool = False,
rope_type: str = "interleaved",
):
super().__init__()
# 1. Self-Attention (video and audio)
self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.attn1 = LTX2Attention(
query_dim=dim,
heads=num_attention_heads,
kv_heads=num_attention_heads,
dim_head=attention_head_dim,
bias=attention_bias,
cross_attention_dim=None,
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
)
self.audio_norm1 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
self.audio_attn1 = LTX2Attention(
query_dim=audio_dim,
heads=audio_num_attention_heads,
kv_heads=audio_num_attention_heads,
dim_head=audio_attention_head_dim,
bias=attention_bias,
cross_attention_dim=None,
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
)
# 2. Prompt Cross-Attention
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.attn2 = LTX2Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
kv_heads=num_attention_heads,
dim_head=attention_head_dim,
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
)
self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
self.audio_attn2 = LTX2Attention(
query_dim=audio_dim,
cross_attention_dim=audio_cross_attention_dim,
heads=audio_num_attention_heads,
kv_heads=audio_num_attention_heads,
dim_head=audio_attention_head_dim,
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
)
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
# Audio-to-Video (a2v) Attention --> Q: Video; K,V: Audio
self.audio_to_video_norm = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.audio_to_video_attn = LTX2Attention(
query_dim=dim,
cross_attention_dim=audio_dim,
heads=audio_num_attention_heads,
kv_heads=audio_num_attention_heads,
dim_head=audio_attention_head_dim,
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
)
# Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video
self.video_to_audio_norm = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
self.video_to_audio_attn = LTX2Attention(
query_dim=audio_dim,
cross_attention_dim=dim,
heads=audio_num_attention_heads,
kv_heads=audio_num_attention_heads,
dim_head=audio_attention_head_dim,
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
)
# 4. Feedforward layers
self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.ff = FeedForward(dim, activation_fn=activation_fn)
self.audio_norm3 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn)
# 5. Per-Layer Modulation Parameters
# Self-Attention / Feedforward AdaLayerNorm-Zero mod params
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
self.audio_scale_shift_table = nn.Parameter(torch.randn(6, audio_dim) / audio_dim**0.5)
# Per-layer a2v, v2a Cross-Attention mod params
self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim))
self.audio_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, audio_dim))
def forward(
self,
hidden_states: torch.Tensor,
audio_hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
audio_encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
temb_audio: torch.Tensor,
temb_ca_scale_shift: torch.Tensor,
temb_ca_audio_scale_shift: torch.Tensor,
temb_ca_gate: torch.Tensor,
temb_ca_audio_gate: torch.Tensor,
video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
ca_video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
ca_audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
encoder_attention_mask: torch.Tensor | None = None,
audio_encoder_attention_mask: torch.Tensor | None = None,
a2v_cross_attention_mask: torch.Tensor | None = None,
v2a_cross_attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
batch_size = hidden_states.size(0)
# 1. Video and Audio Self-Attention
norm_hidden_states = self.norm1(hidden_states)
num_ada_params = self.scale_shift_table.shape[0]
ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
batch_size, temb.size(1), num_ada_params, -1
)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
attn_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=None,
query_rotary_emb=video_rotary_emb,
)
hidden_states = hidden_states + attn_hidden_states * gate_msa
norm_audio_hidden_states = self.audio_norm1(audio_hidden_states)
num_audio_ada_params = self.audio_scale_shift_table.shape[0]
audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape(
batch_size, temb_audio.size(1), num_audio_ada_params, -1
)
audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = (
audio_ada_values.unbind(dim=2)
)
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa
attn_audio_hidden_states = self.audio_attn1(
hidden_states=norm_audio_hidden_states,
encoder_hidden_states=None,
query_rotary_emb=audio_rotary_emb,
)
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa
# 2. Video and Audio Cross-Attention with the text embeddings
norm_hidden_states = self.norm2(hidden_states)
attn_hidden_states = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
query_rotary_emb=None,
attention_mask=encoder_attention_mask,
)
hidden_states = hidden_states + attn_hidden_states
norm_audio_hidden_states = self.audio_norm2(audio_hidden_states)
attn_audio_hidden_states = self.audio_attn2(
norm_audio_hidden_states,
encoder_hidden_states=audio_encoder_hidden_states,
query_rotary_emb=None,
attention_mask=audio_encoder_attention_mask,
)
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
norm_hidden_states = self.audio_to_video_norm(hidden_states)
norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states)
# Combine global and per-layer cross attention modulation parameters
# Video
video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :]
video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :]
video_ca_scale_shift_table = (
video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype)
+ temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1)
).unbind(dim=2)
video_ca_gate = (
video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype)
+ temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1)
).unbind(dim=2)
video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table
a2v_gate = video_ca_gate[0].squeeze(2)
# Audio
audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :]
audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :]
audio_ca_scale_shift_table = (
audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype)
+ temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1)
).unbind(dim=2)
audio_ca_gate = (
audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype)
+ temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1)
).unbind(dim=2)
audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table
v2a_gate = audio_ca_gate[0].squeeze(2)
# Audio-to-Video Cross Attention: Q: Video; K,V: Audio
mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze(
2
)
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
1 + audio_a2v_ca_scale.squeeze(2)
) + audio_a2v_ca_shift.squeeze(2)
a2v_attn_hidden_states = self.audio_to_video_attn(
mod_norm_hidden_states,
encoder_hidden_states=mod_norm_audio_hidden_states,
query_rotary_emb=ca_video_rotary_emb,
key_rotary_emb=ca_audio_rotary_emb,
attention_mask=a2v_cross_attention_mask,
)
hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
# Video-to-Audio Cross Attention: Q: Audio; K,V: Video
mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze(
2
)
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
1 + audio_v2a_ca_scale.squeeze(2)
) + audio_v2a_ca_shift.squeeze(2)
v2a_attn_hidden_states = self.video_to_audio_attn(
mod_norm_audio_hidden_states,
encoder_hidden_states=mod_norm_hidden_states,
query_rotary_emb=ca_audio_rotary_emb,
key_rotary_emb=ca_video_rotary_emb,
attention_mask=v2a_cross_attention_mask,
)
audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
# 4. Feedforward
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp) + shift_mlp
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + ff_output * gate_mlp
norm_audio_hidden_states = self.audio_norm3(audio_hidden_states) * (1 + audio_scale_mlp) + audio_shift_mlp
audio_ff_output = self.audio_ff(norm_audio_hidden_states)
audio_hidden_states = audio_hidden_states + audio_ff_output * audio_gate_mlp
return hidden_states, audio_hidden_states
class LTX2AudioVideoRotaryPosEmbed(nn.Module):
"""
Video and audio rotary positional embeddings (RoPE) for the LTX-2.0 model.
Args:
causal_offset (`int`, *optional*, defaults to `1`):
Offset in the temporal axis for causal VAE modeling. This is typically 1 (for causal modeling where the VAE
treats the very first frame differently), but could also be 0 (for non-causal modeling).
"""
def __init__(
self,
dim: int,
patch_size: int = 1,
patch_size_t: int = 1,
base_num_frames: int = 20,
base_height: int = 2048,
base_width: int = 2048,
sampling_rate: int = 16000,
hop_length: int = 160,
scale_factors: tuple[int, ...] = (8, 32, 32),
theta: float = 10000.0,
causal_offset: int = 1,
modality: str = "video",
double_precision: bool = True,
rope_type: str = "interleaved",
num_attention_heads: int = 32,
) -> None:
super().__init__()
self.dim = dim
self.patch_size = patch_size
self.patch_size_t = patch_size_t
if rope_type not in ["interleaved", "split"]:
raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.")
self.rope_type = rope_type
self.base_num_frames = base_num_frames
self.num_attention_heads = num_attention_heads
# Video-specific
self.base_height = base_height
self.base_width = base_width
# Audio-specific
self.sampling_rate = sampling_rate
self.hop_length = hop_length
self.audio_latents_per_second = float(sampling_rate) / float(hop_length) / float(scale_factors[0])
self.scale_factors = scale_factors
self.theta = theta
self.causal_offset = causal_offset
self.modality = modality
if self.modality not in ["video", "audio"]:
raise ValueError(f"Modality {modality} is not supported. Supported modalities are `video` and `audio`.")
self.double_precision = double_precision
def prepare_video_coords(
self,
batch_size: int,
num_frames: int,
height: int,
width: int,
device: torch.device,
fps: float = 24.0,
) -> torch.Tensor:
"""
Create per-dimension bounds [inclusive start, exclusive end) for each patch with respect to the original pixel
space video grid (num_frames, height, width). This will ultimately have shape (batch_size, 3, num_patches, 2)
where
- axis 1 (size 3) enumerates (frame, height, width) dimensions (e.g. idx 0 corresponds to frames)
- axis 3 (size 2) stores `[start, end)` indices within each dimension
Args:
batch_size (`int`):
Batch size of the video latents.
num_frames (`int`):
Number of latent frames in the video latents.
height (`int`):
Latent height of the video latents.
width (`int`):
Latent width of the video latents.
device (`torch.device`):
Device on which to create the video grid.
Returns:
`torch.Tensor`:
Per-dimension patch boundaries tensor of shape [batch_size, 3, num_patches, 2].
"""
# 1. Generate grid coordinates for each spatiotemporal dimension (frames, height, width)
# Always compute rope in fp32
grid_f = torch.arange(start=0, end=num_frames, step=self.patch_size_t, dtype=torch.float32, device=device)
grid_h = torch.arange(start=0, end=height, step=self.patch_size, dtype=torch.float32, device=device)
grid_w = torch.arange(start=0, end=width, step=self.patch_size, dtype=torch.float32, device=device)
# indexing='ij' ensures that the dimensions are kept in order as (frames, height, width)
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
grid = torch.stack(grid, dim=0) # [3, N_F, N_H, N_W], where e.g. N_F is the number of temporal patches
# 2. Get the patch boundaries with respect to the latent video grid
patch_size = (self.patch_size_t, self.patch_size, self.patch_size)
patch_size_delta = torch.tensor(patch_size, dtype=grid.dtype, device=grid.device)
patch_ends = grid + patch_size_delta.view(3, 1, 1, 1)
# Combine the start (grid) and end (patch_ends) coordinates along new trailing dimension
latent_coords = torch.stack([grid, patch_ends], dim=-1) # [3, N_F, N_H, N_W, 2]
# Reshape to (batch_size, 3, num_patches, 2)
latent_coords = latent_coords.flatten(1, 3)
latent_coords = latent_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1)
# 3. Calculate the pixel space patch boundaries from the latent boundaries.
scale_tensor = torch.tensor(self.scale_factors, device=latent_coords.device)
# Broadcast the VAE scale factors such that they are compatible with latent_coords's shape
broadcast_shape = [1] * latent_coords.ndim
broadcast_shape[1] = -1 # This is the (frame, height, width) dim
# Apply per-axis scaling to convert latent coordinates to pixel space coordinates
pixel_coords = latent_coords * scale_tensor.view(*broadcast_shape)
# As the VAE temporal stride for the first frame is 1 instead of self.vae_scale_factors[0], we need to shift
# and clamp to keep the first-frame timestamps causal and non-negative.
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + self.causal_offset - self.scale_factors[0]).clamp(min=0)
# Scale the temporal coordinates by the video FPS
pixel_coords[:, 0, ...] = pixel_coords[:, 0, ...] / fps
return pixel_coords
def prepare_audio_coords(
self,
batch_size: int,
num_frames: int,
device: torch.device,
shift: int = 0,
) -> torch.Tensor:
"""
Create per-dimension bounds [inclusive start, exclusive end) of start and end timestamps for each latent frame.
This will ultimately have shape (batch_size, 3, num_patches, 2) where
- axis 1 (size 1) represents the temporal dimension
- axis 3 (size 2) stores `[start, end)` indices within each dimension
Args:
batch_size (`int`):
Batch size of the audio latents.
num_frames (`int`):
Number of latent frames in the audio latents.
device (`torch.device`):
Device on which to create the audio grid.
shift (`int`, *optional*, defaults to `0`):
Offset on the latent indices. Different shift values correspond to different overlapping windows with
respect to the same underlying latent grid.
Returns:
`torch.Tensor`:
Per-dimension patch boundaries tensor of shape [batch_size, 1, num_patches, 2].
"""
# 1. Generate coordinates in the frame (time) dimension.
# Always compute rope in fp32
grid_f = torch.arange(
start=shift, end=num_frames + shift, step=self.patch_size_t, dtype=torch.float32, device=device
)
# 2. Calculate start timstamps in seconds with respect to the original spectrogram grid
audio_scale_factor = self.scale_factors[0]
# Scale back to mel spectrogram space
grid_start_mel = grid_f * audio_scale_factor
# Handle first frame causal offset, ensuring non-negative timestamps
grid_start_mel = (grid_start_mel + self.causal_offset - audio_scale_factor).clip(min=0)
# Convert mel bins back into seconds
grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate
# 3. Calculate start timstamps in seconds with respect to the original spectrogram grid
grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor
grid_end_mel = (grid_end_mel + self.causal_offset - audio_scale_factor).clip(min=0)
grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate
audio_coords = torch.stack([grid_start_s, grid_end_s], dim=-1) # [num_patches, 2]
audio_coords = audio_coords.unsqueeze(0).expand(batch_size, -1, -1) # [batch_size, num_patches, 2]
audio_coords = audio_coords.unsqueeze(1) # [batch_size, 1, num_patches, 2]
return audio_coords
def prepare_coords(self, *args, **kwargs):
if self.modality == "video":
return self.prepare_video_coords(*args, **kwargs)
elif self.modality == "audio":
return self.prepare_audio_coords(*args, **kwargs)
def forward(
self, coords: torch.Tensor, device: str | torch.device | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
device = device or coords.device
# Number of spatiotemporal dimensions (3 for video, 1 (temporal) for audio and cross attn)
num_pos_dims = coords.shape[1]
# 1. If the coords are patch boundaries [start, end), use the midpoint of these boundaries as the patch
# position index
if coords.ndim == 4:
coords_start, coords_end = coords.chunk(2, dim=-1)
coords = (coords_start + coords_end) / 2.0
coords = coords.squeeze(-1) # [B, num_pos_dims, num_patches]
# 2. Get coordinates as a fraction of the base data shape
if self.modality == "video":
max_positions = (self.base_num_frames, self.base_height, self.base_width)
elif self.modality == "audio":
max_positions = (self.base_num_frames,)
# [B, num_pos_dims, num_patches] --> [B, num_patches, num_pos_dims]
grid = torch.stack([coords[:, i] / max_positions[i] for i in range(num_pos_dims)], dim=-1).to(device)
# Number of spatiotemporal dimensions (3 for video, 1 for audio and cross attn) times 2 for cos, sin
num_rope_elems = num_pos_dims * 2
# 3. Create a 1D grid of frequencies for RoPE
freqs_dtype = torch.float64 if self.double_precision else torch.float32
pow_indices = torch.pow(
self.theta,
torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device),
)
freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32)
# 4. Tensor-vector outer product between pos ids tensor of shape (B, 3, num_patches) and freqs vector of shape
# (self.dim // num_elems,)
freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, num_patches, num_pos_dims, self.dim // num_elems]
freqs = freqs.transpose(-1, -2).flatten(2) # [B, num_patches, self.dim // 2]
# 5. Get real, interleaved (cos, sin) frequencies, padded to self.dim
# TODO: consider implementing this as a utility and reuse in `connectors.py`.
# src/diffusers/pipelines/ltx2/connectors.py
if self.rope_type == "interleaved":
cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)
sin_freqs = freqs.sin().repeat_interleave(2, dim=-1)
if self.dim % num_rope_elems != 0:
cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems])
sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % num_rope_elems])
cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
elif self.rope_type == "split":
expected_freqs = self.dim // 2
current_freqs = freqs.shape[-1]
pad_size = expected_freqs - current_freqs
cos_freq = freqs.cos()
sin_freq = freqs.sin()
if pad_size != 0:
cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])
cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)
# Reshape freqs to be compatible with multi-head attention
b = cos_freq.shape[0]
t = cos_freq.shape[1]
cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1)
sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1)
cos_freqs = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
sin_freqs = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
return cos_freqs, sin_freqs
class LTX2VideoTransformer3DModel(
ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin
):
r"""
A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
Args:
in_channels (`int`, defaults to `128`):
The number of channels in the input.
out_channels (`int`, defaults to `128`):
The number of channels in the output.
patch_size (`int`, defaults to `1`):
The size of the spatial patches to use in the patch embedding layer.
patch_size_t (`int`, defaults to `1`):
The size of the tmeporal patches to use in the patch embedding layer.
num_attention_heads (`int`, defaults to `32`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `64`):
The number of channels in each head.
cross_attention_dim (`int`, defaults to `2048 `):
The number of channels for cross attention heads.
num_layers (`int`, defaults to `28`):
The number of layers of Transformer blocks to use.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to use in feed-forward.
qk_norm (`str`, defaults to `"rms_norm_across_heads"`):
The normalization layer to use.
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["norm"]
_repeated_blocks = ["LTX2VideoTransformerBlock"]
_cp_plan = {
"": {
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
"encoder_attention_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
},
"rope": {
0: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True),
1: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True),
},
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
}
@register_to_config
def __init__(
self,
in_channels: int = 128, # Video Arguments
out_channels: int | None = 128,
patch_size: int = 1,
patch_size_t: int = 1,
num_attention_heads: int = 32,
attention_head_dim: int = 128,
cross_attention_dim: int = 4096,
vae_scale_factors: tuple[int, int, int] = (8, 32, 32),
pos_embed_max_pos: int = 20,
base_height: int = 2048,
base_width: int = 2048,
audio_in_channels: int = 128, # Audio Arguments
audio_out_channels: int | None = 128,
audio_patch_size: int = 1,
audio_patch_size_t: int = 1,
audio_num_attention_heads: int = 32,
audio_attention_head_dim: int = 64,
audio_cross_attention_dim: int = 2048,
audio_scale_factor: int = 4,
audio_pos_embed_max_pos: int = 20,
audio_sampling_rate: int = 16000,
audio_hop_length: int = 160,
num_layers: int = 48, # Shared arguments
activation_fn: str = "gelu-approximate",
qk_norm: str = "rms_norm_across_heads",
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
caption_channels: int = 3840,
attention_bias: bool = True,
attention_out_bias: bool = True,
rope_theta: float = 10000.0,
rope_double_precision: bool = True,
causal_offset: int = 1,
timestep_scale_multiplier: int = 1000,
cross_attn_timestep_scale_multiplier: int = 1000,
rope_type: str = "interleaved",
) -> None:
super().__init__()
out_channels = out_channels or in_channels
audio_out_channels = audio_out_channels or audio_in_channels
inner_dim = num_attention_heads * attention_head_dim
audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
# 1. Patchification input projections
self.proj_in = nn.Linear(in_channels, inner_dim)
self.audio_proj_in = nn.Linear(audio_in_channels, audio_inner_dim)
# 2. Prompt embeddings
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=caption_channels, hidden_size=audio_inner_dim
)
# 3. Timestep Modulation Params and Embedding
# 3.1. Global Timestep Modulation Parameters (except for cross-attention) and timestep + size embedding
# time_embed and audio_time_embed calculate both the timestep embedding and (global) modulation parameters
self.time_embed = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=6, use_additional_conditions=False)
self.audio_time_embed = LTX2AdaLayerNormSingle(
audio_inner_dim, num_mod_params=6, use_additional_conditions=False
)
# 3.2. Global Cross Attention Modulation Parameters
# Used in the audio-to-video and video-to-audio cross attention layers as a global set of modulation params,
# which are then further modified by per-block modulaton params in each transformer block.
# There are 2 sets of scale/shift parameters for each modality, 1 each for audio-to-video (a2v) and
# video-to-audio (v2a) cross attention
self.av_cross_attn_video_scale_shift = LTX2AdaLayerNormSingle(
inner_dim, num_mod_params=4, use_additional_conditions=False
)
self.av_cross_attn_audio_scale_shift = LTX2AdaLayerNormSingle(
audio_inner_dim, num_mod_params=4, use_additional_conditions=False
)
# Gate param for audio-to-video (a2v) cross attn (where the video is the queries (Q) and the audio is the keys
# and values (KV))
self.av_cross_attn_video_a2v_gate = LTX2AdaLayerNormSingle(
inner_dim, num_mod_params=1, use_additional_conditions=False
)
# Gate param for video-to-audio (v2a) cross attn (where the audio is the queries (Q) and the video is the keys
# and values (KV))
self.av_cross_attn_audio_v2a_gate = LTX2AdaLayerNormSingle(
audio_inner_dim, num_mod_params=1, use_additional_conditions=False
)
# 3.3. Output Layer Scale/Shift Modulation parameters
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
self.audio_scale_shift_table = nn.Parameter(torch.randn(2, audio_inner_dim) / audio_inner_dim**0.5)
# 4. Rotary Positional Embeddings (RoPE)
# Self-Attention
self.rope = LTX2AudioVideoRotaryPosEmbed(
dim=inner_dim,
patch_size=patch_size,
patch_size_t=patch_size_t,
base_num_frames=pos_embed_max_pos,
base_height=base_height,
base_width=base_width,
scale_factors=vae_scale_factors,
theta=rope_theta,
causal_offset=causal_offset,
modality="video",
double_precision=rope_double_precision,
rope_type=rope_type,
num_attention_heads=num_attention_heads,
)
self.audio_rope = LTX2AudioVideoRotaryPosEmbed(
dim=audio_inner_dim,
patch_size=audio_patch_size,
patch_size_t=audio_patch_size_t,
base_num_frames=audio_pos_embed_max_pos,
sampling_rate=audio_sampling_rate,
hop_length=audio_hop_length,
scale_factors=[audio_scale_factor],
theta=rope_theta,
causal_offset=causal_offset,
modality="audio",
double_precision=rope_double_precision,
rope_type=rope_type,
num_attention_heads=audio_num_attention_heads,
)
# Audio-to-Video, Video-to-Audio Cross-Attention
cross_attn_pos_embed_max_pos = max(pos_embed_max_pos, audio_pos_embed_max_pos)
self.cross_attn_rope = LTX2AudioVideoRotaryPosEmbed(
dim=audio_cross_attention_dim,
patch_size=patch_size,
patch_size_t=patch_size_t,
base_num_frames=cross_attn_pos_embed_max_pos,
base_height=base_height,
base_width=base_width,
theta=rope_theta,
causal_offset=causal_offset,
modality="video",
double_precision=rope_double_precision,
rope_type=rope_type,
num_attention_heads=num_attention_heads,
)
self.cross_attn_audio_rope = LTX2AudioVideoRotaryPosEmbed(
dim=audio_cross_attention_dim,
patch_size=audio_patch_size,
patch_size_t=audio_patch_size_t,
base_num_frames=cross_attn_pos_embed_max_pos,
sampling_rate=audio_sampling_rate,
hop_length=audio_hop_length,
theta=rope_theta,
causal_offset=causal_offset,
modality="audio",
double_precision=rope_double_precision,
rope_type=rope_type,
num_attention_heads=audio_num_attention_heads,
)
# 5. Transformer Blocks
self.transformer_blocks = nn.ModuleList(
[
LTX2VideoTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
cross_attention_dim=cross_attention_dim,
audio_dim=audio_inner_dim,
audio_num_attention_heads=audio_num_attention_heads,
audio_attention_head_dim=audio_attention_head_dim,
audio_cross_attention_dim=audio_cross_attention_dim,
qk_norm=qk_norm,
activation_fn=activation_fn,
attention_bias=attention_bias,
attention_out_bias=attention_out_bias,
eps=norm_eps,
elementwise_affine=norm_elementwise_affine,
rope_type=rope_type,
)
for _ in range(num_layers)
]
)
# 6. Output layers
self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
self.proj_out = nn.Linear(inner_dim, out_channels)
self.audio_norm_out = nn.LayerNorm(audio_inner_dim, eps=1e-6, elementwise_affine=False)
self.audio_proj_out = nn.Linear(audio_inner_dim, audio_out_channels)
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
audio_hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
audio_encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
audio_timestep: torch.LongTensor | None = None,
encoder_attention_mask: torch.Tensor | None = None,
audio_encoder_attention_mask: torch.Tensor | None = None,
num_frames: int | None = None,
height: int | None = None,
width: int | None = None,
fps: float = 24.0,
audio_num_frames: int | None = None,
video_coords: torch.Tensor | None = None,
audio_coords: torch.Tensor | None = None,
attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
) -> torch.Tensor:
"""
Forward pass for LTX-2.0 audiovisual video transformer.
Args:
hidden_states (`torch.Tensor`):
Input patchified video latents of shape `(batch_size, num_video_tokens, in_channels)`.
audio_hidden_states (`torch.Tensor`):
Input patchified audio latents of shape `(batch_size, num_audio_tokens, audio_in_channels)`.
encoder_hidden_states (`torch.Tensor`):
Input video text embeddings of shape `(batch_size, text_seq_len, self.config.caption_channels)`.
audio_encoder_hidden_states (`torch.Tensor`):
Input audio text embeddings of shape `(batch_size, text_seq_len, self.config.caption_channels)`.
timestep (`torch.Tensor`):
Input timestep of shape `(batch_size, num_video_tokens)`. These should already be scaled by
`self.config.timestep_scale_multiplier`.
audio_timestep (`torch.Tensor`, *optional*):
Input timestep of shape `(batch_size,)` or `(batch_size, num_audio_tokens)` for audio modulation
params. This is only used by certain pipelines such as the I2V pipeline.
encoder_attention_mask (`torch.Tensor`, *optional*):
Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)`.
audio_encoder_attention_mask (`torch.Tensor`, *optional*):
Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)` for audio modeling.
num_frames (`int`, *optional*):
The number of latent video frames. Used if calculating the video coordinates for RoPE.
height (`int`, *optional*):
The latent video height. Used if calculating the video coordinates for RoPE.
width (`int`, *optional*):
The latent video width. Used if calculating the video coordinates for RoPE.
fps: (`float`, *optional*, defaults to `24.0`):
The desired frames per second of the generated video. Used if calculating the video coordinates for
RoPE.
audio_num_frames: (`int`, *optional*):
The number of latent audio frames. Used if calculating the audio coordinates for RoPE.
video_coords (`torch.Tensor`, *optional*):
The video coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape
`(batch_size, 3, num_video_tokens, 2)`. If not supplied, this will be calculated inside `forward`.
audio_coords (`torch.Tensor`, *optional*):
The audio coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape
`(batch_size, 1, num_audio_tokens, 2)`. If not supplied, this will be calculated inside `forward`.
attention_kwargs (`dict[str, Any]`, *optional*):
Optional dict of keyword args to be passed to the attention processor.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a dict-like structured output of type `AudioVisualModelOutput` or a tuple.
Returns:
`AudioVisualModelOutput` or `tuple`:
If `return_dict` is `True`, returns a structured output of type `AudioVisualModelOutput`, otherwise a
`tuple` is returned where the first element is the denoised video latent patch sequence and the second
element is the denoised audio latent patch sequence.
"""
# Determine timestep for audio.
audio_timestep = audio_timestep if audio_timestep is not None else timestep
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
if audio_encoder_attention_mask is not None and audio_encoder_attention_mask.ndim == 2:
audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype)) * -10000.0
audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1)
batch_size = hidden_states.size(0)
# 1. Prepare RoPE positional embeddings
if video_coords is None:
video_coords = self.rope.prepare_video_coords(
batch_size, num_frames, height, width, hidden_states.device, fps=fps
)
if audio_coords is None:
audio_coords = self.audio_rope.prepare_audio_coords(
batch_size, audio_num_frames, audio_hidden_states.device
)
video_rotary_emb = self.rope(video_coords, device=hidden_states.device)
audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device)
video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device)
audio_cross_attn_rotary_emb = self.cross_attn_audio_rope(
audio_coords[:, 0:1, :], device=audio_hidden_states.device
)
# 2. Patchify input projections
hidden_states = self.proj_in(hidden_states)
audio_hidden_states = self.audio_proj_in(audio_hidden_states)
# 3. Prepare timestep embeddings and modulation parameters
timestep_cross_attn_gate_scale_factor = (
self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier
)
# 3.1. Prepare global modality (video and audio) timestep embedding and modulation parameters
# temb is used in the transformer blocks (as expected), while embedded_timestep is used for the output layer
# modulation with scale_shift_table (and similarly for audio)
temb, embedded_timestep = self.time_embed(
timestep.flatten(),
batch_size=batch_size,
hidden_dtype=hidden_states.dtype,
)
temb = temb.view(batch_size, -1, temb.size(-1))
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
temb_audio, audio_embedded_timestep = self.audio_time_embed(
audio_timestep.flatten(),
batch_size=batch_size,
hidden_dtype=audio_hidden_states.dtype,
)
temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1))
audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1))
# 3.2. Prepare global modality cross attention modulation parameters
video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift(
timestep.flatten(),
batch_size=batch_size,
hidden_dtype=hidden_states.dtype,
)
video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate(
timestep.flatten() * timestep_cross_attn_gate_scale_factor,
batch_size=batch_size,
hidden_dtype=hidden_states.dtype,
)
video_cross_attn_scale_shift = video_cross_attn_scale_shift.view(
batch_size, -1, video_cross_attn_scale_shift.shape[-1]
)
video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1])
audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift(
audio_timestep.flatten(),
batch_size=batch_size,
hidden_dtype=audio_hidden_states.dtype,
)
audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate(
audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor,
batch_size=batch_size,
hidden_dtype=audio_hidden_states.dtype,
)
audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.view(
batch_size, -1, audio_cross_attn_scale_shift.shape[-1]
)
audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1])
# 4. Prepare prompt embeddings
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states)
audio_encoder_hidden_states = audio_encoder_hidden_states.view(batch_size, -1, audio_hidden_states.size(-1))
# 5. Run transformer blocks
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, audio_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
audio_hidden_states,
encoder_hidden_states,
audio_encoder_hidden_states,
temb,
temb_audio,
video_cross_attn_scale_shift,
audio_cross_attn_scale_shift,
video_cross_attn_a2v_gate,
audio_cross_attn_v2a_gate,
video_rotary_emb,
audio_rotary_emb,
video_cross_attn_rotary_emb,
audio_cross_attn_rotary_emb,
encoder_attention_mask,
audio_encoder_attention_mask,
)
else:
hidden_states, audio_hidden_states = block(
hidden_states=hidden_states,
audio_hidden_states=audio_hidden_states,
encoder_hidden_states=encoder_hidden_states,
audio_encoder_hidden_states=audio_encoder_hidden_states,
temb=temb,
temb_audio=temb_audio,
temb_ca_scale_shift=video_cross_attn_scale_shift,
temb_ca_audio_scale_shift=audio_cross_attn_scale_shift,
temb_ca_gate=video_cross_attn_a2v_gate,
temb_ca_audio_gate=audio_cross_attn_v2a_gate,
video_rotary_emb=video_rotary_emb,
audio_rotary_emb=audio_rotary_emb,
ca_video_rotary_emb=video_cross_attn_rotary_emb,
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
encoder_attention_mask=encoder_attention_mask,
audio_encoder_attention_mask=audio_encoder_attention_mask,
)
# 6. Output layers (including unpatchification)
scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states * (1 + scale) + shift
output = self.proj_out(hidden_states)
audio_scale_shift_values = self.audio_scale_shift_table[None, None] + audio_embedded_timestep[:, :, None]
audio_shift, audio_scale = audio_scale_shift_values[:, :, 0], audio_scale_shift_values[:, :, 1]
audio_hidden_states = self.audio_norm_out(audio_hidden_states)
audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift
audio_output = self.audio_proj_out(audio_hidden_states)
if not return_dict:
return (output, audio_output)
return AudioVisualModelOutput(sample=output, audio_sample=audio_output)