Instructions to use vidfom/Ltx-3 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- llama-cpp-python
How to use vidfom/Ltx-3 with llama-cpp-python:
# !pip install llama-cpp-python from llama_cpp import Llama llm = Llama.from_pretrained( repo_id="vidfom/Ltx-3", filename="ComfyUI/models/text_encoders/gemma-3-12b-it-qat-UD-Q4_K_XL.gguf", )
llm.create_chat_completion( messages = "No input example has been defined for this model task." )
- Notebooks
- Google Colab
- Kaggle
- Local Apps
- llama.cpp
How to use vidfom/Ltx-3 with llama.cpp:
Install from brew
brew install llama.cpp # Start a local OpenAI-compatible server with a web UI: llama-server -hf vidfom/Ltx-3:UD-Q4_K_XL # Run inference directly in the terminal: llama-cli -hf vidfom/Ltx-3:UD-Q4_K_XL
Install from WinGet (Windows)
winget install llama.cpp # Start a local OpenAI-compatible server with a web UI: llama-server -hf vidfom/Ltx-3:UD-Q4_K_XL # Run inference directly in the terminal: llama-cli -hf vidfom/Ltx-3:UD-Q4_K_XL
Use pre-built binary
# Download pre-built binary from: # https://github.com/ggerganov/llama.cpp/releases # Start a local OpenAI-compatible server with a web UI: ./llama-server -hf vidfom/Ltx-3:UD-Q4_K_XL # Run inference directly in the terminal: ./llama-cli -hf vidfom/Ltx-3:UD-Q4_K_XL
Build from source code
git clone https://github.com/ggerganov/llama.cpp.git cd llama.cpp cmake -B build cmake --build build -j --target llama-server llama-cli # Start a local OpenAI-compatible server with a web UI: ./build/bin/llama-server -hf vidfom/Ltx-3:UD-Q4_K_XL # Run inference directly in the terminal: ./build/bin/llama-cli -hf vidfom/Ltx-3:UD-Q4_K_XL
Use Docker
docker model run hf.co/vidfom/Ltx-3:UD-Q4_K_XL
- LM Studio
- Jan
- Ollama
How to use vidfom/Ltx-3 with Ollama:
ollama run hf.co/vidfom/Ltx-3:UD-Q4_K_XL
- Unsloth Studio
How to use vidfom/Ltx-3 with Unsloth Studio:
Install Unsloth Studio (macOS, Linux, WSL)
curl -fsSL https://unsloth.ai/install.sh | sh # Run unsloth studio unsloth studio -H 0.0.0.0 -p 8888 # Then open http://localhost:8888 in your browser # Search for vidfom/Ltx-3 to start chatting
Install Unsloth Studio (Windows)
irm https://unsloth.ai/install.ps1 | iex # Run unsloth studio unsloth studio -H 0.0.0.0 -p 8888 # Then open http://localhost:8888 in your browser # Search for vidfom/Ltx-3 to start chatting
Using HuggingFace Spaces for Unsloth
# No setup required # Open https://huggingface.co/spaces/unsloth/studio in your browser # Search for vidfom/Ltx-3 to start chatting
- Docker Model Runner
How to use vidfom/Ltx-3 with Docker Model Runner:
docker model run hf.co/vidfom/Ltx-3:UD-Q4_K_XL
- Lemonade
How to use vidfom/Ltx-3 with Lemonade:
Pull the model
# Download Lemonade from https://lemonade-server.ai/ lemonade pull vidfom/Ltx-3:UD-Q4_K_XL
Run and chat with the model
lemonade run user.Ltx-3-UD-Q4_K_XL
List all available models
lemonade list
| from abc import ABC, abstractmethod | |
| from enum import Enum | |
| import functools | |
| import logging | |
| import math | |
| from typing import Dict, Optional, Tuple | |
| from einops import rearrange | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| import comfy.patcher_extension | |
| import comfy.ldm.modules.attention | |
| import comfy.ldm.common_dit | |
| from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords | |
| logger = logging.getLogger(__name__) | |
| def _log_base(x, base): | |
| return np.log(x) / np.log(base) | |
| class LTXRopeType(str, Enum): | |
| INTERLEAVED = "interleaved" | |
| SPLIT = "split" | |
| KEY = "rope_type" | |
| def from_dict(cls, kwargs, default=None): | |
| if default is None: | |
| default = cls.INTERLEAVED | |
| return cls(kwargs.get(cls.KEY, default)) | |
| class LTXFrequenciesPrecision(str, Enum): | |
| FLOAT32 = "float32" | |
| FLOAT64 = "float64" | |
| KEY = "frequencies_precision" | |
| def from_dict(cls, kwargs, default=None): | |
| if default is None: | |
| default = cls.FLOAT32 | |
| return cls(kwargs.get(cls.KEY, default)) | |
| def get_timestep_embedding( | |
| timesteps: torch.Tensor, | |
| embedding_dim: int, | |
| flip_sin_to_cos: bool = False, | |
| downscale_freq_shift: float = 1, | |
| scale: float = 1, | |
| max_period: int = 10000, | |
| ): | |
| """ | |
| This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. | |
| Args | |
| timesteps (torch.Tensor): | |
| a 1-D Tensor of N indices, one per batch element. These may be fractional. | |
| embedding_dim (int): | |
| the dimension of the output. | |
| flip_sin_to_cos (bool): | |
| Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) | |
| downscale_freq_shift (float): | |
| Controls the delta between frequencies between dimensions | |
| scale (float): | |
| Scaling factor applied to the embeddings. | |
| max_period (int): | |
| Controls the maximum frequency of the embeddings | |
| Returns | |
| torch.Tensor: an [N x dim] Tensor of positional embeddings. | |
| """ | |
| assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" | |
| half_dim = embedding_dim // 2 | |
| exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) | |
| exponent = exponent / (half_dim - downscale_freq_shift) | |
| emb = torch.exp(exponent) | |
| emb = timesteps[:, None].float() * emb[None, :] | |
| # scale embeddings | |
| emb = scale * emb | |
| # concat sine and cosine embeddings | |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) | |
| # flip sine and cosine embeddings | |
| if flip_sin_to_cos: | |
| emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) | |
| # zero pad | |
| if embedding_dim % 2 == 1: | |
| emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) | |
| return emb | |
| class TimestepEmbedding(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| time_embed_dim: int, | |
| act_fn: str = "silu", | |
| out_dim: int = None, | |
| post_act_fn: Optional[str] = None, | |
| cond_proj_dim=None, | |
| sample_proj_bias=True, | |
| dtype=None, | |
| device=None, | |
| operations=None, | |
| ): | |
| super().__init__() | |
| self.linear_1 = operations.Linear(in_channels, time_embed_dim, sample_proj_bias, dtype=dtype, device=device) | |
| if cond_proj_dim is not None: | |
| self.cond_proj = operations.Linear(cond_proj_dim, in_channels, bias=False, dtype=dtype, device=device) | |
| else: | |
| self.cond_proj = None | |
| self.act = nn.SiLU() | |
| if out_dim is not None: | |
| time_embed_dim_out = out_dim | |
| else: | |
| time_embed_dim_out = time_embed_dim | |
| self.linear_2 = operations.Linear( | |
| time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device | |
| ) | |
| if post_act_fn is None: | |
| self.post_act = None | |
| # else: | |
| # self.post_act = get_activation(post_act_fn) | |
| def forward(self, sample, condition=None): | |
| if condition is not None: | |
| sample = sample + self.cond_proj(condition) | |
| sample = self.linear_1(sample) | |
| if self.act is not None: | |
| sample = self.act(sample) | |
| sample = self.linear_2(sample) | |
| if self.post_act is not None: | |
| sample = self.post_act(sample) | |
| return sample | |
| class Timesteps(nn.Module): | |
| def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): | |
| super().__init__() | |
| self.num_channels = num_channels | |
| self.flip_sin_to_cos = flip_sin_to_cos | |
| self.downscale_freq_shift = downscale_freq_shift | |
| self.scale = scale | |
| def forward(self, timesteps): | |
| t_emb = get_timestep_embedding( | |
| timesteps, | |
| self.num_channels, | |
| flip_sin_to_cos=self.flip_sin_to_cos, | |
| downscale_freq_shift=self.downscale_freq_shift, | |
| scale=self.scale, | |
| ) | |
| return t_emb | |
| class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): | |
| """ | |
| For PixArt-Alpha. | |
| Reference: | |
| https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 | |
| """ | |
| def __init__( | |
| self, | |
| embedding_dim, | |
| size_emb_dim, | |
| use_additional_conditions: bool = False, | |
| dtype=None, | |
| device=None, | |
| operations=None, | |
| ): | |
| super().__init__() | |
| self.outdim = size_emb_dim | |
| self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | |
| self.timestep_embedder = TimestepEmbedding( | |
| in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations | |
| ) | |
| def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): | |
| timesteps_proj = self.time_proj(timestep) | |
| timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) | |
| return timesteps_emb | |
| class AdaLayerNormSingle(nn.Module): | |
| r""" | |
| Norm layer adaptive layer norm single (adaLN-single). | |
| As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). | |
| Parameters: | |
| embedding_dim (`int`): The size of each embedding vector. | |
| use_additional_conditions (`bool`): To use additional conditions for normalization or not. | |
| """ | |
| def __init__( | |
| self, embedding_dim: int, embedding_coefficient: int = 6, use_additional_conditions: bool = False, dtype=None, device=None, operations=None | |
| ): | |
| super().__init__() | |
| self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( | |
| embedding_dim, | |
| size_emb_dim=embedding_dim // 3, | |
| use_additional_conditions=use_additional_conditions, | |
| dtype=dtype, | |
| device=device, | |
| operations=operations, | |
| ) | |
| self.silu = nn.SiLU() | |
| self.linear = operations.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True, dtype=dtype, device=device) | |
| def forward( | |
| self, | |
| timestep: torch.Tensor, | |
| added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, | |
| batch_size: Optional[int] = None, | |
| hidden_dtype: Optional[torch.dtype] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| # No modulation happening here. | |
| added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None} | |
| embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) | |
| return self.linear(self.silu(embedded_timestep)), embedded_timestep | |
| class PixArtAlphaTextProjection(nn.Module): | |
| """ | |
| Projects caption embeddings. Also handles dropout for classifier-free guidance. | |
| Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py | |
| """ | |
| def __init__( | |
| self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None | |
| ): | |
| super().__init__() | |
| if out_features is None: | |
| out_features = hidden_size | |
| self.linear_1 = operations.Linear( | |
| in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device | |
| ) | |
| if act_fn == "gelu_tanh": | |
| self.act_1 = nn.GELU(approximate="tanh") | |
| elif act_fn == "silu": | |
| self.act_1 = nn.SiLU() | |
| else: | |
| raise ValueError(f"Unknown activation function: {act_fn}") | |
| self.linear_2 = operations.Linear( | |
| in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device | |
| ) | |
| def forward(self, caption): | |
| hidden_states = self.linear_1(caption) | |
| hidden_states = self.act_1(hidden_states) | |
| hidden_states = self.linear_2(hidden_states) | |
| return hidden_states | |
| class NormSingleLinearTextProjection(nn.Module): | |
| """Text projection for 20B models - single linear with RMSNorm (no activation).""" | |
| def __init__( | |
| self, in_features, hidden_size, dtype=None, device=None, operations=None | |
| ): | |
| super().__init__() | |
| if operations is None: | |
| operations = comfy.ops.disable_weight_init | |
| self.in_norm = operations.RMSNorm( | |
| in_features, eps=1e-6, elementwise_affine=False | |
| ) | |
| self.linear_1 = operations.Linear( | |
| in_features, hidden_size, bias=True, dtype=dtype, device=device | |
| ) | |
| self.hidden_size = hidden_size | |
| self.in_features = in_features | |
| def forward(self, caption): | |
| caption = self.in_norm(caption) | |
| caption = caption * (self.hidden_size / self.in_features) ** 0.5 | |
| return self.linear_1(caption) | |
| class GELU_approx(nn.Module): | |
| def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None): | |
| super().__init__() | |
| self.proj = operations.Linear(dim_in, dim_out, dtype=dtype, device=device) | |
| def forward(self, x): | |
| return torch.nn.functional.gelu(self.proj(x), approximate="tanh") | |
| class FeedForward(nn.Module): | |
| def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0.0, dtype=None, device=None, operations=None): | |
| super().__init__() | |
| inner_dim = int(dim * mult) | |
| project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations) | |
| self.net = nn.Sequential( | |
| project_in, nn.Dropout(dropout), operations.Linear(inner_dim, dim_out, dtype=dtype, device=device) | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| def apply_rotary_emb(input_tensor, freqs_cis): | |
| cos_freqs, sin_freqs = freqs_cis[0], freqs_cis[1] | |
| split_pe = freqs_cis[2] if len(freqs_cis) > 2 else False | |
| return ( | |
| apply_split_rotary_emb(input_tensor, cos_freqs, sin_freqs) | |
| if split_pe else | |
| apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs) | |
| ) | |
| def apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs): # TODO: remove duplicate funcs and pick the best/fastest one | |
| t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) | |
| t1, t2 = t_dup.unbind(dim=-1) | |
| t_dup = torch.stack((-t2, t1), dim=-1) | |
| input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") | |
| out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs | |
| return out | |
| def apply_split_rotary_emb(input_tensor, cos, sin): | |
| needs_reshape = False | |
| if input_tensor.ndim != 4 and cos.ndim == 4: | |
| B, H, T, _ = cos.shape | |
| input_tensor = input_tensor.reshape(B, T, H, -1).swapaxes(1, 2) | |
| needs_reshape = True | |
| split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2) | |
| first_half_input = split_input[..., :1, :] | |
| second_half_input = split_input[..., 1:, :] | |
| output = split_input * cos.unsqueeze(-2) | |
| first_half_output = output[..., :1, :] | |
| second_half_output = output[..., 1:, :] | |
| first_half_output.addcmul_(-sin.unsqueeze(-2), second_half_input) | |
| second_half_output.addcmul_(sin.unsqueeze(-2), first_half_input) | |
| output = rearrange(output, "... d r -> ... (d r)") | |
| return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output | |
| class CrossAttention(nn.Module): | |
| def __init__( | |
| self, | |
| query_dim, | |
| context_dim=None, | |
| heads=8, | |
| dim_head=64, | |
| dropout=0.0, | |
| attn_precision=None, | |
| apply_gated_attention=False, | |
| dtype=None, | |
| device=None, | |
| operations=None, | |
| ): | |
| super().__init__() | |
| inner_dim = dim_head * heads | |
| context_dim = query_dim if context_dim is None else context_dim | |
| self.attn_precision = attn_precision | |
| self.heads = heads | |
| self.dim_head = dim_head | |
| self.q_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device) | |
| self.k_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device) | |
| self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device) | |
| self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) | |
| self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) | |
| # Optional per-head gating | |
| if apply_gated_attention: | |
| self.to_gate_logits = operations.Linear(query_dim, heads, bias=True, dtype=dtype, device=device) | |
| else: | |
| self.to_gate_logits = None | |
| self.to_out = nn.Sequential( | |
| operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout) | |
| ) | |
| def forward(self, x, context=None, mask=None, pe=None, k_pe=None, transformer_options={}): | |
| q = self.to_q(x) | |
| context = x if context is None else context | |
| k = self.to_k(context) | |
| v = self.to_v(context) | |
| q = self.q_norm(q) | |
| k = self.k_norm(k) | |
| if pe is not None: | |
| q = apply_rotary_emb(q, pe) | |
| k = apply_rotary_emb(k, pe if k_pe is None else k_pe) | |
| if mask is None: | |
| out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) | |
| else: | |
| out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) | |
| # Apply per-head gating if enabled | |
| if self.to_gate_logits is not None: | |
| gate_logits = self.to_gate_logits(x) # (B, T, H) | |
| b, t, _ = out.shape | |
| out = out.view(b, t, self.heads, self.dim_head) | |
| gates = 2.0 * torch.sigmoid(gate_logits) # zero-init -> identity | |
| out = out * gates.unsqueeze(-1) | |
| out = out.view(b, t, self.heads * self.dim_head) | |
| return self.to_out(out) | |
| # 6 base ADaLN params (shift/scale/gate for MSA + MLP), +3 for cross-attention Q (shift/scale/gate) | |
| ADALN_BASE_PARAMS_COUNT = 6 | |
| ADALN_CROSS_ATTN_PARAMS_COUNT = 9 | |
| class BasicTransformerBlock(nn.Module): | |
| def __init__( | |
| self, dim, n_heads, d_head, context_dim=None, attn_precision=None, cross_attention_adaln=False, dtype=None, device=None, operations=None | |
| ): | |
| super().__init__() | |
| self.attn_precision = attn_precision | |
| self.cross_attention_adaln = cross_attention_adaln | |
| self.attn1 = CrossAttention( | |
| query_dim=dim, | |
| heads=n_heads, | |
| dim_head=d_head, | |
| context_dim=None, | |
| attn_precision=self.attn_precision, | |
| dtype=dtype, | |
| device=device, | |
| operations=operations, | |
| ) | |
| self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations) | |
| self.attn2 = CrossAttention( | |
| query_dim=dim, | |
| context_dim=context_dim, | |
| heads=n_heads, | |
| dim_head=d_head, | |
| attn_precision=self.attn_precision, | |
| dtype=dtype, | |
| device=device, | |
| operations=operations, | |
| ) | |
| num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT | |
| self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, dim, device=device, dtype=dtype)) | |
| if cross_attention_adaln: | |
| self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype)) | |
| def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None, prompt_timestep=None): | |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None, :6].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, :6, :]).unbind(dim=2) | |
| x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) * gate_msa | |
| if self.cross_attention_adaln: | |
| shift_q_mca, scale_q_mca, gate_mca = (self.scale_shift_table[None, None, 6:9].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, 6:9, :]).unbind(dim=2) | |
| x += apply_cross_attention_adaln( | |
| x, context, self.attn2, shift_q_mca, scale_q_mca, gate_mca, | |
| self.prompt_scale_shift_table, prompt_timestep, attention_mask, transformer_options, | |
| ) | |
| else: | |
| x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) | |
| y = comfy.ldm.common_dit.rms_norm(x) | |
| y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp) | |
| x.addcmul_(self.ff(y), gate_mlp) | |
| return x | |
| def compute_prompt_timestep(adaln_module, timestep_scaled, batch_size, hidden_dtype): | |
| """Compute a single global prompt timestep for cross-attention ADaLN. | |
| Uses the max across tokens (matching JAX max_per_segment) and broadcasts | |
| over text tokens. Returns None when *adaln_module* is None. | |
| """ | |
| if adaln_module is None: | |
| return None | |
| ts_input = ( | |
| timestep_scaled.max(dim=1, keepdim=True).values.flatten() | |
| if timestep_scaled.dim() > 1 | |
| else timestep_scaled.flatten() | |
| ) | |
| prompt_ts, _ = adaln_module( | |
| ts_input, | |
| {"resolution": None, "aspect_ratio": None}, | |
| batch_size=batch_size, | |
| hidden_dtype=hidden_dtype, | |
| ) | |
| return prompt_ts.view(batch_size, 1, prompt_ts.shape[-1]) | |
| def apply_cross_attention_adaln( | |
| x, context, attn, q_shift, q_scale, q_gate, | |
| prompt_scale_shift_table, prompt_timestep, | |
| attention_mask=None, transformer_options={}, | |
| ): | |
| """Apply cross-attention with ADaLN modulation (shift/scale/gate on Q and KV). | |
| Q params (q_shift, q_scale, q_gate) are pre-extracted by the caller so | |
| that both regular tensors and CompressedTimestep are supported. | |
| """ | |
| batch_size = x.shape[0] | |
| shift_kv, scale_kv = ( | |
| prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) | |
| + prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1) | |
| ).unbind(dim=2) | |
| attn_input = comfy.ldm.common_dit.rms_norm(x) * (1 + q_scale) + q_shift | |
| encoder_hidden_states = context * (1 + scale_kv) + shift_kv | |
| return attn(attn_input, context=encoder_hidden_states, mask=attention_mask, transformer_options=transformer_options) * q_gate | |
| def get_fractional_positions(indices_grid, max_pos): | |
| n_pos_dims = indices_grid.shape[1] | |
| assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})' | |
| fractional_positions = torch.stack( | |
| [indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)], | |
| axis=-1, | |
| ) | |
| return fractional_positions | |
| def generate_freq_grid_np(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, _ = None): | |
| theta = positional_embedding_theta | |
| start = 1 | |
| end = theta | |
| n_elem = 2 * positional_embedding_max_pos_count | |
| pow_indices = np.power( | |
| theta, | |
| np.linspace( | |
| _log_base(start, theta), | |
| _log_base(end, theta), | |
| inner_dim // n_elem, | |
| dtype=np.float64, | |
| ), | |
| ) | |
| return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32) | |
| def generate_freq_grid_pytorch(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, device): | |
| theta = positional_embedding_theta | |
| start = 1 | |
| end = theta | |
| n_elem = 2 * positional_embedding_max_pos_count | |
| indices = theta ** ( | |
| torch.linspace( | |
| math.log(start, theta), | |
| math.log(end, theta), | |
| inner_dim // n_elem, | |
| device=device, | |
| dtype=torch.float32, | |
| ) | |
| ) | |
| indices = indices.to(dtype=torch.float32) | |
| indices = indices * math.pi / 2 | |
| return indices | |
| def generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid): | |
| if use_middle_indices_grid: | |
| assert(len(indices_grid.shape) == 4 and indices_grid.shape[-1] ==2) | |
| indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1] | |
| indices_grid = (indices_grid_start + indices_grid_end) / 2.0 | |
| elif len(indices_grid.shape) == 4: | |
| indices_grid = indices_grid[..., 0] | |
| # Get fractional positions and compute frequency indices | |
| fractional_positions = get_fractional_positions(indices_grid, max_pos) | |
| indices = indices.to(device=fractional_positions.device) | |
| freqs = ( | |
| (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) | |
| .transpose(-1, -2) | |
| .flatten(2) | |
| ) | |
| return freqs | |
| def interleaved_freqs_cis(freqs, pad_size): | |
| cos_freq = freqs.cos().repeat_interleave(2, dim=-1) | |
| sin_freq = freqs.sin().repeat_interleave(2, dim=-1) | |
| if pad_size != 0: | |
| cos_padding = torch.ones_like(cos_freq[:, :, : pad_size]) | |
| sin_padding = torch.zeros_like(cos_freq[:, :, : pad_size]) | |
| cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) | |
| sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) | |
| return cos_freq, sin_freq | |
| def split_freqs_cis(freqs, pad_size, num_attention_heads): | |
| cos_freq = freqs.cos() | |
| sin_freq = freqs.sin() | |
| if pad_size != 0: | |
| cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) | |
| sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) | |
| cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) | |
| sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) | |
| # Reshape freqs to be compatible with multi-head attention | |
| B , T, half_HD = cos_freq.shape | |
| cos_freq = cos_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads) | |
| sin_freq = sin_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads) | |
| cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) | |
| sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) | |
| return cos_freq, sin_freq | |
| class LTXBaseModel(torch.nn.Module, ABC): | |
| """ | |
| Abstract base class for LTX models (Lightricks Transformer models). | |
| This class defines the common interface and shared functionality for all LTX models, | |
| including LTXV (video) and LTXAV (audio-video) variants. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| cross_attention_dim: int, | |
| attention_head_dim: int, | |
| num_attention_heads: int, | |
| caption_channels: int, | |
| num_layers: int, | |
| positional_embedding_theta: float = 10000.0, | |
| positional_embedding_max_pos: list = [20, 2048, 2048], | |
| causal_temporal_positioning: bool = False, | |
| vae_scale_factors: tuple = (8, 32, 32), | |
| use_middle_indices_grid=False, | |
| timestep_scale_multiplier = 1000.0, | |
| caption_proj_before_connector=False, | |
| cross_attention_adaln=False, | |
| caption_projection_first_linear=True, | |
| dtype=None, | |
| device=None, | |
| operations=None, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.generator = None | |
| self.vae_scale_factors = vae_scale_factors | |
| self.use_middle_indices_grid = use_middle_indices_grid | |
| self.dtype = dtype | |
| self.in_channels = in_channels | |
| self.cross_attention_dim = cross_attention_dim | |
| self.attention_head_dim = attention_head_dim | |
| self.num_attention_heads = num_attention_heads | |
| self.caption_channels = caption_channels | |
| self.num_layers = num_layers | |
| self.positional_embedding_theta = positional_embedding_theta | |
| self.positional_embedding_max_pos = positional_embedding_max_pos | |
| self.split_positional_embedding = LTXRopeType.from_dict(kwargs) | |
| self.freq_grid_generator = ( | |
| generate_freq_grid_np if LTXFrequenciesPrecision.from_dict(kwargs) == LTXFrequenciesPrecision.FLOAT64 | |
| else generate_freq_grid_pytorch | |
| ) | |
| self.causal_temporal_positioning = causal_temporal_positioning | |
| self.operations = operations | |
| self.timestep_scale_multiplier = timestep_scale_multiplier | |
| self.caption_proj_before_connector = caption_proj_before_connector | |
| self.cross_attention_adaln = cross_attention_adaln | |
| self.caption_projection_first_linear = caption_projection_first_linear | |
| # Common dimensions | |
| self.inner_dim = num_attention_heads * attention_head_dim | |
| self.out_channels = in_channels | |
| # Initialize common components | |
| self._init_common_components(device, dtype) | |
| # Initialize model-specific components | |
| self._init_model_components(device, dtype, **kwargs) | |
| # Initialize transformer blocks | |
| self._init_transformer_blocks(device, dtype, **kwargs) | |
| # Initialize output components | |
| self._init_output_components(device, dtype) | |
| def _init_common_components(self, device, dtype): | |
| """Initialize components common to all LTX models | |
| - patchify_proj: Linear projection for patchifying input | |
| - adaln_single: AdaLN layer for timestep embedding | |
| - caption_projection: Linear projection for caption embedding | |
| """ | |
| self.patchify_proj = self.operations.Linear( | |
| self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device | |
| ) | |
| embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT | |
| self.adaln_single = AdaLayerNormSingle( | |
| self.inner_dim, embedding_coefficient=embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations | |
| ) | |
| if self.cross_attention_adaln: | |
| self.prompt_adaln_single = AdaLayerNormSingle( | |
| self.inner_dim, embedding_coefficient=2, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations | |
| ) | |
| else: | |
| self.prompt_adaln_single = None | |
| if self.caption_proj_before_connector: | |
| if self.caption_projection_first_linear: | |
| self.caption_projection = NormSingleLinearTextProjection( | |
| in_features=self.caption_channels, | |
| hidden_size=self.inner_dim, | |
| dtype=dtype, | |
| device=device, | |
| operations=self.operations, | |
| ) | |
| else: | |
| self.caption_projection = lambda a: a | |
| else: | |
| self.caption_projection = PixArtAlphaTextProjection( | |
| in_features=self.caption_channels, | |
| hidden_size=self.inner_dim, | |
| dtype=dtype, | |
| device=device, | |
| operations=self.operations, | |
| ) | |
| def _init_model_components(self, device, dtype, **kwargs): | |
| """Initialize model-specific components. Must be implemented by subclasses.""" | |
| pass | |
| def _init_transformer_blocks(self, device, dtype, **kwargs): | |
| """Initialize transformer blocks. Must be implemented by subclasses.""" | |
| pass | |
| def _init_output_components(self, device, dtype): | |
| """Initialize output components. Must be implemented by subclasses.""" | |
| pass | |
| def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs): | |
| """Process input data. Must be implemented by subclasses.""" | |
| pass | |
| def _build_guide_self_attention_mask(self, x, transformer_options, merged_args): | |
| """Build self-attention mask for per-guide attention attenuation. | |
| Base implementation returns None (no attenuation). Subclasses that | |
| support guide-based attention control should override this. | |
| """ | |
| return None | |
| def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, self_attention_mask=None, **kwargs): | |
| """Process transformer blocks. Must be implemented by subclasses.""" | |
| pass | |
| def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): | |
| """Process output data. Must be implemented by subclasses.""" | |
| pass | |
| def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs): | |
| """Prepare timestep embeddings.""" | |
| grid_mask = kwargs.get("grid_mask", None) | |
| if grid_mask is not None: | |
| timestep = timestep[:, grid_mask] | |
| timestep_scaled = timestep * self.timestep_scale_multiplier | |
| timestep, embedded_timestep = self.adaln_single( | |
| timestep_scaled.flatten(), | |
| {"resolution": None, "aspect_ratio": None}, | |
| batch_size=batch_size, | |
| hidden_dtype=hidden_dtype, | |
| ) | |
| # Second dimension is 1 or number of tokens (if timestep_per_token) | |
| timestep = timestep.view(batch_size, -1, timestep.shape[-1]) | |
| embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1]) | |
| prompt_timestep = compute_prompt_timestep( | |
| self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype | |
| ) | |
| return timestep, embedded_timestep, prompt_timestep | |
| def _prepare_context(self, context, batch_size, x, attention_mask=None): | |
| """Prepare context for transformer blocks.""" | |
| if self.caption_proj_before_connector is False: | |
| context = self.caption_projection(context) | |
| context = context.view(batch_size, -1, x.shape[-1]) | |
| return context, attention_mask | |
| def _precompute_freqs_cis( | |
| self, | |
| indices_grid, | |
| dim, | |
| out_dtype, | |
| theta=10000.0, | |
| max_pos=[20, 2048, 2048], | |
| use_middle_indices_grid=False, | |
| num_attention_heads=32, | |
| ): | |
| split_mode = self.split_positional_embedding == LTXRopeType.SPLIT | |
| indices = self.freq_grid_generator(theta, indices_grid.shape[1], dim, indices_grid.device) | |
| freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid) | |
| if split_mode: | |
| expected_freqs = dim // 2 | |
| current_freqs = freqs.shape[-1] | |
| pad_size = expected_freqs - current_freqs | |
| cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads) | |
| else: | |
| # 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only | |
| n_elem = 2 * indices_grid.shape[1] | |
| cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) | |
| return cos_freq.to(out_dtype), sin_freq.to(out_dtype), split_mode | |
| def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype): | |
| """Prepare positional embeddings.""" | |
| fractional_coords = pixel_coords.to(torch.float32) | |
| fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) | |
| pe = self._precompute_freqs_cis( | |
| fractional_coords, | |
| dim=self.inner_dim, | |
| out_dtype=x_dtype, | |
| max_pos=self.positional_embedding_max_pos, | |
| use_middle_indices_grid=self.use_middle_indices_grid, | |
| num_attention_heads=self.num_attention_heads, | |
| ) | |
| return pe | |
| def _prepare_attention_mask(self, attention_mask, x_dtype): | |
| """Prepare attention mask.""" | |
| if attention_mask is not None and not torch.is_floating_point(attention_mask): | |
| attention_mask = (attention_mask - 1).to(x_dtype).reshape( | |
| (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) | |
| ) * torch.finfo(x_dtype).max | |
| return attention_mask | |
| def forward( | |
| self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs | |
| ): | |
| """ | |
| Forward pass for LTX models. | |
| Args: | |
| x: Input tensor | |
| timestep: Timestep tensor | |
| context: Context tensor (e.g., text embeddings) | |
| attention_mask: Attention mask tensor | |
| frame_rate: Frame rate for temporal processing | |
| transformer_options: Additional options for transformer blocks | |
| keyframe_idxs: Keyframe indices for temporal processing | |
| **kwargs: Additional keyword arguments | |
| Returns: | |
| Processed output tensor | |
| """ | |
| return comfy.patcher_extension.WrapperExecutor.new_class_executor( | |
| self._forward, | |
| self, | |
| comfy.patcher_extension.get_all_wrappers( | |
| comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options | |
| ), | |
| ).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, denoise_mask=denoise_mask, **kwargs) | |
| def _forward( | |
| self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs | |
| ): | |
| """ | |
| Internal forward pass for LTX models. | |
| Args: | |
| x: Input tensor | |
| timestep: Timestep tensor | |
| context: Context tensor (e.g., text embeddings) | |
| attention_mask: Attention mask tensor | |
| frame_rate: Frame rate for temporal processing | |
| transformer_options: Additional options for transformer blocks | |
| keyframe_idxs: Keyframe indices for temporal processing | |
| **kwargs: Additional keyword arguments | |
| Returns: | |
| Processed output tensor | |
| """ | |
| if isinstance(x, list): | |
| input_dtype = x[0].dtype | |
| batch_size = x[0].shape[0] | |
| else: | |
| input_dtype = x.dtype | |
| batch_size = x.shape[0] | |
| # Process input | |
| merged_args = {**transformer_options, **kwargs} | |
| x, pixel_coords, additional_args = self._process_input(x, keyframe_idxs, denoise_mask, **merged_args) | |
| merged_args.update(additional_args) | |
| # Prepare timestep and context | |
| timestep, embedded_timestep, prompt_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args) | |
| merged_args["prompt_timestep"] = prompt_timestep | |
| context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask) | |
| # Prepare attention mask and positional embeddings | |
| attention_mask = self._prepare_attention_mask(attention_mask, input_dtype) | |
| pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype) | |
| # Build self-attention mask for per-guide attenuation | |
| self_attention_mask = self._build_guide_self_attention_mask( | |
| x, transformer_options, merged_args | |
| ) | |
| # Process transformer blocks | |
| x = self._process_transformer_blocks( | |
| x, context, attention_mask, timestep, pe, | |
| transformer_options=transformer_options, | |
| self_attention_mask=self_attention_mask, | |
| **merged_args, | |
| ) | |
| # Process output | |
| x = self._process_output(x, embedded_timestep, keyframe_idxs, **merged_args) | |
| return x | |
| class LTXVModel(LTXBaseModel): | |
| """LTXV model for video generation.""" | |
| def __init__( | |
| self, | |
| in_channels=128, | |
| cross_attention_dim=2048, | |
| attention_head_dim=64, | |
| num_attention_heads=32, | |
| caption_channels=4096, | |
| num_layers=28, | |
| positional_embedding_theta=10000.0, | |
| positional_embedding_max_pos=[20, 2048, 2048], | |
| causal_temporal_positioning=False, | |
| vae_scale_factors=(8, 32, 32), | |
| use_middle_indices_grid=False, | |
| timestep_scale_multiplier=1000.0, | |
| caption_proj_before_connector=False, | |
| cross_attention_adaln=False, | |
| dtype=None, | |
| device=None, | |
| operations=None, | |
| **kwargs, | |
| ): | |
| super().__init__( | |
| in_channels=in_channels, | |
| cross_attention_dim=cross_attention_dim, | |
| attention_head_dim=attention_head_dim, | |
| num_attention_heads=num_attention_heads, | |
| caption_channels=caption_channels, | |
| num_layers=num_layers, | |
| positional_embedding_theta=positional_embedding_theta, | |
| positional_embedding_max_pos=positional_embedding_max_pos, | |
| causal_temporal_positioning=causal_temporal_positioning, | |
| vae_scale_factors=vae_scale_factors, | |
| use_middle_indices_grid=use_middle_indices_grid, | |
| timestep_scale_multiplier=timestep_scale_multiplier, | |
| caption_proj_before_connector=caption_proj_before_connector, | |
| cross_attention_adaln=cross_attention_adaln, | |
| dtype=dtype, | |
| device=device, | |
| operations=operations, | |
| **kwargs, | |
| ) | |
| def _init_model_components(self, device, dtype, **kwargs): | |
| """Initialize LTXV-specific components.""" | |
| pass | |
| def _init_transformer_blocks(self, device, dtype, **kwargs): | |
| """Initialize transformer blocks for LTXV.""" | |
| self.transformer_blocks = nn.ModuleList( | |
| [ | |
| BasicTransformerBlock( | |
| self.inner_dim, | |
| self.num_attention_heads, | |
| self.attention_head_dim, | |
| context_dim=self.cross_attention_dim, | |
| cross_attention_adaln=self.cross_attention_adaln, | |
| dtype=dtype, | |
| device=device, | |
| operations=self.operations, | |
| ) | |
| for _ in range(self.num_layers) | |
| ] | |
| ) | |
| def _init_output_components(self, device, dtype): | |
| """Initialize output components for LTXV.""" | |
| self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device)) | |
| self.norm_out = self.operations.LayerNorm( | |
| self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device | |
| ) | |
| self.proj_out = self.operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device) | |
| self.patchifier = SymmetricPatchifier(1, start_end=True) | |
| def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs): | |
| """Process input for LTXV.""" | |
| additional_args = {"orig_shape": list(x.shape)} | |
| x, latent_coords = self.patchifier.patchify(x) | |
| pixel_coords = latent_to_pixel_coords( | |
| latent_coords=latent_coords, | |
| scale_factors=self.vae_scale_factors, | |
| causal_fix=self.causal_temporal_positioning, | |
| ) | |
| grid_mask = None | |
| if keyframe_idxs is not None: | |
| additional_args.update({ "orig_patchified_shape": list(x.shape)}) | |
| denoise_mask = self.patchifier.patchify(denoise_mask)[0] | |
| grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0] | |
| additional_args.update({"grid_mask": grid_mask}) | |
| x = x[:, grid_mask, :] | |
| pixel_coords = pixel_coords[:, :, grid_mask, ...] | |
| kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:] | |
| # Compute per-guide surviving token counts from guide_attention_entries. | |
| # Each entry tracks one guide reference; they are appended in order and | |
| # their pre_filter_counts partition the kf_grid_mask. | |
| guide_entries = kwargs.get("guide_attention_entries", None) | |
| if guide_entries: | |
| total_pfc = sum(e["pre_filter_count"] for e in guide_entries) | |
| if total_pfc != len(kf_grid_mask): | |
| raise ValueError( | |
| f"guide pre_filter_counts ({total_pfc}) != " | |
| f"keyframe grid mask length ({len(kf_grid_mask)})" | |
| ) | |
| resolved_entries = [] | |
| offset = 0 | |
| for entry in guide_entries: | |
| pfc = entry["pre_filter_count"] | |
| entry_mask = kf_grid_mask[offset:offset + pfc] | |
| surviving = int(entry_mask.sum().item()) | |
| resolved_entries.append({ | |
| **entry, | |
| "surviving_count": surviving, | |
| }) | |
| offset += pfc | |
| additional_args["resolved_guide_entries"] = resolved_entries | |
| keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :] | |
| pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs | |
| # Total surviving guide tokens (all guides) | |
| additional_args["num_guide_tokens"] = keyframe_idxs.shape[2] | |
| x = self.patchify_proj(x) | |
| return x, pixel_coords, additional_args | |
| def _build_guide_self_attention_mask(self, x, transformer_options, merged_args): | |
| """Build self-attention mask for per-guide attention attenuation. | |
| Reads resolved_guide_entries from merged_args (computed in _process_input) | |
| to build a log-space additive bias mask that attenuates noisy ↔ guide | |
| attention for each guide reference independently. | |
| Returns None if no attenuation is needed (all strengths == 1.0 and no | |
| spatial masks, or no guide tokens). | |
| """ | |
| if isinstance(x, list): | |
| # AV model: x = [vx, ax]; use vx for token count and device | |
| total_tokens = x[0].shape[1] | |
| device = x[0].device | |
| dtype = x[0].dtype | |
| else: | |
| total_tokens = x.shape[1] | |
| device = x.device | |
| dtype = x.dtype | |
| num_guide_tokens = merged_args.get("num_guide_tokens", 0) | |
| if num_guide_tokens == 0: | |
| return None | |
| resolved_entries = merged_args.get("resolved_guide_entries", None) | |
| if not resolved_entries: | |
| return None | |
| # Check if any attenuation is actually needed | |
| needs_attenuation = any( | |
| e["strength"] < 1.0 or e.get("pixel_mask") is not None | |
| for e in resolved_entries | |
| ) | |
| if not needs_attenuation: | |
| return None | |
| # Build per-guide-token weights for all tracked guide tokens. | |
| # Guides are appended in order at the end of the sequence. | |
| guide_start = total_tokens - num_guide_tokens | |
| all_weights = [] | |
| total_tracked = 0 | |
| for entry in resolved_entries: | |
| surviving = entry["surviving_count"] | |
| if surviving == 0: | |
| continue | |
| strength = entry["strength"] | |
| pixel_mask = entry.get("pixel_mask") | |
| latent_shape = entry.get("latent_shape") | |
| if pixel_mask is not None and latent_shape is not None: | |
| f_lat, h_lat, w_lat = latent_shape | |
| per_token = self._downsample_mask_to_latent( | |
| pixel_mask.to(device=device, dtype=dtype), | |
| f_lat, h_lat, w_lat, | |
| ) | |
| # per_token shape: (B, f_lat*h_lat*w_lat). | |
| # Collapse batch dim — the mask is assumed identical across the | |
| # batch; validate and take the first element to get (1, tokens). | |
| if per_token.shape[0] > 1: | |
| ref = per_token[0] | |
| for bi in range(1, per_token.shape[0]): | |
| if not torch.equal(ref, per_token[bi]): | |
| logger.warning( | |
| "pixel_mask differs across batch elements; " | |
| "using first element only." | |
| ) | |
| break | |
| per_token = per_token[:1] | |
| # `surviving` is the post-grid_mask token count. | |
| # Clamp to surviving to handle any mismatch safely. | |
| n_weights = min(per_token.shape[1], surviving) | |
| weights = per_token[:, :n_weights] * strength # (1, n_weights) | |
| else: | |
| weights = torch.full( | |
| (1, surviving), strength, device=device, dtype=dtype | |
| ) | |
| all_weights.append(weights) | |
| total_tracked += weights.shape[1] | |
| if not all_weights: | |
| return None | |
| # Concatenate per-token weights for all tracked guides | |
| tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked) | |
| # Check if any weight is actually < 1.0 (otherwise no attenuation needed) | |
| if (tracked_weights >= 1.0).all(): | |
| return None | |
| # Build the mask: guide tokens are at the end of the sequence. | |
| # Tracked guides come first (in order), untracked follow. | |
| return self._build_self_attention_mask( | |
| total_tokens, num_guide_tokens, total_tracked, | |
| tracked_weights, guide_start, device, dtype, | |
| ) | |
| def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat): | |
| """Downsample a pixel-space mask to per-token latent weights. | |
| Args: | |
| mask: (B, 1, F_pix, H_pix, W_pix) pixel-space mask with values in [0, 1]. | |
| f_lat: Number of latent frames (pre-dilation original count). | |
| h_lat: Latent height (pre-dilation original height). | |
| w_lat: Latent width (pre-dilation original width). | |
| Returns: | |
| (B, F_lat * H_lat * W_lat) flattened per-token weights. | |
| """ | |
| b = mask.shape[0] | |
| f_pix = mask.shape[2] | |
| # Spatial downsampling: area interpolation per frame | |
| spatial_down = torch.nn.functional.interpolate( | |
| rearrange(mask, "b 1 f h w -> (b f) 1 h w"), | |
| size=(h_lat, w_lat), | |
| mode="area", | |
| ) | |
| spatial_down = rearrange(spatial_down, "(b f) 1 h w -> b 1 f h w", b=b) | |
| # Temporal downsampling: first pixel frame maps to first latent frame, | |
| # remaining pixel frames are averaged in groups for causal temporal structure. | |
| first_frame = spatial_down[:, :, :1, :, :] | |
| if f_pix > 1 and f_lat > 1: | |
| remaining_pix = f_pix - 1 | |
| remaining_lat = f_lat - 1 | |
| t = remaining_pix // remaining_lat | |
| if t < 1: | |
| # Fewer pixel frames than latent frames — upsample by repeating | |
| # the available pixel frames via nearest interpolation. | |
| rest_flat = rearrange( | |
| spatial_down[:, :, 1:, :, :], | |
| "b 1 f h w -> (b h w) 1 f", | |
| ) | |
| rest_up = torch.nn.functional.interpolate( | |
| rest_flat, size=remaining_lat, mode="nearest", | |
| ) | |
| rest = rearrange( | |
| rest_up, "(b h w) 1 f -> b 1 f h w", | |
| b=b, h=h_lat, w=w_lat, | |
| ) | |
| else: | |
| # Trim trailing pixel frames that don't fill a complete group | |
| usable = remaining_lat * t | |
| rest = rearrange( | |
| spatial_down[:, :, 1:1 + usable, :, :], | |
| "b 1 (f t) h w -> b 1 f t h w", | |
| t=t, | |
| ) | |
| rest = rest.mean(dim=3) | |
| latent_mask = torch.cat([first_frame, rest], dim=2) | |
| elif f_lat > 1: | |
| # Single pixel frame but multiple latent frames — repeat the | |
| # single frame across all latent frames. | |
| latent_mask = first_frame.expand(-1, -1, f_lat, -1, -1) | |
| else: | |
| latent_mask = first_frame | |
| return rearrange(latent_mask, "b 1 f h w -> b (f h w)") | |
| def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count, | |
| tracked_weights, guide_start, device, dtype): | |
| """Build a log-space additive self-attention bias mask. | |
| Attenuates attention between noisy tokens and tracked guide tokens. | |
| Untracked guide tokens (at the end of the guide portion) keep full attention. | |
| Args: | |
| total_tokens: Total sequence length. | |
| num_guide_tokens: Total guide tokens (all guides) at end of sequence. | |
| tracked_count: Number of tracked guide tokens (first in the guide portion). | |
| tracked_weights: (1, tracked_count) tensor, values in [0, 1]. | |
| guide_start: Index where guide tokens begin in the sequence. | |
| device: Target device. | |
| dtype: Target dtype. | |
| Returns: | |
| (1, 1, total_tokens, total_tokens) additive bias mask. | |
| 0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked. | |
| """ | |
| finfo = torch.finfo(dtype) | |
| mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype) | |
| tracked_end = guide_start + tracked_count | |
| # Convert weights to log-space bias | |
| w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count) | |
| log_w = torch.full_like(w, finfo.min) | |
| positive_mask = w > 0 | |
| if positive_mask.any(): | |
| log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny)) | |
| # noisy → tracked guides: each noisy row gets the same per-guide weight | |
| mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1) | |
| # tracked guides → noisy: each guide row broadcasts its weight across noisy cols | |
| mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1) | |
| return mask | |
| def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs): | |
| """Process transformer blocks for LTXV.""" | |
| patches_replace = transformer_options.get("patches_replace", {}) | |
| blocks_replace = patches_replace.get("dit", {}) | |
| prompt_timestep = kwargs.get("prompt_timestep", None) | |
| for i, block in enumerate(self.transformer_blocks): | |
| if ("double_block", i) in blocks_replace: | |
| def block_wrap(args): | |
| out = {} | |
| out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), prompt_timestep=args.get("prompt_timestep")) | |
| return out | |
| out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, "prompt_timestep": prompt_timestep}, {"original_block": block_wrap}) | |
| x = out["img"] | |
| else: | |
| x = block( | |
| x, | |
| context=context, | |
| attention_mask=attention_mask, | |
| timestep=timestep, | |
| pe=pe, | |
| transformer_options=transformer_options, | |
| self_attention_mask=self_attention_mask, | |
| prompt_timestep=prompt_timestep, | |
| ) | |
| return x | |
| def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): | |
| """Process output for LTXV.""" | |
| # Apply scale-shift modulation | |
| scale_shift_values = ( | |
| self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] | |
| ) | |
| shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] | |
| x = self.norm_out(x) | |
| x = x * (1 + scale) + shift | |
| x = self.proj_out(x) | |
| if keyframe_idxs is not None: | |
| grid_mask = kwargs["grid_mask"] | |
| orig_patchified_shape = kwargs["orig_patchified_shape"] | |
| full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device) | |
| full_x[:, grid_mask, :] = x | |
| x = full_x | |
| # Unpatchify to restore original dimensions | |
| orig_shape = kwargs["orig_shape"] | |
| x = self.patchifier.unpatchify( | |
| latents=x, | |
| output_height=orig_shape[3], | |
| output_width=orig_shape[4], | |
| output_num_frames=orig_shape[2], | |
| out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size), | |
| ) | |
| return x | |