Upload 13 files
Browse files- packages/ltx-core/src/ltx_core/model/transformer/__init__.py +24 -0
- packages/ltx-core/src/ltx_core/model/transformer/adaln.py +0 -5
- packages/ltx-core/src/ltx_core/model/transformer/attention.py +3 -17
- packages/ltx-core/src/ltx_core/model/transformer/feed_forward.py +0 -3
- packages/ltx-core/src/ltx_core/model/transformer/gelu_approx.py +0 -3
- packages/ltx-core/src/ltx_core/model/transformer/modality.py +6 -3
- packages/ltx-core/src/ltx_core/model/transformer/model.py +0 -14
- packages/ltx-core/src/ltx_core/model/transformer/model_configurator.py +15 -0
- packages/ltx-core/src/ltx_core/model/transformer/rope.py +0 -3
- packages/ltx-core/src/ltx_core/model/transformer/text_projection.py +0 -4
- packages/ltx-core/src/ltx_core/model/transformer/timestep_embedding.py +0 -5
- packages/ltx-core/src/ltx_core/model/transformer/transformer.py +37 -21
- packages/ltx-core/src/ltx_core/model/transformer/transformer_args.py +0 -3
packages/ltx-core/src/ltx_core/model/transformer/__init__.py
CHANGED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Transformer model components."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.model.transformer.modality import Modality
|
| 4 |
+
from ltx_core.model.transformer.model import LTXModel, X0Model
|
| 5 |
+
from ltx_core.model.transformer.model_configurator import (
|
| 6 |
+
LTXV_MODEL_COMFY_RENAMING_MAP,
|
| 7 |
+
LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP,
|
| 8 |
+
UPCAST_DURING_INFERENCE,
|
| 9 |
+
LTXModelConfigurator,
|
| 10 |
+
LTXVideoOnlyModelConfigurator,
|
| 11 |
+
UpcastWithStochasticRounding,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"LTXV_MODEL_COMFY_RENAMING_MAP",
|
| 16 |
+
"LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP",
|
| 17 |
+
"UPCAST_DURING_INFERENCE",
|
| 18 |
+
"LTXModel",
|
| 19 |
+
"LTXModelConfigurator",
|
| 20 |
+
"LTXVideoOnlyModelConfigurator",
|
| 21 |
+
"Modality",
|
| 22 |
+
"UpcastWithStochasticRounding",
|
| 23 |
+
"X0Model",
|
| 24 |
+
]
|
packages/ltx-core/src/ltx_core/model/transformer/adaln.py
CHANGED
|
@@ -1,6 +1,3 @@
|
|
| 1 |
-
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
-
# Created by Andrew Kvochko
|
| 3 |
-
|
| 4 |
from typing import Optional, Tuple
|
| 5 |
|
| 6 |
import torch
|
|
@@ -11,9 +8,7 @@ from ltx_core.model.transformer.timestep_embedding import PixArtAlphaCombinedTim
|
|
| 11 |
class AdaLayerNormSingle(torch.nn.Module):
|
| 12 |
r"""
|
| 13 |
Norm layer adaptive layer norm single (adaLN-single).
|
| 14 |
-
|
| 15 |
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
| 16 |
-
|
| 17 |
Parameters:
|
| 18 |
embedding_dim (`int`): The size of each embedding vector.
|
| 19 |
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from typing import Optional, Tuple
|
| 2 |
|
| 3 |
import torch
|
|
|
|
| 8 |
class AdaLayerNormSingle(torch.nn.Module):
|
| 9 |
r"""
|
| 10 |
Norm layer adaptive layer norm single (adaLN-single).
|
|
|
|
| 11 |
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
|
|
|
| 12 |
Parameters:
|
| 13 |
embedding_dim (`int`): The size of each embedding vector.
|
| 14 |
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
packages/ltx-core/src/ltx_core/model/transformer/attention.py
CHANGED
|
@@ -1,6 +1,3 @@
|
|
| 1 |
-
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
-
# Created by Andrew Kvochko
|
| 3 |
-
|
| 4 |
from enum import Enum
|
| 5 |
from typing import Protocol
|
| 6 |
|
|
@@ -14,13 +11,8 @@ try:
|
|
| 14 |
from xformers.ops import memory_efficient_attention
|
| 15 |
except ImportError:
|
| 16 |
memory_efficient_attention = None
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
if memory_efficient_attention is None:
|
| 20 |
-
import flash_attn_interface
|
| 21 |
-
except ImportError:
|
| 22 |
-
flash_attn_interface = None
|
| 23 |
-
|
| 24 |
|
| 25 |
class AttentionCallable(Protocol):
|
| 26 |
def __call__(
|
|
@@ -67,7 +59,6 @@ class XFormersAttention(AttentionCallable):
|
|
| 67 |
# xformers expects [B, M, H, K]
|
| 68 |
q, k, v = (t.view(b, -1, heads, dim_head) for t in (q, k, v))
|
| 69 |
|
| 70 |
-
# LT_INTERNAL: https://github.com/LightricksResearch/ComfyUI/blob/ee2a50cd8fb3544c66f8a3096390c741fff12ae3/comfy/ldm/modules/attention.py#L441-L459
|
| 71 |
if mask is not None:
|
| 72 |
# add a singleton batch dimension
|
| 73 |
if mask.ndim == 2:
|
|
@@ -129,14 +120,9 @@ class AttentionFunction(Enum):
|
|
| 129 |
def __call__(
|
| 130 |
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None
|
| 131 |
) -> torch.Tensor:
|
| 132 |
-
if
|
| 133 |
-
return PytorchAttention()(q, k, v, heads, mask)
|
| 134 |
-
elif self is AttentionFunction.XFORMERS:
|
| 135 |
-
return XFormersAttention()(q, k, v, heads, mask)
|
| 136 |
-
elif self is AttentionFunction.FLASH_ATTENTION_3:
|
| 137 |
return FlashAttention3()(q, k, v, heads, mask)
|
| 138 |
else:
|
| 139 |
-
# Default behavior: XFormers if installed else - PyTorch
|
| 140 |
return (
|
| 141 |
XFormersAttention()(q, k, v, heads, mask)
|
| 142 |
if memory_efficient_attention is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from enum import Enum
|
| 2 |
from typing import Protocol
|
| 3 |
|
|
|
|
| 11 |
from xformers.ops import memory_efficient_attention
|
| 12 |
except ImportError:
|
| 13 |
memory_efficient_attention = None
|
| 14 |
+
|
| 15 |
+
import flash_attn_interface
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
class AttentionCallable(Protocol):
|
| 18 |
def __call__(
|
|
|
|
| 59 |
# xformers expects [B, M, H, K]
|
| 60 |
q, k, v = (t.view(b, -1, heads, dim_head) for t in (q, k, v))
|
| 61 |
|
|
|
|
| 62 |
if mask is not None:
|
| 63 |
# add a singleton batch dimension
|
| 64 |
if mask.ndim == 2:
|
|
|
|
| 120 |
def __call__(
|
| 121 |
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None
|
| 122 |
) -> torch.Tensor:
|
| 123 |
+
if mask is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
return FlashAttention3()(q, k, v, heads, mask)
|
| 125 |
else:
|
|
|
|
| 126 |
return (
|
| 127 |
XFormersAttention()(q, k, v, heads, mask)
|
| 128 |
if memory_efficient_attention is not None
|
packages/ltx-core/src/ltx_core/model/transformer/feed_forward.py
CHANGED
|
@@ -1,6 +1,3 @@
|
|
| 1 |
-
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
-
# Created by Andrew Kvochko
|
| 3 |
-
|
| 4 |
import torch
|
| 5 |
|
| 6 |
from ltx_core.model.transformer.gelu_approx import GELUApprox
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
|
| 3 |
from ltx_core.model.transformer.gelu_approx import GELUApprox
|
packages/ltx-core/src/ltx_core/model/transformer/gelu_approx.py
CHANGED
|
@@ -1,6 +1,3 @@
|
|
| 1 |
-
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
-
# Created by Andrew Kvochko
|
| 3 |
-
|
| 4 |
import torch
|
| 5 |
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
|
| 3 |
|
packages/ltx-core/src/ltx_core/model/transformer/modality.py
CHANGED
|
@@ -1,6 +1,3 @@
|
|
| 1 |
-
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
-
# Created by Andrew Kvochko
|
| 3 |
-
|
| 4 |
from dataclasses import dataclass
|
| 5 |
|
| 6 |
import torch
|
|
@@ -8,6 +5,12 @@ import torch
|
|
| 8 |
|
| 9 |
@dataclass(frozen=True)
|
| 10 |
class Modality:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
latent: (
|
| 12 |
torch.Tensor
|
| 13 |
) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
|
| 3 |
import torch
|
|
|
|
| 5 |
|
| 6 |
@dataclass(frozen=True)
|
| 7 |
class Modality:
|
| 8 |
+
"""
|
| 9 |
+
Input data for a single modality (video or audio) in the transformer.
|
| 10 |
+
Bundles the latent tokens, timestep embeddings, positional information,
|
| 11 |
+
and text conditioning context for processing by the diffusion transformer.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
latent: (
|
| 15 |
torch.Tensor
|
| 16 |
) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
|
packages/ltx-core/src/ltx_core/model/transformer/model.py
CHANGED
|
@@ -1,7 +1,3 @@
|
|
| 1 |
-
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
-
# Created by Andrew Kvochko
|
| 3 |
-
|
| 4 |
-
|
| 5 |
from enum import Enum
|
| 6 |
|
| 7 |
import torch
|
|
@@ -36,7 +32,6 @@ class LTXModelType(Enum):
|
|
| 36 |
class LTXModel(torch.nn.Module):
|
| 37 |
"""
|
| 38 |
LTX model transformer implementation.
|
| 39 |
-
|
| 40 |
This class implements the transformer blocks for the LTX model.
|
| 41 |
"""
|
| 42 |
|
|
@@ -315,11 +310,9 @@ class LTXModel(torch.nn.Module):
|
|
| 315 |
|
| 316 |
def set_gradient_checkpointing(self, enable: bool) -> None:
|
| 317 |
"""Enable or disable gradient checkpointing for transformer blocks.
|
| 318 |
-
|
| 319 |
Gradient checkpointing trades compute for memory by recomputing activations
|
| 320 |
during the backward pass instead of storing them. This can significantly
|
| 321 |
reduce memory usage at the cost of ~20-30% slower training.
|
| 322 |
-
|
| 323 |
Args:
|
| 324 |
enable: Whether to enable gradient checkpointing
|
| 325 |
"""
|
|
@@ -380,7 +373,6 @@ class LTXModel(torch.nn.Module):
|
|
| 380 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 381 |
"""
|
| 382 |
Forward pass for LTX models.
|
| 383 |
-
|
| 384 |
Returns:
|
| 385 |
Processed output tensors
|
| 386 |
"""
|
|
@@ -424,10 +416,6 @@ class LegacyX0Model(torch.nn.Module):
|
|
| 424 |
"""
|
| 425 |
Legacy X0 model implementation.
|
| 426 |
Returns fully denoised output based on the velocities produced by the base model.
|
| 427 |
-
LT_INTERNAL_BEGIN
|
| 428 |
-
Applies full sigma when denoising which is mathematically incorrect but in accordance with:
|
| 429 |
-
https://github.com/LightricksResearch/ComfyUI/blob/cc26711bd34135a3eac782b81f9526c5acfcf94d/comfy/model_sampling.py#L62-L68
|
| 430 |
-
LT_INTERNAL_END
|
| 431 |
"""
|
| 432 |
|
| 433 |
def __init__(self, velocity_model: LTXModel):
|
|
@@ -443,7 +431,6 @@ class LegacyX0Model(torch.nn.Module):
|
|
| 443 |
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
| 444 |
"""
|
| 445 |
Denoise the video and audio according to the sigma.
|
| 446 |
-
|
| 447 |
Returns:
|
| 448 |
Denoised video and audio
|
| 449 |
"""
|
|
@@ -472,7 +459,6 @@ class X0Model(torch.nn.Module):
|
|
| 472 |
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
| 473 |
"""
|
| 474 |
Denoise the video and audio according to the sigma.
|
| 475 |
-
|
| 476 |
Returns:
|
| 477 |
Denoised video and audio
|
| 478 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from enum import Enum
|
| 2 |
|
| 3 |
import torch
|
|
|
|
| 32 |
class LTXModel(torch.nn.Module):
|
| 33 |
"""
|
| 34 |
LTX model transformer implementation.
|
|
|
|
| 35 |
This class implements the transformer blocks for the LTX model.
|
| 36 |
"""
|
| 37 |
|
|
|
|
| 310 |
|
| 311 |
def set_gradient_checkpointing(self, enable: bool) -> None:
|
| 312 |
"""Enable or disable gradient checkpointing for transformer blocks.
|
|
|
|
| 313 |
Gradient checkpointing trades compute for memory by recomputing activations
|
| 314 |
during the backward pass instead of storing them. This can significantly
|
| 315 |
reduce memory usage at the cost of ~20-30% slower training.
|
|
|
|
| 316 |
Args:
|
| 317 |
enable: Whether to enable gradient checkpointing
|
| 318 |
"""
|
|
|
|
| 373 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 374 |
"""
|
| 375 |
Forward pass for LTX models.
|
|
|
|
| 376 |
Returns:
|
| 377 |
Processed output tensors
|
| 378 |
"""
|
|
|
|
| 416 |
"""
|
| 417 |
Legacy X0 model implementation.
|
| 418 |
Returns fully denoised output based on the velocities produced by the base model.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
"""
|
| 420 |
|
| 421 |
def __init__(self, velocity_model: LTXModel):
|
|
|
|
| 431 |
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
| 432 |
"""
|
| 433 |
Denoise the video and audio according to the sigma.
|
|
|
|
| 434 |
Returns:
|
| 435 |
Denoised video and audio
|
| 436 |
"""
|
|
|
|
| 459 |
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
| 460 |
"""
|
| 461 |
Denoise the video and audio according to the sigma.
|
|
|
|
| 462 |
Returns:
|
| 463 |
Denoised video and audio
|
| 464 |
"""
|
packages/ltx-core/src/ltx_core/model/transformer/model_configurator.py
CHANGED
|
@@ -11,6 +11,11 @@ from ltx_core.utils import check_config_value
|
|
| 11 |
|
| 12 |
|
| 13 |
class LTXModelConfigurator(ModelConfigurator[LTXModel]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
@classmethod
|
| 15 |
def from_config(cls: type[LTXModel], config: dict) -> LTXModel:
|
| 16 |
config = config.get("transformer", {})
|
|
@@ -62,6 +67,11 @@ class LTXModelConfigurator(ModelConfigurator[LTXModel]):
|
|
| 62 |
|
| 63 |
|
| 64 |
class LTXVideoOnlyModelConfigurator(ModelConfigurator[LTXModel]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
@classmethod
|
| 66 |
def from_config(cls: type[LTXModel], config: dict) -> LTXModel:
|
| 67 |
config = config.get("transformer", {})
|
|
@@ -213,6 +223,11 @@ UPCAST_DURING_INFERENCE = ModuleOps(
|
|
| 213 |
|
| 214 |
|
| 215 |
class UpcastWithStochasticRounding(ModuleOps):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
def __new__(cls, seed: int = 0):
|
| 217 |
return super().__new__(
|
| 218 |
cls,
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
class LTXModelConfigurator(ModelConfigurator[LTXModel]):
|
| 14 |
+
"""
|
| 15 |
+
Configurator for LTX model.
|
| 16 |
+
Used to create an LTX model from a configuration dictionary.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
@classmethod
|
| 20 |
def from_config(cls: type[LTXModel], config: dict) -> LTXModel:
|
| 21 |
config = config.get("transformer", {})
|
|
|
|
| 67 |
|
| 68 |
|
| 69 |
class LTXVideoOnlyModelConfigurator(ModelConfigurator[LTXModel]):
|
| 70 |
+
"""
|
| 71 |
+
Configurator for LTX video only model.
|
| 72 |
+
Used to create an LTX video only model from a configuration dictionary.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
@classmethod
|
| 76 |
def from_config(cls: type[LTXModel], config: dict) -> LTXModel:
|
| 77 |
config = config.get("transformer", {})
|
|
|
|
| 223 |
|
| 224 |
|
| 225 |
class UpcastWithStochasticRounding(ModuleOps):
|
| 226 |
+
"""
|
| 227 |
+
ModuleOps for upcasting the model's float8_e4m3fn weights and biases to the bfloat16 dtype
|
| 228 |
+
and applying stochastic rounding during linear forward.
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
def __new__(cls, seed: int = 0):
|
| 232 |
return super().__new__(
|
| 233 |
cls,
|
packages/ltx-core/src/ltx_core/model/transformer/rope.py
CHANGED
|
@@ -1,6 +1,3 @@
|
|
| 1 |
-
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
-
# Created by Andrew Kvochko
|
| 3 |
-
|
| 4 |
import functools
|
| 5 |
import math
|
| 6 |
from enum import Enum
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import functools
|
| 2 |
import math
|
| 3 |
from enum import Enum
|
packages/ltx-core/src/ltx_core/model/transformer/text_projection.py
CHANGED
|
@@ -1,13 +1,9 @@
|
|
| 1 |
-
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
-
# Created by Andrew Kvochko
|
| 3 |
-
|
| 4 |
import torch
|
| 5 |
|
| 6 |
|
| 7 |
class PixArtAlphaTextProjection(torch.nn.Module):
|
| 8 |
"""
|
| 9 |
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
| 10 |
-
|
| 11 |
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
| 12 |
"""
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
|
| 3 |
|
| 4 |
class PixArtAlphaTextProjection(torch.nn.Module):
|
| 5 |
"""
|
| 6 |
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
|
|
|
| 7 |
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
| 8 |
"""
|
| 9 |
|
packages/ltx-core/src/ltx_core/model/transformer/timestep_embedding.py
CHANGED
|
@@ -1,6 +1,3 @@
|
|
| 1 |
-
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
-
# Created by Andrew Kvochko
|
| 3 |
-
|
| 4 |
import math
|
| 5 |
|
| 6 |
import torch
|
|
@@ -16,7 +13,6 @@ def get_timestep_embedding(
|
|
| 16 |
) -> torch.Tensor:
|
| 17 |
"""
|
| 18 |
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
| 19 |
-
|
| 20 |
Args
|
| 21 |
timesteps (torch.Tensor):
|
| 22 |
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
|
@@ -122,7 +118,6 @@ class Timesteps(torch.nn.Module):
|
|
| 122 |
class PixArtAlphaCombinedTimestepSizeEmbeddings(torch.nn.Module):
|
| 123 |
"""
|
| 124 |
For PixArt-Alpha.
|
| 125 |
-
|
| 126 |
Reference:
|
| 127 |
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
| 128 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import math
|
| 2 |
|
| 3 |
import torch
|
|
|
|
| 13 |
) -> torch.Tensor:
|
| 14 |
"""
|
| 15 |
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
|
|
|
| 16 |
Args
|
| 17 |
timesteps (torch.Tensor):
|
| 18 |
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
|
|
|
| 118 |
class PixArtAlphaCombinedTimestepSizeEmbeddings(torch.nn.Module):
|
| 119 |
"""
|
| 120 |
For PixArt-Alpha.
|
|
|
|
| 121 |
Reference:
|
| 122 |
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
| 123 |
"""
|
packages/ltx-core/src/ltx_core/model/transformer/transformer.py
CHANGED
|
@@ -1,6 +1,3 @@
|
|
| 1 |
-
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
-
# Created by Andrew Kvochko
|
| 3 |
-
|
| 4 |
from dataclasses import dataclass, replace
|
| 5 |
|
| 6 |
import torch
|
|
@@ -107,16 +104,13 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
|
| 107 |
self.norm_eps = norm_eps
|
| 108 |
|
| 109 |
def get_ada_values(
|
| 110 |
-
self,
|
| 111 |
-
|
| 112 |
-
batch_size: int,
|
| 113 |
-
timestep: torch.Tensor,
|
| 114 |
-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 115 |
num_ada_params = scale_shift_table.shape[0]
|
| 116 |
|
| 117 |
ada_values = (
|
| 118 |
-
scale_shift_table.unsqueeze(0).unsqueeze(0).to(timestep.dtype)
|
| 119 |
-
+ timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)
|
| 120 |
).unbind(dim=2)
|
| 121 |
return ada_values
|
| 122 |
|
|
@@ -129,14 +123,10 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
|
| 129 |
num_scale_shift_values: int = 4,
|
| 130 |
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 131 |
scale_shift_ada_values = self.get_ada_values(
|
| 132 |
-
scale_shift_table[:num_scale_shift_values, :],
|
| 133 |
-
batch_size,
|
| 134 |
-
scale_shift_timestep,
|
| 135 |
)
|
| 136 |
gate_ada_values = self.get_ada_values(
|
| 137 |
-
scale_shift_table[num_scale_shift_values:, :],
|
| 138 |
-
batch_size,
|
| 139 |
-
gate_timestep,
|
| 140 |
)
|
| 141 |
|
| 142 |
scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values]
|
|
@@ -144,7 +134,7 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
|
| 144 |
|
| 145 |
return (*scale_shift_chunks, *gate_ada_values)
|
| 146 |
|
| 147 |
-
def forward(
|
| 148 |
self,
|
| 149 |
video: TransformerArgs | None,
|
| 150 |
audio: TransformerArgs | None,
|
|
@@ -164,8 +154,8 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
|
| 164 |
run_v2a = run_ax and (video is not None and video.enabled and vx.numel() > 0)
|
| 165 |
|
| 166 |
if run_vx:
|
| 167 |
-
vshift_msa, vscale_msa, vgate_msa
|
| 168 |
-
self.scale_shift_table, vx.shape[0], video.timesteps
|
| 169 |
)
|
| 170 |
if not perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx):
|
| 171 |
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
|
|
@@ -174,9 +164,11 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
|
| 174 |
|
| 175 |
vx = vx + self.attn2(rms_norm(vx, eps=self.norm_eps), context=video.context, mask=video.context_mask)
|
| 176 |
|
|
|
|
|
|
|
| 177 |
if run_ax:
|
| 178 |
-
ashift_msa, ascale_msa, agate_msa
|
| 179 |
-
self.audio_scale_shift_table, ax.shape[0], audio.timesteps
|
| 180 |
)
|
| 181 |
|
| 182 |
if not perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx):
|
|
@@ -186,6 +178,8 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
|
| 186 |
|
| 187 |
ax = ax + self.audio_attn2(rms_norm(ax, eps=self.norm_eps), context=audio.context, mask=audio.context_mask)
|
| 188 |
|
|
|
|
|
|
|
| 189 |
# Audio - Video cross attention.
|
| 190 |
if run_a2v or run_v2a:
|
| 191 |
vx_norm3 = rms_norm(vx, eps=self.norm_eps)
|
|
@@ -247,12 +241,34 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
|
| 247 |
* v2a_mask
|
| 248 |
)
|
| 249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
if run_vx:
|
|
|
|
|
|
|
|
|
|
| 251 |
vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
|
| 252 |
vx = vx + self.ff(vx_scaled) * vgate_mlp
|
| 253 |
|
|
|
|
|
|
|
| 254 |
if run_ax:
|
|
|
|
|
|
|
|
|
|
| 255 |
ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
|
| 256 |
ax = ax + self.audio_ff(ax_scaled) * agate_mlp
|
| 257 |
|
|
|
|
|
|
|
| 258 |
return replace(video, x=vx) if video is not None else None, replace(audio, x=ax) if audio is not None else None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from dataclasses import dataclass, replace
|
| 2 |
|
| 3 |
import torch
|
|
|
|
| 104 |
self.norm_eps = norm_eps
|
| 105 |
|
| 106 |
def get_ada_values(
|
| 107 |
+
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice
|
| 108 |
+
) -> tuple[torch.Tensor, ...]:
|
|
|
|
|
|
|
|
|
|
| 109 |
num_ada_params = scale_shift_table.shape[0]
|
| 110 |
|
| 111 |
ada_values = (
|
| 112 |
+
scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype)
|
| 113 |
+
+ timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :]
|
| 114 |
).unbind(dim=2)
|
| 115 |
return ada_values
|
| 116 |
|
|
|
|
| 123 |
num_scale_shift_values: int = 4,
|
| 124 |
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 125 |
scale_shift_ada_values = self.get_ada_values(
|
| 126 |
+
scale_shift_table[:num_scale_shift_values, :], batch_size, scale_shift_timestep, slice(None, None)
|
|
|
|
|
|
|
| 127 |
)
|
| 128 |
gate_ada_values = self.get_ada_values(
|
| 129 |
+
scale_shift_table[num_scale_shift_values:, :], batch_size, gate_timestep, slice(None, None)
|
|
|
|
|
|
|
| 130 |
)
|
| 131 |
|
| 132 |
scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values]
|
|
|
|
| 134 |
|
| 135 |
return (*scale_shift_chunks, *gate_ada_values)
|
| 136 |
|
| 137 |
+
def forward( # noqa: PLR0915
|
| 138 |
self,
|
| 139 |
video: TransformerArgs | None,
|
| 140 |
audio: TransformerArgs | None,
|
|
|
|
| 154 |
run_v2a = run_ax and (video is not None and video.enabled and vx.numel() > 0)
|
| 155 |
|
| 156 |
if run_vx:
|
| 157 |
+
vshift_msa, vscale_msa, vgate_msa = self.get_ada_values(
|
| 158 |
+
self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3)
|
| 159 |
)
|
| 160 |
if not perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx):
|
| 161 |
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
|
|
|
|
| 164 |
|
| 165 |
vx = vx + self.attn2(rms_norm(vx, eps=self.norm_eps), context=video.context, mask=video.context_mask)
|
| 166 |
|
| 167 |
+
del vshift_msa, vscale_msa, vgate_msa
|
| 168 |
+
|
| 169 |
if run_ax:
|
| 170 |
+
ashift_msa, ascale_msa, agate_msa = self.get_ada_values(
|
| 171 |
+
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
|
| 172 |
)
|
| 173 |
|
| 174 |
if not perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx):
|
|
|
|
| 178 |
|
| 179 |
ax = ax + self.audio_attn2(rms_norm(ax, eps=self.norm_eps), context=audio.context, mask=audio.context_mask)
|
| 180 |
|
| 181 |
+
del ashift_msa, ascale_msa, agate_msa
|
| 182 |
+
|
| 183 |
# Audio - Video cross attention.
|
| 184 |
if run_a2v or run_v2a:
|
| 185 |
vx_norm3 = rms_norm(vx, eps=self.norm_eps)
|
|
|
|
| 241 |
* v2a_mask
|
| 242 |
)
|
| 243 |
|
| 244 |
+
del gate_out_a2v, gate_out_v2a
|
| 245 |
+
del (
|
| 246 |
+
scale_ca_video_hidden_states_a2v,
|
| 247 |
+
shift_ca_video_hidden_states_a2v,
|
| 248 |
+
scale_ca_audio_hidden_states_a2v,
|
| 249 |
+
shift_ca_audio_hidden_states_a2v,
|
| 250 |
+
scale_ca_video_hidden_states_v2a,
|
| 251 |
+
shift_ca_video_hidden_states_v2a,
|
| 252 |
+
scale_ca_audio_hidden_states_v2a,
|
| 253 |
+
shift_ca_audio_hidden_states_v2a,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
if run_vx:
|
| 257 |
+
vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(
|
| 258 |
+
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, None)
|
| 259 |
+
)
|
| 260 |
vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
|
| 261 |
vx = vx + self.ff(vx_scaled) * vgate_mlp
|
| 262 |
|
| 263 |
+
del vshift_mlp, vscale_mlp, vgate_mlp
|
| 264 |
+
|
| 265 |
if run_ax:
|
| 266 |
+
ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(
|
| 267 |
+
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, None)
|
| 268 |
+
)
|
| 269 |
ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
|
| 270 |
ax = ax + self.audio_ff(ax_scaled) * agate_mlp
|
| 271 |
|
| 272 |
+
del ashift_mlp, ascale_mlp, agate_mlp
|
| 273 |
+
|
| 274 |
return replace(video, x=vx) if video is not None else None, replace(audio, x=ax) if audio is not None else None
|
packages/ltx-core/src/ltx_core/model/transformer/transformer_args.py
CHANGED
|
@@ -1,6 +1,3 @@
|
|
| 1 |
-
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
-
# Created by Andrew Kvochko
|
| 3 |
-
|
| 4 |
from dataclasses import dataclass, replace
|
| 5 |
|
| 6 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from dataclasses import dataclass, replace
|
| 2 |
|
| 3 |
import torch
|