alexnasa commited on
Commit
66cbb01
·
verified ·
1 Parent(s): 7df84fa

Upload 100 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. packages/ltx-core/README.md +280 -0
  2. packages/ltx-core/pyproject.toml +37 -0
  3. packages/ltx-core/src/ltx_core/__init__.py +0 -0
  4. packages/ltx-core/src/ltx_core/components/__init__.py +10 -0
  5. packages/ltx-core/src/ltx_core/components/diffusion_steps.py +22 -0
  6. packages/ltx-core/src/ltx_core/components/guiders.py +198 -0
  7. packages/ltx-core/src/ltx_core/components/noisers.py +35 -0
  8. packages/ltx-core/src/ltx_core/components/patchifiers.py +348 -0
  9. packages/ltx-core/src/ltx_core/components/protocols.py +101 -0
  10. packages/ltx-core/src/ltx_core/components/schedulers.py +129 -0
  11. packages/ltx-core/src/ltx_core/conditioning/__init__.py +12 -0
  12. packages/ltx-core/src/ltx_core/conditioning/exceptions.py +4 -0
  13. packages/ltx-core/src/ltx_core/conditioning/item.py +20 -0
  14. packages/ltx-core/src/ltx_core/conditioning/types/__init__.py +9 -0
  15. packages/ltx-core/src/ltx_core/conditioning/types/keyframe_cond.py +53 -0
  16. packages/ltx-core/src/ltx_core/conditioning/types/latent_cond.py +44 -0
  17. packages/ltx-core/src/ltx_core/guidance/__init__.py +15 -0
  18. packages/ltx-core/src/ltx_core/guidance/perturbations.py +79 -0
  19. packages/ltx-core/src/ltx_core/loader/__init__.py +48 -0
  20. packages/ltx-core/src/ltx_core/loader/fuse_loras.py +100 -0
  21. packages/ltx-core/src/ltx_core/loader/kernels.py +72 -0
  22. packages/ltx-core/src/ltx_core/loader/module_ops.py +14 -0
  23. packages/ltx-core/src/ltx_core/loader/primitives.py +109 -0
  24. packages/ltx-core/src/ltx_core/loader/registry.py +84 -0
  25. packages/ltx-core/src/ltx_core/loader/sd_ops.py +127 -0
  26. packages/ltx-core/src/ltx_core/loader/sft_loader.py +63 -0
  27. packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py +101 -0
  28. packages/ltx-core/src/ltx_core/model/__init__.py +8 -0
  29. packages/ltx-core/src/ltx_core/model/audio_vae/__init__.py +27 -0
  30. packages/ltx-core/src/ltx_core/model/audio_vae/attention.py +71 -0
  31. packages/ltx-core/src/ltx_core/model/audio_vae/audio_vae.py +480 -0
  32. packages/ltx-core/src/ltx_core/model/audio_vae/causal_conv_2d.py +110 -0
  33. packages/ltx-core/src/ltx_core/model/audio_vae/causality_axis.py +10 -0
  34. packages/ltx-core/src/ltx_core/model/audio_vae/downsample.py +110 -0
  35. packages/ltx-core/src/ltx_core/model/audio_vae/model_configurator.py +123 -0
  36. packages/ltx-core/src/ltx_core/model/audio_vae/ops.py +76 -0
  37. packages/ltx-core/src/ltx_core/model/audio_vae/resnet.py +176 -0
  38. packages/ltx-core/src/ltx_core/model/audio_vae/upsample.py +106 -0
  39. packages/ltx-core/src/ltx_core/model/audio_vae/vocoder.py +123 -0
  40. packages/ltx-core/src/ltx_core/model/common/__init__.py +9 -0
  41. packages/ltx-core/src/ltx_core/model/common/normalization.py +59 -0
  42. packages/ltx-core/src/ltx_core/model/model_protocol.py +10 -0
  43. packages/ltx-core/src/ltx_core/model/transformer/__init__.py +24 -0
  44. packages/ltx-core/src/ltx_core/model/transformer/adaln.py +34 -0
  45. packages/ltx-core/src/ltx_core/model/transformer/attention.py +185 -0
  46. packages/ltx-core/src/ltx_core/model/transformer/feed_forward.py +15 -0
  47. packages/ltx-core/src/ltx_core/model/transformer/gelu_approx.py +10 -0
  48. packages/ltx-core/src/ltx_core/model/transformer/modality.py +23 -0
  49. packages/ltx-core/src/ltx_core/model/transformer/model.py +468 -0
  50. packages/ltx-core/src/ltx_core/model/transformer/model_configurator.py +237 -0
packages/ltx-core/README.md ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LTX-Core
2
+
3
+ The foundational library for the LTX-2 Audio-Video generation model. This package contains the raw model definitions, component implementations, and loading logic used by `ltx-pipelines` and `ltx-trainer`.
4
+
5
+ ## 📦 What's Inside?
6
+
7
+ - **`components/`**: Modular diffusion components (Schedulers, Guiders, Noisers, Patchifiers) following standard protocols
8
+ - **`conditioning/`**: Tools for preparing latent states and applying conditioning (image, video, keyframes)
9
+ - **`guidance/`**: Perturbation system for fine-grained control over attention mechanisms
10
+ - **`loader/`**: Utilities for loading weights from `.safetensors`, fusing LoRAs, and managing memory
11
+ - **`model/`**: PyTorch implementations of the LTX-2 Transformer, Video VAE, Audio VAE, Vocoder and Upscaler
12
+ - **`text_encoders/gemma`**: Gemma text encoder implementation with tokenizers, feature extractors, and separate encoders for audio-video and video-only generation
13
+
14
+ ## 🚀 Quick Start
15
+
16
+ `ltx-core` provides the building blocks (models, components, and utilities) needed to construct inference flows. For ready-made inference pipelines use [`ltx-pipelines`](../ltx-pipelines/) or [`ltx-trainer`](../ltx-trainer/) for training.
17
+
18
+ ## 🔧 Installation
19
+
20
+ ```bash
21
+ # From the repository root
22
+ uv sync --frozen
23
+
24
+ # Or install as a package
25
+ pip install -e packages/ltx-core
26
+ ```
27
+
28
+ ## Building Blocks Overview
29
+
30
+ `ltx-core` provides modular components that can be combined to build custom inference flows:
31
+
32
+ ### Core Models
33
+
34
+ - **Transformer** ([`model/transformer/`](src/ltx_core/model/transformer/)): The 48-layer LTX-2 transformer with cross-modal attention for joint audio-video processing. Expects inputs in [`Modality`](src/ltx_core/model/transformer/modality.py) format
35
+ - **Video VAE** ([`model/video_vae/`](src/ltx_core/model/video_vae/)): Encodes/decodes video pixels to/from latent space with temporal and spatial compression
36
+ - **Audio VAE** ([`model/audio_vae/`](src/ltx_core/model/audio_vae/)): Encodes/decodes audio spectrograms to/from latent space
37
+ - **Vocoder** ([`model/audio_vae/`](src/ltx_core/model/audio_vae/)): Neural vocoder that converts mel spectrograms to audio waveforms
38
+ - **Text Encoder** ([`text_encoders/`](src/ltx_core/text_encoders/)): Gemma-based encoder that produces separate embeddings for video and audio conditioning
39
+ - **Spatial Upscaler** ([`model/upsampler/`](src/ltx_core/model/upsampler/)): Upsamples latent representations for higher-resolution generation
40
+
41
+ ### Diffusion Components
42
+
43
+ - **Schedulers** ([`components/schedulers.py`](src/ltx_core/components/schedulers.py)): Noise schedules (LTX2Scheduler, LinearQuadratic, Beta) that control the denoising process
44
+ - **Guiders** ([`components/guiders.py`](src/ltx_core/components/guiders.py)): Guidance strategies (CFG, STG, APG) for controlling generation quality and adherence to prompts
45
+ - **Noisers** ([`components/noisers.py`](src/ltx_core/components/noisers.py)): Add noise to latents according to the diffusion schedule
46
+ - **Patchifiers** ([`components/patchifiers.py`](src/ltx_core/components/patchifiers.py)): Convert between spatial latents `[B, C, F, H, W]` and sequence format `[B, seq_len, dim]` for transformer processing
47
+
48
+ ### Conditioning & Control
49
+
50
+ - **Conditioning** ([`conditioning/`](src/ltx_core/conditioning/)): Tools for preparing and applying various conditioning types (image, video, keyframes)
51
+ - **Guidance** ([`guidance/`](src/ltx_core/guidance/)): Perturbation system for fine-grained control over attention mechanisms (e.g., skipping specific attention layers)
52
+
53
+ ### Utilities
54
+
55
+ - **Loader** ([`loader/`](src/ltx_core/loader/)): Model loading from `.safetensors`, LoRA fusion, weight remapping, and memory management
56
+
57
+ For complete, production-ready pipeline implementations that combine these building blocks, see the [`ltx-pipelines`](../ltx-pipelines/) package.
58
+
59
+ ---
60
+
61
+ # Architecture Overview
62
+
63
+ This section provides a deep dive into the internal architecture of the LTX-2 Audio-Video generation model.
64
+
65
+ ## Table of Contents
66
+
67
+ 1. [High-Level Architecture](#high-level-architecture)
68
+ 2. [The Transformer](#the-transformer)
69
+ 3. [Video VAE](#video-vae)
70
+ 4. [Audio VAE](#audio-vae)
71
+ 5. [Text Encoding (Gemma)](#text-encoding-gemma)
72
+ 6. [Spatial Upscaler](#spatial-upsampler)
73
+ 7. [Data Flow](#data-flow)
74
+
75
+ ---
76
+
77
+ ## High-Level Architecture
78
+
79
+ LTX-2 is a **joint Audio-Video diffusion transformer** that processes both modalities simultaneously in a unified architecture. Unlike traditional models that handle video and audio separately, LTX-2 uses cross-modal attention to enable natural synchronization.
80
+
81
+ ```text
82
+ ┌─────────────────────────────────────────────────────────────┐
83
+ │ INPUT PREPARATION │
84
+ │ │
85
+ │ Video Pixels → Video VAE Encoder → Video Latents │
86
+ │ Audio Waveform → Audio VAE Encoder → Audio Latents │
87
+ │ Text Prompt → Gemma Encoder → Text Embeddings │
88
+ └─────────────────────────────────────────────────────────────┘
89
+
90
+ ┌─────────────────────────────────────────────────────────────┐
91
+ │ LTX-2 TRANSFORMER (48 Blocks) │
92
+ │ │
93
+ │ ┌──────────────┐ ┌──────────────┐ │
94
+ │ │ Video Stream │ │ Audio Stream │ │
95
+ │ │ │ │ │ │
96
+ │ │ Self-Attn │ │ Self-Attn │ │
97
+ │ │ Cross-Attn │ │ Cross-Attn │ │
98
+ │ │ │◄────────────►│ │ │
99
+ │ │ A↔V Cross │ │ A↔V Cross │ │
100
+ │ │ Feed-Forward │ │ Feed-Forward │ │
101
+ │ └──────────────┘ └──────────────┘ │
102
+ └─────────────────────────────────────────────────────────────┘
103
+
104
+ ┌─────────────────────────────────────────────────────────────┐
105
+ │ OUTPUT DECODING │
106
+ │ │
107
+ │ Video Latents → Video VAE Decoder → Video Pixels │
108
+ │ Audio Latents → Audio VAE Decoder → Mel Spectrogram │
109
+ │ Mel Spectrogram → Vocoder → Audio Waveform │
110
+ └─────────────────────────────────────────────────────────────┘
111
+ ```
112
+
113
+ ---
114
+
115
+ ## The Transformer
116
+
117
+ The core of LTX-2 is a 48-layer transformer that processes both video and audio tokens simultaneously.
118
+
119
+ ### Model Structure
120
+
121
+ **Source**: [`src/ltx_core/model/transformer/model.py`](src/ltx_core/model/transformer/model.py)
122
+
123
+ The `LTXModel` class implements the transformer. It supports both video-only and audio-video generation modes. For actual usage, see the [`ltx-pipelines`](../ltx-pipelines/) package which handles model loading and initialization.
124
+
125
+ ### Transformer Block Architecture
126
+
127
+ **Source**: [`src/ltx_core/model/transformer/transformer.py`](src/ltx_core/model/transformer/transformer.py)
128
+
129
+ ```text
130
+ ┌─────────────────────────────────────────────────────────────┐
131
+ │ TRANSFORMER BLOCK │
132
+ │ │
133
+ │ VIDEO PATH: │
134
+ │ Input → RMSNorm → AdaLN → Self-Attn (attn1) │
135
+ │ → RMSNorm → Cross-Attn (attn2, text) │
136
+ │ → RMSNorm → AdaLN → A↔V Cross-Attn │
137
+ │ → RMSNorm → AdaLN → Feed-Forward (ff) → Output │
138
+ │ │
139
+ │ AUDIO PATH: │
140
+ │ Input → RMSNorm → AdaLN → Self-Attn (audio_attn1) │
141
+ │ → RMSNorm → Cross-Attn (audio_attn2, text) │
142
+ │ → RMSNorm → AdaLN → A↔V Cross-Attn │
143
+ │ → RMSNorm → AdaLN → Feed-Forward (audio_ff) │
144
+ │ │
145
+ │ AdaLN (Adaptive Layer Normalization): │
146
+ │ - Uses scale_shift_table (6 params) for video/audio │
147
+ │ - Uses scale_shift_table_a2v_ca (5 params) for A↔V CA │
148
+ │ - Conditioned on per-token timestep embeddings │
149
+ └─────────────────────────────────────────────────────────────┘
150
+ ```
151
+
152
+ ### Perturbations
153
+
154
+ The transformer supports [**perturbations**](src/ltx_core/guidance/perturbations.py) that selectively skip attention operations.
155
+
156
+ Perturbations allow you to disable specific attention mechanisms during inference, which is useful for guidance techniques like STG (Spatio-Temporal Guidance).
157
+
158
+ **Supported Perturbation Types**:
159
+
160
+ - `SKIP_VIDEO_SELF_ATTN`: Skip video self-attention
161
+ - `SKIP_AUDIO_SELF_ATTN`: Skip audio self-attention
162
+ - `SKIP_A2V_CROSS_ATTN`: Skip audio-to-video cross-attention
163
+ - `SKIP_V2A_CROSS_ATTN`: Skip video-to-audio cross-attention
164
+
165
+ Perturbations are used internally by guidance mechanisms like STG (Spatio-Temporal Guidance). For usage examples, see the [`ltx-pipelines`](../ltx-pipelines/) package.
166
+
167
+ ---
168
+
169
+ ## Video VAE
170
+
171
+ The Video VAE ([`src/ltx_core/model/video_vae/`](src/ltx_core/model/video_vae/)) encodes video pixels into latent representations and decodes them back.
172
+
173
+ ### Architecture
174
+
175
+ - **Encoder**: Compresses `[B, 3, F, H, W]` pixels → `[B, 128, F', H/32, W/32]` latents
176
+ - Where `F' = 1 + (F-1)/8` (frame count must satisfy `(F-1) % 8 == 0`)
177
+ - Example: `[B, 3, 33, 512, 512]` → `[B, 128, 5, 16, 16]`
178
+ - **Decoder**: Expands `[B, 128, F, H, W]` latents → `[B, 3, F', H*32, W*32]` pixels
179
+ - Where `F' = 1 + (F-1)*8`
180
+ - Example: `[B, 128, 5, 16, 16]` → `[B, 3, 33, 512, 512]`
181
+
182
+ The Video VAE is used internally by pipelines for encoding video pixels to latents and decoding latents back to pixels. For usage examples, see the [`ltx-pipelines`](../ltx-pipelines/) package.
183
+
184
+ ---
185
+
186
+ ## Audio VAE
187
+
188
+ The Audio VAE ([`src/ltx_core/model/audio_vae/`](src/ltx_core/model/audio_vae/)) processes audio spectrograms.
189
+
190
+ ### Audio VAE Architecture
191
+
192
+ - **Encoder**: Compresses mel spectrogram `[B, mel_bins, T]` → `[B, 8, T/4, 16]` latents
193
+ - Temporal downsampling: 4× (`LATENT_DOWNSAMPLE_FACTOR = 4`)
194
+ - Frequency bins: Fixed 16 mel bins in latent space
195
+ - Latent channels: 8
196
+ - **Decoder**: Expands `[B, 8, T, 16]` latents → mel spectrogram `[B, mel_bins, T*4]`
197
+ - **Vocoder**: Converts mel spectrogram → audio waveform
198
+
199
+ **Downsampling**:
200
+
201
+ - Temporal: 4× (time steps)
202
+ - Frequency: Variable (input mel_bins → fixed 16 in latent space)
203
+
204
+ The Audio VAE is used internally by pipelines for encoding mel spectrograms to latents and decoding latents back to mel spectrograms. The vocoder converts mel spectrograms to audio waveforms. For usage examples, see the [`ltx-pipelines`](../ltx-pipelines/) package.
205
+
206
+ ---
207
+
208
+ ## Text Encoding (Gemma)
209
+
210
+ LTX-2 uses **Gemma** (Google's open LLM) as the text encoder, located in [`src/ltx_core/text_encoders/gemma/`](src/ltx_core/text_encoders/gemma/).
211
+
212
+ ### Text Encoder Architecture
213
+
214
+ - **Tokenizer**: Converts text → token IDs
215
+ - **Gemma Model**: Processes tokens → embeddings
216
+ - **Text Projection**: Uses `PixArtAlphaTextProjection` to project caption embeddings
217
+ - Two-layer MLP with GELU (tanh approximation) or SiLU activation
218
+ - Projects from caption channels (3840) to model dimensions
219
+ - **Feature Extractor**: Extracts video/audio-specific embeddings
220
+ - **Separate Encoders**:
221
+ - `AVEncoder`: For audio-video generation (outputs separate video and audio contexts)
222
+ - `VideoOnlyEncoder`: For video-only generation
223
+
224
+ ### System Prompts
225
+
226
+ System prompts are also used to enhance user's prompts.
227
+
228
+ - **Text-to-Video**: [`gemma_t2v_system_prompt.txt`](src/ltx_core/text_encoders/gemma/encoders/prompts/gemma_t2v_system_prompt.txt)
229
+ - **Image-to-Video**: [`gemma_i2v_system_prompt.txt`](src/ltx_core/text_encoders/gemma/encoders/prompts/gemma_i2v_system_prompt.txt)
230
+
231
+ **Important**: Video and audio receive **different** context embeddings, even from the same prompt. This allows better modality-specific conditioning.
232
+
233
+ **Output Format**:
234
+
235
+ - Video context: `[B, seq_len, 4096]` - Video-specific text embeddings
236
+ - Audio context: `[B, seq_len, 2048]` - Audio-specific text embeddings
237
+
238
+ The text encoder is used internally by pipelines. For usage examples, see the [`ltx-pipelines`](../ltx-pipelines/) package.
239
+
240
+ ---
241
+
242
+ ## Upscaler
243
+
244
+ The Upscaler ([`src/ltx_core/model/upsampler/`](src/ltx_core/model/upsampler/)) upsamples latent representations for higher-resolution output.
245
+
246
+ The spatial upsampler is used internally by two-stage pipelines (e.g., [`TI2VidTwoStagesPipeline`](../ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py), [`ICLoraPipeline`](../ltx-pipelines/src/ltx_pipelines/ic_lora.py)) to upsample low-resolution latents before final VAE decoding. For usage examples, see the [`ltx-pipelines`](../ltx-pipelines/) package.
247
+
248
+ ---
249
+
250
+ ## Data Flow
251
+
252
+ ### Complete Generation Pipeline
253
+
254
+ Here's how all the components work together conceptually ([`src/ltx_core/components/`](src/ltx_core/components/)):
255
+
256
+ **Pipeline Steps**:
257
+
258
+ 1. **Text Encoding**: Text prompt → Gemma encoder → separate video/audio embeddings
259
+ 2. **Latent Initialization**: Initialize noise latents in spatial format `[B, C, F, H, W]`
260
+ 3. **Patchification**: Convert spatial latents to sequence format `[B, seq_len, dim]` for transformer
261
+ 4. **Sigma Schedule**: Generate noise schedule (adapts to token count)
262
+ 5. **Denoising Loop**: Iteratively denoise using transformer predictions
263
+ - Create Modality inputs with per-token timesteps and RoPE positions
264
+ - Forward pass through transformer (conditional and unconditional for CFG)
265
+ - Apply guidance (CFG, STG, etc.)
266
+ - Update latents using diffusion step (Euler, etc.)
267
+ 6. **Unpatchification**: Convert sequence back to spatial format
268
+ 7. **VAE Decoding**: Decode latents to pixel space (with optional upsampling for two-stage)
269
+
270
+ - [`TI2VidTwoStagesPipeline`](../ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py) - Two-stage text-to-video (recommended)
271
+ - [`ICLoraPipeline`](../ltx-pipelines/src/ltx_pipelines/ic_lora.py) - Video-to-video with IC-LoRA control
272
+ - [`DistilledPipeline`](../ltx-pipelines/src/ltx_pipelines/distilled.py) - Fast inference with distilled model
273
+ - [`KeyframeInterpolationPipeline`](../ltx-pipelines/src/ltx_pipelines/keyframe_interpolation.py) - Keyframe-based interpolation
274
+
275
+ See the [ltx-pipelines README](../ltx-pipelines/README.md) for usage examples.
276
+
277
+ ## 🔗 Related Projects
278
+
279
+ - **[ltx-pipelines](../ltx-pipelines/)** - High-level pipeline implementations for text-to-video, image-to-video, and video-to-video
280
+ - **[ltx-trainer](../ltx-trainer/)** - Training and fine-tuning tools
packages/ltx-core/pyproject.toml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "ltx-core"
3
+ version = "1.0.0"
4
+ description = "Core implementation of Lightricks' LTX-2 model"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "torch~=2.7",
9
+ "torchaudio",
10
+ "einops",
11
+ "numpy",
12
+ "transformers",
13
+ "safetensors",
14
+ "accelerate",
15
+ "scipy>=1.14",
16
+ ]
17
+
18
+ [project.optional-dependencies]
19
+ xformers = ["xformers"]
20
+
21
+
22
+ [tool.uv.sources]
23
+ xformers = { index = "pytorch" }
24
+
25
+ [[tool.uv.index]]
26
+ name = "pytorch"
27
+ url = "https://download.pytorch.org/whl/cu129"
28
+ explicit = true
29
+
30
+ [build-system]
31
+ requires = ["uv_build>=0.9.8,<0.10.0"]
32
+ build-backend = "uv_build"
33
+
34
+ [dependency-groups]
35
+ dev = [
36
+ "scikit-image>=0.25.2",
37
+ ]
packages/ltx-core/src/ltx_core/__init__.py ADDED
File without changes
packages/ltx-core/src/ltx_core/components/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Diffusion pipeline components.
3
+ Submodules:
4
+ diffusion_steps - Diffusion stepping algorithms (EulerDiffusionStep)
5
+ guiders - Guidance strategies (CFGGuider, STGGuider, APG variants)
6
+ noisers - Noise samplers (GaussianNoiser)
7
+ patchifiers - Latent patchification (VideoLatentPatchifier, AudioPatchifier)
8
+ protocols - Protocol definitions (Patchifier, etc.)
9
+ schedulers - Sigma schedulers (LTX2Scheduler, LinearQuadraticScheduler)
10
+ """
packages/ltx-core/src/ltx_core/components/diffusion_steps.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ltx_core.components.protocols import DiffusionStepProtocol
4
+ from ltx_core.utils import to_velocity
5
+
6
+
7
+ class EulerDiffusionStep(DiffusionStepProtocol):
8
+ """
9
+ First-order Euler method for diffusion sampling.
10
+ Takes a single step from the current noise level (sigma) to the next by
11
+ computing velocity from the denoised prediction and applying: sample + velocity * dt.
12
+ """
13
+
14
+ def step(
15
+ self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int
16
+ ) -> torch.Tensor:
17
+ sigma = sigmas[step_index]
18
+ sigma_next = sigmas[step_index + 1]
19
+ dt = sigma_next - sigma
20
+ velocity = to_velocity(sample, sigma, denoised_sample)
21
+
22
+ return (sample.to(torch.float32) + velocity.to(torch.float32) * dt).to(sample.dtype)
packages/ltx-core/src/ltx_core/components/guiders.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+ from ltx_core.components.protocols import GuiderProtocol
6
+
7
+
8
+ @dataclass(frozen=True)
9
+ class CFGGuider(GuiderProtocol):
10
+ """
11
+ Classifier-free guidance (CFG) guider.
12
+ Computes the guidance delta as (scale - 1) * (cond - uncond), steering the
13
+ denoising process toward the conditioned prediction.
14
+ Attributes:
15
+ scale: Guidance strength. 1.0 means no guidance, higher values increase
16
+ adherence to the conditioning.
17
+ """
18
+
19
+ scale: float
20
+
21
+ def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
22
+ return (self.scale - 1) * (cond - uncond)
23
+
24
+ def enabled(self) -> bool:
25
+ return self.scale != 1.0
26
+
27
+
28
+ @dataclass(frozen=True)
29
+ class CFGStarRescalingGuider(GuiderProtocol):
30
+ """
31
+ Calculates the CFG delta between conditioned and unconditioned samples.
32
+ To minimize offset in the denoising direction and move mostly along the
33
+ conditioning axis within the distribution, the unconditioned sample is
34
+ rescaled in accordance with the norm of the conditioned sample.
35
+ Attributes:
36
+ scale (float):
37
+ Global guidance strength. A value of 1.0 corresponds to no extra
38
+ guidance beyond the base model prediction. Values > 1.0 increase
39
+ the influence of the conditioned sample relative to the
40
+ unconditioned one.
41
+ """
42
+
43
+ scale: float
44
+
45
+ def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
46
+ rescaled_neg = projection_coef(cond, uncond) * uncond
47
+ return (self.scale - 1) * (cond - rescaled_neg)
48
+
49
+ def enabled(self) -> bool:
50
+ return self.scale != 1.0
51
+
52
+
53
+ @dataclass(frozen=True)
54
+ class STGGuider(GuiderProtocol):
55
+ """
56
+ Calculates the STG delta between conditioned and perturbed denoised samples.
57
+ Perturbed samples are the result of the denoising process with perturbations,
58
+ e.g. attentions acting as passthrough for certain layers and modalities.
59
+ Attributes:
60
+ scale (float):
61
+ Global strength of the STG guidance. A value of 0.0 disables the
62
+ guidance. Larger values increase the correction applied in the
63
+ direction of (pos_denoised - perturbed_denoised).
64
+ """
65
+
66
+ scale: float
67
+
68
+ def delta(self, pos_denoised: torch.Tensor, perturbed_denoised: torch.Tensor) -> torch.Tensor:
69
+ return self.scale * (pos_denoised - perturbed_denoised)
70
+
71
+ def enabled(self) -> bool:
72
+ return self.scale != 0.0
73
+
74
+
75
+ @dataclass(frozen=True)
76
+ class LtxAPGGuider(GuiderProtocol):
77
+ """
78
+ Calculates the APG (adaptive projected guidance) delta between conditioned
79
+ and unconditioned samples.
80
+ To minimize offset in the denoising direction and move mostly along the
81
+ conditioning axis within the distribution, the (cond - uncond) delta is
82
+ decomposed into components parallel and orthogonal to the conditioned
83
+ sample. The `eta` parameter weights the parallel component, while `scale`
84
+ is applied to the orthogonal component. Optionally, a norm threshold can
85
+ be used to suppress guidance when the magnitude of the correction is small.
86
+ Attributes:
87
+ scale (float):
88
+ Strength applied to the component of the guidance that is orthogonal
89
+ to the conditioned sample. Controls how aggressively we move in
90
+ directions that change semantics but stay consistent with the
91
+ conditioning manifold.
92
+ eta (float):
93
+ Weight of the component of the guidance that is parallel to the
94
+ conditioned sample. A value of 1.0 keeps the full parallel
95
+ component; values in [0, 1] attenuate it, and values > 1.0 amplify
96
+ motion along the conditioning direction.
97
+ norm_threshold (float):
98
+ Minimum L2 norm of the guidance delta below which the guidance
99
+ can be reduced or ignored (depending on implementation).
100
+ This is useful for avoiding noisy or unstable updates when the
101
+ guidance signal is very small.
102
+ """
103
+
104
+ scale: float
105
+ eta: float = 1.0
106
+ norm_threshold: float = 0.0
107
+
108
+ def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
109
+ guidance = cond - uncond
110
+ if self.norm_threshold > 0:
111
+ ones = torch.ones_like(guidance)
112
+ guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
113
+ scale_factor = torch.minimum(ones, self.norm_threshold / guidance_norm)
114
+ guidance = guidance * scale_factor
115
+ proj_coeff = projection_coef(guidance, cond)
116
+ g_parallel = proj_coeff * cond
117
+ g_orth = guidance - g_parallel
118
+ g_apg = g_parallel * self.eta + g_orth
119
+
120
+ return g_apg * (self.scale - 1)
121
+
122
+ def enabled(self) -> bool:
123
+ return self.scale != 1.0
124
+
125
+
126
+ @dataclass(frozen=False)
127
+ class LegacyStatefulAPGGuider(GuiderProtocol):
128
+ """
129
+ Calculates the APG (adaptive projected guidance) delta between conditioned
130
+ and unconditioned samples.
131
+ To minimize offset in the denoising direction and move mostly along the
132
+ conditioning axis within the distribution, the (cond - uncond) delta is
133
+ decomposed into components parallel and orthogonal to the conditioned
134
+ sample. The `eta` parameter weights the parallel component, while `scale`
135
+ is applied to the orthogonal component. Optionally, a norm threshold can
136
+ be used to suppress guidance when the magnitude of the correction is small.
137
+ Attributes:
138
+ scale (float):
139
+ Strength applied to the component of the guidance that is orthogonal
140
+ to the conditioned sample. Controls how aggressively we move in
141
+ directions that change semantics but stay consistent with the
142
+ conditioning manifold.
143
+ eta (float):
144
+ Weight of the component of the guidance that is parallel to the
145
+ conditioned sample. A value of 1.0 keeps the full parallel
146
+ component; values in [0, 1] attenuate it, and values > 1.0 amplify
147
+ motion along the conditioning direction.
148
+ norm_threshold (float):
149
+ Minimum L2 norm of the guidance delta below which the guidance
150
+ can be reduced or ignored (depending on implementation).
151
+ This is useful for avoiding noisy or unstable updates when the
152
+ guidance signal is very small.
153
+ momentum (float):
154
+ Exponential moving-average coefficient for accumulating guidance
155
+ over time. running_avg = momentum * running_avg + guidance
156
+ """
157
+
158
+ scale: float
159
+ eta: float
160
+ norm_threshold: float = 5.0
161
+ momentum: float = 0.0
162
+ # it is user's responsibility not to use same APGGuider for several denoisings or different modalities
163
+ # in order not to share accumulated average across different denoisings or modalities
164
+ running_avg: torch.Tensor | None = None
165
+
166
+ def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
167
+ guidance = cond - uncond
168
+ if self.momentum != 0:
169
+ if self.running_avg is None:
170
+ self.running_avg = guidance.clone()
171
+ else:
172
+ self.running_avg = self.momentum * self.running_avg + guidance
173
+ guidance = self.running_avg
174
+
175
+ if self.norm_threshold > 0:
176
+ ones = torch.ones_like(guidance)
177
+ guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
178
+ scale_factor = torch.minimum(ones, self.norm_threshold / guidance_norm)
179
+ guidance = guidance * scale_factor
180
+
181
+ proj_coeff = projection_coef(guidance, cond)
182
+ g_parallel = proj_coeff * cond
183
+ g_orth = guidance - g_parallel
184
+ g_apg = g_parallel * self.eta + g_orth
185
+
186
+ return g_apg * self.scale
187
+
188
+ def enabled(self) -> bool:
189
+ return self.scale != 0.0
190
+
191
+
192
+ def projection_coef(to_project: torch.Tensor, project_onto: torch.Tensor) -> torch.Tensor:
193
+ batch_size = to_project.shape[0]
194
+ positive_flat = to_project.reshape(batch_size, -1)
195
+ negative_flat = project_onto.reshape(batch_size, -1)
196
+ dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
197
+ squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
198
+ return dot_product / squared_norm
packages/ltx-core/src/ltx_core/components/noisers.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import replace
2
+ from typing import Protocol
3
+
4
+ import torch
5
+
6
+ from ltx_core.types import LatentState
7
+
8
+
9
+ class Noiser(Protocol):
10
+ """Protocol for adding noise to a latent state during diffusion."""
11
+
12
+ def __call__(self, latent_state: LatentState, noise_scale: float) -> LatentState: ...
13
+
14
+
15
+ class GaussianNoiser(Noiser):
16
+ """Adds Gaussian noise to a latent state, scaled by the denoise mask."""
17
+
18
+ def __init__(self, generator: torch.Generator):
19
+ super().__init__()
20
+
21
+ self.generator = generator
22
+
23
+ def __call__(self, latent_state: LatentState, noise_scale: float = 1.0) -> LatentState:
24
+ noise = torch.randn(
25
+ *latent_state.latent.shape,
26
+ device=latent_state.latent.device,
27
+ dtype=latent_state.latent.dtype,
28
+ generator=self.generator,
29
+ )
30
+ scaled_mask = latent_state.denoise_mask * noise_scale
31
+ latent = noise * scaled_mask + latent_state.latent * (1 - scaled_mask)
32
+ return replace(
33
+ latent_state,
34
+ latent=latent.to(latent_state.latent.dtype),
35
+ )
packages/ltx-core/src/ltx_core/components/patchifiers.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple
3
+
4
+ import einops
5
+ import torch
6
+
7
+ from ltx_core.components.protocols import Patchifier
8
+ from ltx_core.types import AudioLatentShape, SpatioTemporalScaleFactors, VideoLatentShape
9
+
10
+
11
+ class VideoLatentPatchifier(Patchifier):
12
+ def __init__(self, patch_size: int):
13
+ # Patch sizes for video latents.
14
+ self._patch_size = (
15
+ 1, # temporal dimension
16
+ patch_size, # height dimension
17
+ patch_size, # width dimension
18
+ )
19
+
20
+ @property
21
+ def patch_size(self) -> Tuple[int, int, int]:
22
+ return self._patch_size
23
+
24
+ def get_token_count(self, tgt_shape: VideoLatentShape) -> int:
25
+ return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size)
26
+
27
+ def patchify(
28
+ self,
29
+ latents: torch.Tensor,
30
+ ) -> torch.Tensor:
31
+ latents = einops.rearrange(
32
+ latents,
33
+ "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
34
+ p1=self._patch_size[0],
35
+ p2=self._patch_size[1],
36
+ p3=self._patch_size[2],
37
+ )
38
+
39
+ return latents
40
+
41
+ def unpatchify(
42
+ self,
43
+ latents: torch.Tensor,
44
+ output_shape: VideoLatentShape,
45
+ ) -> torch.Tensor:
46
+ assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier"
47
+
48
+ patch_grid_frames = output_shape.frames // self._patch_size[0]
49
+ patch_grid_height = output_shape.height // self._patch_size[1]
50
+ patch_grid_width = output_shape.width // self._patch_size[2]
51
+
52
+ latents = einops.rearrange(
53
+ latents,
54
+ "b (f h w) (c p q) -> b c f (h p) (w q)",
55
+ f=patch_grid_frames,
56
+ h=patch_grid_height,
57
+ w=patch_grid_width,
58
+ p=self._patch_size[1],
59
+ q=self._patch_size[2],
60
+ )
61
+
62
+ return latents
63
+
64
+ def get_patch_grid_bounds(
65
+ self,
66
+ output_shape: AudioLatentShape | VideoLatentShape,
67
+ device: Optional[torch.device] = None,
68
+ ) -> torch.Tensor:
69
+ """
70
+ Return the per-dimension bounds [inclusive start, exclusive end) for every
71
+ patch produced by `patchify`. The bounds are expressed in the original
72
+ video grid coordinates: frame/time, height, and width.
73
+ The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where:
74
+ - axis 1 (size 3) enumerates (frame/time, height, width) dimensions
75
+ - axis 3 (size 2) stores `[start, end)` indices within each dimension
76
+ Args:
77
+ output_shape: Video grid description containing frames, height, and width.
78
+ device: Device of the latent tensor.
79
+ """
80
+ if not isinstance(output_shape, VideoLatentShape):
81
+ raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates")
82
+
83
+ frames = output_shape.frames
84
+ height = output_shape.height
85
+ width = output_shape.width
86
+ batch_size = output_shape.batch
87
+
88
+ # Validate inputs to ensure positive dimensions
89
+ assert frames > 0, f"frames must be positive, got {frames}"
90
+ assert height > 0, f"height must be positive, got {height}"
91
+ assert width > 0, f"width must be positive, got {width}"
92
+ assert batch_size > 0, f"batch_size must be positive, got {batch_size}"
93
+
94
+ # Generate grid coordinates for each dimension (frame, height, width)
95
+ # We use torch.arange to create the starting coordinates for each patch.
96
+ # indexing='ij' ensures the dimensions are in the order (frame, height, width).
97
+ grid_coords = torch.meshgrid(
98
+ torch.arange(start=0, end=frames, step=self._patch_size[0], device=device),
99
+ torch.arange(start=0, end=height, step=self._patch_size[1], device=device),
100
+ torch.arange(start=0, end=width, step=self._patch_size[2], device=device),
101
+ indexing="ij",
102
+ )
103
+
104
+ # Stack the grid coordinates to create the start coordinates tensor.
105
+ # Shape becomes (3, grid_f, grid_h, grid_w)
106
+ patch_starts = torch.stack(grid_coords, dim=0)
107
+
108
+ # Create a tensor containing the size of a single patch:
109
+ # (frame_patch_size, height_patch_size, width_patch_size).
110
+ # Reshape to (3, 1, 1, 1) to enable broadcasting when adding to the start coordinates.
111
+ patch_size_delta = torch.tensor(
112
+ self._patch_size,
113
+ device=patch_starts.device,
114
+ dtype=patch_starts.dtype,
115
+ ).view(3, 1, 1, 1)
116
+
117
+ # Calculate end coordinates: start + patch_size
118
+ # Shape becomes (3, grid_f, grid_h, grid_w)
119
+ patch_ends = patch_starts + patch_size_delta
120
+
121
+ # Stack start and end coordinates together along the last dimension
122
+ # Shape becomes (3, grid_f, grid_h, grid_w, 2), where the last dimension is [start, end]
123
+ latent_coords = torch.stack((patch_starts, patch_ends), dim=-1)
124
+
125
+ # Broadcast to batch size and flatten all spatial/temporal dimensions into one sequence.
126
+ # Final Shape: (batch_size, 3, num_patches, 2)
127
+ latent_coords = einops.repeat(
128
+ latent_coords,
129
+ "c f h w bounds -> b c (f h w) bounds",
130
+ b=batch_size,
131
+ bounds=2,
132
+ )
133
+
134
+ return latent_coords
135
+
136
+
137
+ def get_pixel_coords(
138
+ latent_coords: torch.Tensor,
139
+ scale_factors: SpatioTemporalScaleFactors,
140
+ causal_fix: bool = False,
141
+ ) -> torch.Tensor:
142
+ """
143
+ Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
144
+ each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
145
+ Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
146
+ Args:
147
+ latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
148
+ scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
149
+ per axis.
150
+ causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
151
+ that treat frame zero differently still yield non-negative timestamps.
152
+ """
153
+ # Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
154
+ broadcast_shape = [1] * latent_coords.ndim
155
+ broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
156
+ scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
157
+
158
+ # Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
159
+ pixel_coords = latent_coords * scale_tensor
160
+
161
+ if causal_fix:
162
+ # VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
163
+ # Shift and clamp to keep the first-frame timestamps causal and non-negative.
164
+ pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
165
+
166
+ return pixel_coords
167
+
168
+
169
+ class AudioPatchifier(Patchifier):
170
+ def __init__(
171
+ self,
172
+ patch_size: int,
173
+ sample_rate: int = 16000,
174
+ hop_length: int = 160,
175
+ audio_latent_downsample_factor: int = 4,
176
+ is_causal: bool = True,
177
+ shift: int = 0,
178
+ ):
179
+ """
180
+ Patchifier tailored for spectrogram/audio latents.
181
+ Args:
182
+ patch_size: Number of mel bins combined into a single patch. This
183
+ controls the resolution along the frequency axis.
184
+ sample_rate: Original waveform sampling rate. Used to map latent
185
+ indices back to seconds so downstream consumers can align audio
186
+ and video cues.
187
+ hop_length: Window hop length used for the spectrogram. Determines
188
+ how many real-time samples separate two consecutive latent frames.
189
+ audio_latent_downsample_factor: Ratio between spectrogram frames and
190
+ latent frames; compensates for additional downsampling inside the
191
+ VAE encoder.
192
+ is_causal: When True, timing is shifted to account for causal
193
+ receptive fields so timestamps do not peek into the future.
194
+ shift: Integer offset applied to the latent indices. Enables
195
+ constructing overlapping windows from the same latent sequence.
196
+ """
197
+ self.hop_length = hop_length
198
+ self.sample_rate = sample_rate
199
+ self.audio_latent_downsample_factor = audio_latent_downsample_factor
200
+ self.is_causal = is_causal
201
+ self.shift = shift
202
+ self._patch_size = (1, patch_size, patch_size)
203
+
204
+ @property
205
+ def patch_size(self) -> Tuple[int, int, int]:
206
+ return self._patch_size
207
+
208
+ def get_token_count(self, tgt_shape: AudioLatentShape) -> int:
209
+ return tgt_shape.frames
210
+
211
+ def _get_audio_latent_time_in_sec(
212
+ self,
213
+ start_latent: int,
214
+ end_latent: int,
215
+ dtype: torch.dtype,
216
+ device: Optional[torch.device] = None,
217
+ ) -> torch.Tensor:
218
+ """
219
+ Converts latent indices into real-time seconds while honoring causal
220
+ offsets and the configured hop length.
221
+ Args:
222
+ start_latent: Inclusive start index inside the latent sequence. This
223
+ sets the first timestamp returned.
224
+ end_latent: Exclusive end index. Determines how many timestamps get
225
+ generated.
226
+ dtype: Floating-point dtype used for the returned tensor, allowing
227
+ callers to control precision.
228
+ device: Target device for the timestamp tensor. When omitted the
229
+ computation occurs on CPU to avoid surprising GPU allocations.
230
+ """
231
+ if device is None:
232
+ device = torch.device("cpu")
233
+
234
+ audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)
235
+
236
+ audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor
237
+
238
+ if self.is_causal:
239
+ # Frame offset for causal alignment.
240
+ # The "+1" ensures the timestamp corresponds to the first sample that is fully available.
241
+ causal_offset = 1
242
+ audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0)
243
+
244
+ return audio_mel_frame * self.hop_length / self.sample_rate
245
+
246
+ def _compute_audio_timings(
247
+ self,
248
+ batch_size: int,
249
+ num_steps: int,
250
+ device: Optional[torch.device] = None,
251
+ ) -> torch.Tensor:
252
+ """
253
+ Builds a `(B, 1, T, 2)` tensor containing timestamps for each latent frame.
254
+ This helper method underpins `get_patch_grid_bounds` for the audio patchifier.
255
+ Args:
256
+ batch_size: Number of sequences to broadcast the timings over.
257
+ num_steps: Number of latent frames (time steps) to convert into timestamps.
258
+ device: Device on which the resulting tensor should reside.
259
+ """
260
+ resolved_device = device
261
+ if resolved_device is None:
262
+ resolved_device = torch.device("cpu")
263
+
264
+ start_timings = self._get_audio_latent_time_in_sec(
265
+ self.shift,
266
+ num_steps + self.shift,
267
+ torch.float32,
268
+ resolved_device,
269
+ )
270
+ start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
271
+
272
+ end_timings = self._get_audio_latent_time_in_sec(
273
+ self.shift + 1,
274
+ num_steps + self.shift + 1,
275
+ torch.float32,
276
+ resolved_device,
277
+ )
278
+ end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
279
+
280
+ return torch.stack([start_timings, end_timings], dim=-1)
281
+
282
+ def patchify(
283
+ self,
284
+ audio_latents: torch.Tensor,
285
+ ) -> torch.Tensor:
286
+ """
287
+ Flattens the audio latent tensor along time. Use `get_patch_grid_bounds`
288
+ to derive timestamps for each latent frame based on the configured hop
289
+ length and downsampling.
290
+ Args:
291
+ audio_latents: Latent tensor to patchify.
292
+ Returns:
293
+ Flattened patch tokens tensor. Use `get_patch_grid_bounds` to compute the
294
+ corresponding timing metadata when needed.
295
+ """
296
+ audio_latents = einops.rearrange(
297
+ audio_latents,
298
+ "b c t f -> b t (c f)",
299
+ )
300
+
301
+ return audio_latents
302
+
303
+ def unpatchify(
304
+ self,
305
+ audio_latents: torch.Tensor,
306
+ output_shape: AudioLatentShape,
307
+ ) -> torch.Tensor:
308
+ """
309
+ Restores the `(B, C, T, F)` spectrogram tensor from flattened patches.
310
+ Use `get_patch_grid_bounds` to recompute the timestamps that describe each
311
+ frame's position in real time.
312
+ Args:
313
+ audio_latents: Latent tensor to unpatchify.
314
+ output_shape: Shape of the unpatched output tensor.
315
+ Returns:
316
+ Unpatched latent tensor. Use `get_patch_grid_bounds` to compute the timing
317
+ metadata associated with the restored latents.
318
+ """
319
+ # audio_latents shape: (batch, time, freq * channels)
320
+ audio_latents = einops.rearrange(
321
+ audio_latents,
322
+ "b t (c f) -> b c t f",
323
+ c=output_shape.channels,
324
+ f=output_shape.mel_bins,
325
+ )
326
+
327
+ return audio_latents
328
+
329
+ def get_patch_grid_bounds(
330
+ self,
331
+ output_shape: AudioLatentShape | VideoLatentShape,
332
+ device: Optional[torch.device] = None,
333
+ ) -> torch.Tensor:
334
+ """
335
+ Return the temporal bounds `[inclusive start, exclusive end)` for every
336
+ patch emitted by `patchify`. For audio this corresponds to timestamps in
337
+ seconds aligned with the original spectrogram grid.
338
+ The returned tensor has shape `[batch_size, 1, time_steps, 2]`, where:
339
+ - axis 1 (size 1) represents the temporal dimension
340
+ - axis 3 (size 2) stores the `[start, end)` timestamps per patch
341
+ Args:
342
+ output_shape: Audio grid specification describing the number of time steps.
343
+ device: Target device for the returned tensor.
344
+ """
345
+ if not isinstance(output_shape, AudioLatentShape):
346
+ raise ValueError("AudioPatchifier expects AudioLatentShape when computing coordinates")
347
+
348
+ return self._compute_audio_timings(output_shape.batch, output_shape.frames, device)
packages/ltx-core/src/ltx_core/components/protocols.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Protocol, Tuple
2
+
3
+ import torch
4
+
5
+ from ltx_core.types import AudioLatentShape, VideoLatentShape
6
+
7
+
8
+ class Patchifier(Protocol):
9
+ """
10
+ Protocol for patchifiers that convert latent tensors into patches and assemble them back.
11
+ """
12
+
13
+ def patchify(
14
+ self,
15
+ latents: torch.Tensor,
16
+ ) -> torch.Tensor:
17
+ ...
18
+ """
19
+ Convert latent tensors into flattened patch tokens.
20
+ Args:
21
+ latents: Latent tensor to patchify.
22
+ Returns:
23
+ Flattened patch tokens tensor.
24
+ """
25
+
26
+ def unpatchify(
27
+ self,
28
+ latents: torch.Tensor,
29
+ output_shape: AudioLatentShape | VideoLatentShape,
30
+ ) -> torch.Tensor:
31
+ """
32
+ Converts latent tensors between spatio-temporal formats and flattened sequence representations.
33
+ Args:
34
+ latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`.
35
+ output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or
36
+ VideoLatentShape.
37
+ Returns:
38
+ Dense latent tensor restored from the flattened representation.
39
+ """
40
+
41
+ @property
42
+ def patch_size(self) -> Tuple[int, int, int]:
43
+ ...
44
+ """
45
+ Returns the patch size as a tuple of (temporal, height, width) dimensions
46
+ """
47
+
48
+ def get_patch_grid_bounds(
49
+ self,
50
+ output_shape: AudioLatentShape | VideoLatentShape,
51
+ device: torch.device | None = None,
52
+ ) -> torch.Tensor:
53
+ ...
54
+ """
55
+ Compute metadata describing where each latent patch resides within the
56
+ grid specified by `output_shape`.
57
+ Args:
58
+ output_shape: Target grid layout for the patches.
59
+ device: Target device for the returned tensor.
60
+ Returns:
61
+ Tensor containing patch coordinate metadata such as spatial or temporal intervals.
62
+ """
63
+
64
+
65
+ class SchedulerProtocol(Protocol):
66
+ """
67
+ Protocol for schedulers that provide a sigmas schedule tensor for a
68
+ given number of steps. Device is cpu.
69
+ """
70
+
71
+ def execute(self, steps: int, **kwargs) -> torch.FloatTensor: ...
72
+
73
+
74
+ class GuiderProtocol(Protocol):
75
+ """
76
+ Protocol for guiders that compute a delta tensor given conditioning inputs.
77
+ The returned delta should be added to the conditional output (cond), enabling
78
+ multiple guiders to be chained together by accumulating their deltas.
79
+ """
80
+
81
+ scale: float
82
+
83
+ def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor: ...
84
+
85
+ def enabled(self) -> bool:
86
+ """
87
+ Returns whether the corresponding perturbation is enabled. E.g. for CFG, this should return False if the scale
88
+ is 1.0.
89
+ """
90
+ ...
91
+
92
+
93
+ class DiffusionStepProtocol(Protocol):
94
+ """
95
+ Protocol for diffusion steps that provide a next sample tensor for a given current sample tensor,
96
+ current denoised sample tensor, and sigmas tensor.
97
+ """
98
+
99
+ def step(
100
+ self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int
101
+ ) -> torch.Tensor: ...
packages/ltx-core/src/ltx_core/components/schedulers.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import lru_cache
3
+
4
+ import numpy
5
+ import scipy
6
+ import torch
7
+
8
+ from ltx_core.components.protocols import SchedulerProtocol
9
+
10
+ BASE_SHIFT_ANCHOR = 1024
11
+ MAX_SHIFT_ANCHOR = 4096
12
+
13
+
14
+ class LTX2Scheduler(SchedulerProtocol):
15
+ """
16
+ Default scheduler for LTX-2 diffusion sampling.
17
+ Generates a sigma schedule with token-count-dependent shifting and optional
18
+ stretching to a terminal value.
19
+ """
20
+
21
+ def execute(
22
+ self,
23
+ steps: int,
24
+ latent: torch.Tensor | None = None,
25
+ max_shift: float = 2.05,
26
+ base_shift: float = 0.95,
27
+ stretch: bool = True,
28
+ terminal: float = 0.1,
29
+ **_kwargs,
30
+ ) -> torch.FloatTensor:
31
+ tokens = math.prod(latent.shape[2:]) if latent is not None else MAX_SHIFT_ANCHOR
32
+ sigmas = torch.linspace(1.0, 0.0, steps + 1)
33
+
34
+ x1 = BASE_SHIFT_ANCHOR
35
+ x2 = MAX_SHIFT_ANCHOR
36
+ mm = (max_shift - base_shift) / (x2 - x1)
37
+ b = base_shift - mm * x1
38
+ sigma_shift = (tokens) * mm + b
39
+
40
+ power = 1
41
+ sigmas = torch.where(
42
+ sigmas != 0,
43
+ math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
44
+ 0,
45
+ )
46
+
47
+ # Stretch sigmas so that its final value matches the given terminal value.
48
+ if stretch:
49
+ non_zero_mask = sigmas != 0
50
+ non_zero_sigmas = sigmas[non_zero_mask]
51
+ one_minus_z = 1.0 - non_zero_sigmas
52
+ scale_factor = one_minus_z[-1] / (1.0 - terminal)
53
+ stretched = 1.0 - (one_minus_z / scale_factor)
54
+ sigmas[non_zero_mask] = stretched
55
+
56
+ return sigmas.to(torch.float32)
57
+
58
+
59
+ class LinearQuadraticScheduler(SchedulerProtocol):
60
+ """
61
+ Scheduler with linear steps followed by quadratic steps.
62
+ Produces a sigma schedule that transitions linearly up to a threshold,
63
+ then follows a quadratic curve for the remaining steps.
64
+ """
65
+
66
+ def execute(
67
+ self, steps: int, threshold_noise: float = 0.025, linear_steps: int | None = None, **_kwargs
68
+ ) -> torch.FloatTensor:
69
+ if steps == 1:
70
+ return torch.FloatTensor([1.0, 0.0])
71
+
72
+ if linear_steps is None:
73
+ linear_steps = steps // 2
74
+ linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
75
+ threshold_noise_step_diff = linear_steps - threshold_noise * steps
76
+ quadratic_steps = steps - linear_steps
77
+ quadratic_sigma_schedule = []
78
+ if quadratic_steps > 0:
79
+ quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
80
+ linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
81
+ const = quadratic_coef * (linear_steps**2)
82
+ quadratic_sigma_schedule = [
83
+ quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, steps)
84
+ ]
85
+ sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
86
+ sigma_schedule = [1.0 - x for x in sigma_schedule]
87
+ return torch.FloatTensor(sigma_schedule)
88
+
89
+
90
+ class BetaScheduler(SchedulerProtocol):
91
+ """
92
+ Scheduler using a beta distribution to sample timesteps.
93
+ Based on: https://arxiv.org/abs/2407.12173
94
+ """
95
+
96
+ shift = 2.37
97
+ timesteps_length = 10000
98
+
99
+ def execute(self, steps: int, alpha: float = 0.6, beta: float = 0.6) -> torch.FloatTensor:
100
+ """
101
+ Execute the beta scheduler.
102
+ Args:
103
+ steps: The number of steps to execute the scheduler for.
104
+ alpha: The alpha parameter for the beta distribution.
105
+ beta: The beta parameter for the beta distribution.
106
+ Warnings:
107
+ The number of steps within `sigmas` theoretically might be less than `steps+1`,
108
+ because of the deduplication of the identical timesteps
109
+ Returns:
110
+ A tensor of sigmas.
111
+ """
112
+ model_sampling_sigmas = _precalculate_model_sampling_sigmas(self.shift, self.timesteps_length)
113
+ total_timesteps = len(model_sampling_sigmas) - 1
114
+ ts = 1 - numpy.linspace(0, 1, steps, endpoint=False)
115
+ ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps).tolist()
116
+ ts = list(dict.fromkeys(ts))
117
+
118
+ sigmas = [float(model_sampling_sigmas[int(t)]) for t in ts] + [0.0]
119
+ return torch.FloatTensor(sigmas)
120
+
121
+
122
+ @lru_cache(maxsize=5)
123
+ def _precalculate_model_sampling_sigmas(shift: float, timesteps_length: int) -> torch.Tensor:
124
+ timesteps = torch.arange(1, timesteps_length + 1, 1) / timesteps_length
125
+ return torch.Tensor([flux_time_shift(shift, 1.0, t) for t in timesteps])
126
+
127
+
128
+ def flux_time_shift(mu: float, sigma: float, t: float) -> float:
129
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
packages/ltx-core/src/ltx_core/conditioning/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Conditioning utilities: latent state, tools, and conditioning types."""
2
+
3
+ from ltx_core.conditioning.exceptions import ConditioningError
4
+ from ltx_core.conditioning.item import ConditioningItem
5
+ from ltx_core.conditioning.types import VideoConditionByKeyframeIndex, VideoConditionByLatentIndex
6
+
7
+ __all__ = [
8
+ "ConditioningError",
9
+ "ConditioningItem",
10
+ "VideoConditionByKeyframeIndex",
11
+ "VideoConditionByLatentIndex",
12
+ ]
packages/ltx-core/src/ltx_core/conditioning/exceptions.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ class ConditioningError(Exception):
2
+ """
3
+ Class for conditioning-related errors.
4
+ """
packages/ltx-core/src/ltx_core/conditioning/item.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Protocol
2
+
3
+ from ltx_core.tools import LatentTools
4
+ from ltx_core.types import LatentState
5
+
6
+
7
+ class ConditioningItem(Protocol):
8
+ """Protocol for conditioning items that modify latent state during diffusion."""
9
+
10
+ def apply_to(self, latent_state: LatentState, latent_tools: LatentTools) -> LatentState:
11
+ """
12
+ Apply the conditioning to the latent state.
13
+ Args:
14
+ latent_state: The latent state to apply the conditioning to. This is state always patchified.
15
+ Returns:
16
+ The latent state after the conditioning has been applied.
17
+ IMPORTANT: If the conditioning needs to add extra tokens to the latent, it should add them to the end of the
18
+ latent.
19
+ """
20
+ ...
packages/ltx-core/src/ltx_core/conditioning/types/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """Conditioning type implementations."""
2
+
3
+ from ltx_core.conditioning.types.keyframe_cond import VideoConditionByKeyframeIndex
4
+ from ltx_core.conditioning.types.latent_cond import VideoConditionByLatentIndex
5
+
6
+ __all__ = [
7
+ "VideoConditionByKeyframeIndex",
8
+ "VideoConditionByLatentIndex",
9
+ ]
packages/ltx-core/src/ltx_core/conditioning/types/keyframe_cond.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ltx_core.components.patchifiers import get_pixel_coords
4
+ from ltx_core.conditioning.item import ConditioningItem
5
+ from ltx_core.tools import VideoLatentTools
6
+ from ltx_core.types import LatentState, VideoLatentShape
7
+
8
+
9
+ class VideoConditionByKeyframeIndex(ConditioningItem):
10
+ """
11
+ Conditions video generation on keyframe latents at a specific frame index.
12
+ Appends keyframe tokens to the latent state with positions offset by frame_idx,
13
+ and sets denoise strength according to the strength parameter.
14
+ """
15
+
16
+ def __init__(self, keyframes: torch.Tensor, frame_idx: int, strength: float):
17
+ self.keyframes = keyframes
18
+ self.frame_idx = frame_idx
19
+ self.strength = strength
20
+
21
+ def apply_to(
22
+ self,
23
+ latent_state: LatentState,
24
+ latent_tools: VideoLatentTools,
25
+ ) -> LatentState:
26
+ tokens = latent_tools.patchifier.patchify(self.keyframes)
27
+ latent_coords = latent_tools.patchifier.get_patch_grid_bounds(
28
+ output_shape=VideoLatentShape.from_torch_shape(self.keyframes.shape),
29
+ device=self.keyframes.device,
30
+ )
31
+ positions = get_pixel_coords(
32
+ latent_coords=latent_coords,
33
+ scale_factors=latent_tools.scale_factors,
34
+ causal_fix=latent_tools.causal_fix if self.frame_idx == 0 else False,
35
+ )
36
+
37
+ positions[:, 0, ...] += self.frame_idx
38
+ positions = positions.to(dtype=torch.float32)
39
+ positions[:, 0, ...] /= latent_tools.fps
40
+
41
+ denoise_mask = torch.full(
42
+ size=(*tokens.shape[:2], 1),
43
+ fill_value=1.0 - self.strength,
44
+ device=self.keyframes.device,
45
+ dtype=self.keyframes.dtype,
46
+ )
47
+
48
+ return LatentState(
49
+ latent=torch.cat([latent_state.latent, tokens], dim=1),
50
+ denoise_mask=torch.cat([latent_state.denoise_mask, denoise_mask], dim=1),
51
+ positions=torch.cat([latent_state.positions, positions], dim=2),
52
+ clean_latent=torch.cat([latent_state.clean_latent, tokens], dim=1),
53
+ )
packages/ltx-core/src/ltx_core/conditioning/types/latent_cond.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ltx_core.conditioning.exceptions import ConditioningError
4
+ from ltx_core.conditioning.item import ConditioningItem
5
+ from ltx_core.tools import LatentTools
6
+ from ltx_core.types import LatentState
7
+
8
+
9
+ class VideoConditionByLatentIndex(ConditioningItem):
10
+ """
11
+ Conditions video generation by injecting latents at a specific latent frame index.
12
+ Replaces tokens in the latent state at positions corresponding to latent_idx,
13
+ and sets denoise strength according to the strength parameter.
14
+ """
15
+
16
+ def __init__(self, latent: torch.Tensor, strength: float, latent_idx: int):
17
+ self.latent = latent
18
+ self.strength = strength
19
+ self.latent_idx = latent_idx
20
+
21
+ def apply_to(self, latent_state: LatentState, latent_tools: LatentTools) -> LatentState:
22
+ cond_batch, cond_channels, _, cond_height, cond_width = self.latent.shape
23
+ tgt_batch, tgt_channels, tgt_frames, tgt_height, tgt_width = latent_tools.target_shape.to_torch_shape()
24
+
25
+ if (cond_batch, cond_channels, cond_height, cond_width) != (tgt_batch, tgt_channels, tgt_height, tgt_width):
26
+ raise ConditioningError(
27
+ f"Can't apply image conditioning item to latent with shape {latent_tools.target_shape}, expected "
28
+ f"shape is ({tgt_batch}, {tgt_channels}, {tgt_frames}, {tgt_height}, {tgt_width}). Make sure "
29
+ "the image and latent have the same spatial shape."
30
+ )
31
+
32
+ tokens = latent_tools.patchifier.patchify(self.latent)
33
+ start_token = latent_tools.patchifier.get_token_count(
34
+ latent_tools.target_shape._replace(frames=self.latent_idx)
35
+ )
36
+ stop_token = start_token + tokens.shape[1]
37
+
38
+ latent_state = latent_state.clone()
39
+
40
+ latent_state.latent[:, start_token:stop_token] = tokens
41
+ latent_state.clean_latent[:, start_token:stop_token] = tokens
42
+ latent_state.denoise_mask[:, start_token:stop_token] = 1.0 - self.strength
43
+
44
+ return latent_state
packages/ltx-core/src/ltx_core/guidance/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Guidance and perturbation utilities for attention manipulation."""
2
+
3
+ from ltx_core.guidance.perturbations import (
4
+ BatchedPerturbationConfig,
5
+ Perturbation,
6
+ PerturbationConfig,
7
+ PerturbationType,
8
+ )
9
+
10
+ __all__ = [
11
+ "BatchedPerturbationConfig",
12
+ "Perturbation",
13
+ "PerturbationConfig",
14
+ "PerturbationType",
15
+ ]
packages/ltx-core/src/ltx_core/guidance/perturbations.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from enum import Enum
3
+
4
+ import torch
5
+ from torch._prims_common import DeviceLikeType
6
+
7
+
8
+ class PerturbationType(Enum):
9
+ """Types of attention perturbations for STG (Spatio-Temporal Guidance)."""
10
+
11
+ SKIP_A2V_CROSS_ATTN = "skip_a2v_cross_attn"
12
+ SKIP_V2A_CROSS_ATTN = "skip_v2a_cross_attn"
13
+ SKIP_VIDEO_SELF_ATTN = "skip_video_self_attn"
14
+ SKIP_AUDIO_SELF_ATTN = "skip_audio_self_attn"
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class Perturbation:
19
+ """A single perturbation specifying which attention type to skip and in which blocks."""
20
+
21
+ type: PerturbationType
22
+ blocks: list[int] | None # None means all blocks
23
+
24
+ def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:
25
+ if self.type != perturbation_type:
26
+ return False
27
+
28
+ if self.blocks is None:
29
+ return True
30
+
31
+ return block in self.blocks
32
+
33
+
34
+ @dataclass(frozen=True)
35
+ class PerturbationConfig:
36
+ """Configuration holding a list of perturbations for a single sample."""
37
+
38
+ perturbations: list[Perturbation] | None
39
+
40
+ def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:
41
+ if self.perturbations is None:
42
+ return False
43
+
44
+ return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
45
+
46
+ @staticmethod
47
+ def empty() -> "PerturbationConfig":
48
+ return PerturbationConfig([])
49
+
50
+
51
+ @dataclass(frozen=True)
52
+ class BatchedPerturbationConfig:
53
+ """Perturbation configurations for a batch, with utilities for generating attention masks."""
54
+
55
+ perturbations: list[PerturbationConfig]
56
+
57
+ def mask(
58
+ self, perturbation_type: PerturbationType, block: int, device: DeviceLikeType, dtype: torch.dtype
59
+ ) -> torch.Tensor:
60
+ mask = torch.ones((len(self.perturbations),), device=device, dtype=dtype)
61
+ for batch_idx, perturbation in enumerate(self.perturbations):
62
+ if perturbation.is_perturbed(perturbation_type, block):
63
+ mask[batch_idx] = 0
64
+
65
+ return mask
66
+
67
+ def mask_like(self, perturbation_type: PerturbationType, block: int, values: torch.Tensor) -> torch.Tensor:
68
+ mask = self.mask(perturbation_type, block, values.device, values.dtype)
69
+ return mask.view(mask.numel(), *([1] * len(values.shape[1:])))
70
+
71
+ def any_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool:
72
+ return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
73
+
74
+ def all_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool:
75
+ return all(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
76
+
77
+ @staticmethod
78
+ def empty(batch_size: int) -> "BatchedPerturbationConfig":
79
+ return BatchedPerturbationConfig([PerturbationConfig.empty() for _ in range(batch_size)])
packages/ltx-core/src/ltx_core/loader/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Loader utilities for model weights, LoRAs, and safetensor operations."""
2
+
3
+ from ltx_core.loader.fuse_loras import apply_loras
4
+ from ltx_core.loader.module_ops import ModuleOps
5
+ from ltx_core.loader.primitives import (
6
+ LoRAAdaptableProtocol,
7
+ LoraPathStrengthAndSDOps,
8
+ LoraStateDictWithStrength,
9
+ ModelBuilderProtocol,
10
+ StateDict,
11
+ StateDictLoader,
12
+ )
13
+ from ltx_core.loader.registry import DummyRegistry, Registry, StateDictRegistry
14
+ from ltx_core.loader.sd_ops import (
15
+ LTXV_LORA_COMFY_RENAMING_MAP,
16
+ ContentMatching,
17
+ ContentReplacement,
18
+ KeyValueOperation,
19
+ KeyValueOperationResult,
20
+ SDKeyValueOperation,
21
+ SDOps,
22
+ )
23
+ from ltx_core.loader.sft_loader import SafetensorsModelStateDictLoader, SafetensorsStateDictLoader
24
+ from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder
25
+
26
+ __all__ = [
27
+ "LTXV_LORA_COMFY_RENAMING_MAP",
28
+ "ContentMatching",
29
+ "ContentReplacement",
30
+ "DummyRegistry",
31
+ "KeyValueOperation",
32
+ "KeyValueOperationResult",
33
+ "LoRAAdaptableProtocol",
34
+ "LoraPathStrengthAndSDOps",
35
+ "LoraStateDictWithStrength",
36
+ "ModelBuilderProtocol",
37
+ "ModuleOps",
38
+ "Registry",
39
+ "SDKeyValueOperation",
40
+ "SDOps",
41
+ "SafetensorsModelStateDictLoader",
42
+ "SafetensorsStateDictLoader",
43
+ "SingleGPUModelBuilder",
44
+ "StateDict",
45
+ "StateDictLoader",
46
+ "StateDictRegistry",
47
+ "apply_loras",
48
+ ]
packages/ltx-core/src/ltx_core/loader/fuse_loras.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+
4
+ from ltx_core.loader.kernels import fused_add_round_kernel
5
+ from ltx_core.loader.primitives import LoraStateDictWithStrength, StateDict
6
+
7
+ BLOCK_SIZE = 1024
8
+
9
+
10
+ def fused_add_round_launch(target_weight: torch.Tensor, original_weight: torch.Tensor, seed: int) -> torch.Tensor:
11
+ if original_weight.dtype == torch.float8_e4m3fn:
12
+ exponent_bits, mantissa_bits, exponent_bias = 4, 3, 7
13
+ elif original_weight.dtype == torch.float8_e5m2:
14
+ exponent_bits, mantissa_bits, exponent_bias = 5, 2, 15 # noqa: F841
15
+ else:
16
+ raise ValueError("Unsupported dtype")
17
+
18
+ if target_weight.dtype != torch.bfloat16:
19
+ raise ValueError("target_weight dtype must be bfloat16")
20
+
21
+ # Calculate grid and block sizes
22
+ n_elements = original_weight.numel()
23
+ grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
24
+
25
+ # Launch kernel
26
+ fused_add_round_kernel[grid](
27
+ original_weight,
28
+ target_weight,
29
+ seed,
30
+ n_elements,
31
+ exponent_bias,
32
+ mantissa_bits,
33
+ BLOCK_SIZE,
34
+ )
35
+ return target_weight
36
+
37
+
38
+ def calculate_weight_float8_(target_weights: torch.Tensor, original_weights: torch.Tensor) -> torch.Tensor:
39
+ result = fused_add_round_launch(target_weights, original_weights, seed=0).to(target_weights.dtype)
40
+ target_weights.copy_(result, non_blocking=True)
41
+ return target_weights
42
+
43
+
44
+ def _prepare_deltas(
45
+ lora_sd_and_strengths: list[LoraStateDictWithStrength], key: str, dtype: torch.dtype, device: torch.device
46
+ ) -> torch.Tensor | None:
47
+ deltas = []
48
+ prefix = key[: -len(".weight")]
49
+ key_a = f"{prefix}.lora_A.weight"
50
+ key_b = f"{prefix}.lora_B.weight"
51
+ for lsd, coef in lora_sd_and_strengths:
52
+ if key_a not in lsd.sd or key_b not in lsd.sd:
53
+ continue
54
+ product = torch.matmul(lsd.sd[key_b] * coef, lsd.sd[key_a])
55
+ deltas.append(product.to(dtype=dtype, device=device))
56
+ if len(deltas) == 0:
57
+ return None
58
+ elif len(deltas) == 1:
59
+ return deltas[0]
60
+ return torch.sum(torch.stack(deltas, dim=0), dim=0)
61
+
62
+
63
+ def apply_loras(
64
+ model_sd: StateDict,
65
+ lora_sd_and_strengths: list[LoraStateDictWithStrength],
66
+ dtype: torch.dtype,
67
+ destination_sd: StateDict | None = None,
68
+ ) -> StateDict:
69
+ sd = {}
70
+ if destination_sd is not None:
71
+ sd = destination_sd.sd
72
+ size = 0
73
+ device = torch.device("meta")
74
+ inner_dtypes = set()
75
+ for key, weight in model_sd.sd.items():
76
+ if weight is None:
77
+ continue
78
+ device = weight.device
79
+ target_dtype = dtype if dtype is not None else weight.dtype
80
+ deltas_dtype = target_dtype if target_dtype not in [torch.float8_e4m3fn, torch.float8_e5m2] else torch.bfloat16
81
+ deltas = _prepare_deltas(lora_sd_and_strengths, key, deltas_dtype, device)
82
+ if deltas is None:
83
+ if key in sd:
84
+ continue
85
+ deltas = weight.clone().to(dtype=target_dtype, device=device)
86
+ elif weight.dtype == torch.float8_e4m3fn:
87
+ if str(device).startswith("cuda"):
88
+ deltas = calculate_weight_float8_(deltas, weight)
89
+ else:
90
+ deltas.add_(weight.to(dtype=deltas.dtype, device=device))
91
+ elif weight.dtype == torch.bfloat16:
92
+ deltas.add_(weight)
93
+ else:
94
+ raise ValueError(f"Unsupported dtype: {weight.dtype}")
95
+ sd[key] = deltas.to(dtype=target_dtype)
96
+ inner_dtypes.add(target_dtype)
97
+ size += deltas.nbytes
98
+ if destination_sd is not None:
99
+ return destination_sd
100
+ return StateDict(sd, device, size, inner_dtypes)
packages/ltx-core/src/ltx_core/loader/kernels.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ruff: noqa: ANN001, ANN201, ERA001, N803, N806
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def fused_add_round_kernel(
8
+ x_ptr,
9
+ output_ptr, # contents will be added to the output
10
+ seed,
11
+ n_elements,
12
+ EXPONENT_BIAS,
13
+ MANTISSA_BITS,
14
+ BLOCK_SIZE: tl.constexpr,
15
+ ):
16
+ """
17
+ A kernel to upcast 8bit quantized weights to bfloat16 with stochastic rounding
18
+ and add them to bfloat16 output weights. Might be used to upcast original model weights
19
+ and to further add them to precalculated deltas coming from LoRAs.
20
+ """
21
+ # Get program ID and compute offsets
22
+ pid = tl.program_id(axis=0)
23
+ block_start = pid * BLOCK_SIZE
24
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
25
+ mask = offsets < n_elements
26
+
27
+ # Load data
28
+ x = tl.load(x_ptr + offsets, mask=mask)
29
+ rand_vals = tl.rand(seed, offsets) - 0.5
30
+
31
+ x = tl.cast(x, tl.float16)
32
+ delta = tl.load(output_ptr + offsets, mask=mask)
33
+ delta = tl.cast(delta, tl.float16)
34
+ x = x + delta
35
+
36
+ x_bits = tl.cast(x, tl.int16, bitcast=True)
37
+
38
+ # Calculate the exponent. Unbiased fp16 exponent is ((x_bits & 0x7C00) >> 10) - 15 for
39
+ # normal numbers and -14 for subnormals.
40
+ fp16_exponent_bits = (x_bits & 0x7C00) >> 10
41
+ fp16_normals = fp16_exponent_bits > 0
42
+ fp16_exponent = tl.where(fp16_normals, fp16_exponent_bits - 15, -14)
43
+
44
+ # Add the target dtype's exponent bias and clamp to the target dtype's exponent range.
45
+ exponent = fp16_exponent + EXPONENT_BIAS
46
+ MAX_EXPONENT = 2 * EXPONENT_BIAS + 1
47
+ exponent = tl.where(exponent > MAX_EXPONENT, MAX_EXPONENT, exponent)
48
+ exponent = tl.where(exponent < 0, 0, exponent)
49
+
50
+ # Normal ULP exponent, expressed as an fp16 exponent field:
51
+ # (exponent - EXPONENT_BIAS - MANTISSA_BITS) + 15
52
+ # Simplifies to: fp16_exponent - MANTISSA_BITS + 15
53
+ # See https://en.wikipedia.org/wiki/Unit_in_the_last_place
54
+ eps_exp = tl.maximum(0, tl.minimum(31, exponent - EXPONENT_BIAS - MANTISSA_BITS + 15))
55
+
56
+ # Calculate epsilon in the target dtype
57
+ eps_normal = tl.cast(tl.cast(eps_exp << 10, tl.int16), tl.float16, bitcast=True)
58
+
59
+ # Subnormal ULP: 2^(1 - EXPONENT_BIAS - MANTISSA_BITS) ->
60
+ # fp16 exponent bits: (1 - EXPONENT_BIAS - MANTISSA_BITS) + 15 =
61
+ # 16 - EXPONENT_BIAS - MANTISSA_BITS
62
+ eps_subnormal = tl.cast(tl.cast((16 - EXPONENT_BIAS - MANTISSA_BITS) << 10, tl.int16), tl.float16, bitcast=True)
63
+ eps = tl.where(exponent > 0, eps_normal, eps_subnormal)
64
+
65
+ # Apply zero mask to epsilon
66
+ eps = tl.where(x == 0, 0.0, eps)
67
+
68
+ # Apply stochastic rounding
69
+ output = tl.cast(x + rand_vals * eps, tl.bfloat16)
70
+
71
+ # Store the result
72
+ tl.store(output_ptr + offsets, output, mask=mask)
packages/ltx-core/src/ltx_core/loader/module_ops.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, NamedTuple
2
+
3
+ import torch
4
+
5
+
6
+ class ModuleOps(NamedTuple):
7
+ """
8
+ Defines a named operation for matching and mutating PyTorch modules.
9
+ Used to selectively transform modules in a model (e.g., replacing layers with quantized versions).
10
+ """
11
+
12
+ name: str
13
+ matcher: Callable[[torch.nn.Module], bool]
14
+ mutator: Callable[[torch.nn.Module], torch.nn.Module]
packages/ltx-core/src/ltx_core/loader/primitives.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import NamedTuple, Protocol
3
+
4
+ import torch
5
+
6
+ from ltx_core.loader.module_ops import ModuleOps
7
+ from ltx_core.loader.sd_ops import SDOps
8
+ from ltx_core.model.model_protocol import ModelType
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class StateDict:
13
+ """
14
+ Immutable container for a PyTorch state dictionary.
15
+ Contains:
16
+ - sd: Dictionary of tensors (weights, buffers, etc.)
17
+ - device: Device where tensors are stored
18
+ - size: Total memory footprint in bytes
19
+ - dtype: Set of tensor dtypes present
20
+ """
21
+
22
+ sd: dict
23
+ device: torch.device
24
+ size: int
25
+ dtype: set[torch.dtype]
26
+
27
+ def footprint(self) -> tuple[int, torch.device]:
28
+ return self.size, self.device
29
+
30
+
31
+ class StateDictLoader(Protocol):
32
+ """
33
+ Protocol for loading state dictionaries from various sources.
34
+ Implementations must provide:
35
+ - metadata: Extract model metadata from a single path
36
+ - load: Load state dict from path(s) and apply SDOps transformations
37
+ """
38
+
39
+ def metadata(self, path: str) -> dict:
40
+ """
41
+ Load metadata from path
42
+ """
43
+
44
+ def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict:
45
+ """
46
+ Load state dict from path or paths (for sharded model storage) and apply sd_ops
47
+ """
48
+
49
+
50
+ class ModelBuilderProtocol(Protocol[ModelType]):
51
+ """
52
+ Protocol for building PyTorch models from configuration dictionaries.
53
+ Implementations must provide:
54
+ - meta_model: Create a model from configuration dictionary and apply module operations
55
+ - build: Create and initialize a model from state dictionary and apply dtype transformations
56
+ """
57
+
58
+ def meta_model(self, config: dict, module_ops: list[ModuleOps] | None = None) -> ModelType:
59
+ """
60
+ Create a model on the meta device from a configuration dictionary.
61
+ This decouples model creation from weight loading, allowing the model
62
+ architecture to be instantiated without allocating memory for parameters.
63
+ Args:
64
+ config: Model configuration dictionary.
65
+ module_ops: Optional list of module operations to apply (e.g., quantization).
66
+ Returns:
67
+ Model instance on meta device (no actual memory allocated for parameters).
68
+ """
69
+ ...
70
+
71
+ def build(self, dtype: torch.dtype | None = None) -> ModelType:
72
+ """
73
+ Build the model
74
+ Args:
75
+ dtype: Target dtype for the model, if None, uses the dtype of the model_path model
76
+ Returns:
77
+ Model instance
78
+ """
79
+ ...
80
+
81
+
82
+ class LoRAAdaptableProtocol(Protocol):
83
+ """
84
+ Protocol for models that can be adapted with LoRAs.
85
+ Implementations must provide:
86
+ - lora: Add a LoRA to the model
87
+ """
88
+
89
+ def lora(self, lora_path: str, strength: float) -> "LoRAAdaptableProtocol":
90
+ pass
91
+
92
+
93
+ class LoraPathStrengthAndSDOps(NamedTuple):
94
+ """
95
+ Tuple containing a LoRA path, strength, and SDOps for applying to the LoRA state dict.
96
+ """
97
+
98
+ path: str
99
+ strength: float
100
+ sd_ops: SDOps
101
+
102
+
103
+ class LoraStateDictWithStrength(NamedTuple):
104
+ """
105
+ Tuple containing a LoRA state dict and strength for applying to the model.
106
+ """
107
+
108
+ state_dict: StateDict
109
+ strength: float
packages/ltx-core/src/ltx_core/loader/registry.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import threading
3
+ from dataclasses import dataclass, field
4
+ from pathlib import Path
5
+ from typing import Protocol
6
+
7
+ from ltx_core.loader.primitives import StateDict
8
+ from ltx_core.loader.sd_ops import SDOps
9
+
10
+
11
+ class Registry(Protocol):
12
+ """
13
+ Protocol for managing state dictionaries in a registry.
14
+ It is used to store state dictionaries and reuse them later without loading them again.
15
+ Implementations must provide:
16
+ - add: Add a state dictionary to the registry
17
+ - pop: Remove a state dictionary from the registry
18
+ - get: Retrieve a state dictionary from the registry
19
+ - clear: Clear all state dictionaries from the registry
20
+ """
21
+
22
+ def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None: ...
23
+
24
+ def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ...
25
+
26
+ def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ...
27
+
28
+ def clear(self) -> None: ...
29
+
30
+
31
+ class DummyRegistry(Registry):
32
+ """
33
+ Dummy registry that does not store state dictionaries.
34
+ """
35
+
36
+ def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None:
37
+ pass
38
+
39
+ def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
40
+ pass
41
+
42
+ def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
43
+ pass
44
+
45
+ def clear(self) -> None:
46
+ pass
47
+
48
+
49
+ @dataclass
50
+ class StateDictRegistry(Registry):
51
+ """
52
+ Registry that stores state dictionaries in a dictionary.
53
+ """
54
+
55
+ _state_dicts: dict[str, StateDict] = field(default_factory=dict)
56
+ _lock: threading.Lock = field(default_factory=threading.Lock)
57
+
58
+ def _generate_id(self, paths: list[str], sd_ops: SDOps) -> str:
59
+ m = hashlib.sha256()
60
+ parts = [str(Path(p).resolve()) for p in paths]
61
+ if sd_ops is not None:
62
+ parts.append(sd_ops.name)
63
+ m.update("\0".join(parts).encode("utf-8"))
64
+ return m.hexdigest()
65
+
66
+ def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> str:
67
+ sd_id = self._generate_id(paths, sd_ops)
68
+ with self._lock:
69
+ if sd_id in self._state_dicts:
70
+ raise ValueError(f"State dict retrieved from {paths} with {sd_ops} already added, check with get first")
71
+ self._state_dicts[sd_id] = state_dict
72
+ return sd_id
73
+
74
+ def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
75
+ with self._lock:
76
+ return self._state_dicts.pop(self._generate_id(paths, sd_ops), None)
77
+
78
+ def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
79
+ with self._lock:
80
+ return self._state_dicts.get(self._generate_id(paths, sd_ops), None)
81
+
82
+ def clear(self) -> None:
83
+ with self._lock:
84
+ self._state_dicts.clear()
packages/ltx-core/src/ltx_core/loader/sd_ops.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, replace
2
+ from typing import NamedTuple, Protocol
3
+
4
+ import torch
5
+
6
+
7
+ @dataclass(frozen=True, slots=True)
8
+ class ContentReplacement:
9
+ """
10
+ Represents a content replacement operation.
11
+ Used to replace a specific content with a replacement in a state dict key.
12
+ """
13
+
14
+ content: str
15
+ replacement: str
16
+
17
+
18
+ @dataclass(frozen=True, slots=True)
19
+ class ContentMatching:
20
+ """
21
+ Represents a content matching operation.
22
+ Used to match a specific prefix and suffix in a state dict key.
23
+ """
24
+
25
+ prefix: str = ""
26
+ suffix: str = ""
27
+
28
+
29
+ class KeyValueOperationResult(NamedTuple):
30
+ """
31
+ Represents the result of a key-value operation.
32
+ Contains the new key and value after the operation has been applied.
33
+ """
34
+
35
+ new_key: str
36
+ new_value: torch.Tensor
37
+
38
+
39
+ class KeyValueOperation(Protocol):
40
+ """
41
+ Protocol for key-value operations.
42
+ Used to apply operations to a specific key and value in a state dict.
43
+ """
44
+
45
+ def __call__(self, tensor_key: str, tensor_value: torch.Tensor) -> list[KeyValueOperationResult]: ...
46
+
47
+
48
+ @dataclass(frozen=True, slots=True)
49
+ class SDKeyValueOperation:
50
+ """
51
+ Represents a key-value operation.
52
+ Used to apply operations to a specific key and value in a state dict.
53
+ """
54
+
55
+ key_matcher: ContentMatching
56
+ kv_operation: KeyValueOperation
57
+
58
+
59
+ @dataclass(frozen=True, slots=True)
60
+ class SDOps:
61
+ """Immutable class representing state dict key operations."""
62
+
63
+ name: str
64
+ mapping: tuple[
65
+ ContentReplacement | ContentMatching | SDKeyValueOperation, ...
66
+ ] = () # Immutable tuple of (key, value) pairs
67
+
68
+ def with_replacement(self, content: str, replacement: str) -> "SDOps":
69
+ """Create a new SDOps instance with the specified replacement added to the mapping."""
70
+
71
+ new_mapping = (*self.mapping, ContentReplacement(content, replacement))
72
+ return replace(self, mapping=new_mapping)
73
+
74
+ def with_matching(self, prefix: str = "", suffix: str = "") -> "SDOps":
75
+ """Create a new SDOps instance with the specified prefix and suffix matching added to the mapping."""
76
+
77
+ new_mapping = (*self.mapping, ContentMatching(prefix, suffix))
78
+ return replace(self, mapping=new_mapping)
79
+
80
+ def with_kv_operation(
81
+ self,
82
+ operation: KeyValueOperation,
83
+ key_prefix: str = "",
84
+ key_suffix: str = "",
85
+ ) -> "SDOps":
86
+ """Create a new SDOps instance with the specified value operation added to the mapping."""
87
+ key_matcher = ContentMatching(key_prefix, key_suffix)
88
+ sd_kv_operation = SDKeyValueOperation(key_matcher, operation)
89
+ new_mapping = (*self.mapping, sd_kv_operation)
90
+ return replace(self, mapping=new_mapping)
91
+
92
+ def apply_to_key(self, key: str) -> str | None:
93
+ """Apply the mapping to the given name."""
94
+ matchers = [content for content in self.mapping if isinstance(content, ContentMatching)]
95
+ valid = any(key.startswith(f.prefix) and key.endswith(f.suffix) for f in matchers)
96
+ if not valid:
97
+ return None
98
+
99
+ for replacement in self.mapping:
100
+ if not isinstance(replacement, ContentReplacement):
101
+ continue
102
+ if replacement.content in key:
103
+ key = key.replace(replacement.content, replacement.replacement)
104
+ return key
105
+
106
+ def apply_to_key_value(self, key: str, value: torch.Tensor) -> list[KeyValueOperationResult]:
107
+ """Apply the value operation to the given name and associated value."""
108
+ for operation in self.mapping:
109
+ if not isinstance(operation, SDKeyValueOperation):
110
+ continue
111
+ if key.startswith(operation.key_matcher.prefix) and key.endswith(operation.key_matcher.suffix):
112
+ return operation.kv_operation(key, value)
113
+ return [KeyValueOperationResult(key, value)]
114
+
115
+
116
+ # Predefined SDOps instances
117
+ LTXV_LORA_COMFY_RENAMING_MAP = (
118
+ SDOps("LTXV_LORA_COMFY_PREFIX_MAP").with_matching().with_replacement("diffusion_model.", "")
119
+ )
120
+
121
+ LTXV_LORA_COMFY_TARGET_MAP = (
122
+ SDOps("LTXV_LORA_COMFY_TARGET_MAP")
123
+ .with_matching()
124
+ .with_replacement("diffusion_model.", "")
125
+ .with_replacement(".lora_A.weight", ".weight")
126
+ .with_replacement(".lora_B.weight", ".weight")
127
+ )
packages/ltx-core/src/ltx_core/loader/sft_loader.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import safetensors
4
+ import torch
5
+
6
+ from ltx_core.loader.primitives import StateDict, StateDictLoader
7
+ from ltx_core.loader.sd_ops import SDOps
8
+
9
+
10
+ class SafetensorsStateDictLoader(StateDictLoader):
11
+ """
12
+ Loads weights from safetensors files without metadata support.
13
+ Use this for loading raw weight files. For model files that include
14
+ configuration metadata, use SafetensorsModelStateDictLoader instead.
15
+ """
16
+
17
+ def metadata(self, path: str) -> dict:
18
+ raise NotImplementedError("Not implemented")
19
+
20
+ def load(self, path: str | list[str], sd_ops: SDOps, device: torch.device | None = None) -> StateDict:
21
+ """
22
+ Load state dict from path or paths (for sharded model storage) and apply sd_ops
23
+ """
24
+ sd = {}
25
+ size = 0
26
+ dtype = set()
27
+ device = device or torch.device("cpu")
28
+ model_paths = path if isinstance(path, list) else [path]
29
+ for shard_path in model_paths:
30
+ with safetensors.safe_open(shard_path, framework="pt", device=str(device)) as f:
31
+ safetensor_keys = f.keys()
32
+ for name in safetensor_keys:
33
+ expected_name = name if sd_ops is None else sd_ops.apply_to_key(name)
34
+ if expected_name is None:
35
+ continue
36
+ value = f.get_tensor(name).to(device=device, non_blocking=True, copy=False)
37
+ key_value_pairs = ((expected_name, value),)
38
+ if sd_ops is not None:
39
+ key_value_pairs = sd_ops.apply_to_key_value(expected_name, value)
40
+ for key, value in key_value_pairs:
41
+ size += value.nbytes
42
+ dtype.add(value.dtype)
43
+ sd[key] = value
44
+
45
+ return StateDict(sd=sd, device=device, size=size, dtype=dtype)
46
+
47
+
48
+ class SafetensorsModelStateDictLoader(StateDictLoader):
49
+ """
50
+ Loads weights and configuration metadata from safetensors model files.
51
+ Unlike SafetensorsStateDictLoader, this loader can read model configuration
52
+ from the safetensors file metadata via the metadata() method.
53
+ """
54
+
55
+ def __init__(self, weight_loader: SafetensorsStateDictLoader | None = None):
56
+ self.weight_loader = weight_loader if weight_loader is not None else SafetensorsStateDictLoader()
57
+
58
+ def metadata(self, path: str) -> dict:
59
+ with safetensors.safe_open(path, framework="pt") as f:
60
+ return json.loads(f.metadata()["config"])
61
+
62
+ def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict:
63
+ return self.weight_loader.load(path, sd_ops, device)
packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass, field, replace
3
+ from typing import Generic
4
+
5
+ import torch
6
+
7
+ from ltx_core.loader.fuse_loras import apply_loras
8
+ from ltx_core.loader.module_ops import ModuleOps
9
+ from ltx_core.loader.primitives import (
10
+ LoRAAdaptableProtocol,
11
+ LoraPathStrengthAndSDOps,
12
+ LoraStateDictWithStrength,
13
+ ModelBuilderProtocol,
14
+ StateDict,
15
+ StateDictLoader,
16
+ )
17
+ from ltx_core.loader.registry import DummyRegistry, Registry
18
+ from ltx_core.loader.sd_ops import SDOps
19
+ from ltx_core.loader.sft_loader import SafetensorsModelStateDictLoader
20
+ from ltx_core.model.model_protocol import ModelConfigurator, ModelType
21
+
22
+ logger: logging.Logger = logging.getLogger(__name__)
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class SingleGPUModelBuilder(Generic[ModelType], ModelBuilderProtocol[ModelType], LoRAAdaptableProtocol):
27
+ """
28
+ Builder for PyTorch models residing on a single GPU.
29
+ """
30
+
31
+ model_class_configurator: type[ModelConfigurator[ModelType]]
32
+ model_path: str | tuple[str, ...]
33
+ model_sd_ops: SDOps | None = None
34
+ module_ops: tuple[ModuleOps, ...] = field(default_factory=tuple)
35
+ loras: tuple[LoraPathStrengthAndSDOps, ...] = field(default_factory=tuple)
36
+ model_loader: StateDictLoader = field(default_factory=SafetensorsModelStateDictLoader)
37
+ registry: Registry = field(default_factory=DummyRegistry)
38
+
39
+ def lora(self, lora_path: str, strength: float = 1.0, sd_ops: SDOps | None = None) -> "SingleGPUModelBuilder":
40
+ return replace(self, loras=(*self.loras, LoraPathStrengthAndSDOps(lora_path, strength, sd_ops)))
41
+
42
+ def model_config(self) -> dict:
43
+ first_shard_path = self.model_path[0] if isinstance(self.model_path, tuple) else self.model_path
44
+ return self.model_loader.metadata(first_shard_path)
45
+
46
+ def meta_model(self, config: dict, module_ops: tuple[ModuleOps, ...]) -> ModelType:
47
+ with torch.device("meta"):
48
+ model = self.model_class_configurator.from_config(config)
49
+ for module_op in module_ops:
50
+ if module_op.matcher(model):
51
+ model = module_op.mutator(model)
52
+ return model
53
+
54
+ def load_sd(
55
+ self, paths: list[str], registry: Registry, device: torch.device | None, sd_ops: SDOps | None = None
56
+ ) -> StateDict:
57
+ state_dict = registry.get(paths, sd_ops)
58
+ if state_dict is None:
59
+ state_dict = self.model_loader.load(paths, sd_ops=sd_ops, device=device)
60
+ registry.add(paths, sd_ops=sd_ops, state_dict=state_dict)
61
+ return state_dict
62
+
63
+ def _return_model(self, meta_model: ModelType, device: torch.device) -> ModelType:
64
+ uninitialized_params = [name for name, param in meta_model.named_parameters() if str(param.device) == "meta"]
65
+ uninitialized_buffers = [name for name, buffer in meta_model.named_buffers() if str(buffer.device) == "meta"]
66
+ if uninitialized_params or uninitialized_buffers:
67
+ logger.warning(f"Uninitialized parameters or buffers: {uninitialized_params + uninitialized_buffers}")
68
+ return meta_model
69
+ retval = meta_model.to(device)
70
+ return retval
71
+
72
+ def build(self, device: torch.device | None = None, dtype: torch.dtype | None = None) -> ModelType:
73
+ device = torch.device("cuda") if device is None else device
74
+ config = self.model_config()
75
+ meta_model = self.meta_model(config, self.module_ops)
76
+ model_paths = self.model_path if isinstance(self.model_path, tuple) else [self.model_path]
77
+ model_state_dict = self.load_sd(model_paths, sd_ops=self.model_sd_ops, registry=self.registry, device=device)
78
+
79
+ lora_strengths = [lora.strength for lora in self.loras]
80
+ if not lora_strengths or (min(lora_strengths) == 0 and max(lora_strengths) == 0):
81
+ sd = model_state_dict.sd
82
+ if dtype is not None:
83
+ sd = {key: value.to(dtype=dtype) for key, value in model_state_dict.sd.items()}
84
+ meta_model.load_state_dict(sd, strict=False, assign=True)
85
+ return self._return_model(meta_model, device)
86
+
87
+ lora_state_dicts = [
88
+ self.load_sd([lora.path], sd_ops=lora.sd_ops, registry=self.registry, device=device) for lora in self.loras
89
+ ]
90
+ lora_sd_and_strengths = [
91
+ LoraStateDictWithStrength(sd, strength)
92
+ for sd, strength in zip(lora_state_dicts, lora_strengths, strict=True)
93
+ ]
94
+ final_sd = apply_loras(
95
+ model_sd=model_state_dict,
96
+ lora_sd_and_strengths=lora_sd_and_strengths,
97
+ dtype=dtype,
98
+ destination_sd=model_state_dict if isinstance(self.registry, DummyRegistry) else None,
99
+ )
100
+ meta_model.load_state_dict(final_sd.sd, strict=False, assign=True)
101
+ return self._return_model(meta_model, device)
packages/ltx-core/src/ltx_core/model/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """Model definitions for LTX-2."""
2
+
3
+ from ltx_core.model.model_protocol import ModelConfigurator, ModelType
4
+
5
+ __all__ = [
6
+ "ModelConfigurator",
7
+ "ModelType",
8
+ ]
packages/ltx-core/src/ltx_core/model/audio_vae/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio VAE model components."""
2
+
3
+ from ltx_core.model.audio_vae.audio_vae import AudioDecoder, AudioEncoder, decode_audio
4
+ from ltx_core.model.audio_vae.model_configurator import (
5
+ AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
6
+ AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER,
7
+ VOCODER_COMFY_KEYS_FILTER,
8
+ AudioDecoderConfigurator,
9
+ AudioEncoderConfigurator,
10
+ VocoderConfigurator,
11
+ )
12
+ from ltx_core.model.audio_vae.ops import AudioProcessor
13
+ from ltx_core.model.audio_vae.vocoder import Vocoder
14
+
15
+ __all__ = [
16
+ "AUDIO_VAE_DECODER_COMFY_KEYS_FILTER",
17
+ "AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER",
18
+ "VOCODER_COMFY_KEYS_FILTER",
19
+ "AudioDecoder",
20
+ "AudioDecoderConfigurator",
21
+ "AudioEncoder",
22
+ "AudioEncoderConfigurator",
23
+ "AudioProcessor",
24
+ "Vocoder",
25
+ "VocoderConfigurator",
26
+ "decode_audio",
27
+ ]
packages/ltx-core/src/ltx_core/model/audio_vae/attention.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ import torch
4
+
5
+ from ltx_core.model.common.normalization import NormType, build_normalization_layer
6
+
7
+
8
+ class AttentionType(Enum):
9
+ """Enum for specifying the attention mechanism type."""
10
+
11
+ VANILLA = "vanilla"
12
+ LINEAR = "linear"
13
+ NONE = "none"
14
+
15
+
16
+ class AttnBlock(torch.nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_channels: int,
20
+ norm_type: NormType = NormType.GROUP,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.in_channels = in_channels
24
+
25
+ self.norm = build_normalization_layer(in_channels, normtype=norm_type)
26
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
27
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
28
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
29
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
30
+
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
32
+ h_ = x
33
+ h_ = self.norm(h_)
34
+ q = self.q(h_)
35
+ k = self.k(h_)
36
+ v = self.v(h_)
37
+
38
+ # compute attention
39
+ b, c, h, w = q.shape
40
+ q = q.reshape(b, c, h * w).contiguous()
41
+ q = q.permute(0, 2, 1).contiguous() # b,hw,c
42
+ k = k.reshape(b, c, h * w).contiguous() # b,c,hw
43
+ w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
44
+ w_ = w_ * (int(c) ** (-0.5))
45
+ w_ = torch.nn.functional.softmax(w_, dim=2)
46
+
47
+ # attend to values
48
+ v = v.reshape(b, c, h * w).contiguous()
49
+ w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
50
+ h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
51
+ h_ = h_.reshape(b, c, h, w).contiguous()
52
+
53
+ h_ = self.proj_out(h_)
54
+
55
+ return x + h_
56
+
57
+
58
+ def make_attn(
59
+ in_channels: int,
60
+ attn_type: AttentionType = AttentionType.VANILLA,
61
+ norm_type: NormType = NormType.GROUP,
62
+ ) -> torch.nn.Module:
63
+ match attn_type:
64
+ case AttentionType.VANILLA:
65
+ return AttnBlock(in_channels, norm_type=norm_type)
66
+ case AttentionType.NONE:
67
+ return torch.nn.Identity()
68
+ case AttentionType.LINEAR:
69
+ raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
70
+ case _:
71
+ raise ValueError(f"Unknown attention type: {attn_type}")
packages/ltx-core/src/ltx_core/model/audio_vae/audio_vae.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Set, Tuple
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from ltx_core.components.patchifiers import AudioPatchifier
7
+ from ltx_core.model.audio_vae.attention import AttentionType, make_attn
8
+ from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
9
+ from ltx_core.model.audio_vae.causality_axis import CausalityAxis
10
+ from ltx_core.model.audio_vae.downsample import build_downsampling_path
11
+ from ltx_core.model.audio_vae.ops import PerChannelStatistics
12
+ from ltx_core.model.audio_vae.resnet import ResnetBlock
13
+ from ltx_core.model.audio_vae.upsample import build_upsampling_path
14
+ from ltx_core.model.audio_vae.vocoder import Vocoder
15
+ from ltx_core.model.common.normalization import NormType, build_normalization_layer
16
+ from ltx_core.types import AudioLatentShape
17
+
18
+ LATENT_DOWNSAMPLE_FACTOR = 4
19
+
20
+
21
+ def build_mid_block(
22
+ channels: int,
23
+ temb_channels: int,
24
+ dropout: float,
25
+ norm_type: NormType,
26
+ causality_axis: CausalityAxis,
27
+ attn_type: AttentionType,
28
+ add_attention: bool,
29
+ ) -> torch.nn.Module:
30
+ """Build the middle block with two ResNet blocks and optional attention."""
31
+ mid = torch.nn.Module()
32
+ mid.block_1 = ResnetBlock(
33
+ in_channels=channels,
34
+ out_channels=channels,
35
+ temb_channels=temb_channels,
36
+ dropout=dropout,
37
+ norm_type=norm_type,
38
+ causality_axis=causality_axis,
39
+ )
40
+ mid.attn_1 = make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else torch.nn.Identity()
41
+ mid.block_2 = ResnetBlock(
42
+ in_channels=channels,
43
+ out_channels=channels,
44
+ temb_channels=temb_channels,
45
+ dropout=dropout,
46
+ norm_type=norm_type,
47
+ causality_axis=causality_axis,
48
+ )
49
+ return mid
50
+
51
+
52
+ def run_mid_block(mid: torch.nn.Module, features: torch.Tensor) -> torch.Tensor:
53
+ """Run features through the middle block."""
54
+ features = mid.block_1(features, temb=None)
55
+ features = mid.attn_1(features)
56
+ return mid.block_2(features, temb=None)
57
+
58
+
59
+ class AudioEncoder(torch.nn.Module):
60
+ """
61
+ Encoder that compresses audio spectrograms into latent representations.
62
+ The encoder uses a series of downsampling blocks with residual connections,
63
+ attention mechanisms, and configurable causal convolutions.
64
+ """
65
+
66
+ def __init__( # noqa: PLR0913
67
+ self,
68
+ *,
69
+ ch: int,
70
+ ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
71
+ num_res_blocks: int,
72
+ attn_resolutions: Set[int],
73
+ dropout: float = 0.0,
74
+ resamp_with_conv: bool = True,
75
+ in_channels: int,
76
+ resolution: int,
77
+ z_channels: int,
78
+ double_z: bool = True,
79
+ attn_type: AttentionType = AttentionType.VANILLA,
80
+ mid_block_add_attention: bool = True,
81
+ norm_type: NormType = NormType.GROUP,
82
+ causality_axis: CausalityAxis = CausalityAxis.WIDTH,
83
+ sample_rate: int = 16000,
84
+ mel_hop_length: int = 160,
85
+ n_fft: int = 1024,
86
+ is_causal: bool = True,
87
+ mel_bins: int = 64,
88
+ **_ignore_kwargs,
89
+ ) -> None:
90
+ """
91
+ Initialize the Encoder.
92
+ Args:
93
+ Arguments are configuration parameters, loaded from the audio VAE checkpoint config
94
+ (audio_vae.model.params.ddconfig):
95
+ ch: Base number of feature channels used in the first convolution layer.
96
+ ch_mult: Multiplicative factors for the number of channels at each resolution level.
97
+ num_res_blocks: Number of residual blocks to use at each resolution level.
98
+ attn_resolutions: Spatial resolutions (e.g., in time/frequency) at which to apply attention.
99
+ resolution: Input spatial resolution of the spectrogram (height, width).
100
+ z_channels: Number of channels in the latent representation.
101
+ norm_type: Normalization layer type to use within the network (e.g., group, batch).
102
+ causality_axis: Axis along which convolutions should be causal (e.g., time axis).
103
+ sample_rate: Audio sample rate in Hz for the input signals.
104
+ mel_hop_length: Hop length used when computing the mel spectrogram.
105
+ n_fft: FFT size used to compute the spectrogram.
106
+ mel_bins: Number of mel-frequency bins in the input spectrogram.
107
+ in_channels: Number of channels in the input spectrogram tensor.
108
+ double_z: If True, predict both mean and log-variance (doubling latent channels).
109
+ is_causal: If True, use causal convolutions suitable for streaming setups.
110
+ dropout: Dropout probability used in residual and mid blocks.
111
+ attn_type: Type of attention mechanism to use in attention blocks.
112
+ resamp_with_conv: If True, perform resolution changes using strided convolutions.
113
+ mid_block_add_attention: If True, add an attention block in the mid-level of the encoder.
114
+ """
115
+ super().__init__()
116
+
117
+ self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
118
+ self.sample_rate = sample_rate
119
+ self.mel_hop_length = mel_hop_length
120
+ self.n_fft = n_fft
121
+ self.is_causal = is_causal
122
+ self.mel_bins = mel_bins
123
+
124
+ self.patchifier = AudioPatchifier(
125
+ patch_size=1,
126
+ audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
127
+ sample_rate=sample_rate,
128
+ hop_length=mel_hop_length,
129
+ is_causal=is_causal,
130
+ )
131
+
132
+ self.ch = ch
133
+ self.temb_ch = 0
134
+ self.num_resolutions = len(ch_mult)
135
+ self.num_res_blocks = num_res_blocks
136
+ self.resolution = resolution
137
+ self.in_channels = in_channels
138
+ self.z_channels = z_channels
139
+ self.double_z = double_z
140
+ self.norm_type = norm_type
141
+ self.causality_axis = causality_axis
142
+ self.attn_type = attn_type
143
+
144
+ # downsampling
145
+ self.conv_in = make_conv2d(
146
+ in_channels,
147
+ self.ch,
148
+ kernel_size=3,
149
+ stride=1,
150
+ causality_axis=self.causality_axis,
151
+ )
152
+
153
+ self.non_linearity = torch.nn.SiLU()
154
+
155
+ self.down, block_in = build_downsampling_path(
156
+ ch=ch,
157
+ ch_mult=ch_mult,
158
+ num_resolutions=self.num_resolutions,
159
+ num_res_blocks=num_res_blocks,
160
+ resolution=resolution,
161
+ temb_channels=self.temb_ch,
162
+ dropout=dropout,
163
+ norm_type=self.norm_type,
164
+ causality_axis=self.causality_axis,
165
+ attn_type=self.attn_type,
166
+ attn_resolutions=attn_resolutions,
167
+ resamp_with_conv=resamp_with_conv,
168
+ )
169
+
170
+ self.mid = build_mid_block(
171
+ channels=block_in,
172
+ temb_channels=self.temb_ch,
173
+ dropout=dropout,
174
+ norm_type=self.norm_type,
175
+ causality_axis=self.causality_axis,
176
+ attn_type=self.attn_type,
177
+ add_attention=mid_block_add_attention,
178
+ )
179
+
180
+ self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
181
+ self.conv_out = make_conv2d(
182
+ block_in,
183
+ 2 * z_channels if double_z else z_channels,
184
+ kernel_size=3,
185
+ stride=1,
186
+ causality_axis=self.causality_axis,
187
+ )
188
+
189
+ def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:
190
+ """
191
+ Encode audio spectrogram into latent representations.
192
+ Args:
193
+ spectrogram: Input spectrogram of shape (batch, channels, time, frequency)
194
+ Returns:
195
+ Encoded latent representation of shape (batch, channels, frames, mel_bins)
196
+ """
197
+ h = self.conv_in(spectrogram)
198
+ h = self._run_downsampling_path(h)
199
+ h = run_mid_block(self.mid, h)
200
+ h = self._finalize_output(h)
201
+
202
+ return self._normalize_latents(h)
203
+
204
+ def _run_downsampling_path(self, h: torch.Tensor) -> torch.Tensor:
205
+ for level in range(self.num_resolutions):
206
+ stage = self.down[level]
207
+ for block_idx in range(self.num_res_blocks):
208
+ h = stage.block[block_idx](h, temb=None)
209
+ if stage.attn:
210
+ h = stage.attn[block_idx](h)
211
+
212
+ if level != self.num_resolutions - 1:
213
+ h = stage.downsample(h)
214
+
215
+ return h
216
+
217
+ def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:
218
+ h = self.norm_out(h)
219
+ h = self.non_linearity(h)
220
+ return self.conv_out(h)
221
+
222
+ def _normalize_latents(self, latent_output: torch.Tensor) -> torch.Tensor:
223
+ """
224
+ Normalize encoder latents using per-channel statistics.
225
+ When the encoder is configured with ``double_z=True``, the final
226
+ convolution produces twice the number of latent channels, typically
227
+ interpreted as two concatenated tensors along the channel dimension
228
+ (e.g., mean and variance or other auxiliary parameters).
229
+ This method intentionally uses only the first half of the channels
230
+ (the "mean" component) as input to the patchifier and normalization
231
+ logic. The remaining channels are left unchanged by this method and
232
+ are expected to be consumed elsewhere in the VAE pipeline.
233
+ If ``double_z=False``, the encoder output already contains only the
234
+ mean latents and the chunking operation simply returns that tensor.
235
+ """
236
+ means = torch.chunk(latent_output, 2, dim=1)[0]
237
+ latent_shape = AudioLatentShape(
238
+ batch=means.shape[0],
239
+ channels=means.shape[1],
240
+ frames=means.shape[2],
241
+ mel_bins=means.shape[3],
242
+ )
243
+ latent_patched = self.patchifier.patchify(means)
244
+ latent_normalized = self.per_channel_statistics.normalize(latent_patched)
245
+ return self.patchifier.unpatchify(latent_normalized, latent_shape)
246
+
247
+
248
+ class AudioDecoder(torch.nn.Module):
249
+ """
250
+ Symmetric decoder that reconstructs audio spectrograms from latent features.
251
+ The decoder mirrors the encoder structure with configurable channel multipliers,
252
+ attention resolutions, and causal convolutions.
253
+ """
254
+
255
+ def __init__( # noqa: PLR0913
256
+ self,
257
+ *,
258
+ ch: int,
259
+ out_ch: int,
260
+ ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
261
+ num_res_blocks: int,
262
+ attn_resolutions: Set[int],
263
+ resolution: int,
264
+ z_channels: int,
265
+ norm_type: NormType = NormType.GROUP,
266
+ causality_axis: CausalityAxis = CausalityAxis.WIDTH,
267
+ dropout: float = 0.0,
268
+ mid_block_add_attention: bool = True,
269
+ sample_rate: int = 16000,
270
+ mel_hop_length: int = 160,
271
+ is_causal: bool = True,
272
+ mel_bins: int | None = None,
273
+ ) -> None:
274
+ """
275
+ Initialize the Decoder.
276
+ Args:
277
+ Arguments are configuration parameters, loaded from the audio VAE checkpoint config
278
+ (audio_vae.model.params.ddconfig):
279
+ - ch, out_ch, ch_mult, num_res_blocks, attn_resolutions
280
+ - resolution, z_channels
281
+ - norm_type, causality_axis
282
+ """
283
+ super().__init__()
284
+
285
+ # Internal behavioural defaults that are not driven by the checkpoint.
286
+ resamp_with_conv = True
287
+ attn_type = AttentionType.VANILLA
288
+
289
+ # Per-channel statistics for denormalizing latents
290
+ self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
291
+ self.sample_rate = sample_rate
292
+ self.mel_hop_length = mel_hop_length
293
+ self.is_causal = is_causal
294
+ self.mel_bins = mel_bins
295
+ self.patchifier = AudioPatchifier(
296
+ patch_size=1,
297
+ audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
298
+ sample_rate=sample_rate,
299
+ hop_length=mel_hop_length,
300
+ is_causal=is_causal,
301
+ )
302
+
303
+ self.ch = ch
304
+ self.temb_ch = 0
305
+ self.num_resolutions = len(ch_mult)
306
+ self.num_res_blocks = num_res_blocks
307
+ self.resolution = resolution
308
+ self.out_ch = out_ch
309
+ self.give_pre_end = False
310
+ self.tanh_out = False
311
+ self.norm_type = norm_type
312
+ self.z_channels = z_channels
313
+ self.channel_multipliers = ch_mult
314
+ self.attn_resolutions = attn_resolutions
315
+ self.causality_axis = causality_axis
316
+ self.attn_type = attn_type
317
+
318
+ base_block_channels = ch * self.channel_multipliers[-1]
319
+ base_resolution = resolution // (2 ** (self.num_resolutions - 1))
320
+ self.z_shape = (1, z_channels, base_resolution, base_resolution)
321
+
322
+ self.conv_in = make_conv2d(
323
+ z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
324
+ )
325
+ self.non_linearity = torch.nn.SiLU()
326
+ self.mid = build_mid_block(
327
+ channels=base_block_channels,
328
+ temb_channels=self.temb_ch,
329
+ dropout=dropout,
330
+ norm_type=self.norm_type,
331
+ causality_axis=self.causality_axis,
332
+ attn_type=self.attn_type,
333
+ add_attention=mid_block_add_attention,
334
+ )
335
+ self.up, final_block_channels = build_upsampling_path(
336
+ ch=ch,
337
+ ch_mult=ch_mult,
338
+ num_resolutions=self.num_resolutions,
339
+ num_res_blocks=num_res_blocks,
340
+ resolution=resolution,
341
+ temb_channels=self.temb_ch,
342
+ dropout=dropout,
343
+ norm_type=self.norm_type,
344
+ causality_axis=self.causality_axis,
345
+ attn_type=self.attn_type,
346
+ attn_resolutions=attn_resolutions,
347
+ resamp_with_conv=resamp_with_conv,
348
+ initial_block_channels=base_block_channels,
349
+ )
350
+
351
+ self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
352
+ self.conv_out = make_conv2d(
353
+ final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
354
+ )
355
+
356
+ def forward(self, sample: torch.Tensor) -> torch.Tensor:
357
+ """
358
+ Decode latent features back to audio spectrograms.
359
+ Args:
360
+ sample: Encoded latent representation of shape (batch, channels, frames, mel_bins)
361
+ Returns:
362
+ Reconstructed audio spectrogram of shape (batch, channels, time, frequency)
363
+ """
364
+ sample, target_shape = self._denormalize_latents(sample)
365
+
366
+ h = self.conv_in(sample)
367
+ h = run_mid_block(self.mid, h)
368
+ h = self._run_upsampling_path(h)
369
+ h = self._finalize_output(h)
370
+
371
+ return self._adjust_output_shape(h, target_shape)
372
+
373
+ def _denormalize_latents(self, sample: torch.Tensor) -> tuple[torch.Tensor, AudioLatentShape]:
374
+ latent_shape = AudioLatentShape(
375
+ batch=sample.shape[0],
376
+ channels=sample.shape[1],
377
+ frames=sample.shape[2],
378
+ mel_bins=sample.shape[3],
379
+ )
380
+
381
+ sample_patched = self.patchifier.patchify(sample)
382
+ sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched)
383
+ sample = self.patchifier.unpatchify(sample_denormalized, latent_shape)
384
+
385
+ target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR
386
+ if self.causality_axis != CausalityAxis.NONE:
387
+ target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
388
+
389
+ target_shape = AudioLatentShape(
390
+ batch=latent_shape.batch,
391
+ channels=self.out_ch,
392
+ frames=target_frames,
393
+ mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
394
+ )
395
+
396
+ return sample, target_shape
397
+
398
+ def _adjust_output_shape(
399
+ self,
400
+ decoded_output: torch.Tensor,
401
+ target_shape: AudioLatentShape,
402
+ ) -> torch.Tensor:
403
+ """
404
+ Adjust output shape to match target dimensions for variable-length audio.
405
+ This function handles the common case where decoded audio spectrograms need to be
406
+ resized to match a specific target shape.
407
+ Args:
408
+ decoded_output: Tensor of shape (batch, channels, time, frequency)
409
+ target_shape: AudioLatentShape describing (batch, channels, time, mel bins)
410
+ Returns:
411
+ Tensor adjusted to match target_shape exactly
412
+ """
413
+ # Current output shape: (batch, channels, time, frequency)
414
+ _, _, current_time, current_freq = decoded_output.shape
415
+ target_channels = target_shape.channels
416
+ target_time = target_shape.frames
417
+ target_freq = target_shape.mel_bins
418
+
419
+ # Step 1: Crop first to avoid exceeding target dimensions
420
+ decoded_output = decoded_output[
421
+ :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
422
+ ]
423
+
424
+ # Step 2: Calculate padding needed for time and frequency dimensions
425
+ time_padding_needed = target_time - decoded_output.shape[2]
426
+ freq_padding_needed = target_freq - decoded_output.shape[3]
427
+
428
+ # Step 3: Apply padding if needed
429
+ if time_padding_needed > 0 or freq_padding_needed > 0:
430
+ # PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)
431
+ # For audio: pad_left/right = frequency, pad_top/bottom = time
432
+ padding = (
433
+ 0,
434
+ max(freq_padding_needed, 0), # frequency padding (left, right)
435
+ 0,
436
+ max(time_padding_needed, 0), # time padding (top, bottom)
437
+ )
438
+ decoded_output = F.pad(decoded_output, padding)
439
+
440
+ # Step 4: Final safety crop to ensure exact target shape
441
+ decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
442
+
443
+ return decoded_output
444
+
445
+ def _run_upsampling_path(self, h: torch.Tensor) -> torch.Tensor:
446
+ for level in reversed(range(self.num_resolutions)):
447
+ stage = self.up[level]
448
+ for block_idx, block in enumerate(stage.block):
449
+ h = block(h, temb=None)
450
+ if stage.attn:
451
+ h = stage.attn[block_idx](h)
452
+
453
+ if level != 0 and hasattr(stage, "upsample"):
454
+ h = stage.upsample(h)
455
+
456
+ return h
457
+
458
+ def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:
459
+ if self.give_pre_end:
460
+ return h
461
+
462
+ h = self.norm_out(h)
463
+ h = self.non_linearity(h)
464
+ h = self.conv_out(h)
465
+ return torch.tanh(h) if self.tanh_out else h
466
+
467
+
468
+ def decode_audio(latent: torch.Tensor, audio_decoder: "AudioDecoder", vocoder: "Vocoder") -> torch.Tensor:
469
+ """
470
+ Decode an audio latent representation using the provided audio decoder and vocoder.
471
+ Args:
472
+ latent: Input audio latent tensor.
473
+ audio_decoder: Model to decode the latent to waveform features.
474
+ vocoder: Model to convert decoded features to audio waveform.
475
+ Returns:
476
+ Decoded audio as a float tensor.
477
+ """
478
+ decoded_audio = audio_decoder(latent)
479
+ decoded_audio = vocoder(decoded_audio).squeeze(0).float()
480
+ return decoded_audio
packages/ltx-core/src/ltx_core/model/audio_vae/causal_conv_2d.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from ltx_core.model.audio_vae.causality_axis import CausalityAxis
5
+
6
+
7
+ class CausalConv2d(torch.nn.Module):
8
+ """
9
+ A causal 2D convolution.
10
+ This layer ensures that the output at time `t` only depends on inputs
11
+ at time `t` and earlier. It achieves this by applying asymmetric padding
12
+ to the time dimension (width) before the convolution.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ in_channels: int,
18
+ out_channels: int,
19
+ kernel_size: int | tuple[int, int],
20
+ stride: int = 1,
21
+ dilation: int | tuple[int, int] = 1,
22
+ groups: int = 1,
23
+ bias: bool = True,
24
+ causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
25
+ ) -> None:
26
+ super().__init__()
27
+
28
+ self.causality_axis = causality_axis
29
+
30
+ # Ensure kernel_size and dilation are tuples
31
+ kernel_size = torch.nn.modules.utils._pair(kernel_size)
32
+ dilation = torch.nn.modules.utils._pair(dilation)
33
+
34
+ # Calculate padding dimensions
35
+ pad_h = (kernel_size[0] - 1) * dilation[0]
36
+ pad_w = (kernel_size[1] - 1) * dilation[1]
37
+
38
+ # The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom)
39
+ match self.causality_axis:
40
+ case CausalityAxis.NONE:
41
+ self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
42
+ case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY:
43
+ self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
44
+ case CausalityAxis.HEIGHT:
45
+ self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
46
+ case _:
47
+ raise ValueError(f"Invalid causality_axis: {causality_axis}")
48
+
49
+ # The internal convolution layer uses no padding, as we handle it manually
50
+ self.conv = torch.nn.Conv2d(
51
+ in_channels,
52
+ out_channels,
53
+ kernel_size,
54
+ stride=stride,
55
+ padding=0,
56
+ dilation=dilation,
57
+ groups=groups,
58
+ bias=bias,
59
+ )
60
+
61
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
62
+ # Apply causal padding before convolution
63
+ x = F.pad(x, self.padding)
64
+ return self.conv(x)
65
+
66
+
67
+ def make_conv2d(
68
+ in_channels: int,
69
+ out_channels: int,
70
+ kernel_size: int | tuple[int, int],
71
+ stride: int = 1,
72
+ padding: tuple[int, int, int, int] | None = None,
73
+ dilation: int = 1,
74
+ groups: int = 1,
75
+ bias: bool = True,
76
+ causality_axis: CausalityAxis | None = None,
77
+ ) -> torch.nn.Module:
78
+ """
79
+ Create a 2D convolution layer that can be either causal or non-causal.
80
+ Args:
81
+ in_channels: Number of input channels
82
+ out_channels: Number of output channels
83
+ kernel_size: Size of the convolution kernel
84
+ stride: Convolution stride
85
+ padding: Padding (if None, will be calculated based on causal flag)
86
+ dilation: Dilation rate
87
+ groups: Number of groups for grouped convolution
88
+ bias: Whether to use bias
89
+ causality_axis: Dimension along which to apply causality.
90
+ Returns:
91
+ Either a regular Conv2d or CausalConv2d layer
92
+ """
93
+ if causality_axis is not None:
94
+ # For causal convolution, padding is handled internally by CausalConv2d
95
+ return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis)
96
+ else:
97
+ # For non-causal convolution, use symmetric padding if not specified
98
+ if padding is None:
99
+ padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size)
100
+
101
+ return torch.nn.Conv2d(
102
+ in_channels,
103
+ out_channels,
104
+ kernel_size,
105
+ stride,
106
+ padding,
107
+ dilation,
108
+ groups,
109
+ bias,
110
+ )
packages/ltx-core/src/ltx_core/model/audio_vae/causality_axis.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+
4
+ class CausalityAxis(Enum):
5
+ """Enum for specifying the causality axis in causal convolutions."""
6
+
7
+ NONE = None
8
+ WIDTH = "width"
9
+ HEIGHT = "height"
10
+ WIDTH_COMPATIBILITY = "width-compatibility"
packages/ltx-core/src/ltx_core/model/audio_vae/downsample.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Set, Tuple
2
+
3
+ import torch
4
+
5
+ from ltx_core.model.audio_vae.attention import AttentionType, make_attn
6
+ from ltx_core.model.audio_vae.causality_axis import CausalityAxis
7
+ from ltx_core.model.audio_vae.resnet import ResnetBlock
8
+ from ltx_core.model.common.normalization import NormType
9
+
10
+
11
+ class Downsample(torch.nn.Module):
12
+ """
13
+ A downsampling layer that can use either a strided convolution
14
+ or average pooling. Supports standard and causal padding for the
15
+ convolutional mode.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ in_channels: int,
21
+ with_conv: bool,
22
+ causality_axis: CausalityAxis = CausalityAxis.WIDTH,
23
+ ) -> None:
24
+ super().__init__()
25
+ self.with_conv = with_conv
26
+ self.causality_axis = causality_axis
27
+
28
+ if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
29
+ raise ValueError("causality is only supported when `with_conv=True`.")
30
+
31
+ if self.with_conv:
32
+ # Do time downsampling here
33
+ # no asymmetric padding in torch conv, must do it ourselves
34
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ if self.with_conv:
38
+ # Padding tuple is in the order: (left, right, top, bottom).
39
+ match self.causality_axis:
40
+ case CausalityAxis.NONE:
41
+ pad = (0, 1, 0, 1)
42
+ case CausalityAxis.WIDTH:
43
+ pad = (2, 0, 0, 1)
44
+ case CausalityAxis.HEIGHT:
45
+ pad = (0, 1, 2, 0)
46
+ case CausalityAxis.WIDTH_COMPATIBILITY:
47
+ pad = (1, 0, 0, 1)
48
+ case _:
49
+ raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
50
+
51
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
52
+ x = self.conv(x)
53
+ else:
54
+ # This branch is only taken if with_conv=False, which implies causality_axis is NONE.
55
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
56
+
57
+ return x
58
+
59
+
60
+ def build_downsampling_path( # noqa: PLR0913
61
+ *,
62
+ ch: int,
63
+ ch_mult: Tuple[int, ...],
64
+ num_resolutions: int,
65
+ num_res_blocks: int,
66
+ resolution: int,
67
+ temb_channels: int,
68
+ dropout: float,
69
+ norm_type: NormType,
70
+ causality_axis: CausalityAxis,
71
+ attn_type: AttentionType,
72
+ attn_resolutions: Set[int],
73
+ resamp_with_conv: bool,
74
+ ) -> tuple[torch.nn.ModuleList, int]:
75
+ """Build the downsampling path with residual blocks, attention, and downsampling layers."""
76
+ down_modules = torch.nn.ModuleList()
77
+ curr_res = resolution
78
+ in_ch_mult = (1, *tuple(ch_mult))
79
+ block_in = ch
80
+
81
+ for i_level in range(num_resolutions):
82
+ block = torch.nn.ModuleList()
83
+ attn = torch.nn.ModuleList()
84
+ block_in = ch * in_ch_mult[i_level]
85
+ block_out = ch * ch_mult[i_level]
86
+
87
+ for _ in range(num_res_blocks):
88
+ block.append(
89
+ ResnetBlock(
90
+ in_channels=block_in,
91
+ out_channels=block_out,
92
+ temb_channels=temb_channels,
93
+ dropout=dropout,
94
+ norm_type=norm_type,
95
+ causality_axis=causality_axis,
96
+ )
97
+ )
98
+ block_in = block_out
99
+ if curr_res in attn_resolutions:
100
+ attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type))
101
+
102
+ down = torch.nn.Module()
103
+ down.block = block
104
+ down.attn = attn
105
+ if i_level != num_resolutions - 1:
106
+ down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
107
+ curr_res = curr_res // 2
108
+ down_modules.append(down)
109
+
110
+ return down_modules, block_in
packages/ltx-core/src/ltx_core/model/audio_vae/model_configurator.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ltx_core.loader.sd_ops import SDOps
2
+ from ltx_core.model.audio_vae.attention import AttentionType
3
+ from ltx_core.model.audio_vae.audio_vae import AudioDecoder, AudioEncoder
4
+ from ltx_core.model.audio_vae.causality_axis import CausalityAxis
5
+ from ltx_core.model.audio_vae.vocoder import Vocoder
6
+ from ltx_core.model.common.normalization import NormType
7
+ from ltx_core.model.model_protocol import ModelConfigurator
8
+
9
+
10
+ class VocoderConfigurator(ModelConfigurator[Vocoder]):
11
+ @classmethod
12
+ def from_config(cls: type[Vocoder], config: dict) -> Vocoder:
13
+ config = config.get("vocoder", {})
14
+ return Vocoder(
15
+ resblock_kernel_sizes=config.get("resblock_kernel_sizes", [3, 7, 11]),
16
+ upsample_rates=config.get("upsample_rates", [6, 5, 2, 2, 2]),
17
+ upsample_kernel_sizes=config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4]),
18
+ resblock_dilation_sizes=config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]]),
19
+ upsample_initial_channel=config.get("upsample_initial_channel", 1024),
20
+ stereo=config.get("stereo", True),
21
+ resblock=config.get("resblock", "1"),
22
+ output_sample_rate=config.get("output_sample_rate", 24000),
23
+ )
24
+
25
+
26
+ VOCODER_COMFY_KEYS_FILTER = (
27
+ SDOps("VOCODER_COMFY_KEYS_FILTER").with_matching(prefix="vocoder.").with_replacement("vocoder.", "")
28
+ )
29
+
30
+
31
+ class AudioDecoderConfigurator(ModelConfigurator[AudioDecoder]):
32
+ @classmethod
33
+ def from_config(cls: type[AudioDecoder], config: dict) -> AudioDecoder:
34
+ audio_vae_cfg = config.get("audio_vae", {})
35
+ model_cfg = audio_vae_cfg.get("model", {})
36
+ model_params = model_cfg.get("params", {})
37
+ ddconfig = model_params.get("ddconfig", {})
38
+ preprocessing_cfg = audio_vae_cfg.get("preprocessing", {})
39
+ stft_cfg = preprocessing_cfg.get("stft", {})
40
+ mel_cfg = preprocessing_cfg.get("mel", {})
41
+ variables_cfg = audio_vae_cfg.get("variables", {})
42
+
43
+ sample_rate = model_params.get("sampling_rate", 16000)
44
+ mel_hop_length = stft_cfg.get("hop_length", 160)
45
+ is_causal = stft_cfg.get("causal", True)
46
+ mel_bins = ddconfig.get("mel_bins") or mel_cfg.get("n_mel_channels") or variables_cfg.get("mel_bins")
47
+
48
+ return AudioDecoder(
49
+ ch=ddconfig.get("ch", 128),
50
+ out_ch=ddconfig.get("out_ch", 2),
51
+ ch_mult=tuple(ddconfig.get("ch_mult", (1, 2, 4))),
52
+ num_res_blocks=ddconfig.get("num_res_blocks", 2),
53
+ attn_resolutions=ddconfig.get("attn_resolutions", {8, 16, 32}),
54
+ resolution=ddconfig.get("resolution", 256),
55
+ z_channels=ddconfig.get("z_channels", 8),
56
+ norm_type=NormType(ddconfig.get("norm_type", "pixel")),
57
+ causality_axis=CausalityAxis(ddconfig.get("causality_axis", "height")),
58
+ dropout=ddconfig.get("dropout", 0.0),
59
+ mid_block_add_attention=ddconfig.get("mid_block_add_attention", True),
60
+ sample_rate=sample_rate,
61
+ mel_hop_length=mel_hop_length,
62
+ is_causal=is_causal,
63
+ mel_bins=mel_bins,
64
+ )
65
+
66
+
67
+ class AudioEncoderConfigurator(ModelConfigurator[AudioEncoder]):
68
+ @classmethod
69
+ def from_config(cls: type[AudioEncoder], config: dict) -> AudioEncoder:
70
+ audio_vae_cfg = config.get("audio_vae", {})
71
+ model_cfg = audio_vae_cfg.get("model", {})
72
+ model_params = model_cfg.get("params", {})
73
+ ddconfig = model_params.get("ddconfig", {})
74
+ preprocessing_cfg = audio_vae_cfg.get("preprocessing", {})
75
+ stft_cfg = preprocessing_cfg.get("stft", {})
76
+ mel_cfg = preprocessing_cfg.get("mel", {})
77
+ variables_cfg = audio_vae_cfg.get("variables", {})
78
+
79
+ sample_rate = model_params.get("sampling_rate", 16000)
80
+ mel_hop_length = stft_cfg.get("hop_length", 160)
81
+ n_fft = stft_cfg.get("filter_length", 1024)
82
+ is_causal = stft_cfg.get("causal", True)
83
+ mel_bins = ddconfig.get("mel_bins") or mel_cfg.get("n_mel_channels") or variables_cfg.get("mel_bins")
84
+
85
+ return AudioEncoder(
86
+ ch=ddconfig.get("ch", 128),
87
+ ch_mult=tuple(ddconfig.get("ch_mult", (1, 2, 4))),
88
+ num_res_blocks=ddconfig.get("num_res_blocks", 2),
89
+ attn_resolutions=ddconfig.get("attn_resolutions", {8, 16, 32}),
90
+ resolution=ddconfig.get("resolution", 256),
91
+ z_channels=ddconfig.get("z_channels", 8),
92
+ double_z=ddconfig.get("double_z", True),
93
+ dropout=ddconfig.get("dropout", 0.0),
94
+ resamp_with_conv=ddconfig.get("resamp_with_conv", True),
95
+ in_channels=ddconfig.get("in_channels", 2),
96
+ attn_type=AttentionType(ddconfig.get("attn_type", "vanilla")),
97
+ mid_block_add_attention=ddconfig.get("mid_block_add_attention", True),
98
+ norm_type=NormType(ddconfig.get("norm_type", "pixel")),
99
+ causality_axis=CausalityAxis(ddconfig.get("causality_axis", "height")),
100
+ sample_rate=sample_rate,
101
+ mel_hop_length=mel_hop_length,
102
+ n_fft=n_fft,
103
+ is_causal=is_causal,
104
+ mel_bins=mel_bins,
105
+ )
106
+
107
+
108
+ AUDIO_VAE_DECODER_COMFY_KEYS_FILTER = (
109
+ SDOps("AUDIO_VAE_DECODER_COMFY_KEYS_FILTER")
110
+ .with_matching(prefix="audio_vae.decoder.")
111
+ .with_matching(prefix="audio_vae.per_channel_statistics.")
112
+ .with_replacement("audio_vae.decoder.", "")
113
+ .with_replacement("audio_vae.per_channel_statistics.", "per_channel_statistics.")
114
+ )
115
+
116
+
117
+ AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER = (
118
+ SDOps("AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER")
119
+ .with_matching(prefix="audio_vae.encoder.")
120
+ .with_matching(prefix="audio_vae.per_channel_statistics.")
121
+ .with_replacement("audio_vae.encoder.", "")
122
+ .with_replacement("audio_vae.per_channel_statistics.", "per_channel_statistics.")
123
+ )
packages/ltx-core/src/ltx_core/model/audio_vae/ops.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from torch import nn
4
+
5
+
6
+ class AudioProcessor(nn.Module):
7
+ """Converts audio waveforms to log-mel spectrograms with optional resampling."""
8
+
9
+ def __init__(
10
+ self,
11
+ sample_rate: int,
12
+ mel_bins: int,
13
+ mel_hop_length: int,
14
+ n_fft: int,
15
+ ) -> None:
16
+ super().__init__()
17
+ self.sample_rate = sample_rate
18
+ self.mel_transform = torchaudio.transforms.MelSpectrogram(
19
+ sample_rate=sample_rate,
20
+ n_fft=n_fft,
21
+ win_length=n_fft,
22
+ hop_length=mel_hop_length,
23
+ f_min=0.0,
24
+ f_max=sample_rate / 2.0,
25
+ n_mels=mel_bins,
26
+ window_fn=torch.hann_window,
27
+ center=True,
28
+ pad_mode="reflect",
29
+ power=1.0,
30
+ mel_scale="slaney",
31
+ norm="slaney",
32
+ )
33
+
34
+ def resample_waveform(
35
+ self,
36
+ waveform: torch.Tensor,
37
+ source_rate: int,
38
+ target_rate: int,
39
+ ) -> torch.Tensor:
40
+ """Resample waveform to target sample rate if needed."""
41
+ if source_rate == target_rate:
42
+ return waveform
43
+ resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
44
+ return resampled.to(device=waveform.device, dtype=waveform.dtype)
45
+
46
+ def waveform_to_mel(
47
+ self,
48
+ waveform: torch.Tensor,
49
+ waveform_sample_rate: int,
50
+ ) -> torch.Tensor:
51
+ """Convert waveform to log-mel spectrogram [batch, channels, time, n_mels]."""
52
+ waveform = self.resample_waveform(waveform, waveform_sample_rate, self.sample_rate)
53
+
54
+ mel = self.mel_transform(waveform)
55
+ mel = torch.log(torch.clamp(mel, min=1e-5))
56
+
57
+ mel = mel.to(device=waveform.device, dtype=waveform.dtype)
58
+ return mel.permute(0, 1, 3, 2).contiguous()
59
+
60
+
61
+ class PerChannelStatistics(nn.Module):
62
+ """
63
+ Per-channel statistics for normalizing and denormalizing the latent representation.
64
+ This statics is computed over the entire dataset and stored in model's checkpoint under AudioVAE state_dict.
65
+ """
66
+
67
+ def __init__(self, latent_channels: int = 128) -> None:
68
+ super().__init__()
69
+ self.register_buffer("std-of-means", torch.empty(latent_channels))
70
+ self.register_buffer("mean-of-means", torch.empty(latent_channels))
71
+
72
+ def un_normalize(self, x: torch.Tensor) -> torch.Tensor:
73
+ return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x)
74
+
75
+ def normalize(self, x: torch.Tensor) -> torch.Tensor:
76
+ return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x)
packages/ltx-core/src/ltx_core/model/audio_vae/resnet.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
+
5
+ from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
6
+ from ltx_core.model.audio_vae.causality_axis import CausalityAxis
7
+ from ltx_core.model.common.normalization import NormType, build_normalization_layer
8
+
9
+ LRELU_SLOPE = 0.1
10
+
11
+
12
+ class ResBlock1(torch.nn.Module):
13
+ def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int, int] = (1, 3, 5)):
14
+ super(ResBlock1, self).__init__()
15
+ self.convs1 = torch.nn.ModuleList(
16
+ [
17
+ torch.nn.Conv1d(
18
+ channels,
19
+ channels,
20
+ kernel_size,
21
+ 1,
22
+ dilation=dilation[0],
23
+ padding="same",
24
+ ),
25
+ torch.nn.Conv1d(
26
+ channels,
27
+ channels,
28
+ kernel_size,
29
+ 1,
30
+ dilation=dilation[1],
31
+ padding="same",
32
+ ),
33
+ torch.nn.Conv1d(
34
+ channels,
35
+ channels,
36
+ kernel_size,
37
+ 1,
38
+ dilation=dilation[2],
39
+ padding="same",
40
+ ),
41
+ ]
42
+ )
43
+
44
+ self.convs2 = torch.nn.ModuleList(
45
+ [
46
+ torch.nn.Conv1d(
47
+ channels,
48
+ channels,
49
+ kernel_size,
50
+ 1,
51
+ dilation=1,
52
+ padding="same",
53
+ ),
54
+ torch.nn.Conv1d(
55
+ channels,
56
+ channels,
57
+ kernel_size,
58
+ 1,
59
+ dilation=1,
60
+ padding="same",
61
+ ),
62
+ torch.nn.Conv1d(
63
+ channels,
64
+ channels,
65
+ kernel_size,
66
+ 1,
67
+ dilation=1,
68
+ padding="same",
69
+ ),
70
+ ]
71
+ )
72
+
73
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
74
+ for conv1, conv2 in zip(self.convs1, self.convs2, strict=True):
75
+ xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
76
+ xt = conv1(xt)
77
+ xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE)
78
+ xt = conv2(xt)
79
+ x = xt + x
80
+ return x
81
+
82
+
83
+ class ResBlock2(torch.nn.Module):
84
+ def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int] = (1, 3)):
85
+ super(ResBlock2, self).__init__()
86
+ self.convs = torch.nn.ModuleList(
87
+ [
88
+ torch.nn.Conv1d(
89
+ channels,
90
+ channels,
91
+ kernel_size,
92
+ 1,
93
+ dilation=dilation[0],
94
+ padding="same",
95
+ ),
96
+ torch.nn.Conv1d(
97
+ channels,
98
+ channels,
99
+ kernel_size,
100
+ 1,
101
+ dilation=dilation[1],
102
+ padding="same",
103
+ ),
104
+ ]
105
+ )
106
+
107
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
108
+ for conv in self.convs:
109
+ xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
110
+ xt = conv(xt)
111
+ x = xt + x
112
+ return x
113
+
114
+
115
+ class ResnetBlock(torch.nn.Module):
116
+ def __init__(
117
+ self,
118
+ *,
119
+ in_channels: int,
120
+ out_channels: int | None = None,
121
+ conv_shortcut: bool = False,
122
+ dropout: float = 0.0,
123
+ temb_channels: int = 512,
124
+ norm_type: NormType = NormType.GROUP,
125
+ causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
126
+ ) -> None:
127
+ super().__init__()
128
+ self.causality_axis = causality_axis
129
+
130
+ if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP:
131
+ raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
132
+ self.in_channels = in_channels
133
+ out_channels = in_channels if out_channels is None else out_channels
134
+ self.out_channels = out_channels
135
+ self.use_conv_shortcut = conv_shortcut
136
+
137
+ self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
138
+ self.non_linearity = torch.nn.SiLU()
139
+ self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
140
+ if temb_channels > 0:
141
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
142
+ self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
143
+ self.dropout = torch.nn.Dropout(dropout)
144
+ self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
145
+ if self.in_channels != self.out_channels:
146
+ if self.use_conv_shortcut:
147
+ self.conv_shortcut = make_conv2d(
148
+ in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
149
+ )
150
+ else:
151
+ self.nin_shortcut = make_conv2d(
152
+ in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
153
+ )
154
+
155
+ def forward(
156
+ self,
157
+ x: torch.Tensor,
158
+ temb: torch.Tensor | None = None,
159
+ ) -> torch.Tensor:
160
+ h = x
161
+ h = self.norm1(h)
162
+ h = self.non_linearity(h)
163
+ h = self.conv1(h)
164
+
165
+ if temb is not None:
166
+ h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
167
+
168
+ h = self.norm2(h)
169
+ h = self.non_linearity(h)
170
+ h = self.dropout(h)
171
+ h = self.conv2(h)
172
+
173
+ if self.in_channels != self.out_channels:
174
+ x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)
175
+
176
+ return x + h
packages/ltx-core/src/ltx_core/model/audio_vae/upsample.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Set, Tuple
2
+
3
+ import torch
4
+
5
+ from ltx_core.model.audio_vae.attention import AttentionType, make_attn
6
+ from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
7
+ from ltx_core.model.audio_vae.causality_axis import CausalityAxis
8
+ from ltx_core.model.audio_vae.resnet import ResnetBlock
9
+ from ltx_core.model.common.normalization import NormType
10
+
11
+
12
+ class Upsample(torch.nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_channels: int,
16
+ with_conv: bool,
17
+ causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
18
+ ) -> None:
19
+ super().__init__()
20
+ self.with_conv = with_conv
21
+ self.causality_axis = causality_axis
22
+ if self.with_conv:
23
+ self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
27
+ if self.with_conv:
28
+ x = self.conv(x)
29
+ # Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.
30
+ # For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].
31
+ # The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],
32
+ # So the output elements rely on the following windows:
33
+ # 0: [-,-,0]
34
+ # 1: [-,0,0]
35
+ # 2: [0,0,1]
36
+ # 3: [0,1,1]
37
+ # 4: [1,1,2]
38
+ # 5: [1,2,2]
39
+ # Notice that the first and second elements in the output rely only on the first element in the input,
40
+ # while all other elements rely on two elements in the input.
41
+ # So we can drop the first element to undo the padding (rather than the last element).
42
+ # This is a no-op for non-causal convolutions.
43
+ match self.causality_axis:
44
+ case CausalityAxis.NONE:
45
+ pass # x remains unchanged
46
+ case CausalityAxis.HEIGHT:
47
+ x = x[:, :, 1:, :]
48
+ case CausalityAxis.WIDTH:
49
+ x = x[:, :, :, 1:]
50
+ case CausalityAxis.WIDTH_COMPATIBILITY:
51
+ pass # x remains unchanged
52
+ case _:
53
+ raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
54
+
55
+ return x
56
+
57
+
58
+ def build_upsampling_path( # noqa: PLR0913
59
+ *,
60
+ ch: int,
61
+ ch_mult: Tuple[int, ...],
62
+ num_resolutions: int,
63
+ num_res_blocks: int,
64
+ resolution: int,
65
+ temb_channels: int,
66
+ dropout: float,
67
+ norm_type: NormType,
68
+ causality_axis: CausalityAxis,
69
+ attn_type: AttentionType,
70
+ attn_resolutions: Set[int],
71
+ resamp_with_conv: bool,
72
+ initial_block_channels: int,
73
+ ) -> tuple[torch.nn.ModuleList, int]:
74
+ """Build the upsampling path with residual blocks, attention, and upsampling layers."""
75
+ up_modules = torch.nn.ModuleList()
76
+ block_in = initial_block_channels
77
+ curr_res = resolution // (2 ** (num_resolutions - 1))
78
+
79
+ for level in reversed(range(num_resolutions)):
80
+ stage = torch.nn.Module()
81
+ stage.block = torch.nn.ModuleList()
82
+ stage.attn = torch.nn.ModuleList()
83
+ block_out = ch * ch_mult[level]
84
+
85
+ for _ in range(num_res_blocks + 1):
86
+ stage.block.append(
87
+ ResnetBlock(
88
+ in_channels=block_in,
89
+ out_channels=block_out,
90
+ temb_channels=temb_channels,
91
+ dropout=dropout,
92
+ norm_type=norm_type,
93
+ causality_axis=causality_axis,
94
+ )
95
+ )
96
+ block_in = block_out
97
+ if curr_res in attn_resolutions:
98
+ stage.attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type))
99
+
100
+ if level != 0:
101
+ stage.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
102
+ curr_res *= 2
103
+
104
+ up_modules.insert(0, stage)
105
+
106
+ return up_modules, block_in
packages/ltx-core/src/ltx_core/model/audio_vae/vocoder.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+
4
+ import einops
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+ from ltx_core.model.audio_vae.resnet import LRELU_SLOPE, ResBlock1, ResBlock2
10
+
11
+
12
+ class Vocoder(torch.nn.Module):
13
+ """
14
+ Vocoder model for synthesizing audio from Mel spectrograms.
15
+ Args:
16
+ resblock_kernel_sizes: List of kernel sizes for the residual blocks.
17
+ This value is read from the checkpoint at `config.vocoder.resblock_kernel_sizes`.
18
+ upsample_rates: List of upsampling rates.
19
+ This value is read from the checkpoint at `config.vocoder.upsample_rates`.
20
+ upsample_kernel_sizes: List of kernel sizes for the upsampling layers.
21
+ This value is read from the checkpoint at `config.vocoder.upsample_kernel_sizes`.
22
+ resblock_dilation_sizes: List of dilation sizes for the residual blocks.
23
+ This value is read from the checkpoint at `config.vocoder.resblock_dilation_sizes`.
24
+ upsample_initial_channel: Initial number of channels for the upsampling layers.
25
+ This value is read from the checkpoint at `config.vocoder.upsample_initial_channel`.
26
+ stereo: Whether to use stereo output.
27
+ This value is read from the checkpoint at `config.vocoder.stereo`.
28
+ resblock: Type of residual block to use.
29
+ This value is read from the checkpoint at `config.vocoder.resblock`.
30
+ output_sample_rate: Waveform sample rate.
31
+ This value is read from the checkpoint at `config.vocoder.output_sample_rate`.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ resblock_kernel_sizes: List[int] | None = None,
37
+ upsample_rates: List[int] | None = None,
38
+ upsample_kernel_sizes: List[int] | None = None,
39
+ resblock_dilation_sizes: List[List[int]] | None = None,
40
+ upsample_initial_channel: int = 1024,
41
+ stereo: bool = True,
42
+ resblock: str = "1",
43
+ output_sample_rate: int = 24000,
44
+ ):
45
+ super().__init__()
46
+
47
+ # Initialize default values if not provided. Note that mutable default values are not supported.
48
+ if resblock_kernel_sizes is None:
49
+ resblock_kernel_sizes = [3, 7, 11]
50
+ if upsample_rates is None:
51
+ upsample_rates = [6, 5, 2, 2, 2]
52
+ if upsample_kernel_sizes is None:
53
+ upsample_kernel_sizes = [16, 15, 8, 4, 4]
54
+ if resblock_dilation_sizes is None:
55
+ resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
56
+
57
+ self.output_sample_rate = output_sample_rate
58
+ self.num_kernels = len(resblock_kernel_sizes)
59
+ self.num_upsamples = len(upsample_rates)
60
+ in_channels = 128 if stereo else 64
61
+ self.conv_pre = nn.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
62
+ resblock_class = ResBlock1 if resblock == "1" else ResBlock2
63
+
64
+ self.ups = nn.ModuleList()
65
+ for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True)):
66
+ self.ups.append(
67
+ nn.ConvTranspose1d(
68
+ upsample_initial_channel // (2**i),
69
+ upsample_initial_channel // (2 ** (i + 1)),
70
+ kernel_size,
71
+ stride,
72
+ padding=(kernel_size - stride) // 2,
73
+ )
74
+ )
75
+
76
+ self.resblocks = nn.ModuleList()
77
+ for i, _ in enumerate(self.ups):
78
+ ch = upsample_initial_channel // (2 ** (i + 1))
79
+ for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes, strict=True):
80
+ self.resblocks.append(resblock_class(ch, kernel_size, dilations))
81
+
82
+ out_channels = 2 if stereo else 1
83
+ final_channels = upsample_initial_channel // (2**self.num_upsamples)
84
+ self.conv_post = nn.Conv1d(final_channels, out_channels, 7, 1, padding=3)
85
+
86
+ self.upsample_factor = math.prod(layer.stride[0] for layer in self.ups)
87
+
88
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
89
+ """
90
+ Forward pass of the vocoder.
91
+ Args:
92
+ x: Input Mel spectrogram tensor. Can be either:
93
+ - 3D: (batch_size, time, mel_bins) for mono
94
+ - 4D: (batch_size, 2, time, mel_bins) for stereo
95
+ Returns:
96
+ Audio waveform tensor of shape (batch_size, out_channels, audio_length)
97
+ """
98
+ x = x.transpose(2, 3) # (batch, channels, time, mel_bins) -> (batch, channels, mel_bins, time)
99
+
100
+ if x.dim() == 4: # stereo
101
+ assert x.shape[1] == 2, "Input must have 2 channels for stereo"
102
+ x = einops.rearrange(x, "b s c t -> b (s c) t")
103
+
104
+ x = self.conv_pre(x)
105
+
106
+ for i in range(self.num_upsamples):
107
+ x = F.leaky_relu(x, LRELU_SLOPE)
108
+ x = self.ups[i](x)
109
+ start = i * self.num_kernels
110
+ end = start + self.num_kernels
111
+
112
+ # Evaluate all resblocks with the same input tensor so they can run
113
+ # independently (and thus in parallel on accelerator hardware) before
114
+ # aggregating their outputs via mean.
115
+ block_outputs = torch.stack(
116
+ [self.resblocks[idx](x) for idx in range(start, end)],
117
+ dim=0,
118
+ )
119
+
120
+ x = block_outputs.mean(dim=0)
121
+
122
+ x = self.conv_post(F.leaky_relu(x))
123
+ return torch.tanh(x)
packages/ltx-core/src/ltx_core/model/common/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """Common model utilities."""
2
+
3
+ from ltx_core.model.common.normalization import NormType, PixelNorm, build_normalization_layer
4
+
5
+ __all__ = [
6
+ "NormType",
7
+ "PixelNorm",
8
+ "build_normalization_layer",
9
+ ]
packages/ltx-core/src/ltx_core/model/common/normalization.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ class NormType(Enum):
8
+ """Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm)."""
9
+
10
+ GROUP = "group"
11
+ PIXEL = "pixel"
12
+
13
+
14
+ class PixelNorm(nn.Module):
15
+ """
16
+ Per-pixel (per-location) RMS normalization layer.
17
+ For each element along the chosen dimension, this layer normalizes the tensor
18
+ by the root-mean-square of its values across that dimension:
19
+ y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
20
+ """
21
+
22
+ def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
23
+ """
24
+ Args:
25
+ dim: Dimension along which to compute the RMS (typically channels).
26
+ eps: Small constant added for numerical stability.
27
+ """
28
+ super().__init__()
29
+ self.dim = dim
30
+ self.eps = eps
31
+
32
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
33
+ """
34
+ Apply RMS normalization along the configured dimension.
35
+ """
36
+ # Compute mean of squared values along `dim`, keep dimensions for broadcasting.
37
+ mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
38
+ # Normalize by the root-mean-square (RMS).
39
+ rms = torch.sqrt(mean_sq + self.eps)
40
+ return x / rms
41
+
42
+
43
+ def build_normalization_layer(
44
+ in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP
45
+ ) -> nn.Module:
46
+ """
47
+ Create a normalization layer based on the normalization type.
48
+ Args:
49
+ in_channels: Number of input channels
50
+ num_groups: Number of groups for group normalization
51
+ normtype: Type of normalization: "group" or "pixel"
52
+ Returns:
53
+ A normalization layer
54
+ """
55
+ if normtype == NormType.GROUP:
56
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
57
+ if normtype == NormType.PIXEL:
58
+ return PixelNorm(dim=1, eps=1e-6)
59
+ raise ValueError(f"Invalid normalization type: {normtype}")
packages/ltx-core/src/ltx_core/model/model_protocol.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Protocol, TypeVar
2
+
3
+ ModelType = TypeVar("ModelType")
4
+
5
+
6
+ class ModelConfigurator(Protocol[ModelType]):
7
+ """Protocol for model loader classes that instantiates models from a configuration dictionary."""
8
+
9
+ @classmethod
10
+ def from_config(cls, config: dict) -> ModelType: ...
packages/ltx-core/src/ltx_core/model/transformer/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Transformer model components."""
2
+
3
+ from ltx_core.model.transformer.modality import Modality
4
+ from ltx_core.model.transformer.model import LTXModel, X0Model
5
+ from ltx_core.model.transformer.model_configurator import (
6
+ LTXV_MODEL_COMFY_RENAMING_MAP,
7
+ LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP,
8
+ UPCAST_DURING_INFERENCE,
9
+ LTXModelConfigurator,
10
+ LTXVideoOnlyModelConfigurator,
11
+ UpcastWithStochasticRounding,
12
+ )
13
+
14
+ __all__ = [
15
+ "LTXV_MODEL_COMFY_RENAMING_MAP",
16
+ "LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP",
17
+ "UPCAST_DURING_INFERENCE",
18
+ "LTXModel",
19
+ "LTXModelConfigurator",
20
+ "LTXVideoOnlyModelConfigurator",
21
+ "Modality",
22
+ "UpcastWithStochasticRounding",
23
+ "X0Model",
24
+ ]
packages/ltx-core/src/ltx_core/model/transformer/adaln.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+
5
+ from ltx_core.model.transformer.timestep_embedding import PixArtAlphaCombinedTimestepSizeEmbeddings
6
+
7
+
8
+ class AdaLayerNormSingle(torch.nn.Module):
9
+ r"""
10
+ Norm layer adaptive layer norm single (adaLN-single).
11
+ As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
12
+ Parameters:
13
+ embedding_dim (`int`): The size of each embedding vector.
14
+ use_additional_conditions (`bool`): To use additional conditions for normalization or not.
15
+ """
16
+
17
+ def __init__(self, embedding_dim: int, embedding_coefficient: int = 6):
18
+ super().__init__()
19
+
20
+ self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
21
+ embedding_dim,
22
+ size_emb_dim=embedding_dim // 3,
23
+ )
24
+
25
+ self.silu = torch.nn.SiLU()
26
+ self.linear = torch.nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True)
27
+
28
+ def forward(
29
+ self,
30
+ timestep: torch.Tensor,
31
+ hidden_dtype: Optional[torch.dtype] = None,
32
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
33
+ embedded_timestep = self.emb(timestep, hidden_dtype=hidden_dtype)
34
+ return self.linear(self.silu(embedded_timestep)), embedded_timestep
packages/ltx-core/src/ltx_core/model/transformer/attention.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from typing import Protocol
3
+
4
+ import torch
5
+
6
+ from ltx_core.model.transformer.rope import LTXRopeType, apply_rotary_emb
7
+
8
+ memory_efficient_attention = None
9
+ flash_attn_interface = None
10
+ try:
11
+ from xformers.ops import memory_efficient_attention
12
+ except ImportError:
13
+ memory_efficient_attention = None
14
+
15
+ import flash_attn_interface
16
+
17
+ class AttentionCallable(Protocol):
18
+ def __call__(
19
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None
20
+ ) -> torch.Tensor: ...
21
+
22
+
23
+ class PytorchAttention(AttentionCallable):
24
+ def __call__(
25
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None
26
+ ) -> torch.Tensor:
27
+ b, _, dim_head = q.shape
28
+ dim_head //= heads
29
+ q, k, v = (t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v))
30
+
31
+ if mask is not None:
32
+ # add a batch dimension if there isn't already one
33
+ if mask.ndim == 2:
34
+ mask = mask.unsqueeze(0)
35
+ # add a heads dimension if there isn't already one
36
+ if mask.ndim == 3:
37
+ mask = mask.unsqueeze(1)
38
+
39
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
40
+ out = out.transpose(1, 2).reshape(b, -1, heads * dim_head)
41
+ return out
42
+
43
+
44
+ class XFormersAttention(AttentionCallable):
45
+ def __call__(
46
+ self,
47
+ q: torch.Tensor,
48
+ k: torch.Tensor,
49
+ v: torch.Tensor,
50
+ heads: int,
51
+ mask: torch.Tensor | None = None,
52
+ ) -> torch.Tensor:
53
+ if memory_efficient_attention is None:
54
+ raise RuntimeError("XFormersAttention was selected but `xformers` is not installed.")
55
+
56
+ b, _, dim_head = q.shape
57
+ dim_head //= heads
58
+
59
+ # xformers expects [B, M, H, K]
60
+ q, k, v = (t.view(b, -1, heads, dim_head) for t in (q, k, v))
61
+
62
+ if mask is not None:
63
+ # add a singleton batch dimension
64
+ if mask.ndim == 2:
65
+ mask = mask.unsqueeze(0)
66
+ # add a singleton heads dimension
67
+ if mask.ndim == 3:
68
+ mask = mask.unsqueeze(1)
69
+ # pad to a multiple of 8
70
+ pad = 8 - mask.shape[-1] % 8
71
+ # the xformers docs says that it's allowed to have a mask of shape (1, Nq, Nk)
72
+ # but when using separated heads, the shape has to be (B, H, Nq, Nk)
73
+ # in flux, this matrix ends up being over 1GB
74
+ # here, we create a mask with the same batch/head size as the input mask (potentially singleton or full)
75
+ mask_out = torch.empty(
76
+ [mask.shape[0], mask.shape[1], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device
77
+ )
78
+
79
+ mask_out[..., : mask.shape[-1]] = mask
80
+ # doesn't this remove the padding again??
81
+ mask = mask_out[..., : mask.shape[-1]]
82
+ mask = mask.expand(b, heads, -1, -1)
83
+
84
+ out = memory_efficient_attention(q.to(v.dtype), k.to(v.dtype), v, attn_bias=mask, p=0.0)
85
+ out = out.reshape(b, -1, heads * dim_head)
86
+ return out
87
+
88
+
89
+ class FlashAttention3(AttentionCallable):
90
+ def __call__(
91
+ self,
92
+ q: torch.Tensor,
93
+ k: torch.Tensor,
94
+ v: torch.Tensor,
95
+ heads: int,
96
+ mask: torch.Tensor | None = None,
97
+ ) -> torch.Tensor:
98
+ if flash_attn_interface is None:
99
+ raise RuntimeError("FlashAttention3 was selected but `FlashAttention3` is not installed.")
100
+
101
+ b, _, dim_head = q.shape
102
+ dim_head //= heads
103
+
104
+ q, k, v = (t.view(b, -1, heads, dim_head) for t in (q, k, v))
105
+
106
+ if mask is not None:
107
+ raise NotImplementedError("Mask is not supported for FlashAttention3")
108
+
109
+ out = flash_attn_interface.flash_attn_func(q.to(v.dtype), k.to(v.dtype), v)
110
+ out = out.reshape(b, -1, heads * dim_head)
111
+ return out
112
+
113
+
114
+ class AttentionFunction(Enum):
115
+ PYTORCH = "pytorch"
116
+ XFORMERS = "xformers"
117
+ FLASH_ATTENTION_3 = "flash_attention_3"
118
+ DEFAULT = "default"
119
+
120
+ def __call__(
121
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None
122
+ ) -> torch.Tensor:
123
+ if mask is None:
124
+ return FlashAttention3()(q, k, v, heads, mask)
125
+ else:
126
+ return (
127
+ XFormersAttention()(q, k, v, heads, mask)
128
+ if memory_efficient_attention is not None
129
+ else PytorchAttention()(q, k, v, heads, mask)
130
+ )
131
+
132
+
133
+ class Attention(torch.nn.Module):
134
+ def __init__(
135
+ self,
136
+ query_dim: int,
137
+ context_dim: int | None = None,
138
+ heads: int = 8,
139
+ dim_head: int = 64,
140
+ norm_eps: float = 1e-6,
141
+ rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
142
+ attention_function: AttentionCallable | AttentionFunction = AttentionFunction.DEFAULT,
143
+ ) -> None:
144
+ super().__init__()
145
+ self.rope_type = rope_type
146
+ self.attention_function = attention_function
147
+
148
+ inner_dim = dim_head * heads
149
+ context_dim = query_dim if context_dim is None else context_dim
150
+
151
+ self.heads = heads
152
+ self.dim_head = dim_head
153
+
154
+ self.q_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps)
155
+ self.k_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps)
156
+
157
+ self.to_q = torch.nn.Linear(query_dim, inner_dim, bias=True)
158
+ self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True)
159
+ self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True)
160
+
161
+ self.to_out = torch.nn.Sequential(torch.nn.Linear(inner_dim, query_dim, bias=True), torch.nn.Identity())
162
+
163
+ def forward(
164
+ self,
165
+ x: torch.Tensor,
166
+ context: torch.Tensor | None = None,
167
+ mask: torch.Tensor | None = None,
168
+ pe: torch.Tensor | None = None,
169
+ k_pe: torch.Tensor | None = None,
170
+ ) -> torch.Tensor:
171
+ q = self.to_q(x)
172
+ context = x if context is None else context
173
+ k = self.to_k(context)
174
+ v = self.to_v(context)
175
+
176
+ q = self.q_norm(q)
177
+ k = self.k_norm(k)
178
+
179
+ if pe is not None:
180
+ q = apply_rotary_emb(q, pe, self.rope_type)
181
+ k = apply_rotary_emb(k, pe if k_pe is None else k_pe, self.rope_type)
182
+
183
+ # attention_function can be an enum *or* a custom callable
184
+ out = self.attention_function(q, k, v, self.heads, mask)
185
+ return self.to_out(out)
packages/ltx-core/src/ltx_core/model/transformer/feed_forward.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ltx_core.model.transformer.gelu_approx import GELUApprox
4
+
5
+
6
+ class FeedForward(torch.nn.Module):
7
+ def __init__(self, dim: int, dim_out: int, mult: int = 4) -> None:
8
+ super().__init__()
9
+ inner_dim = int(dim * mult)
10
+ project_in = GELUApprox(dim, inner_dim)
11
+
12
+ self.net = torch.nn.Sequential(project_in, torch.nn.Identity(), torch.nn.Linear(inner_dim, dim_out))
13
+
14
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
15
+ return self.net(x)
packages/ltx-core/src/ltx_core/model/transformer/gelu_approx.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class GELUApprox(torch.nn.Module):
5
+ def __init__(self, dim_in: int, dim_out: int) -> None:
6
+ super().__init__()
7
+ self.proj = torch.nn.Linear(dim_in, dim_out)
8
+
9
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
10
+ return torch.nn.functional.gelu(self.proj(x), approximate="tanh")
packages/ltx-core/src/ltx_core/model/transformer/modality.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+
6
+ @dataclass(frozen=True)
7
+ class Modality:
8
+ """
9
+ Input data for a single modality (video or audio) in the transformer.
10
+ Bundles the latent tokens, timestep embeddings, positional information,
11
+ and text conditioning context for processing by the diffusion transformer.
12
+ """
13
+
14
+ latent: (
15
+ torch.Tensor
16
+ ) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
17
+ timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps
18
+ positions: (
19
+ torch.Tensor
20
+ ) # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens
21
+ context: torch.Tensor
22
+ enabled: bool = True
23
+ context_mask: torch.Tensor | None = None
packages/ltx-core/src/ltx_core/model/transformer/model.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ import torch
4
+
5
+ from ltx_core.guidance.perturbations import BatchedPerturbationConfig
6
+ from ltx_core.model.transformer.adaln import AdaLayerNormSingle
7
+ from ltx_core.model.transformer.attention import AttentionCallable, AttentionFunction
8
+ from ltx_core.model.transformer.modality import Modality
9
+ from ltx_core.model.transformer.rope import LTXRopeType
10
+ from ltx_core.model.transformer.text_projection import PixArtAlphaTextProjection
11
+ from ltx_core.model.transformer.transformer import BasicAVTransformerBlock, TransformerConfig
12
+ from ltx_core.model.transformer.transformer_args import (
13
+ MultiModalTransformerArgsPreprocessor,
14
+ TransformerArgs,
15
+ TransformerArgsPreprocessor,
16
+ )
17
+ from ltx_core.utils import to_denoised
18
+
19
+
20
+ class LTXModelType(Enum):
21
+ AudioVideo = "ltx av model"
22
+ VideoOnly = "ltx video only model"
23
+ AudioOnly = "ltx audio only model"
24
+
25
+ def is_video_enabled(self) -> bool:
26
+ return self in (LTXModelType.AudioVideo, LTXModelType.VideoOnly)
27
+
28
+ def is_audio_enabled(self) -> bool:
29
+ return self in (LTXModelType.AudioVideo, LTXModelType.AudioOnly)
30
+
31
+
32
+ class LTXModel(torch.nn.Module):
33
+ """
34
+ LTX model transformer implementation.
35
+ This class implements the transformer blocks for the LTX model.
36
+ """
37
+
38
+ def __init__( # noqa: PLR0913
39
+ self,
40
+ *,
41
+ model_type: LTXModelType = LTXModelType.AudioVideo,
42
+ num_attention_heads: int = 32,
43
+ attention_head_dim: int = 128,
44
+ in_channels: int = 128,
45
+ out_channels: int = 128,
46
+ num_layers: int = 48,
47
+ cross_attention_dim: int = 4096,
48
+ norm_eps: float = 1e-06,
49
+ attention_type: AttentionFunction | AttentionCallable = AttentionFunction.DEFAULT,
50
+ caption_channels: int = 3840,
51
+ positional_embedding_theta: float = 10000.0,
52
+ positional_embedding_max_pos: list[int] | None = None,
53
+ timestep_scale_multiplier: int = 1000,
54
+ use_middle_indices_grid: bool = True,
55
+ audio_num_attention_heads: int = 32,
56
+ audio_attention_head_dim: int = 64,
57
+ audio_in_channels: int = 128,
58
+ audio_out_channels: int = 128,
59
+ audio_cross_attention_dim: int = 2048,
60
+ audio_positional_embedding_max_pos: list[int] | None = None,
61
+ av_ca_timestep_scale_multiplier: int = 1,
62
+ rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
63
+ double_precision_rope: bool = False,
64
+ ):
65
+ super().__init__()
66
+ self._enable_gradient_checkpointing = False
67
+ self.use_middle_indices_grid = use_middle_indices_grid
68
+ self.rope_type = rope_type
69
+ self.double_precision_rope = double_precision_rope
70
+ self.timestep_scale_multiplier = timestep_scale_multiplier
71
+ self.positional_embedding_theta = positional_embedding_theta
72
+ self.model_type = model_type
73
+ cross_pe_max_pos = None
74
+ if model_type.is_video_enabled():
75
+ if positional_embedding_max_pos is None:
76
+ positional_embedding_max_pos = [20, 2048, 2048]
77
+ self.positional_embedding_max_pos = positional_embedding_max_pos
78
+ self.num_attention_heads = num_attention_heads
79
+ self.inner_dim = num_attention_heads * attention_head_dim
80
+ self._init_video(
81
+ in_channels=in_channels,
82
+ out_channels=out_channels,
83
+ caption_channels=caption_channels,
84
+ norm_eps=norm_eps,
85
+ )
86
+
87
+ if model_type.is_audio_enabled():
88
+ if audio_positional_embedding_max_pos is None:
89
+ audio_positional_embedding_max_pos = [20]
90
+ self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
91
+ self.audio_num_attention_heads = audio_num_attention_heads
92
+ self.audio_inner_dim = self.audio_num_attention_heads * audio_attention_head_dim
93
+ self._init_audio(
94
+ in_channels=audio_in_channels,
95
+ out_channels=audio_out_channels,
96
+ caption_channels=caption_channels,
97
+ norm_eps=norm_eps,
98
+ )
99
+
100
+ if model_type.is_video_enabled() and model_type.is_audio_enabled():
101
+ cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0])
102
+ self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
103
+ self.audio_cross_attention_dim = audio_cross_attention_dim
104
+ self._init_audio_video(num_scale_shift_values=4)
105
+
106
+ self._init_preprocessors(cross_pe_max_pos)
107
+ # Initialize transformer blocks
108
+ self._init_transformer_blocks(
109
+ num_layers=num_layers,
110
+ attention_head_dim=attention_head_dim if model_type.is_video_enabled() else 0,
111
+ cross_attention_dim=cross_attention_dim,
112
+ audio_attention_head_dim=audio_attention_head_dim if model_type.is_audio_enabled() else 0,
113
+ audio_cross_attention_dim=audio_cross_attention_dim,
114
+ norm_eps=norm_eps,
115
+ attention_type=attention_type,
116
+ )
117
+
118
+ def _init_video(
119
+ self,
120
+ in_channels: int,
121
+ out_channels: int,
122
+ caption_channels: int,
123
+ norm_eps: float,
124
+ ) -> None:
125
+ """Initialize video-specific components."""
126
+ # Video input components
127
+ self.patchify_proj = torch.nn.Linear(in_channels, self.inner_dim, bias=True)
128
+
129
+ self.adaln_single = AdaLayerNormSingle(self.inner_dim)
130
+
131
+ # Video caption projection
132
+ self.caption_projection = PixArtAlphaTextProjection(
133
+ in_features=caption_channels,
134
+ hidden_size=self.inner_dim,
135
+ )
136
+
137
+ # Video output components
138
+ self.scale_shift_table = torch.nn.Parameter(torch.empty(2, self.inner_dim))
139
+ self.norm_out = torch.nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=norm_eps)
140
+ self.proj_out = torch.nn.Linear(self.inner_dim, out_channels)
141
+
142
+ def _init_audio(
143
+ self,
144
+ in_channels: int,
145
+ out_channels: int,
146
+ caption_channels: int,
147
+ norm_eps: float,
148
+ ) -> None:
149
+ """Initialize audio-specific components."""
150
+
151
+ # Audio input components
152
+ self.audio_patchify_proj = torch.nn.Linear(in_channels, self.audio_inner_dim, bias=True)
153
+
154
+ self.audio_adaln_single = AdaLayerNormSingle(
155
+ self.audio_inner_dim,
156
+ )
157
+
158
+ # Audio caption projection
159
+ self.audio_caption_projection = PixArtAlphaTextProjection(
160
+ in_features=caption_channels,
161
+ hidden_size=self.audio_inner_dim,
162
+ )
163
+
164
+ # Audio output components
165
+ self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(2, self.audio_inner_dim))
166
+ self.audio_norm_out = torch.nn.LayerNorm(self.audio_inner_dim, elementwise_affine=False, eps=norm_eps)
167
+ self.audio_proj_out = torch.nn.Linear(self.audio_inner_dim, out_channels)
168
+
169
+ def _init_audio_video(
170
+ self,
171
+ num_scale_shift_values: int,
172
+ ) -> None:
173
+ """Initialize audio-video cross-attention components."""
174
+ self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
175
+ self.inner_dim,
176
+ embedding_coefficient=num_scale_shift_values,
177
+ )
178
+
179
+ self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
180
+ self.audio_inner_dim,
181
+ embedding_coefficient=num_scale_shift_values,
182
+ )
183
+
184
+ self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
185
+ self.inner_dim,
186
+ embedding_coefficient=1,
187
+ )
188
+
189
+ self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
190
+ self.audio_inner_dim,
191
+ embedding_coefficient=1,
192
+ )
193
+
194
+ def _init_preprocessors(
195
+ self,
196
+ cross_pe_max_pos: int | None = None,
197
+ ) -> None:
198
+ """Initialize preprocessors for LTX."""
199
+
200
+ if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled():
201
+ self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
202
+ patchify_proj=self.patchify_proj,
203
+ adaln=self.adaln_single,
204
+ caption_projection=self.caption_projection,
205
+ cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single,
206
+ cross_gate_adaln=self.av_ca_a2v_gate_adaln_single,
207
+ inner_dim=self.inner_dim,
208
+ max_pos=self.positional_embedding_max_pos,
209
+ num_attention_heads=self.num_attention_heads,
210
+ cross_pe_max_pos=cross_pe_max_pos,
211
+ use_middle_indices_grid=self.use_middle_indices_grid,
212
+ audio_cross_attention_dim=self.audio_cross_attention_dim,
213
+ timestep_scale_multiplier=self.timestep_scale_multiplier,
214
+ double_precision_rope=self.double_precision_rope,
215
+ positional_embedding_theta=self.positional_embedding_theta,
216
+ rope_type=self.rope_type,
217
+ av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier,
218
+ )
219
+ self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor(
220
+ patchify_proj=self.audio_patchify_proj,
221
+ adaln=self.audio_adaln_single,
222
+ caption_projection=self.audio_caption_projection,
223
+ cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single,
224
+ cross_gate_adaln=self.av_ca_v2a_gate_adaln_single,
225
+ inner_dim=self.audio_inner_dim,
226
+ max_pos=self.audio_positional_embedding_max_pos,
227
+ num_attention_heads=self.audio_num_attention_heads,
228
+ cross_pe_max_pos=cross_pe_max_pos,
229
+ use_middle_indices_grid=self.use_middle_indices_grid,
230
+ audio_cross_attention_dim=self.audio_cross_attention_dim,
231
+ timestep_scale_multiplier=self.timestep_scale_multiplier,
232
+ double_precision_rope=self.double_precision_rope,
233
+ positional_embedding_theta=self.positional_embedding_theta,
234
+ rope_type=self.rope_type,
235
+ av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier,
236
+ )
237
+ elif self.model_type.is_video_enabled():
238
+ self.video_args_preprocessor = TransformerArgsPreprocessor(
239
+ patchify_proj=self.patchify_proj,
240
+ adaln=self.adaln_single,
241
+ caption_projection=self.caption_projection,
242
+ inner_dim=self.inner_dim,
243
+ max_pos=self.positional_embedding_max_pos,
244
+ num_attention_heads=self.num_attention_heads,
245
+ use_middle_indices_grid=self.use_middle_indices_grid,
246
+ timestep_scale_multiplier=self.timestep_scale_multiplier,
247
+ double_precision_rope=self.double_precision_rope,
248
+ positional_embedding_theta=self.positional_embedding_theta,
249
+ rope_type=self.rope_type,
250
+ )
251
+ elif self.model_type.is_audio_enabled():
252
+ self.audio_args_preprocessor = TransformerArgsPreprocessor(
253
+ patchify_proj=self.audio_patchify_proj,
254
+ adaln=self.audio_adaln_single,
255
+ caption_projection=self.audio_caption_projection,
256
+ inner_dim=self.audio_inner_dim,
257
+ max_pos=self.audio_positional_embedding_max_pos,
258
+ num_attention_heads=self.audio_num_attention_heads,
259
+ use_middle_indices_grid=self.use_middle_indices_grid,
260
+ timestep_scale_multiplier=self.timestep_scale_multiplier,
261
+ double_precision_rope=self.double_precision_rope,
262
+ positional_embedding_theta=self.positional_embedding_theta,
263
+ rope_type=self.rope_type,
264
+ )
265
+
266
+ def _init_transformer_blocks(
267
+ self,
268
+ num_layers: int,
269
+ attention_head_dim: int,
270
+ cross_attention_dim: int,
271
+ audio_attention_head_dim: int,
272
+ audio_cross_attention_dim: int,
273
+ norm_eps: float,
274
+ attention_type: AttentionFunction | AttentionCallable,
275
+ ) -> None:
276
+ """Initialize transformer blocks for LTX."""
277
+ video_config = (
278
+ TransformerConfig(
279
+ dim=self.inner_dim,
280
+ heads=self.num_attention_heads,
281
+ d_head=attention_head_dim,
282
+ context_dim=cross_attention_dim,
283
+ )
284
+ if self.model_type.is_video_enabled()
285
+ else None
286
+ )
287
+ audio_config = (
288
+ TransformerConfig(
289
+ dim=self.audio_inner_dim,
290
+ heads=self.audio_num_attention_heads,
291
+ d_head=audio_attention_head_dim,
292
+ context_dim=audio_cross_attention_dim,
293
+ )
294
+ if self.model_type.is_audio_enabled()
295
+ else None
296
+ )
297
+ self.transformer_blocks = torch.nn.ModuleList(
298
+ [
299
+ BasicAVTransformerBlock(
300
+ idx=idx,
301
+ video=video_config,
302
+ audio=audio_config,
303
+ rope_type=self.rope_type,
304
+ norm_eps=norm_eps,
305
+ attention_function=attention_type,
306
+ )
307
+ for idx in range(num_layers)
308
+ ]
309
+ )
310
+
311
+ def set_gradient_checkpointing(self, enable: bool) -> None:
312
+ """Enable or disable gradient checkpointing for transformer blocks.
313
+ Gradient checkpointing trades compute for memory by recomputing activations
314
+ during the backward pass instead of storing them. This can significantly
315
+ reduce memory usage at the cost of ~20-30% slower training.
316
+ Args:
317
+ enable: Whether to enable gradient checkpointing
318
+ """
319
+ self._enable_gradient_checkpointing = enable
320
+
321
+ def _process_transformer_blocks(
322
+ self,
323
+ video: TransformerArgs | None,
324
+ audio: TransformerArgs | None,
325
+ perturbations: BatchedPerturbationConfig,
326
+ ) -> tuple[TransformerArgs, TransformerArgs]:
327
+ """Process transformer blocks for LTXAV."""
328
+
329
+ # Process transformer blocks
330
+ for block in self.transformer_blocks:
331
+ if self._enable_gradient_checkpointing and self.training:
332
+ # Use gradient checkpointing to save memory during training.
333
+ # With use_reentrant=False, we can pass dataclasses directly -
334
+ # PyTorch will track all tensor leaves in the computation graph.
335
+ video, audio = torch.utils.checkpoint.checkpoint(
336
+ block,
337
+ video,
338
+ audio,
339
+ perturbations,
340
+ use_reentrant=False,
341
+ )
342
+ else:
343
+ video, audio = block(
344
+ video=video,
345
+ audio=audio,
346
+ perturbations=perturbations,
347
+ )
348
+
349
+ return video, audio
350
+
351
+ def _process_output(
352
+ self,
353
+ scale_shift_table: torch.Tensor,
354
+ norm_out: torch.nn.LayerNorm,
355
+ proj_out: torch.nn.Linear,
356
+ x: torch.Tensor,
357
+ embedded_timestep: torch.Tensor,
358
+ ) -> torch.Tensor:
359
+ """Process output for LTXV."""
360
+ # Apply scale-shift modulation
361
+ scale_shift_values = (
362
+ scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
363
+ )
364
+ shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
365
+
366
+ x = norm_out(x)
367
+ x = x * (1 + scale) + shift
368
+ x = proj_out(x)
369
+ return x
370
+
371
+ def forward(
372
+ self, video: Modality | None, audio: Modality | None, perturbations: BatchedPerturbationConfig
373
+ ) -> tuple[torch.Tensor, torch.Tensor]:
374
+ """
375
+ Forward pass for LTX models.
376
+ Returns:
377
+ Processed output tensors
378
+ """
379
+ if not self.model_type.is_video_enabled() and video is not None:
380
+ raise ValueError("Video is not enabled for this model")
381
+ if not self.model_type.is_audio_enabled() and audio is not None:
382
+ raise ValueError("Audio is not enabled for this model")
383
+
384
+ video_args = self.video_args_preprocessor.prepare(video) if video is not None else None
385
+ audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None
386
+ # Process transformer blocks
387
+ video_out, audio_out = self._process_transformer_blocks(
388
+ video=video_args,
389
+ audio=audio_args,
390
+ perturbations=perturbations,
391
+ )
392
+
393
+ # Process output
394
+ vx = (
395
+ self._process_output(
396
+ self.scale_shift_table, self.norm_out, self.proj_out, video_out.x, video_out.embedded_timestep
397
+ )
398
+ if video_out is not None
399
+ else None
400
+ )
401
+ ax = (
402
+ self._process_output(
403
+ self.audio_scale_shift_table,
404
+ self.audio_norm_out,
405
+ self.audio_proj_out,
406
+ audio_out.x,
407
+ audio_out.embedded_timestep,
408
+ )
409
+ if audio_out is not None
410
+ else None
411
+ )
412
+ return vx, ax
413
+
414
+
415
+ class LegacyX0Model(torch.nn.Module):
416
+ """
417
+ Legacy X0 model implementation.
418
+ Returns fully denoised output based on the velocities produced by the base model.
419
+ """
420
+
421
+ def __init__(self, velocity_model: LTXModel):
422
+ super().__init__()
423
+ self.velocity_model = velocity_model
424
+
425
+ def forward(
426
+ self,
427
+ video: Modality | None,
428
+ audio: Modality | None,
429
+ perturbations: BatchedPerturbationConfig,
430
+ sigma: float,
431
+ ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
432
+ """
433
+ Denoise the video and audio according to the sigma.
434
+ Returns:
435
+ Denoised video and audio
436
+ """
437
+ vx, ax = self.velocity_model(video, audio, perturbations)
438
+ denoised_video = to_denoised(video.latent, vx, sigma) if vx is not None else None
439
+ denoised_audio = to_denoised(audio.latent, ax, sigma) if ax is not None else None
440
+ return denoised_video, denoised_audio
441
+
442
+
443
+ class X0Model(torch.nn.Module):
444
+ """
445
+ X0 model implementation.
446
+ Returns fully denoised outputs based on the velocities produced by the base model.
447
+ Applies scaled denoising to the video and audio according to the timesteps = sigma * denoising_mask.
448
+ """
449
+
450
+ def __init__(self, velocity_model: LTXModel):
451
+ super().__init__()
452
+ self.velocity_model = velocity_model
453
+
454
+ def forward(
455
+ self,
456
+ video: Modality | None,
457
+ audio: Modality | None,
458
+ perturbations: BatchedPerturbationConfig,
459
+ ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
460
+ """
461
+ Denoise the video and audio according to the sigma.
462
+ Returns:
463
+ Denoised video and audio
464
+ """
465
+ vx, ax = self.velocity_model(video, audio, perturbations)
466
+ denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
467
+ denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None
468
+ return denoised_video, denoised_audio
packages/ltx-core/src/ltx_core/model/transformer/model_configurator.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ltx_core.loader.fuse_loras import fused_add_round_launch
4
+ from ltx_core.loader.module_ops import ModuleOps
5
+ from ltx_core.loader.sd_ops import KeyValueOperationResult, SDOps
6
+ from ltx_core.model.model_protocol import ModelConfigurator
7
+ from ltx_core.model.transformer.attention import AttentionFunction
8
+ from ltx_core.model.transformer.model import LTXModel, LTXModelType
9
+ from ltx_core.model.transformer.rope import LTXRopeType
10
+ from ltx_core.utils import check_config_value
11
+
12
+
13
+ class LTXModelConfigurator(ModelConfigurator[LTXModel]):
14
+ """
15
+ Configurator for LTX model.
16
+ Used to create an LTX model from a configuration dictionary.
17
+ """
18
+
19
+ @classmethod
20
+ def from_config(cls: type[LTXModel], config: dict) -> LTXModel:
21
+ config = config.get("transformer", {})
22
+
23
+ check_config_value(config, "dropout", 0.0)
24
+ check_config_value(config, "attention_bias", True)
25
+ check_config_value(config, "num_vector_embeds", None)
26
+ check_config_value(config, "activation_fn", "gelu-approximate")
27
+ check_config_value(config, "num_embeds_ada_norm", 1000)
28
+ check_config_value(config, "use_linear_projection", False)
29
+ check_config_value(config, "only_cross_attention", False)
30
+ check_config_value(config, "cross_attention_norm", True)
31
+ check_config_value(config, "double_self_attention", False)
32
+ check_config_value(config, "upcast_attention", False)
33
+ check_config_value(config, "standardization_norm", "rms_norm")
34
+ check_config_value(config, "norm_elementwise_affine", False)
35
+ check_config_value(config, "qk_norm", "rms_norm")
36
+ check_config_value(config, "positional_embedding_type", "rope")
37
+ check_config_value(config, "use_audio_video_cross_attention", True)
38
+ check_config_value(config, "share_ff", False)
39
+ check_config_value(config, "av_cross_ada_norm", True)
40
+ check_config_value(config, "use_middle_indices_grid", True)
41
+
42
+ return LTXModel(
43
+ model_type=LTXModelType.AudioVideo,
44
+ num_attention_heads=config.get("num_attention_heads", 32),
45
+ attention_head_dim=config.get("attention_head_dim", 128),
46
+ in_channels=config.get("in_channels", 128),
47
+ out_channels=config.get("out_channels", 128),
48
+ num_layers=config.get("num_layers", 48),
49
+ cross_attention_dim=config.get("cross_attention_dim", 4096),
50
+ norm_eps=config.get("norm_eps", 1e-06),
51
+ attention_type=AttentionFunction(config.get("attention_type", "default")),
52
+ caption_channels=config.get("caption_channels", 3840),
53
+ positional_embedding_theta=config.get("positional_embedding_theta", 10000.0),
54
+ positional_embedding_max_pos=config.get("positional_embedding_max_pos", [20, 2048, 2048]),
55
+ timestep_scale_multiplier=config.get("timestep_scale_multiplier", 1000),
56
+ use_middle_indices_grid=config.get("use_middle_indices_grid", True),
57
+ audio_num_attention_heads=config.get("audio_num_attention_heads", 32),
58
+ audio_attention_head_dim=config.get("audio_attention_head_dim", 64),
59
+ audio_in_channels=config.get("audio_in_channels", 128),
60
+ audio_out_channels=config.get("audio_out_channels", 128),
61
+ audio_cross_attention_dim=config.get("audio_cross_attention_dim", 2048),
62
+ audio_positional_embedding_max_pos=config.get("audio_positional_embedding_max_pos", [20]),
63
+ av_ca_timestep_scale_multiplier=config.get("av_ca_timestep_scale_multiplier", 1),
64
+ rope_type=LTXRopeType(config.get("rope_type", "interleaved")),
65
+ double_precision_rope=config.get("frequencies_precision", False) == "float64",
66
+ )
67
+
68
+
69
+ class LTXVideoOnlyModelConfigurator(ModelConfigurator[LTXModel]):
70
+ """
71
+ Configurator for LTX video only model.
72
+ Used to create an LTX video only model from a configuration dictionary.
73
+ """
74
+
75
+ @classmethod
76
+ def from_config(cls: type[LTXModel], config: dict) -> LTXModel:
77
+ config = config.get("transformer", {})
78
+
79
+ check_config_value(config, "dropout", 0.0)
80
+ check_config_value(config, "attention_bias", True)
81
+ check_config_value(config, "num_vector_embeds", None)
82
+ check_config_value(config, "activation_fn", "gelu-approximate")
83
+ check_config_value(config, "num_embeds_ada_norm", 1000)
84
+ check_config_value(config, "use_linear_projection", False)
85
+ check_config_value(config, "only_cross_attention", False)
86
+ check_config_value(config, "cross_attention_norm", True)
87
+ check_config_value(config, "double_self_attention", False)
88
+ check_config_value(config, "upcast_attention", False)
89
+ check_config_value(config, "standardization_norm", "rms_norm")
90
+ check_config_value(config, "norm_elementwise_affine", False)
91
+ check_config_value(config, "qk_norm", "rms_norm")
92
+ check_config_value(config, "positional_embedding_type", "rope")
93
+ check_config_value(config, "use_middle_indices_grid", True)
94
+
95
+ return LTXModel(
96
+ model_type=LTXModelType.VideoOnly,
97
+ num_attention_heads=config.get("num_attention_heads", 32),
98
+ attention_head_dim=config.get("attention_head_dim", 128),
99
+ in_channels=config.get("in_channels", 128),
100
+ out_channels=config.get("out_channels", 128),
101
+ num_layers=config.get("num_layers", 48),
102
+ cross_attention_dim=config.get("cross_attention_dim", 4096),
103
+ norm_eps=config.get("norm_eps", 1e-06),
104
+ attention_type=AttentionFunction(config.get("attention_type", "default")),
105
+ caption_channels=config.get("caption_channels", 3840),
106
+ positional_embedding_theta=config.get("positional_embedding_theta", 10000.0),
107
+ positional_embedding_max_pos=config.get("positional_embedding_max_pos", [20, 2048, 2048]),
108
+ timestep_scale_multiplier=config.get("timestep_scale_multiplier", 1000),
109
+ use_middle_indices_grid=config.get("use_middle_indices_grid", True),
110
+ rope_type=LTXRopeType(config.get("rope_type", "interleaved")),
111
+ double_precision_rope=config.get("frequencies_precision", False) == "float64",
112
+ )
113
+
114
+
115
+ def _naive_weight_or_bias_downcast(key: str, value: torch.Tensor) -> list[KeyValueOperationResult]:
116
+ """
117
+ Downcast the weight or bias to the float8_e4m3fn dtype.
118
+ """
119
+ return [KeyValueOperationResult(key, value.to(dtype=torch.float8_e4m3fn))]
120
+
121
+
122
+ def _upcast_and_round(
123
+ weight: torch.Tensor, dtype: torch.dtype, with_stochastic_rounding: bool = False, seed: int = 0
124
+ ) -> torch.Tensor:
125
+ """
126
+ Upcast the weight to the given dtype and optionally apply stochastic rounding.
127
+ Input weight needs to have float8_e4m3fn or float8_e5m2 dtype.
128
+ """
129
+ if not with_stochastic_rounding:
130
+ return weight.to(dtype)
131
+ return fused_add_round_launch(torch.zeros_like(weight, dtype=dtype), weight, seed)
132
+
133
+
134
+ def replace_fwd_with_upcast(layer: torch.nn.Linear, with_stochastic_rounding: bool = False, seed: int = 0) -> None:
135
+ """
136
+ Replace linear.forward and rms_norm.forward with a version that:
137
+ - upcasts weight and bias to input's dtype
138
+ - returns F.linear or F.rms_norm calculated in that dtype
139
+ """
140
+
141
+ layer.original_forward = layer.forward
142
+
143
+ def new_linear_forward(*args, **_kwargs) -> torch.Tensor:
144
+ # assume first arg is the input tensor
145
+ x = args[0]
146
+ w_up = _upcast_and_round(layer.weight, x.dtype, with_stochastic_rounding, seed)
147
+ b_up = None
148
+
149
+ if layer.bias is not None:
150
+ b_up = _upcast_and_round(layer.bias, x.dtype, with_stochastic_rounding, seed)
151
+
152
+ return torch.nn.functional.linear(x, w_up, b_up)
153
+
154
+ layer.forward = new_linear_forward
155
+
156
+
157
+ def amend_forward_with_upcast(
158
+ model: torch.nn.Module, with_stochastic_rounding: bool = False, seed: int = 0
159
+ ) -> torch.nn.Module:
160
+ """
161
+ Replace the forward method of the model's Linear and RMSNorm layers to forward
162
+ with upcast and optional stochastic rounding.
163
+ """
164
+ for m in model.modules():
165
+ if isinstance(m, (torch.nn.Linear)):
166
+ replace_fwd_with_upcast(m, with_stochastic_rounding, seed)
167
+ return model
168
+
169
+
170
+ LTXV_MODEL_COMFY_RENAMING_MAP = (
171
+ SDOps("LTXV_MODEL_COMFY_PREFIX_MAP")
172
+ .with_matching(prefix="model.diffusion_model.")
173
+ .with_replacement("model.diffusion_model.", "")
174
+ )
175
+
176
+ LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP = (
177
+ SDOps("LTXV_MODEL_COMFY_PREFIX_MAP")
178
+ .with_matching(prefix="model.diffusion_model.")
179
+ .with_replacement("model.diffusion_model.", "")
180
+ .with_kv_operation(
181
+ key_prefix="transformer_blocks.", key_suffix=".to_q.weight", operation=_naive_weight_or_bias_downcast
182
+ )
183
+ .with_kv_operation(
184
+ key_prefix="transformer_blocks.", key_suffix=".to_q.bias", operation=_naive_weight_or_bias_downcast
185
+ )
186
+ .with_kv_operation(
187
+ key_prefix="transformer_blocks.", key_suffix=".to_k.weight", operation=_naive_weight_or_bias_downcast
188
+ )
189
+ .with_kv_operation(
190
+ key_prefix="transformer_blocks.", key_suffix=".to_k.bias", operation=_naive_weight_or_bias_downcast
191
+ )
192
+ .with_kv_operation(
193
+ key_prefix="transformer_blocks.", key_suffix=".to_v.weight", operation=_naive_weight_or_bias_downcast
194
+ )
195
+ .with_kv_operation(
196
+ key_prefix="transformer_blocks.", key_suffix=".to_v.bias", operation=_naive_weight_or_bias_downcast
197
+ )
198
+ .with_kv_operation(
199
+ key_prefix="transformer_blocks.", key_suffix=".to_out.0.weight", operation=_naive_weight_or_bias_downcast
200
+ )
201
+ .with_kv_operation(
202
+ key_prefix="transformer_blocks.", key_suffix=".to_out.0.bias", operation=_naive_weight_or_bias_downcast
203
+ )
204
+ .with_kv_operation(
205
+ key_prefix="transformer_blocks.", key_suffix=".ff.net.0.proj.weight", operation=_naive_weight_or_bias_downcast
206
+ )
207
+ .with_kv_operation(
208
+ key_prefix="transformer_blocks.", key_suffix=".ff.net.0.proj.bias", operation=_naive_weight_or_bias_downcast
209
+ )
210
+ .with_kv_operation(
211
+ key_prefix="transformer_blocks.", key_suffix=".ff.net.2.weight", operation=_naive_weight_or_bias_downcast
212
+ )
213
+ .with_kv_operation(
214
+ key_prefix="transformer_blocks.", key_suffix=".ff.net.2.bias", operation=_naive_weight_or_bias_downcast
215
+ )
216
+ )
217
+
218
+ UPCAST_DURING_INFERENCE = ModuleOps(
219
+ name="upcast_fp8_during_linear_forward",
220
+ matcher=lambda model: isinstance(model, LTXModel),
221
+ mutator=lambda model: amend_forward_with_upcast(model, False),
222
+ )
223
+
224
+
225
+ class UpcastWithStochasticRounding(ModuleOps):
226
+ """
227
+ ModuleOps for upcasting the model's float8_e4m3fn weights and biases to the bfloat16 dtype
228
+ and applying stochastic rounding during linear forward.
229
+ """
230
+
231
+ def __new__(cls, seed: int = 0):
232
+ return super().__new__(
233
+ cls,
234
+ name="upcast_fp8_during_linear_forward_with_stochastic_rounding",
235
+ matcher=lambda model: isinstance(model, LTXModel),
236
+ mutator=lambda model: amend_forward_with_upcast(model, True, seed),
237
+ )