Fabrice-TIERCELIN commited on
Commit
d975146
·
verified ·
1 Parent(s): cb0a691

Upload 13 files

Browse files
packages/ltx-core/src/ltx_core/model/transformer/__init__.py CHANGED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Transformer model components."""
2
+
3
+ from ltx_core.model.transformer.modality import Modality
4
+ from ltx_core.model.transformer.model import LTXModel, X0Model
5
+ from ltx_core.model.transformer.model_configurator import (
6
+ LTXV_MODEL_COMFY_RENAMING_MAP,
7
+ LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP,
8
+ UPCAST_DURING_INFERENCE,
9
+ LTXModelConfigurator,
10
+ LTXVideoOnlyModelConfigurator,
11
+ UpcastWithStochasticRounding,
12
+ )
13
+
14
+ __all__ = [
15
+ "LTXV_MODEL_COMFY_RENAMING_MAP",
16
+ "LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP",
17
+ "UPCAST_DURING_INFERENCE",
18
+ "LTXModel",
19
+ "LTXModelConfigurator",
20
+ "LTXVideoOnlyModelConfigurator",
21
+ "Modality",
22
+ "UpcastWithStochasticRounding",
23
+ "X0Model",
24
+ ]
packages/ltx-core/src/ltx_core/model/transformer/adaln.py CHANGED
@@ -1,6 +1,3 @@
1
- # Copyright (c) 2025 Lightricks. All rights reserved.
2
- # Created by Andrew Kvochko
3
-
4
  from typing import Optional, Tuple
5
 
6
  import torch
@@ -11,9 +8,7 @@ from ltx_core.model.transformer.timestep_embedding import PixArtAlphaCombinedTim
11
  class AdaLayerNormSingle(torch.nn.Module):
12
  r"""
13
  Norm layer adaptive layer norm single (adaLN-single).
14
-
15
  As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
16
-
17
  Parameters:
18
  embedding_dim (`int`): The size of each embedding vector.
19
  use_additional_conditions (`bool`): To use additional conditions for normalization or not.
 
 
 
 
1
  from typing import Optional, Tuple
2
 
3
  import torch
 
8
  class AdaLayerNormSingle(torch.nn.Module):
9
  r"""
10
  Norm layer adaptive layer norm single (adaLN-single).
 
11
  As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
 
12
  Parameters:
13
  embedding_dim (`int`): The size of each embedding vector.
14
  use_additional_conditions (`bool`): To use additional conditions for normalization or not.
packages/ltx-core/src/ltx_core/model/transformer/attention.py CHANGED
@@ -1,6 +1,3 @@
1
- # Copyright (c) 2025 Lightricks. All rights reserved.
2
- # Created by Andrew Kvochko
3
-
4
  from enum import Enum
5
  from typing import Protocol
6
 
@@ -14,13 +11,8 @@ try:
14
  from xformers.ops import memory_efficient_attention
15
  except ImportError:
16
  memory_efficient_attention = None
17
- try:
18
- # FlashAttention3 and XFormersAttention cannot be used together
19
- if memory_efficient_attention is None:
20
- import flash_attn_interface
21
- except ImportError:
22
- flash_attn_interface = None
23
-
24
 
25
  class AttentionCallable(Protocol):
26
  def __call__(
@@ -67,7 +59,6 @@ class XFormersAttention(AttentionCallable):
67
  # xformers expects [B, M, H, K]
68
  q, k, v = (t.view(b, -1, heads, dim_head) for t in (q, k, v))
69
 
70
- # LT_INTERNAL: https://github.com/LightricksResearch/ComfyUI/blob/ee2a50cd8fb3544c66f8a3096390c741fff12ae3/comfy/ldm/modules/attention.py#L441-L459
71
  if mask is not None:
72
  # add a singleton batch dimension
73
  if mask.ndim == 2:
@@ -129,14 +120,9 @@ class AttentionFunction(Enum):
129
  def __call__(
130
  self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None
131
  ) -> torch.Tensor:
132
- if self is AttentionFunction.PYTORCH:
133
- return PytorchAttention()(q, k, v, heads, mask)
134
- elif self is AttentionFunction.XFORMERS:
135
- return XFormersAttention()(q, k, v, heads, mask)
136
- elif self is AttentionFunction.FLASH_ATTENTION_3:
137
  return FlashAttention3()(q, k, v, heads, mask)
138
  else:
139
- # Default behavior: XFormers if installed else - PyTorch
140
  return (
141
  XFormersAttention()(q, k, v, heads, mask)
142
  if memory_efficient_attention is not None
 
 
 
 
1
  from enum import Enum
2
  from typing import Protocol
3
 
 
11
  from xformers.ops import memory_efficient_attention
12
  except ImportError:
13
  memory_efficient_attention = None
14
+
15
+ import flash_attn_interface
 
 
 
 
 
16
 
17
  class AttentionCallable(Protocol):
18
  def __call__(
 
59
  # xformers expects [B, M, H, K]
60
  q, k, v = (t.view(b, -1, heads, dim_head) for t in (q, k, v))
61
 
 
62
  if mask is not None:
63
  # add a singleton batch dimension
64
  if mask.ndim == 2:
 
120
  def __call__(
121
  self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None
122
  ) -> torch.Tensor:
123
+ if mask is None:
 
 
 
 
124
  return FlashAttention3()(q, k, v, heads, mask)
125
  else:
 
126
  return (
127
  XFormersAttention()(q, k, v, heads, mask)
128
  if memory_efficient_attention is not None
packages/ltx-core/src/ltx_core/model/transformer/feed_forward.py CHANGED
@@ -1,6 +1,3 @@
1
- # Copyright (c) 2025 Lightricks. All rights reserved.
2
- # Created by Andrew Kvochko
3
-
4
  import torch
5
 
6
  from ltx_core.model.transformer.gelu_approx import GELUApprox
 
 
 
 
1
  import torch
2
 
3
  from ltx_core.model.transformer.gelu_approx import GELUApprox
packages/ltx-core/src/ltx_core/model/transformer/gelu_approx.py CHANGED
@@ -1,6 +1,3 @@
1
- # Copyright (c) 2025 Lightricks. All rights reserved.
2
- # Created by Andrew Kvochko
3
-
4
  import torch
5
 
6
 
 
 
 
 
1
  import torch
2
 
3
 
packages/ltx-core/src/ltx_core/model/transformer/modality.py CHANGED
@@ -1,6 +1,3 @@
1
- # Copyright (c) 2025 Lightricks. All rights reserved.
2
- # Created by Andrew Kvochko
3
-
4
  from dataclasses import dataclass
5
 
6
  import torch
@@ -8,6 +5,12 @@ import torch
8
 
9
  @dataclass(frozen=True)
10
  class Modality:
 
 
 
 
 
 
11
  latent: (
12
  torch.Tensor
13
  ) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
 
 
 
 
1
  from dataclasses import dataclass
2
 
3
  import torch
 
5
 
6
  @dataclass(frozen=True)
7
  class Modality:
8
+ """
9
+ Input data for a single modality (video or audio) in the transformer.
10
+ Bundles the latent tokens, timestep embeddings, positional information,
11
+ and text conditioning context for processing by the diffusion transformer.
12
+ """
13
+
14
  latent: (
15
  torch.Tensor
16
  ) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
packages/ltx-core/src/ltx_core/model/transformer/model.py CHANGED
@@ -1,7 +1,3 @@
1
- # Copyright (c) 2025 Lightricks. All rights reserved.
2
- # Created by Andrew Kvochko
3
-
4
-
5
  from enum import Enum
6
 
7
  import torch
@@ -36,7 +32,6 @@ class LTXModelType(Enum):
36
  class LTXModel(torch.nn.Module):
37
  """
38
  LTX model transformer implementation.
39
-
40
  This class implements the transformer blocks for the LTX model.
41
  """
42
 
@@ -315,11 +310,9 @@ class LTXModel(torch.nn.Module):
315
 
316
  def set_gradient_checkpointing(self, enable: bool) -> None:
317
  """Enable or disable gradient checkpointing for transformer blocks.
318
-
319
  Gradient checkpointing trades compute for memory by recomputing activations
320
  during the backward pass instead of storing them. This can significantly
321
  reduce memory usage at the cost of ~20-30% slower training.
322
-
323
  Args:
324
  enable: Whether to enable gradient checkpointing
325
  """
@@ -380,7 +373,6 @@ class LTXModel(torch.nn.Module):
380
  ) -> tuple[torch.Tensor, torch.Tensor]:
381
  """
382
  Forward pass for LTX models.
383
-
384
  Returns:
385
  Processed output tensors
386
  """
@@ -424,10 +416,6 @@ class LegacyX0Model(torch.nn.Module):
424
  """
425
  Legacy X0 model implementation.
426
  Returns fully denoised output based on the velocities produced by the base model.
427
- LT_INTERNAL_BEGIN
428
- Applies full sigma when denoising which is mathematically incorrect but in accordance with:
429
- https://github.com/LightricksResearch/ComfyUI/blob/cc26711bd34135a3eac782b81f9526c5acfcf94d/comfy/model_sampling.py#L62-L68
430
- LT_INTERNAL_END
431
  """
432
 
433
  def __init__(self, velocity_model: LTXModel):
@@ -443,7 +431,6 @@ class LegacyX0Model(torch.nn.Module):
443
  ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
444
  """
445
  Denoise the video and audio according to the sigma.
446
-
447
  Returns:
448
  Denoised video and audio
449
  """
@@ -472,7 +459,6 @@ class X0Model(torch.nn.Module):
472
  ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
473
  """
474
  Denoise the video and audio according to the sigma.
475
-
476
  Returns:
477
  Denoised video and audio
478
  """
 
 
 
 
 
1
  from enum import Enum
2
 
3
  import torch
 
32
  class LTXModel(torch.nn.Module):
33
  """
34
  LTX model transformer implementation.
 
35
  This class implements the transformer blocks for the LTX model.
36
  """
37
 
 
310
 
311
  def set_gradient_checkpointing(self, enable: bool) -> None:
312
  """Enable or disable gradient checkpointing for transformer blocks.
 
313
  Gradient checkpointing trades compute for memory by recomputing activations
314
  during the backward pass instead of storing them. This can significantly
315
  reduce memory usage at the cost of ~20-30% slower training.
 
316
  Args:
317
  enable: Whether to enable gradient checkpointing
318
  """
 
373
  ) -> tuple[torch.Tensor, torch.Tensor]:
374
  """
375
  Forward pass for LTX models.
 
376
  Returns:
377
  Processed output tensors
378
  """
 
416
  """
417
  Legacy X0 model implementation.
418
  Returns fully denoised output based on the velocities produced by the base model.
 
 
 
 
419
  """
420
 
421
  def __init__(self, velocity_model: LTXModel):
 
431
  ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
432
  """
433
  Denoise the video and audio according to the sigma.
 
434
  Returns:
435
  Denoised video and audio
436
  """
 
459
  ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
460
  """
461
  Denoise the video and audio according to the sigma.
 
462
  Returns:
463
  Denoised video and audio
464
  """
packages/ltx-core/src/ltx_core/model/transformer/model_configurator.py CHANGED
@@ -11,6 +11,11 @@ from ltx_core.utils import check_config_value
11
 
12
 
13
  class LTXModelConfigurator(ModelConfigurator[LTXModel]):
 
 
 
 
 
14
  @classmethod
15
  def from_config(cls: type[LTXModel], config: dict) -> LTXModel:
16
  config = config.get("transformer", {})
@@ -62,6 +67,11 @@ class LTXModelConfigurator(ModelConfigurator[LTXModel]):
62
 
63
 
64
  class LTXVideoOnlyModelConfigurator(ModelConfigurator[LTXModel]):
 
 
 
 
 
65
  @classmethod
66
  def from_config(cls: type[LTXModel], config: dict) -> LTXModel:
67
  config = config.get("transformer", {})
@@ -213,6 +223,11 @@ UPCAST_DURING_INFERENCE = ModuleOps(
213
 
214
 
215
  class UpcastWithStochasticRounding(ModuleOps):
 
 
 
 
 
216
  def __new__(cls, seed: int = 0):
217
  return super().__new__(
218
  cls,
 
11
 
12
 
13
  class LTXModelConfigurator(ModelConfigurator[LTXModel]):
14
+ """
15
+ Configurator for LTX model.
16
+ Used to create an LTX model from a configuration dictionary.
17
+ """
18
+
19
  @classmethod
20
  def from_config(cls: type[LTXModel], config: dict) -> LTXModel:
21
  config = config.get("transformer", {})
 
67
 
68
 
69
  class LTXVideoOnlyModelConfigurator(ModelConfigurator[LTXModel]):
70
+ """
71
+ Configurator for LTX video only model.
72
+ Used to create an LTX video only model from a configuration dictionary.
73
+ """
74
+
75
  @classmethod
76
  def from_config(cls: type[LTXModel], config: dict) -> LTXModel:
77
  config = config.get("transformer", {})
 
223
 
224
 
225
  class UpcastWithStochasticRounding(ModuleOps):
226
+ """
227
+ ModuleOps for upcasting the model's float8_e4m3fn weights and biases to the bfloat16 dtype
228
+ and applying stochastic rounding during linear forward.
229
+ """
230
+
231
  def __new__(cls, seed: int = 0):
232
  return super().__new__(
233
  cls,
packages/ltx-core/src/ltx_core/model/transformer/rope.py CHANGED
@@ -1,6 +1,3 @@
1
- # Copyright (c) 2025 Lightricks. All rights reserved.
2
- # Created by Andrew Kvochko
3
-
4
  import functools
5
  import math
6
  from enum import Enum
 
 
 
 
1
  import functools
2
  import math
3
  from enum import Enum
packages/ltx-core/src/ltx_core/model/transformer/text_projection.py CHANGED
@@ -1,13 +1,9 @@
1
- # Copyright (c) 2025 Lightricks. All rights reserved.
2
- # Created by Andrew Kvochko
3
-
4
  import torch
5
 
6
 
7
  class PixArtAlphaTextProjection(torch.nn.Module):
8
  """
9
  Projects caption embeddings. Also handles dropout for classifier-free guidance.
10
-
11
  Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
12
  """
13
 
 
 
 
 
1
  import torch
2
 
3
 
4
  class PixArtAlphaTextProjection(torch.nn.Module):
5
  """
6
  Projects caption embeddings. Also handles dropout for classifier-free guidance.
 
7
  Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
8
  """
9
 
packages/ltx-core/src/ltx_core/model/transformer/timestep_embedding.py CHANGED
@@ -1,6 +1,3 @@
1
- # Copyright (c) 2025 Lightricks. All rights reserved.
2
- # Created by Andrew Kvochko
3
-
4
  import math
5
 
6
  import torch
@@ -16,7 +13,6 @@ def get_timestep_embedding(
16
  ) -> torch.Tensor:
17
  """
18
  This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
19
-
20
  Args
21
  timesteps (torch.Tensor):
22
  a 1-D Tensor of N indices, one per batch element. These may be fractional.
@@ -122,7 +118,6 @@ class Timesteps(torch.nn.Module):
122
  class PixArtAlphaCombinedTimestepSizeEmbeddings(torch.nn.Module):
123
  """
124
  For PixArt-Alpha.
125
-
126
  Reference:
127
  https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
128
  """
 
 
 
 
1
  import math
2
 
3
  import torch
 
13
  ) -> torch.Tensor:
14
  """
15
  This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
 
16
  Args
17
  timesteps (torch.Tensor):
18
  a 1-D Tensor of N indices, one per batch element. These may be fractional.
 
118
  class PixArtAlphaCombinedTimestepSizeEmbeddings(torch.nn.Module):
119
  """
120
  For PixArt-Alpha.
 
121
  Reference:
122
  https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
123
  """
packages/ltx-core/src/ltx_core/model/transformer/transformer.py CHANGED
@@ -1,6 +1,3 @@
1
- # Copyright (c) 2025 Lightricks. All rights reserved.
2
- # Created by Andrew Kvochko
3
-
4
  from dataclasses import dataclass, replace
5
 
6
  import torch
@@ -107,16 +104,13 @@ class BasicAVTransformerBlock(torch.nn.Module):
107
  self.norm_eps = norm_eps
108
 
109
  def get_ada_values(
110
- self,
111
- scale_shift_table: torch.Tensor,
112
- batch_size: int,
113
- timestep: torch.Tensor,
114
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
115
  num_ada_params = scale_shift_table.shape[0]
116
 
117
  ada_values = (
118
- scale_shift_table.unsqueeze(0).unsqueeze(0).to(timestep.dtype)
119
- + timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)
120
  ).unbind(dim=2)
121
  return ada_values
122
 
@@ -129,14 +123,10 @@ class BasicAVTransformerBlock(torch.nn.Module):
129
  num_scale_shift_values: int = 4,
130
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
131
  scale_shift_ada_values = self.get_ada_values(
132
- scale_shift_table[:num_scale_shift_values, :],
133
- batch_size,
134
- scale_shift_timestep,
135
  )
136
  gate_ada_values = self.get_ada_values(
137
- scale_shift_table[num_scale_shift_values:, :],
138
- batch_size,
139
- gate_timestep,
140
  )
141
 
142
  scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values]
@@ -144,7 +134,7 @@ class BasicAVTransformerBlock(torch.nn.Module):
144
 
145
  return (*scale_shift_chunks, *gate_ada_values)
146
 
147
- def forward(
148
  self,
149
  video: TransformerArgs | None,
150
  audio: TransformerArgs | None,
@@ -164,8 +154,8 @@ class BasicAVTransformerBlock(torch.nn.Module):
164
  run_v2a = run_ax and (video is not None and video.enabled and vx.numel() > 0)
165
 
166
  if run_vx:
167
- vshift_msa, vscale_msa, vgate_msa, vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(
168
- self.scale_shift_table, vx.shape[0], video.timesteps
169
  )
170
  if not perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx):
171
  norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
@@ -174,9 +164,11 @@ class BasicAVTransformerBlock(torch.nn.Module):
174
 
175
  vx = vx + self.attn2(rms_norm(vx, eps=self.norm_eps), context=video.context, mask=video.context_mask)
176
 
 
 
177
  if run_ax:
178
- ashift_msa, ascale_msa, agate_msa, ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(
179
- self.audio_scale_shift_table, ax.shape[0], audio.timesteps
180
  )
181
 
182
  if not perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx):
@@ -186,6 +178,8 @@ class BasicAVTransformerBlock(torch.nn.Module):
186
 
187
  ax = ax + self.audio_attn2(rms_norm(ax, eps=self.norm_eps), context=audio.context, mask=audio.context_mask)
188
 
 
 
189
  # Audio - Video cross attention.
190
  if run_a2v or run_v2a:
191
  vx_norm3 = rms_norm(vx, eps=self.norm_eps)
@@ -247,12 +241,34 @@ class BasicAVTransformerBlock(torch.nn.Module):
247
  * v2a_mask
248
  )
249
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  if run_vx:
 
 
 
251
  vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
252
  vx = vx + self.ff(vx_scaled) * vgate_mlp
253
 
 
 
254
  if run_ax:
 
 
 
255
  ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
256
  ax = ax + self.audio_ff(ax_scaled) * agate_mlp
257
 
 
 
258
  return replace(video, x=vx) if video is not None else None, replace(audio, x=ax) if audio is not None else None
 
 
 
 
1
  from dataclasses import dataclass, replace
2
 
3
  import torch
 
104
  self.norm_eps = norm_eps
105
 
106
  def get_ada_values(
107
+ self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice
108
+ ) -> tuple[torch.Tensor, ...]:
 
 
 
109
  num_ada_params = scale_shift_table.shape[0]
110
 
111
  ada_values = (
112
+ scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype)
113
+ + timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :]
114
  ).unbind(dim=2)
115
  return ada_values
116
 
 
123
  num_scale_shift_values: int = 4,
124
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
125
  scale_shift_ada_values = self.get_ada_values(
126
+ scale_shift_table[:num_scale_shift_values, :], batch_size, scale_shift_timestep, slice(None, None)
 
 
127
  )
128
  gate_ada_values = self.get_ada_values(
129
+ scale_shift_table[num_scale_shift_values:, :], batch_size, gate_timestep, slice(None, None)
 
 
130
  )
131
 
132
  scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values]
 
134
 
135
  return (*scale_shift_chunks, *gate_ada_values)
136
 
137
+ def forward( # noqa: PLR0915
138
  self,
139
  video: TransformerArgs | None,
140
  audio: TransformerArgs | None,
 
154
  run_v2a = run_ax and (video is not None and video.enabled and vx.numel() > 0)
155
 
156
  if run_vx:
157
+ vshift_msa, vscale_msa, vgate_msa = self.get_ada_values(
158
+ self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3)
159
  )
160
  if not perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx):
161
  norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
 
164
 
165
  vx = vx + self.attn2(rms_norm(vx, eps=self.norm_eps), context=video.context, mask=video.context_mask)
166
 
167
+ del vshift_msa, vscale_msa, vgate_msa
168
+
169
  if run_ax:
170
+ ashift_msa, ascale_msa, agate_msa = self.get_ada_values(
171
+ self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
172
  )
173
 
174
  if not perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx):
 
178
 
179
  ax = ax + self.audio_attn2(rms_norm(ax, eps=self.norm_eps), context=audio.context, mask=audio.context_mask)
180
 
181
+ del ashift_msa, ascale_msa, agate_msa
182
+
183
  # Audio - Video cross attention.
184
  if run_a2v or run_v2a:
185
  vx_norm3 = rms_norm(vx, eps=self.norm_eps)
 
241
  * v2a_mask
242
  )
243
 
244
+ del gate_out_a2v, gate_out_v2a
245
+ del (
246
+ scale_ca_video_hidden_states_a2v,
247
+ shift_ca_video_hidden_states_a2v,
248
+ scale_ca_audio_hidden_states_a2v,
249
+ shift_ca_audio_hidden_states_a2v,
250
+ scale_ca_video_hidden_states_v2a,
251
+ shift_ca_video_hidden_states_v2a,
252
+ scale_ca_audio_hidden_states_v2a,
253
+ shift_ca_audio_hidden_states_v2a,
254
+ )
255
+
256
  if run_vx:
257
+ vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(
258
+ self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, None)
259
+ )
260
  vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
261
  vx = vx + self.ff(vx_scaled) * vgate_mlp
262
 
263
+ del vshift_mlp, vscale_mlp, vgate_mlp
264
+
265
  if run_ax:
266
+ ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(
267
+ self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, None)
268
+ )
269
  ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
270
  ax = ax + self.audio_ff(ax_scaled) * agate_mlp
271
 
272
+ del ashift_mlp, ascale_mlp, agate_mlp
273
+
274
  return replace(video, x=vx) if video is not None else None, replace(audio, x=ax) if audio is not None else None
packages/ltx-core/src/ltx_core/model/transformer/transformer_args.py CHANGED
@@ -1,6 +1,3 @@
1
- # Copyright (c) 2025 Lightricks. All rights reserved.
2
- # Created by Andrew Kvochko
3
-
4
  from dataclasses import dataclass, replace
5
 
6
  import torch
 
 
 
 
1
  from dataclasses import dataclass, replace
2
 
3
  import torch