| import math |
| import functools |
| from dataclasses import dataclass, replace |
| from enum import Enum |
| from typing import Optional, Tuple, Callable |
| import numpy as np |
| import torch |
| from einops import rearrange |
| from .ltx2_common import rms_norm, Modality |
| from ..core.attention.attention import attention_forward |
| from ..core import gradient_checkpoint_forward |
|
|
|
|
| def get_timestep_embedding( |
| timesteps: torch.Tensor, |
| embedding_dim: int, |
| flip_sin_to_cos: bool = False, |
| downscale_freq_shift: float = 1, |
| scale: float = 1, |
| max_period: int = 10000, |
| ) -> torch.Tensor: |
| """ |
| This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. |
| Args |
| timesteps (torch.Tensor): |
| a 1-D Tensor of N indices, one per batch element. These may be fractional. |
| embedding_dim (int): |
| the dimension of the output. |
| flip_sin_to_cos (bool): |
| Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) |
| downscale_freq_shift (float): |
| Controls the delta between frequencies between dimensions |
| scale (float): |
| Scaling factor applied to the embeddings. |
| max_period (int): |
| Controls the maximum frequency of the embeddings |
| Returns |
| torch.Tensor: an [N x dim] Tensor of positional embeddings. |
| """ |
| assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" |
|
|
| half_dim = embedding_dim // 2 |
| exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) |
| exponent = exponent / (half_dim - downscale_freq_shift) |
|
|
| emb = torch.exp(exponent) |
| emb = timesteps[:, None].float() * emb[None, :] |
|
|
| |
| emb = scale * emb |
|
|
| |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) |
|
|
| |
| if flip_sin_to_cos: |
| emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) |
|
|
| |
| if embedding_dim % 2 == 1: |
| emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) |
| return emb |
|
|
|
|
| class TimestepEmbedding(torch.nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| time_embed_dim: int, |
| out_dim: int | None = None, |
| post_act_fn: str | None = None, |
| cond_proj_dim: int | None = None, |
| sample_proj_bias: bool = True, |
| ): |
| super().__init__() |
|
|
| self.linear_1 = torch.nn.Linear(in_channels, time_embed_dim, sample_proj_bias) |
|
|
| if cond_proj_dim is not None: |
| self.cond_proj = torch.nn.Linear(cond_proj_dim, in_channels, bias=False) |
| else: |
| self.cond_proj = None |
|
|
| self.act = torch.nn.SiLU() |
| time_embed_dim_out = out_dim if out_dim is not None else time_embed_dim |
|
|
| self.linear_2 = torch.nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) |
|
|
| if post_act_fn is None: |
| self.post_act = None |
|
|
| def forward(self, sample: torch.Tensor, condition: torch.Tensor | None = None) -> torch.Tensor: |
| if condition is not None: |
| sample = sample + self.cond_proj(condition) |
| sample = self.linear_1(sample) |
|
|
| if self.act is not None: |
| sample = self.act(sample) |
|
|
| sample = self.linear_2(sample) |
|
|
| if self.post_act is not None: |
| sample = self.post_act(sample) |
| return sample |
|
|
|
|
| class Timesteps(torch.nn.Module): |
| def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): |
| super().__init__() |
| self.num_channels = num_channels |
| self.flip_sin_to_cos = flip_sin_to_cos |
| self.downscale_freq_shift = downscale_freq_shift |
| self.scale = scale |
|
|
| def forward(self, timesteps: torch.Tensor) -> torch.Tensor: |
| t_emb = get_timestep_embedding( |
| timesteps, |
| self.num_channels, |
| flip_sin_to_cos=self.flip_sin_to_cos, |
| downscale_freq_shift=self.downscale_freq_shift, |
| scale=self.scale, |
| ) |
| return t_emb |
|
|
|
|
| class PixArtAlphaCombinedTimestepSizeEmbeddings(torch.nn.Module): |
| """ |
| For PixArt-Alpha. |
| Reference: |
| https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 |
| """ |
|
|
| def __init__( |
| self, |
| embedding_dim: int, |
| size_emb_dim: int, |
| ): |
| super().__init__() |
|
|
| self.outdim = size_emb_dim |
| self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) |
| self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) |
|
|
| def forward( |
| self, |
| timestep: torch.Tensor, |
| hidden_dtype: torch.dtype, |
| ) -> torch.Tensor: |
| timesteps_proj = self.time_proj(timestep) |
| timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) |
| return timesteps_emb |
|
|
|
|
| class PerturbationType(Enum): |
| """Types of attention perturbations for STG (Spatio-Temporal Guidance).""" |
|
|
| SKIP_A2V_CROSS_ATTN = "skip_a2v_cross_attn" |
| SKIP_V2A_CROSS_ATTN = "skip_v2a_cross_attn" |
| SKIP_VIDEO_SELF_ATTN = "skip_video_self_attn" |
| SKIP_AUDIO_SELF_ATTN = "skip_audio_self_attn" |
|
|
|
|
| @dataclass(frozen=True) |
| class Perturbation: |
| """A single perturbation specifying which attention type to skip and in which blocks.""" |
|
|
| type: PerturbationType |
| blocks: list[int] | None |
|
|
| def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool: |
| if self.type != perturbation_type: |
| return False |
|
|
| if self.blocks is None: |
| return True |
|
|
| return block in self.blocks |
|
|
|
|
| @dataclass(frozen=True) |
| class PerturbationConfig: |
| """Configuration holding a list of perturbations for a single sample.""" |
|
|
| perturbations: list[Perturbation] | None |
|
|
| def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool: |
| if self.perturbations is None: |
| return False |
|
|
| return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations) |
|
|
| @staticmethod |
| def empty() -> "PerturbationConfig": |
| return PerturbationConfig([]) |
|
|
|
|
| @dataclass(frozen=True) |
| class BatchedPerturbationConfig: |
| """Perturbation configurations for a batch, with utilities for generating attention masks.""" |
|
|
| perturbations: list[PerturbationConfig] |
|
|
| def mask( |
| self, perturbation_type: PerturbationType, block: int, device, dtype: torch.dtype |
| ) -> torch.Tensor: |
| mask = torch.ones((len(self.perturbations),), device=device, dtype=dtype) |
| for batch_idx, perturbation in enumerate(self.perturbations): |
| if perturbation.is_perturbed(perturbation_type, block): |
| mask[batch_idx] = 0 |
|
|
| return mask |
|
|
| def mask_like(self, perturbation_type: PerturbationType, block: int, values: torch.Tensor) -> torch.Tensor: |
| mask = self.mask(perturbation_type, block, values.device, values.dtype) |
| return mask.view(mask.numel(), *([1] * len(values.shape[1:]))) |
|
|
| def any_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool: |
| return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations) |
|
|
| def all_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool: |
| return all(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations) |
|
|
| @staticmethod |
| def empty(batch_size: int) -> "BatchedPerturbationConfig": |
| return BatchedPerturbationConfig([PerturbationConfig.empty() for _ in range(batch_size)]) |
|
|
|
|
|
|
| ADALN_NUM_BASE_PARAMS = 6 |
| |
| ADALN_NUM_CROSS_ATTN_PARAMS = 3 |
|
|
|
|
| def adaln_embedding_coefficient(cross_attention_adaln: bool) -> int: |
| """Total number of AdaLN parameters per block.""" |
| return ADALN_NUM_BASE_PARAMS + (ADALN_NUM_CROSS_ATTN_PARAMS if cross_attention_adaln else 0) |
|
|
|
|
| class AdaLayerNormSingle(torch.nn.Module): |
| r""" |
| Norm layer adaptive layer norm single (adaLN-single). |
| As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). |
| Parameters: |
| embedding_dim (`int`): The size of each embedding vector. |
| use_additional_conditions (`bool`): To use additional conditions for normalization or not. |
| """ |
|
|
| def __init__(self, embedding_dim: int, embedding_coefficient: int = 6): |
| super().__init__() |
|
|
| self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( |
| embedding_dim, |
| size_emb_dim=embedding_dim // 3, |
| ) |
|
|
| self.silu = torch.nn.SiLU() |
| self.linear = torch.nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True) |
|
|
| def forward( |
| self, |
| timestep: torch.Tensor, |
| hidden_dtype: Optional[torch.dtype] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| embedded_timestep = self.emb(timestep, hidden_dtype=hidden_dtype) |
| return self.linear(self.silu(embedded_timestep)), embedded_timestep |
|
|
|
|
| class LTXRopeType(Enum): |
| INTERLEAVED = "interleaved" |
| SPLIT = "split" |
|
|
|
|
| def apply_rotary_emb( |
| input_tensor: torch.Tensor, |
| freqs_cis: Tuple[torch.Tensor, torch.Tensor], |
| rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, |
| ) -> torch.Tensor: |
| if rope_type == LTXRopeType.INTERLEAVED: |
| return apply_interleaved_rotary_emb(input_tensor, *freqs_cis) |
| elif rope_type == LTXRopeType.SPLIT: |
| return apply_split_rotary_emb(input_tensor, *freqs_cis) |
| else: |
| raise ValueError(f"Invalid rope type: {rope_type}") |
|
|
|
|
|
|
| def apply_interleaved_rotary_emb( |
| input_tensor: torch.Tensor, cos_freqs: torch.Tensor, sin_freqs: torch.Tensor |
| ) -> torch.Tensor: |
| t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) |
| t1, t2 = t_dup.unbind(dim=-1) |
| t_dup = torch.stack((-t2, t1), dim=-1) |
| input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") |
|
|
| out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs |
|
|
| return out |
|
|
|
|
| def apply_split_rotary_emb( |
| input_tensor: torch.Tensor, cos_freqs: torch.Tensor, sin_freqs: torch.Tensor |
| ) -> torch.Tensor: |
| needs_reshape = False |
| if input_tensor.ndim != 4 and cos_freqs.ndim == 4: |
| b, h, t, _ = cos_freqs.shape |
| input_tensor = input_tensor.reshape(b, t, h, -1).swapaxes(1, 2) |
| needs_reshape = True |
|
|
| split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2) |
| first_half_input = split_input[..., :1, :] |
| second_half_input = split_input[..., 1:, :] |
|
|
| output = split_input * cos_freqs.unsqueeze(-2) |
| first_half_output = output[..., :1, :] |
| second_half_output = output[..., 1:, :] |
|
|
| first_half_output.addcmul_(-sin_freqs.unsqueeze(-2), second_half_input) |
| second_half_output.addcmul_(sin_freqs.unsqueeze(-2), first_half_input) |
|
|
| output = rearrange(output, "... d r -> ... (d r)") |
| if needs_reshape: |
| output = output.swapaxes(1, 2).reshape(b, t, -1) |
|
|
| return output |
|
|
|
|
| @functools.lru_cache(maxsize=5) |
| def generate_freq_grid_np( |
| positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int |
| ) -> torch.Tensor: |
| theta = positional_embedding_theta |
| start = 1 |
| end = theta |
|
|
| n_elem = 2 * positional_embedding_max_pos_count |
| pow_indices = np.power( |
| theta, |
| np.linspace( |
| np.log(start) / np.log(theta), |
| np.log(end) / np.log(theta), |
| inner_dim // n_elem, |
| dtype=np.float64, |
| ), |
| ) |
| return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32) |
|
|
|
|
| @functools.lru_cache(maxsize=5) |
| def generate_freq_grid_pytorch( |
| positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int |
| ) -> torch.Tensor: |
| theta = positional_embedding_theta |
| start = 1 |
| end = theta |
| n_elem = 2 * positional_embedding_max_pos_count |
|
|
| indices = theta ** ( |
| torch.linspace( |
| math.log(start, theta), |
| math.log(end, theta), |
| inner_dim // n_elem, |
| dtype=torch.float32, |
| ) |
| ) |
| indices = indices.to(dtype=torch.float32) |
|
|
| indices = indices * math.pi / 2 |
|
|
| return indices |
|
|
|
|
| def get_fractional_positions(indices_grid: torch.Tensor, max_pos: list[int]) -> torch.Tensor: |
| n_pos_dims = indices_grid.shape[1] |
| assert n_pos_dims == len(max_pos), ( |
| f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})" |
| ) |
| fractional_positions = torch.stack( |
| [indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)], |
| dim=-1, |
| ) |
| return fractional_positions |
|
|
|
|
| def generate_freqs( |
| indices: torch.Tensor, indices_grid: torch.Tensor, max_pos: list[int], use_middle_indices_grid: bool |
| ) -> torch.Tensor: |
| if use_middle_indices_grid: |
| assert len(indices_grid.shape) == 4 |
| assert indices_grid.shape[-1] == 2 |
| indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1] |
| indices_grid = (indices_grid_start + indices_grid_end) / 2.0 |
| elif len(indices_grid.shape) == 4: |
| indices_grid = indices_grid[..., 0] |
|
|
| fractional_positions = get_fractional_positions(indices_grid, max_pos) |
| indices = indices.to(device=fractional_positions.device) |
|
|
| freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2) |
| return freqs |
|
|
|
|
| def split_freqs_cis(freqs: torch.Tensor, pad_size: int, num_attention_heads: int) -> tuple[torch.Tensor, torch.Tensor]: |
| cos_freq = freqs.cos() |
| sin_freq = freqs.sin() |
|
|
| if pad_size != 0: |
| cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) |
| sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) |
|
|
| cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) |
| sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) |
|
|
| |
| b = cos_freq.shape[0] |
| t = cos_freq.shape[1] |
|
|
| cos_freq = cos_freq.reshape(b, t, num_attention_heads, -1) |
| sin_freq = sin_freq.reshape(b, t, num_attention_heads, -1) |
|
|
| cos_freq = torch.swapaxes(cos_freq, 1, 2) |
| sin_freq = torch.swapaxes(sin_freq, 1, 2) |
| return cos_freq, sin_freq |
|
|
|
|
| def interleaved_freqs_cis(freqs: torch.Tensor, pad_size: int) -> tuple[torch.Tensor, torch.Tensor]: |
| cos_freq = freqs.cos().repeat_interleave(2, dim=-1) |
| sin_freq = freqs.sin().repeat_interleave(2, dim=-1) |
| if pad_size != 0: |
| cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) |
| sin_padding = torch.zeros_like(cos_freq[:, :, :pad_size]) |
| cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) |
| sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) |
| return cos_freq, sin_freq |
|
|
|
|
| def precompute_freqs_cis( |
| indices_grid: torch.Tensor, |
| dim: int, |
| out_dtype: torch.dtype, |
| theta: float = 10000.0, |
| max_pos: list[int] | None = None, |
| use_middle_indices_grid: bool = False, |
| num_attention_heads: int = 32, |
| rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, |
| freq_grid_generator: Callable[[float, int, int, torch.device], torch.Tensor] = generate_freq_grid_pytorch, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| if max_pos is None: |
| max_pos = [20, 2048, 2048] |
|
|
| indices = freq_grid_generator(theta, indices_grid.shape[1], dim) |
| freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid) |
|
|
| if rope_type == LTXRopeType.SPLIT: |
| expected_freqs = dim // 2 |
| current_freqs = freqs.shape[-1] |
| pad_size = expected_freqs - current_freqs |
| cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads) |
| else: |
| |
| n_elem = 2 * indices_grid.shape[1] |
| cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) |
| return cos_freq.to(out_dtype), sin_freq.to(out_dtype) |
|
|
|
|
| class Attention(torch.nn.Module): |
| def __init__( |
| self, |
| query_dim: int, |
| context_dim: int | None = None, |
| heads: int = 8, |
| dim_head: int = 64, |
| norm_eps: float = 1e-6, |
| rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, |
| apply_gated_attention: bool = False, |
| ) -> None: |
| super().__init__() |
| self.rope_type = rope_type |
|
|
| inner_dim = dim_head * heads |
| context_dim = query_dim if context_dim is None else context_dim |
|
|
| self.heads = heads |
| self.dim_head = dim_head |
|
|
| self.q_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps) |
| self.k_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps) |
|
|
| self.to_q = torch.nn.Linear(query_dim, inner_dim, bias=True) |
| self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True) |
| self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True) |
|
|
| |
| if apply_gated_attention: |
| self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True) |
| else: |
| self.to_gate_logits = None |
|
|
| self.to_out = torch.nn.Sequential(torch.nn.Linear(inner_dim, query_dim, bias=True), torch.nn.Identity()) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| context: torch.Tensor | None = None, |
| mask: torch.Tensor | None = None, |
| pe: torch.Tensor | None = None, |
| k_pe: torch.Tensor | None = None, |
| perturbation_mask: torch.Tensor | None = None, |
| all_perturbed: bool = False, |
| ) -> torch.Tensor: |
| q = self.to_q(x) |
| context = x if context is None else context |
| k = self.to_k(context) |
| v = self.to_v(context) |
|
|
| q = self.q_norm(q) |
| k = self.k_norm(k) |
|
|
| if pe is not None: |
| q = apply_rotary_emb(q, pe, self.rope_type) |
| k = apply_rotary_emb(k, pe if k_pe is None else k_pe, self.rope_type) |
|
|
| |
| q = q.unflatten(-1, (self.heads, self.dim_head)) |
| k = k.unflatten(-1, (self.heads, self.dim_head)) |
| v = v.unflatten(-1, (self.heads, self.dim_head)) |
|
|
| out = attention_forward( |
| q=q, |
| k=k, |
| v=v, |
| q_pattern="b s n d", |
| k_pattern="b s n d", |
| v_pattern="b s n d", |
| out_pattern="b s n d", |
| attn_mask=mask |
| ) |
|
|
| |
| out = out.flatten(2, 3) |
|
|
| |
| if self.to_gate_logits is not None: |
| gate_logits = self.to_gate_logits(x) |
| b, t, _ = out.shape |
| |
| out = out.view(b, t, self.heads, self.dim_head) |
| |
| gates = 2.0 * torch.sigmoid(gate_logits) |
| out = out * gates.unsqueeze(-1) |
| |
| out = out.view(b, t, self.heads * self.dim_head) |
|
|
| return self.to_out(out) |
|
|
|
|
| class PixArtAlphaTextProjection(torch.nn.Module): |
| """ |
| Projects caption embeddings. Also handles dropout for classifier-free guidance. |
| Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py |
| """ |
|
|
| def __init__(self, in_features: int, hidden_size: int, out_features: int | None = None, act_fn: str = "gelu_tanh"): |
| super().__init__() |
| if out_features is None: |
| out_features = hidden_size |
| self.linear_1 = torch.nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) |
| if act_fn == "gelu_tanh": |
| self.act_1 = torch.nn.GELU(approximate="tanh") |
| elif act_fn == "silu": |
| self.act_1 = torch.nn.SiLU() |
| else: |
| raise ValueError(f"Unknown activation function: {act_fn}") |
| self.linear_2 = torch.nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) |
|
|
| def forward(self, caption: torch.Tensor) -> torch.Tensor: |
| hidden_states = self.linear_1(caption) |
| hidden_states = self.act_1(hidden_states) |
| hidden_states = self.linear_2(hidden_states) |
| return hidden_states |
|
|
| @dataclass(frozen=True) |
| class TransformerArgs: |
| x: torch.Tensor |
| context: torch.Tensor |
| context_mask: torch.Tensor |
| timesteps: torch.Tensor |
| embedded_timestep: torch.Tensor |
| positional_embeddings: torch.Tensor |
| cross_positional_embeddings: torch.Tensor | None |
| cross_scale_shift_timestep: torch.Tensor | None |
| cross_gate_timestep: torch.Tensor | None |
| enabled: bool |
| prompt_timestep: torch.Tensor | None = None |
| self_attention_mask: torch.Tensor | None = ( |
| None |
| ) |
|
|
|
|
| class TransformerArgsPreprocessor: |
| def __init__( |
| self, |
| patchify_proj: torch.nn.Linear, |
| adaln: AdaLayerNormSingle, |
| inner_dim: int, |
| max_pos: list[int], |
| num_attention_heads: int, |
| use_middle_indices_grid: bool, |
| timestep_scale_multiplier: int, |
| double_precision_rope: bool, |
| positional_embedding_theta: float, |
| rope_type: LTXRopeType, |
| caption_projection: torch.nn.Module | None = None, |
| prompt_adaln: AdaLayerNormSingle | None = None, |
| ) -> None: |
| self.patchify_proj = patchify_proj |
| self.adaln = adaln |
| self.inner_dim = inner_dim |
| self.max_pos = max_pos |
| self.num_attention_heads = num_attention_heads |
| self.use_middle_indices_grid = use_middle_indices_grid |
| self.timestep_scale_multiplier = timestep_scale_multiplier |
| self.double_precision_rope = double_precision_rope |
| self.positional_embedding_theta = positional_embedding_theta |
| self.rope_type = rope_type |
| self.caption_projection = caption_projection |
| self.prompt_adaln = prompt_adaln |
|
|
| def _prepare_timestep( |
| self, timestep: torch.Tensor, adaln: AdaLayerNormSingle, batch_size: int, hidden_dtype: torch.dtype |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Prepare timestep embeddings.""" |
| timestep_scaled = timestep * self.timestep_scale_multiplier |
| timestep, embedded_timestep = adaln( |
| timestep_scaled.flatten(), |
| hidden_dtype=hidden_dtype, |
| ) |
| |
| timestep = timestep.view(batch_size, -1, timestep.shape[-1]) |
| embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1]) |
| return timestep, embedded_timestep |
|
|
| def _prepare_context( |
| self, |
| context: torch.Tensor, |
| x: torch.Tensor, |
| ) -> torch.Tensor: |
| """Prepare context for transformer blocks.""" |
| if self.caption_projection is not None: |
| context = self.caption_projection(context) |
| batch_size = x.shape[0] |
| return context.view(batch_size, -1, x.shape[-1]) |
|
|
| def _prepare_attention_mask(self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype) -> torch.Tensor | None: |
| """Prepare attention mask.""" |
| if attention_mask is None or torch.is_floating_point(attention_mask): |
| return attention_mask |
|
|
| return (attention_mask - 1).to(x_dtype).reshape( |
| (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) |
| ) * torch.finfo(x_dtype).max |
|
|
| def _prepare_self_attention_mask( |
| self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype |
| ) -> torch.Tensor | None: |
| """Prepare self-attention mask by converting [0,1] values to additive log-space bias. |
| Input shape: (B, T, T) with values in [0, 1]. |
| Output shape: (B, 1, T, T) with 0.0 for full attention and a large negative value |
| for masked positions. |
| Positions with attention_mask <= 0 are fully masked (mapped to the dtype's minimum |
| representable value). Strictly positive entries are converted via log-space for |
| smooth attenuation, with small values clamped for numerical stability. |
| Returns None if input is None (no masking). |
| """ |
| if attention_mask is None: |
| return None |
|
|
| |
| |
| |
| finfo = torch.finfo(x_dtype) |
| eps = finfo.tiny |
|
|
| bias = torch.full_like(attention_mask, finfo.min, dtype=x_dtype) |
| positive = attention_mask > 0 |
| if positive.any(): |
| bias[positive] = torch.log(attention_mask[positive].clamp(min=eps)).to(x_dtype) |
|
|
| return bias.unsqueeze(1) |
|
|
| def _prepare_positional_embeddings( |
| self, |
| positions: torch.Tensor, |
| inner_dim: int, |
| max_pos: list[int], |
| use_middle_indices_grid: bool, |
| num_attention_heads: int, |
| x_dtype: torch.dtype, |
| ) -> torch.Tensor: |
| """Prepare positional embeddings.""" |
| freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch |
| pe = precompute_freqs_cis( |
| positions, |
| dim=inner_dim, |
| out_dtype=x_dtype, |
| theta=self.positional_embedding_theta, |
| max_pos=max_pos, |
| use_middle_indices_grid=use_middle_indices_grid, |
| num_attention_heads=num_attention_heads, |
| rope_type=self.rope_type, |
| freq_grid_generator=freq_grid_generator, |
| ) |
| return pe |
|
|
| def prepare( |
| self, |
| modality: Modality, |
| cross_modality: Modality | None = None, |
| ) -> TransformerArgs: |
| x = self.patchify_proj(modality.latent) |
| batch_size = x.shape[0] |
| timestep, embedded_timestep = self._prepare_timestep( |
| modality.timesteps, self.adaln, batch_size, modality.latent.dtype |
| ) |
| prompt_timestep = None |
| if self.prompt_adaln is not None: |
| prompt_timestep, _ = self._prepare_timestep( |
| modality.sigma, self.prompt_adaln, batch_size, modality.latent.dtype |
| ) |
| context = self._prepare_context(modality.context, x) |
| attention_mask = self._prepare_attention_mask(modality.context_mask, modality.latent.dtype) |
| pe = self._prepare_positional_embeddings( |
| positions=modality.positions, |
| inner_dim=self.inner_dim, |
| max_pos=self.max_pos, |
| use_middle_indices_grid=self.use_middle_indices_grid, |
| num_attention_heads=self.num_attention_heads, |
| x_dtype=modality.latent.dtype, |
| ) |
| self_attention_mask = self._prepare_self_attention_mask(modality.attention_mask, modality.latent.dtype) |
| return TransformerArgs( |
| x=x, |
| context=context, |
| context_mask=attention_mask, |
| timesteps=timestep, |
| embedded_timestep=embedded_timestep, |
| positional_embeddings=pe, |
| cross_positional_embeddings=None, |
| cross_scale_shift_timestep=None, |
| cross_gate_timestep=None, |
| enabled=modality.enabled, |
| prompt_timestep=prompt_timestep, |
| self_attention_mask=self_attention_mask, |
| ) |
|
|
|
|
| class MultiModalTransformerArgsPreprocessor: |
| def __init__( |
| self, |
| patchify_proj: torch.nn.Linear, |
| adaln: AdaLayerNormSingle, |
| cross_scale_shift_adaln: AdaLayerNormSingle, |
| cross_gate_adaln: AdaLayerNormSingle, |
| inner_dim: int, |
| max_pos: list[int], |
| num_attention_heads: int, |
| cross_pe_max_pos: int, |
| use_middle_indices_grid: bool, |
| audio_cross_attention_dim: int, |
| timestep_scale_multiplier: int, |
| double_precision_rope: bool, |
| positional_embedding_theta: float, |
| rope_type: LTXRopeType, |
| av_ca_timestep_scale_multiplier: int, |
| caption_projection: torch.nn.Module | None = None, |
| prompt_adaln: AdaLayerNormSingle | None = None, |
| ) -> None: |
| self.simple_preprocessor = TransformerArgsPreprocessor( |
| patchify_proj=patchify_proj, |
| adaln=adaln, |
| inner_dim=inner_dim, |
| max_pos=max_pos, |
| num_attention_heads=num_attention_heads, |
| use_middle_indices_grid=use_middle_indices_grid, |
| timestep_scale_multiplier=timestep_scale_multiplier, |
| double_precision_rope=double_precision_rope, |
| positional_embedding_theta=positional_embedding_theta, |
| rope_type=rope_type, |
| caption_projection=caption_projection, |
| prompt_adaln=prompt_adaln, |
| ) |
| self.cross_scale_shift_adaln = cross_scale_shift_adaln |
| self.cross_gate_adaln = cross_gate_adaln |
| self.cross_pe_max_pos = cross_pe_max_pos |
| self.audio_cross_attention_dim = audio_cross_attention_dim |
| self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier |
|
|
| def prepare( |
| self, |
| modality: Modality, |
| cross_modality: Modality | None = None, |
| ) -> TransformerArgs: |
| transformer_args = self.simple_preprocessor.prepare(modality) |
| if cross_modality is None: |
| return transformer_args |
|
|
| if cross_modality.sigma.numel() > 1: |
| if cross_modality.sigma.shape[0] != modality.timesteps.shape[0]: |
| raise ValueError("Cross modality sigma must have the same batch size as the modality") |
| if cross_modality.sigma.ndim != 1: |
| raise ValueError("Cross modality sigma must be a 1D tensor") |
|
|
| cross_timestep = cross_modality.sigma.view( |
| modality.timesteps.shape[0], 1, *[1] * len(modality.timesteps.shape[2:]) |
| ) |
|
|
| cross_pe = self.simple_preprocessor._prepare_positional_embeddings( |
| positions=modality.positions[:, 0:1, :], |
| inner_dim=self.audio_cross_attention_dim, |
| max_pos=[self.cross_pe_max_pos], |
| use_middle_indices_grid=True, |
| num_attention_heads=self.simple_preprocessor.num_attention_heads, |
| x_dtype=modality.latent.dtype, |
| ) |
|
|
| cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep( |
| timestep=cross_timestep, |
| timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier, |
| batch_size=transformer_args.x.shape[0], |
| hidden_dtype=modality.latent.dtype, |
| ) |
|
|
| return replace( |
| transformer_args, |
| cross_positional_embeddings=cross_pe, |
| cross_scale_shift_timestep=cross_scale_shift_timestep, |
| cross_gate_timestep=cross_gate_timestep, |
| ) |
|
|
| def _prepare_cross_attention_timestep( |
| self, |
| timestep: torch.Tensor | None, |
| timestep_scale_multiplier: int, |
| batch_size: int, |
| hidden_dtype: torch.dtype, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Prepare cross attention timestep embeddings.""" |
| timestep = timestep * timestep_scale_multiplier |
|
|
| av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier |
|
|
| scale_shift_timestep, _ = self.cross_scale_shift_adaln( |
| timestep.flatten(), |
| hidden_dtype=hidden_dtype, |
| ) |
| scale_shift_timestep = scale_shift_timestep.view(batch_size, -1, scale_shift_timestep.shape[-1]) |
| gate_noise_timestep, _ = self.cross_gate_adaln( |
| timestep.flatten() * av_ca_factor, |
| hidden_dtype=hidden_dtype, |
| ) |
| gate_noise_timestep = gate_noise_timestep.view(batch_size, -1, gate_noise_timestep.shape[-1]) |
|
|
| return scale_shift_timestep, gate_noise_timestep |
|
|
|
|
| @dataclass |
| class TransformerConfig: |
| dim: int |
| heads: int |
| d_head: int |
| context_dim: int |
| apply_gated_attention: bool = False |
| cross_attention_adaln: bool = False |
|
|
|
|
| class BasicAVTransformerBlock(torch.nn.Module): |
| def __init__( |
| self, |
| idx: int, |
| video: TransformerConfig | None = None, |
| audio: TransformerConfig | None = None, |
| rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, |
| norm_eps: float = 1e-6, |
| ): |
| super().__init__() |
|
|
| self.idx = idx |
| if video is not None: |
| self.attn1 = Attention( |
| query_dim=video.dim, |
| heads=video.heads, |
| dim_head=video.d_head, |
| context_dim=None, |
| rope_type=rope_type, |
| norm_eps=norm_eps, |
| apply_gated_attention=video.apply_gated_attention, |
| ) |
| self.attn2 = Attention( |
| query_dim=video.dim, |
| context_dim=video.context_dim, |
| heads=video.heads, |
| dim_head=video.d_head, |
| rope_type=rope_type, |
| norm_eps=norm_eps, |
| apply_gated_attention=video.apply_gated_attention, |
| ) |
| self.ff = FeedForward(video.dim, dim_out=video.dim) |
| video_sst_size = adaln_embedding_coefficient(video.cross_attention_adaln) |
| self.scale_shift_table = torch.nn.Parameter(torch.empty(video_sst_size, video.dim)) |
|
|
| if audio is not None: |
| self.audio_attn1 = Attention( |
| query_dim=audio.dim, |
| heads=audio.heads, |
| dim_head=audio.d_head, |
| context_dim=None, |
| rope_type=rope_type, |
| norm_eps=norm_eps, |
| apply_gated_attention=audio.apply_gated_attention, |
| ) |
| self.audio_attn2 = Attention( |
| query_dim=audio.dim, |
| context_dim=audio.context_dim, |
| heads=audio.heads, |
| dim_head=audio.d_head, |
| rope_type=rope_type, |
| norm_eps=norm_eps, |
| apply_gated_attention=audio.apply_gated_attention, |
| ) |
| self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim) |
| audio_sst_size = adaln_embedding_coefficient(audio.cross_attention_adaln) |
| self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(audio_sst_size, audio.dim)) |
|
|
| if audio is not None and video is not None: |
| |
| self.audio_to_video_attn = Attention( |
| query_dim=video.dim, |
| context_dim=audio.dim, |
| heads=audio.heads, |
| dim_head=audio.d_head, |
| rope_type=rope_type, |
| norm_eps=norm_eps, |
| apply_gated_attention=video.apply_gated_attention, |
| ) |
|
|
| |
| self.video_to_audio_attn = Attention( |
| query_dim=audio.dim, |
| context_dim=video.dim, |
| heads=audio.heads, |
| dim_head=audio.d_head, |
| rope_type=rope_type, |
| norm_eps=norm_eps, |
| apply_gated_attention=audio.apply_gated_attention, |
| ) |
|
|
| self.scale_shift_table_a2v_ca_audio = torch.nn.Parameter(torch.empty(5, audio.dim)) |
| self.scale_shift_table_a2v_ca_video = torch.nn.Parameter(torch.empty(5, video.dim)) |
|
|
| self.cross_attention_adaln = (video is not None and video.cross_attention_adaln) or ( |
| audio is not None and audio.cross_attention_adaln |
| ) |
|
|
| if self.cross_attention_adaln and video is not None: |
| self.prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, video.dim)) |
| if self.cross_attention_adaln and audio is not None: |
| self.audio_prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, audio.dim)) |
|
|
| self.norm_eps = norm_eps |
|
|
| def get_ada_values( |
| self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice |
| ) -> tuple[torch.Tensor, ...]: |
| num_ada_params = scale_shift_table.shape[0] |
|
|
| ada_values = ( |
| scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype) |
| + timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :] |
| ).unbind(dim=2) |
| return ada_values |
|
|
| def get_av_ca_ada_values( |
| self, |
| scale_shift_table: torch.Tensor, |
| batch_size: int, |
| scale_shift_timestep: torch.Tensor, |
| gate_timestep: torch.Tensor, |
| scale_shift_indices: slice, |
| num_scale_shift_values: int = 4, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| scale_shift_ada_values = self.get_ada_values( |
| scale_shift_table[:num_scale_shift_values, :], batch_size, scale_shift_timestep, scale_shift_indices |
| ) |
| gate_ada_values = self.get_ada_values( |
| scale_shift_table[num_scale_shift_values:, :], batch_size, gate_timestep, slice(None, None) |
| ) |
|
|
| scale, shift = (t.squeeze(2) for t in scale_shift_ada_values) |
| (gate,) = (t.squeeze(2) for t in gate_ada_values) |
|
|
| return scale, shift, gate |
|
|
| def _apply_text_cross_attention( |
| self, |
| x: torch.Tensor, |
| context: torch.Tensor, |
| attn: Attention, |
| scale_shift_table: torch.Tensor, |
| prompt_scale_shift_table: torch.Tensor | None, |
| timestep: torch.Tensor, |
| prompt_timestep: torch.Tensor | None, |
| context_mask: torch.Tensor | None, |
| cross_attention_adaln: bool = False, |
| ) -> torch.Tensor: |
| """Apply text cross-attention, with optional AdaLN modulation.""" |
| if cross_attention_adaln: |
| shift_q, scale_q, gate = self.get_ada_values(scale_shift_table, x.shape[0], timestep, slice(6, 9)) |
| return apply_cross_attention_adaln( |
| x, |
| context, |
| attn, |
| shift_q, |
| scale_q, |
| gate, |
| prompt_scale_shift_table, |
| prompt_timestep, |
| context_mask, |
| self.norm_eps, |
| ) |
| return attn(rms_norm(x, eps=self.norm_eps), context=context, mask=context_mask) |
|
|
| def forward( |
| self, |
| video: TransformerArgs | None, |
| audio: TransformerArgs | None, |
| perturbations: BatchedPerturbationConfig | None = None, |
| ) -> tuple[TransformerArgs | None, TransformerArgs | None]: |
| if video is None and audio is None: |
| raise ValueError("At least one of video or audio must be provided") |
|
|
| batch_size = (video or audio).x.shape[0] |
|
|
| if perturbations is None: |
| perturbations = BatchedPerturbationConfig.empty(batch_size) |
|
|
| vx = video.x if video is not None else None |
| ax = audio.x if audio is not None else None |
|
|
| run_vx = video is not None and video.enabled and vx.numel() > 0 |
| run_ax = audio is not None and audio.enabled and ax.numel() > 0 |
|
|
| run_a2v = run_vx and (audio is not None and ax.numel() > 0) |
| run_v2a = run_ax and (video is not None and vx.numel() > 0) |
|
|
| if run_vx: |
| vshift_msa, vscale_msa, vgate_msa = self.get_ada_values( |
| self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3) |
| ) |
| norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa |
| del vshift_msa, vscale_msa |
|
|
| all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx) |
| none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx) |
| v_mask = ( |
| perturbations.mask_like(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx, vx) |
| if not all_perturbed and not none_perturbed |
| else None |
| ) |
| vx = ( |
| vx |
| + self.attn1( |
| norm_vx, |
| pe=video.positional_embeddings, |
| mask=video.self_attention_mask, |
| perturbation_mask=v_mask, |
| all_perturbed=all_perturbed, |
| ) |
| * vgate_msa |
| ) |
| del vgate_msa, norm_vx, v_mask |
| vx = vx + self._apply_text_cross_attention( |
| vx, |
| video.context, |
| self.attn2, |
| self.scale_shift_table, |
| getattr(self, "prompt_scale_shift_table", None), |
| video.timesteps, |
| video.prompt_timestep, |
| video.context_mask, |
| cross_attention_adaln=self.cross_attention_adaln, |
| ) |
|
|
| if run_ax: |
| ashift_msa, ascale_msa, agate_msa = self.get_ada_values( |
| self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3) |
| ) |
|
|
| norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa |
| del ashift_msa, ascale_msa |
| all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx) |
| none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx) |
| a_mask = ( |
| perturbations.mask_like(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx, ax) |
| if not all_perturbed and not none_perturbed |
| else None |
| ) |
| ax = ( |
| ax |
| + self.audio_attn1( |
| norm_ax, |
| pe=audio.positional_embeddings, |
| mask=audio.self_attention_mask, |
| perturbation_mask=a_mask, |
| all_perturbed=all_perturbed, |
| ) |
| * agate_msa |
| ) |
| del agate_msa, norm_ax, a_mask |
| ax = ax + self._apply_text_cross_attention( |
| ax, |
| audio.context, |
| self.audio_attn2, |
| self.audio_scale_shift_table, |
| getattr(self, "audio_prompt_scale_shift_table", None), |
| audio.timesteps, |
| audio.prompt_timestep, |
| audio.context_mask, |
| cross_attention_adaln=self.cross_attention_adaln, |
| ) |
|
|
| |
| if run_a2v or run_v2a: |
| vx_norm3 = rms_norm(vx, eps=self.norm_eps) |
| ax_norm3 = rms_norm(ax, eps=self.norm_eps) |
|
|
| if run_a2v and not perturbations.all_in_batch(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx): |
| scale_ca_video_a2v, shift_ca_video_a2v, gate_out_a2v = self.get_av_ca_ada_values( |
| self.scale_shift_table_a2v_ca_video, |
| vx.shape[0], |
| video.cross_scale_shift_timestep, |
| video.cross_gate_timestep, |
| slice(0, 2), |
| ) |
| vx_scaled = vx_norm3 * (1 + scale_ca_video_a2v) + shift_ca_video_a2v |
| del scale_ca_video_a2v, shift_ca_video_a2v |
|
|
| scale_ca_audio_a2v, shift_ca_audio_a2v, _ = self.get_av_ca_ada_values( |
| self.scale_shift_table_a2v_ca_audio, |
| ax.shape[0], |
| audio.cross_scale_shift_timestep, |
| audio.cross_gate_timestep, |
| slice(0, 2), |
| ) |
| ax_scaled = ax_norm3 * (1 + scale_ca_audio_a2v) + shift_ca_audio_a2v |
| del scale_ca_audio_a2v, shift_ca_audio_a2v |
| a2v_mask = perturbations.mask_like(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx, vx) |
| vx = vx + ( |
| self.audio_to_video_attn( |
| vx_scaled, |
| context=ax_scaled, |
| pe=video.cross_positional_embeddings, |
| k_pe=audio.cross_positional_embeddings, |
| ) |
| * gate_out_a2v |
| * a2v_mask |
| ) |
| del gate_out_a2v, a2v_mask, vx_scaled, ax_scaled |
|
|
| if run_v2a and not perturbations.all_in_batch(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx): |
| scale_ca_audio_v2a, shift_ca_audio_v2a, gate_out_v2a = self.get_av_ca_ada_values( |
| self.scale_shift_table_a2v_ca_audio, |
| ax.shape[0], |
| audio.cross_scale_shift_timestep, |
| audio.cross_gate_timestep, |
| slice(2, 4), |
| ) |
| ax_scaled = ax_norm3 * (1 + scale_ca_audio_v2a) + shift_ca_audio_v2a |
| del scale_ca_audio_v2a, shift_ca_audio_v2a |
| scale_ca_video_v2a, shift_ca_video_v2a, _ = self.get_av_ca_ada_values( |
| self.scale_shift_table_a2v_ca_video, |
| vx.shape[0], |
| video.cross_scale_shift_timestep, |
| video.cross_gate_timestep, |
| slice(2, 4), |
| ) |
| vx_scaled = vx_norm3 * (1 + scale_ca_video_v2a) + shift_ca_video_v2a |
| del scale_ca_video_v2a, shift_ca_video_v2a |
| v2a_mask = perturbations.mask_like(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx, ax) |
| ax = ax + ( |
| self.video_to_audio_attn( |
| ax_scaled, |
| context=vx_scaled, |
| pe=audio.cross_positional_embeddings, |
| k_pe=video.cross_positional_embeddings, |
| ) |
| * gate_out_v2a |
| * v2a_mask |
| ) |
| del gate_out_v2a, v2a_mask, ax_scaled, vx_scaled |
|
|
| del vx_norm3, ax_norm3 |
|
|
| if run_vx: |
| vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values( |
| self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, 6) |
| ) |
| vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp |
| vx = vx + self.ff(vx_scaled) * vgate_mlp |
|
|
| del vshift_mlp, vscale_mlp, vgate_mlp, vx_scaled |
|
|
| if run_ax: |
| ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values( |
| self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, 6) |
| ) |
| ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp |
| ax = ax + self.audio_ff(ax_scaled) * agate_mlp |
|
|
| del ashift_mlp, ascale_mlp, agate_mlp, ax_scaled |
|
|
| return replace(video, x=vx) if video is not None else None, replace(audio, x=ax) if audio is not None else None |
|
|
|
|
| def apply_cross_attention_adaln( |
| x: torch.Tensor, |
| context: torch.Tensor, |
| attn: Attention, |
| q_shift: torch.Tensor, |
| q_scale: torch.Tensor, |
| q_gate: torch.Tensor, |
| prompt_scale_shift_table: torch.Tensor, |
| prompt_timestep: torch.Tensor, |
| context_mask: torch.Tensor | None = None, |
| norm_eps: float = 1e-6, |
| ) -> torch.Tensor: |
| batch_size = x.shape[0] |
| shift_kv, scale_kv = ( |
| prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) |
| + prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1) |
| ).unbind(dim=2) |
| attn_input = rms_norm(x, eps=norm_eps) * (1 + q_scale) + q_shift |
| encoder_hidden_states = context * (1 + scale_kv) + shift_kv |
| return attn(attn_input, context=encoder_hidden_states, mask=context_mask) * q_gate |
|
|
|
|
| class GELUApprox(torch.nn.Module): |
| def __init__(self, dim_in: int, dim_out: int) -> None: |
| super().__init__() |
| self.proj = torch.nn.Linear(dim_in, dim_out) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return torch.nn.functional.gelu(self.proj(x), approximate="tanh") |
|
|
|
|
| class FeedForward(torch.nn.Module): |
| def __init__(self, dim: int, dim_out: int, mult: int = 4) -> None: |
| super().__init__() |
| inner_dim = int(dim * mult) |
| project_in = GELUApprox(dim, inner_dim) |
|
|
| self.net = torch.nn.Sequential(project_in, torch.nn.Identity(), torch.nn.Linear(inner_dim, dim_out)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.net(x) |
|
|
|
|
| class LTXModelType(Enum): |
| AudioVideo = "ltx av model" |
| VideoOnly = "ltx video only model" |
| AudioOnly = "ltx audio only model" |
|
|
| def is_video_enabled(self) -> bool: |
| return self in (LTXModelType.AudioVideo, LTXModelType.VideoOnly) |
|
|
| def is_audio_enabled(self) -> bool: |
| return self in (LTXModelType.AudioVideo, LTXModelType.AudioOnly) |
|
|
|
|
| class LTXModel(torch.nn.Module): |
| """ |
| LTX model transformer implementation. |
| This class implements the transformer blocks for the LTX model. |
| """ |
| _repeated_blocks = ["BasicAVTransformerBlock"] |
|
|
| def __init__( |
| self, |
| *, |
| model_type: LTXModelType = LTXModelType.AudioVideo, |
| num_attention_heads: int = 32, |
| attention_head_dim: int = 128, |
| in_channels: int = 128, |
| out_channels: int = 128, |
| num_layers: int = 48, |
| cross_attention_dim: int = 4096, |
| norm_eps: float = 1e-06, |
| caption_channels: int = 3840, |
| positional_embedding_theta: float = 10000.0, |
| positional_embedding_max_pos: list[int] | None = [20, 2048, 2048], |
| timestep_scale_multiplier: int = 1000, |
| use_middle_indices_grid: bool = True, |
| audio_num_attention_heads: int = 32, |
| audio_attention_head_dim: int = 64, |
| audio_in_channels: int = 128, |
| audio_out_channels: int = 128, |
| audio_cross_attention_dim: int = 2048, |
| audio_positional_embedding_max_pos: list[int] | None = [20], |
| av_ca_timestep_scale_multiplier: int = 1000, |
| rope_type: LTXRopeType = LTXRopeType.SPLIT, |
| double_precision_rope: bool = True, |
| apply_gated_attention: bool = False, |
| cross_attention_adaln: bool = False, |
| ): |
| super().__init__() |
| self._enable_gradient_checkpointing = False |
| self.use_middle_indices_grid = use_middle_indices_grid |
| self.rope_type = rope_type |
| self.double_precision_rope = double_precision_rope |
| self.timestep_scale_multiplier = timestep_scale_multiplier |
| self.positional_embedding_theta = positional_embedding_theta |
| self.model_type = model_type |
| self.cross_attention_adaln = cross_attention_adaln |
| cross_pe_max_pos = None |
| if model_type.is_video_enabled(): |
| if positional_embedding_max_pos is None: |
| positional_embedding_max_pos = [20, 2048, 2048] |
| self.positional_embedding_max_pos = positional_embedding_max_pos |
| self.num_attention_heads = num_attention_heads |
| self.inner_dim = num_attention_heads * attention_head_dim |
| self._init_video( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| caption_channels=caption_channels, |
| norm_eps=norm_eps, |
| ) |
|
|
| if model_type.is_audio_enabled(): |
| if audio_positional_embedding_max_pos is None: |
| audio_positional_embedding_max_pos = [20] |
| self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos |
| self.audio_num_attention_heads = audio_num_attention_heads |
| self.audio_inner_dim = self.audio_num_attention_heads * audio_attention_head_dim |
| self._init_audio( |
| in_channels=audio_in_channels, |
| out_channels=audio_out_channels, |
| caption_channels=caption_channels, |
| norm_eps=norm_eps, |
| ) |
|
|
| if model_type.is_video_enabled() and model_type.is_audio_enabled(): |
| cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]) |
| self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier |
| self.audio_cross_attention_dim = audio_cross_attention_dim |
| self._init_audio_video(num_scale_shift_values=4) |
|
|
| self._init_preprocessors(cross_pe_max_pos) |
| |
| self._init_transformer_blocks( |
| num_layers=num_layers, |
| attention_head_dim=attention_head_dim if model_type.is_video_enabled() else 0, |
| cross_attention_dim=cross_attention_dim, |
| audio_attention_head_dim=audio_attention_head_dim if model_type.is_audio_enabled() else 0, |
| audio_cross_attention_dim=audio_cross_attention_dim, |
| norm_eps=norm_eps, |
| apply_gated_attention=apply_gated_attention, |
| ) |
|
|
| @property |
| def _adaln_embedding_coefficient(self) -> int: |
| return adaln_embedding_coefficient(self.cross_attention_adaln) |
|
|
| def _init_video( |
| self, |
| in_channels: int, |
| out_channels: int, |
| caption_channels: int, |
| norm_eps: float, |
| ) -> None: |
| """Initialize video-specific components.""" |
| |
| self.patchify_proj = torch.nn.Linear(in_channels, self.inner_dim, bias=True) |
| self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=self._adaln_embedding_coefficient) |
| self.prompt_adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None |
|
|
| |
| if caption_channels is not None: |
| self.caption_projection = PixArtAlphaTextProjection( |
| in_features=caption_channels, |
| hidden_size=self.inner_dim, |
| ) |
|
|
| |
| self.scale_shift_table = torch.nn.Parameter(torch.empty(2, self.inner_dim)) |
| self.norm_out = torch.nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=norm_eps) |
| self.proj_out = torch.nn.Linear(self.inner_dim, out_channels) |
|
|
| def _init_audio( |
| self, |
| in_channels: int, |
| out_channels: int, |
| caption_channels: int, |
| norm_eps: float, |
| ) -> None: |
| """Initialize audio-specific components.""" |
|
|
| |
| self.audio_patchify_proj = torch.nn.Linear(in_channels, self.audio_inner_dim, bias=True) |
|
|
| self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=self._adaln_embedding_coefficient) |
| self.audio_prompt_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None |
|
|
| |
| if caption_channels is not None: |
| self.audio_caption_projection = PixArtAlphaTextProjection( |
| in_features=caption_channels, |
| hidden_size=self.audio_inner_dim, |
| ) |
|
|
| |
| self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(2, self.audio_inner_dim)) |
| self.audio_norm_out = torch.nn.LayerNorm(self.audio_inner_dim, elementwise_affine=False, eps=norm_eps) |
| self.audio_proj_out = torch.nn.Linear(self.audio_inner_dim, out_channels) |
|
|
| def _init_audio_video( |
| self, |
| num_scale_shift_values: int, |
| ) -> None: |
| """Initialize audio-video cross-attention components.""" |
| self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle( |
| self.inner_dim, |
| embedding_coefficient=num_scale_shift_values, |
| ) |
|
|
| self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle( |
| self.audio_inner_dim, |
| embedding_coefficient=num_scale_shift_values, |
| ) |
|
|
| self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle( |
| self.inner_dim, |
| embedding_coefficient=1, |
| ) |
|
|
| self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle( |
| self.audio_inner_dim, |
| embedding_coefficient=1, |
| ) |
|
|
| def _init_preprocessors( |
| self, |
| cross_pe_max_pos: int | None = None, |
| ) -> None: |
| """Initialize preprocessors for LTX.""" |
|
|
| if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled(): |
| self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor( |
| patchify_proj=self.patchify_proj, |
| adaln=self.adaln_single, |
| cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single, |
| cross_gate_adaln=self.av_ca_a2v_gate_adaln_single, |
| inner_dim=self.inner_dim, |
| max_pos=self.positional_embedding_max_pos, |
| num_attention_heads=self.num_attention_heads, |
| cross_pe_max_pos=cross_pe_max_pos, |
| use_middle_indices_grid=self.use_middle_indices_grid, |
| audio_cross_attention_dim=self.audio_cross_attention_dim, |
| timestep_scale_multiplier=self.timestep_scale_multiplier, |
| double_precision_rope=self.double_precision_rope, |
| positional_embedding_theta=self.positional_embedding_theta, |
| rope_type=self.rope_type, |
| av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier, |
| caption_projection=getattr(self, "caption_projection", None), |
| prompt_adaln=getattr(self, "prompt_adaln_single", None), |
| ) |
| self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor( |
| patchify_proj=self.audio_patchify_proj, |
| adaln=self.audio_adaln_single, |
| cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single, |
| cross_gate_adaln=self.av_ca_v2a_gate_adaln_single, |
| inner_dim=self.audio_inner_dim, |
| max_pos=self.audio_positional_embedding_max_pos, |
| num_attention_heads=self.audio_num_attention_heads, |
| cross_pe_max_pos=cross_pe_max_pos, |
| use_middle_indices_grid=self.use_middle_indices_grid, |
| audio_cross_attention_dim=self.audio_cross_attention_dim, |
| timestep_scale_multiplier=self.timestep_scale_multiplier, |
| double_precision_rope=self.double_precision_rope, |
| positional_embedding_theta=self.positional_embedding_theta, |
| rope_type=self.rope_type, |
| av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier, |
| caption_projection=getattr(self, "audio_caption_projection", None), |
| prompt_adaln=getattr(self, "audio_prompt_adaln_single", None), |
| ) |
| elif self.model_type.is_video_enabled(): |
| self.video_args_preprocessor = TransformerArgsPreprocessor( |
| patchify_proj=self.patchify_proj, |
| adaln=self.adaln_single, |
| inner_dim=self.inner_dim, |
| max_pos=self.positional_embedding_max_pos, |
| num_attention_heads=self.num_attention_heads, |
| use_middle_indices_grid=self.use_middle_indices_grid, |
| timestep_scale_multiplier=self.timestep_scale_multiplier, |
| double_precision_rope=self.double_precision_rope, |
| positional_embedding_theta=self.positional_embedding_theta, |
| rope_type=self.rope_type, |
| caption_projection=getattr(self, "caption_projection", None), |
| prompt_adaln=getattr(self, "prompt_adaln_single", None), |
| ) |
| elif self.model_type.is_audio_enabled(): |
| self.audio_args_preprocessor = TransformerArgsPreprocessor( |
| patchify_proj=self.audio_patchify_proj, |
| adaln=self.audio_adaln_single, |
| inner_dim=self.audio_inner_dim, |
| max_pos=self.audio_positional_embedding_max_pos, |
| num_attention_heads=self.audio_num_attention_heads, |
| use_middle_indices_grid=self.use_middle_indices_grid, |
| timestep_scale_multiplier=self.timestep_scale_multiplier, |
| double_precision_rope=self.double_precision_rope, |
| positional_embedding_theta=self.positional_embedding_theta, |
| rope_type=self.rope_type, |
| caption_projection=getattr(self, "audio_caption_projection", None), |
| prompt_adaln=getattr(self, "audio_prompt_adaln_single", None), |
| ) |
|
|
| def _init_transformer_blocks( |
| self, |
| num_layers: int, |
| attention_head_dim: int, |
| cross_attention_dim: int, |
| audio_attention_head_dim: int, |
| audio_cross_attention_dim: int, |
| norm_eps: float, |
| apply_gated_attention: bool, |
| ) -> None: |
| """Initialize transformer blocks for LTX.""" |
| video_config = ( |
| TransformerConfig( |
| dim=self.inner_dim, |
| heads=self.num_attention_heads, |
| d_head=attention_head_dim, |
| context_dim=cross_attention_dim, |
| apply_gated_attention=apply_gated_attention, |
| cross_attention_adaln=self.cross_attention_adaln, |
| ) |
| if self.model_type.is_video_enabled() |
| else None |
| ) |
| audio_config = ( |
| TransformerConfig( |
| dim=self.audio_inner_dim, |
| heads=self.audio_num_attention_heads, |
| d_head=audio_attention_head_dim, |
| context_dim=audio_cross_attention_dim, |
| apply_gated_attention=apply_gated_attention, |
| cross_attention_adaln=self.cross_attention_adaln, |
| ) |
| if self.model_type.is_audio_enabled() |
| else None |
| ) |
| self.transformer_blocks = torch.nn.ModuleList( |
| [ |
| BasicAVTransformerBlock( |
| idx=idx, |
| video=video_config, |
| audio=audio_config, |
| rope_type=self.rope_type, |
| norm_eps=norm_eps, |
| ) |
| for idx in range(num_layers) |
| ] |
| ) |
|
|
| def set_gradient_checkpointing(self, enable: bool) -> None: |
| """Enable or disable gradient checkpointing for transformer blocks. |
| Gradient checkpointing trades compute for memory by recomputing activations |
| during the backward pass instead of storing them. This can significantly |
| reduce memory usage at the cost of ~20-30% slower training. |
| Args: |
| enable: Whether to enable gradient checkpointing |
| """ |
| self._enable_gradient_checkpointing = enable |
|
|
| def _process_transformer_blocks( |
| self, |
| video: TransformerArgs | None, |
| audio: TransformerArgs | None, |
| perturbations: BatchedPerturbationConfig, |
| use_gradient_checkpointing: bool = False, |
| use_gradient_checkpointing_offload: bool = False, |
| ) -> tuple[TransformerArgs, TransformerArgs]: |
| """Process transformer blocks for LTXAV.""" |
|
|
| |
| for block in self.transformer_blocks: |
| video, audio = gradient_checkpoint_forward( |
| block, |
| use_gradient_checkpointing, |
| use_gradient_checkpointing_offload, |
| video=video, |
| audio=audio, |
| perturbations=perturbations, |
| ) |
|
|
| return video, audio |
|
|
| def _process_output( |
| self, |
| scale_shift_table: torch.Tensor, |
| norm_out: torch.nn.LayerNorm, |
| proj_out: torch.nn.Linear, |
| x: torch.Tensor, |
| embedded_timestep: torch.Tensor, |
| ) -> torch.Tensor: |
| """Process output for LTXV.""" |
| |
| scale_shift_values = ( |
| scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] |
| ) |
| shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] |
|
|
| x = norm_out(x) |
| x = x * (1 + scale) + shift |
| x = proj_out(x) |
| return x |
|
|
| def _forward( |
| self, |
| video: Modality | None, |
| audio: Modality | None, |
| perturbations: BatchedPerturbationConfig, |
| use_gradient_checkpointing: bool = False, |
| use_gradient_checkpointing_offload: bool = False, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Forward pass for LTX models. |
| Returns: |
| Processed output tensors |
| """ |
| if not self.model_type.is_video_enabled() and video is not None: |
| raise ValueError("Video is not enabled for this model") |
| if not self.model_type.is_audio_enabled() and audio is not None: |
| raise ValueError("Audio is not enabled for this model") |
|
|
| video_args = self.video_args_preprocessor.prepare(video, audio) if video is not None else None |
| audio_args = self.audio_args_preprocessor.prepare(audio, video) if audio is not None else None |
| |
| video_out, audio_out = self._process_transformer_blocks( |
| video=video_args, |
| audio=audio_args, |
| perturbations=perturbations, |
| use_gradient_checkpointing=use_gradient_checkpointing, |
| use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, |
| ) |
|
|
| |
| vx = ( |
| self._process_output( |
| self.scale_shift_table, self.norm_out, self.proj_out, video_out.x, video_out.embedded_timestep |
| ) |
| if video_out is not None |
| else None |
| ) |
| ax = ( |
| self._process_output( |
| self.audio_scale_shift_table, |
| self.audio_norm_out, |
| self.audio_proj_out, |
| audio_out.x, |
| audio_out.embedded_timestep, |
| ) |
| if audio_out is not None |
| else None |
| ) |
| return vx, ax |
|
|
| def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps, sigma, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False): |
| cross_pe_max_pos = None |
| if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled(): |
| cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]) |
| self._init_preprocessors(cross_pe_max_pos) |
| video = Modality(video_latents, sigma, video_timesteps, video_positions, video_context) |
| audio = Modality(audio_latents, sigma, audio_timesteps, audio_positions, audio_context) if audio_latents is not None else None |
| vx, ax = self._forward(video=video, audio=audio, perturbations=None, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload) |
| return vx, ax |
|
|