prithivMLmods commited on
Commit
23791ce
·
verified ·
1 Parent(s): 40e111f

upload app [.files]

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +726 -0
  2. ltx2_two_stage.py +84 -0
  3. packages/ltx-core/README.md +280 -0
  4. packages/ltx-core/pyproject.toml +37 -0
  5. packages/ltx-core/src/ltx_core/__init__.py +0 -0
  6. packages/ltx-core/src/ltx_core/components/__init__.py +10 -0
  7. packages/ltx-core/src/ltx_core/components/diffusion_steps.py +22 -0
  8. packages/ltx-core/src/ltx_core/components/guiders.py +198 -0
  9. packages/ltx-core/src/ltx_core/components/noisers.py +35 -0
  10. packages/ltx-core/src/ltx_core/components/patchifiers.py +348 -0
  11. packages/ltx-core/src/ltx_core/components/protocols.py +101 -0
  12. packages/ltx-core/src/ltx_core/components/schedulers.py +129 -0
  13. packages/ltx-core/src/ltx_core/conditioning/__init__.py +12 -0
  14. packages/ltx-core/src/ltx_core/conditioning/exceptions.py +4 -0
  15. packages/ltx-core/src/ltx_core/conditioning/item.py +20 -0
  16. packages/ltx-core/src/ltx_core/conditioning/types/__init__.py +9 -0
  17. packages/ltx-core/src/ltx_core/conditioning/types/keyframe_cond.py +53 -0
  18. packages/ltx-core/src/ltx_core/conditioning/types/latent_cond.py +44 -0
  19. packages/ltx-core/src/ltx_core/guidance/__init__.py +15 -0
  20. packages/ltx-core/src/ltx_core/guidance/perturbations.py +79 -0
  21. packages/ltx-core/src/ltx_core/loader/__init__.py +48 -0
  22. packages/ltx-core/src/ltx_core/loader/fuse_loras.py +100 -0
  23. packages/ltx-core/src/ltx_core/loader/kernels.py +72 -0
  24. packages/ltx-core/src/ltx_core/loader/module_ops.py +14 -0
  25. packages/ltx-core/src/ltx_core/loader/primitives.py +109 -0
  26. packages/ltx-core/src/ltx_core/loader/registry.py +84 -0
  27. packages/ltx-core/src/ltx_core/loader/sd_ops.py +127 -0
  28. packages/ltx-core/src/ltx_core/loader/sft_loader.py +63 -0
  29. packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py +101 -0
  30. packages/ltx-core/src/ltx_core/model/__init__.py +8 -0
  31. packages/ltx-core/src/ltx_core/model/audio_vae/__init__.py +27 -0
  32. packages/ltx-core/src/ltx_core/model/audio_vae/attention.py +71 -0
  33. packages/ltx-core/src/ltx_core/model/audio_vae/audio_vae.py +480 -0
  34. packages/ltx-core/src/ltx_core/model/audio_vae/causal_conv_2d.py +110 -0
  35. packages/ltx-core/src/ltx_core/model/audio_vae/causality_axis.py +10 -0
  36. packages/ltx-core/src/ltx_core/model/audio_vae/downsample.py +110 -0
  37. packages/ltx-core/src/ltx_core/model/audio_vae/model_configurator.py +123 -0
  38. packages/ltx-core/src/ltx_core/model/audio_vae/ops.py +76 -0
  39. packages/ltx-core/src/ltx_core/model/audio_vae/resnet.py +176 -0
  40. packages/ltx-core/src/ltx_core/model/audio_vae/upsample.py +106 -0
  41. packages/ltx-core/src/ltx_core/model/audio_vae/vocoder.py +123 -0
  42. packages/ltx-core/src/ltx_core/model/common/__init__.py +9 -0
  43. packages/ltx-core/src/ltx_core/model/common/normalization.py +59 -0
  44. packages/ltx-core/src/ltx_core/model/model_protocol.py +10 -0
  45. packages/ltx-core/src/ltx_core/model/transformer/__init__.py +24 -0
  46. packages/ltx-core/src/ltx_core/model/transformer/adaln.py +34 -0
  47. packages/ltx-core/src/ltx_core/model/transformer/attention.py +185 -0
  48. packages/ltx-core/src/ltx_core/model/transformer/feed_forward.py +15 -0
  49. packages/ltx-core/src/ltx_core/model/transformer/gelu_approx.py +10 -0
  50. packages/ltx-core/src/ltx_core/model/transformer/modality.py +23 -0
app.py ADDED
@@ -0,0 +1,726 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ import uuid
4
+ import tempfile
5
+
6
+ # Add packages to Python path
7
+ current_dir = Path(__file__).parent
8
+ sys.path.insert(0, str(current_dir / "packages" / "ltx-pipelines" / "src"))
9
+ sys.path.insert(0, str(current_dir / "packages" / "ltx-core" / "src"))
10
+
11
+ import spaces
12
+ import flash_attn_interface
13
+ import time
14
+ import gradio as gr
15
+ import numpy as np
16
+ import random
17
+ import torch
18
+ from typing import Optional
19
+ from pathlib import Path
20
+ from huggingface_hub import hf_hub_download, snapshot_download
21
+ from ltx_pipelines.distilled import DistilledPipeline
22
+ from ltx_core.model.video_vae import TilingConfig
23
+ from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
24
+ from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
25
+ from ltx_pipelines.utils.constants import (
26
+ DEFAULT_SEED,
27
+ DEFAULT_1_STAGE_HEIGHT,
28
+ DEFAULT_1_STAGE_WIDTH ,
29
+ DEFAULT_NUM_FRAMES,
30
+ DEFAULT_FRAME_RATE,
31
+ DEFAULT_LORA_STRENGTH,
32
+ )
33
+
34
+
35
+ MAX_SEED = np.iinfo(np.int32).max
36
+ # Import from public LTX-2 package
37
+ # Install with: pip install git+https://github.com/Lightricks/LTX-2.git
38
+ from ltx_pipelines.utils import ModelLedger
39
+ from ltx_pipelines.utils.helpers import generate_enhanced_prompt
40
+
41
+ # HuggingFace Hub defaults
42
+ DEFAULT_REPO_ID = "Lightricks/LTX-2"
43
+ DEFAULT_GEMMA_REPO_ID = "unsloth/gemma-3-12b-it-qat-bnb-4bit"
44
+ DEFAULT_CHECKPOINT_FILENAME = "ltx-2-19b-dev.safetensors"
45
+
46
+
47
+ def get_hub_or_local_checkpoint(repo_id: str, filename: str):
48
+ """Download from HuggingFace Hub."""
49
+ print(f"Downloading {filename} from {repo_id}...")
50
+ ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
51
+ print(f"Downloaded to {ckpt_path}")
52
+ return ckpt_path
53
+
54
+ def download_gemma_model(repo_id: str):
55
+ """Download the full Gemma model directory."""
56
+ print(f"Downloading Gemma model from {repo_id}...")
57
+ local_dir = snapshot_download(repo_id=repo_id)
58
+ print(f"Gemma model downloaded to {local_dir}")
59
+ return local_dir
60
+
61
+ # Initialize model ledger and text encoder at startup (load once, keep in memory)
62
+ print("=" * 80)
63
+ print("Loading Gemma Text Encoder...")
64
+ print("=" * 80)
65
+
66
+ checkpoint_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_CHECKPOINT_FILENAME)
67
+ gemma_local_path = download_gemma_model(DEFAULT_GEMMA_REPO_ID)
68
+ device = "cuda"
69
+
70
+ print(f"Initializing text encoder with:")
71
+ print(f" checkpoint_path={checkpoint_path}")
72
+ print(f" gemma_root={gemma_local_path}")
73
+ print(f" device={device}")
74
+
75
+
76
+ model_ledger = ModelLedger(
77
+ dtype=torch.bfloat16,
78
+ device=device,
79
+ checkpoint_path=checkpoint_path,
80
+ gemma_root_path=DEFAULT_GEMMA_REPO_ID,
81
+ local_files_only=False
82
+ )
83
+
84
+
85
+ # Load text encoder once and keep it in memory
86
+ text_encoder = model_ledger.text_encoder()
87
+
88
+ print("=" * 80)
89
+ print("Text encoder loaded and ready!")
90
+ print("=" * 80)
91
+
92
+ def encode_text_simple(text_encoder, prompt: str):
93
+ """Simple text encoding without using pipeline_utils."""
94
+ v_context, a_context, _ = text_encoder(prompt)
95
+ return v_context, a_context
96
+
97
+ @spaces.GPU()
98
+ def encode_prompt(
99
+ prompt: str,
100
+ enhance_prompt: bool = True,
101
+ input_image=None, # this is now filepath (string) or None
102
+ seed: int = 42,
103
+ negative_prompt: str = ""
104
+ ):
105
+ start_time = time.time()
106
+ try:
107
+ final_prompt = prompt
108
+ if enhance_prompt:
109
+ final_prompt = generate_enhanced_prompt(
110
+ text_encoder=text_encoder,
111
+ prompt=prompt,
112
+ image_path=input_image if input_image is not None else None,
113
+ seed=seed,
114
+ )
115
+
116
+ with torch.inference_mode():
117
+ video_context, audio_context = encode_text_simple(text_encoder, final_prompt)
118
+
119
+ video_context_negative = None
120
+ audio_context_negative = None
121
+ if negative_prompt:
122
+ video_context_negative, audio_context_negative = encode_text_simple(text_encoder, negative_prompt)
123
+
124
+ # IMPORTANT: return tensors directly (no torch.save)
125
+ embedding_data = {
126
+ "video_context": video_context.detach().cpu(),
127
+ "audio_context": audio_context.detach().cpu(),
128
+ "prompt": final_prompt,
129
+ "original_prompt": prompt,
130
+ }
131
+ if video_context_negative is not None:
132
+ embedding_data["video_context_negative"] = video_context_negative
133
+ embedding_data["audio_context_negative"] = audio_context_negative
134
+ embedding_data["negative_prompt"] = negative_prompt
135
+
136
+ elapsed_time = time.time() - start_time
137
+ if torch.cuda.is_available():
138
+ allocated = torch.cuda.memory_allocated() / 1024**3
139
+ peak = torch.cuda.max_memory_allocated() / 1024**3
140
+ status = f"✓ Encoded in {elapsed_time:.2f}s | VRAM: {allocated:.2f}GB allocated, {peak:.2f}GB peak"
141
+ else:
142
+ status = f"✓ Encoded in {elapsed_time:.2f}s (CPU mode)"
143
+
144
+ return embedding_data, final_prompt, status
145
+
146
+ except Exception as e:
147
+ import traceback
148
+ error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
149
+ print(error_msg)
150
+ return None, prompt, error_msg
151
+
152
+
153
+ # Default prompt from docstring example
154
+ DEFAULT_PROMPT = "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a breath-taking, movie-like shot."
155
+
156
+ # HuggingFace Hub defaults
157
+ DEFAULT_REPO_ID = "Lightricks/LTX-2"
158
+ DEFAULT_CHECKPOINT_FILENAME = "ltx-2-19b-dev.safetensors"
159
+ DEFAULT_DISTILLED_LORA_FILENAME = "ltx-2-19b-distilled-lora-384.safetensors"
160
+ DEFAULT_SPATIAL_UPSAMPLER_FILENAME = "ltx-2-spatial-upscaler-x2-1.0.safetensors"
161
+
162
+ def get_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None):
163
+ """Download from HuggingFace Hub or use local checkpoint."""
164
+ if repo_id is None and filename is None:
165
+ raise ValueError("Please supply at least one of `repo_id` or `filename`")
166
+
167
+ if repo_id is not None:
168
+ if filename is None:
169
+ raise ValueError("If repo_id is specified, filename must also be specified.")
170
+ print(f"Downloading {filename} from {repo_id}...")
171
+ ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
172
+ print(f"Downloaded to {ckpt_path}")
173
+ else:
174
+ ckpt_path = filename
175
+
176
+ return ckpt_path
177
+
178
+
179
+ # Initialize pipeline at startup
180
+ print("=" * 80)
181
+ print("Loading LTX-2 Distilled pipeline...")
182
+ print("=" * 80)
183
+
184
+ checkpoint_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_CHECKPOINT_FILENAME)
185
+ distilled_lora_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_DISTILLED_LORA_FILENAME)
186
+ spatial_upsampler_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_SPATIAL_UPSAMPLER_FILENAME)
187
+
188
+ print(f"Initializing pipeline with:")
189
+ print(f" checkpoint_path={checkpoint_path}")
190
+ print(f" distilled_lora_path={distilled_lora_path}")
191
+ print(f" spatial_upsampler_path={spatial_upsampler_path}")
192
+
193
+
194
+ # Load distilled LoRA as a regular LoRA
195
+ loras = [
196
+ LoraPathStrengthAndSDOps(
197
+ path=distilled_lora_path,
198
+ strength=DEFAULT_LORA_STRENGTH,
199
+ sd_ops=LTXV_LORA_COMFY_RENAMING_MAP,
200
+ )
201
+ ]
202
+
203
+ # Initialize pipeline WITHOUT text encoder (gemma_root=None)
204
+ # Text encoding will be done by external space
205
+ pipeline = DistilledPipeline(
206
+ device=torch.device("cuda"),
207
+ checkpoint_path=checkpoint_path,
208
+ spatial_upsampler_path=spatial_upsampler_path,
209
+ gemma_root=None, # No text encoder in this space
210
+ loras=loras,
211
+ fp8transformer=False,
212
+ local_files_only=False,
213
+ )
214
+
215
+ pipeline._video_encoder = pipeline.model_ledger.video_encoder()
216
+ pipeline._transformer = pipeline.model_ledger.transformer()
217
+ # pipeline.device = torch.device("cuda")
218
+ # pipeline.model_ledger.device = torch.device("cuda")
219
+
220
+
221
+ print("=" * 80)
222
+ print("Pipeline fully loaded and ready!")
223
+ print("=" * 80)
224
+
225
+ def get_duration(
226
+ input_image,
227
+ prompt,
228
+ duration,
229
+ enhance_prompt,
230
+ seed,
231
+ randomize_seed,
232
+ height,
233
+ width,
234
+ progress
235
+ ):
236
+ if duration <= 5:
237
+ return 80
238
+ else:
239
+ return 120
240
+
241
+ class RadioAnimated(gr.HTML):
242
+ """
243
+ Animated segmented radio (like iOS pill selector).
244
+ Outputs: selected option string, e.g. "768x512"
245
+ """
246
+ def __init__(self, choices, value=None, **kwargs):
247
+ if not choices or len(choices) < 2:
248
+ raise ValueError("RadioAnimated requires at least 2 choices.")
249
+ if value is None:
250
+ value = choices[0]
251
+
252
+ uid = uuid.uuid4().hex[:8] # unique per instance
253
+ group_name = f"ra-{uid}"
254
+
255
+ inputs_html = "\n".join(
256
+ f"""
257
+ <input class="ra-input" type="radio" name="{group_name}" id="{group_name}-{i}" value="{c}">
258
+ <label class="ra-label" for="{group_name}-{i}">{c}</label>
259
+ """
260
+ for i, c in enumerate(choices)
261
+ )
262
+
263
+ # NOTE: use classes instead of duplicate IDs
264
+ html_template = f"""
265
+ <div class="ra-wrap" data-ra="{uid}">
266
+ <div class="ra-inner">
267
+ <div class="ra-highlight"></div>
268
+ {inputs_html}
269
+ </div>
270
+ </div>
271
+ """
272
+
273
+ js_on_load = r"""
274
+ (() => {
275
+ const wrap = element.querySelector('.ra-wrap');
276
+ const inner = element.querySelector('.ra-inner');
277
+ const highlight = element.querySelector('.ra-highlight');
278
+ const inputs = Array.from(element.querySelectorAll('.ra-input'));
279
+
280
+ if (!inputs.length) return;
281
+
282
+ const choices = inputs.map(i => i.value);
283
+
284
+ function setHighlightByIndex(idx) {
285
+ const n = choices.length;
286
+ const pct = 100 / n;
287
+ highlight.style.width = `calc(${pct}% - 6px)`;
288
+ highlight.style.transform = `translateX(${idx * 100}%)`;
289
+ }
290
+
291
+ function setCheckedByValue(val, shouldTrigger=false) {
292
+ const idx = Math.max(0, choices.indexOf(val));
293
+ inputs.forEach((inp, i) => { inp.checked = (i === idx); });
294
+ setHighlightByIndex(idx);
295
+
296
+ props.value = choices[idx];
297
+ if (shouldTrigger) trigger('change', props.value);
298
+ }
299
+
300
+ // Init from props.value
301
+ setCheckedByValue(props.value ?? choices[0], false);
302
+
303
+ // Input handlers
304
+ inputs.forEach((inp) => {
305
+ inp.addEventListener('change', () => {
306
+ setCheckedByValue(inp.value, true);
307
+ });
308
+ });
309
+ })();
310
+ """
311
+
312
+ super().__init__(
313
+ value=value,
314
+ html_template=html_template,
315
+ js_on_load=js_on_load,
316
+ **kwargs
317
+ )
318
+
319
+ def generate_video_example(input_image, prompt, duration, progress=gr.Progress(track_tqdm=True)):
320
+ output_video, seed = generate_video(input_image, prompt, 5, True, 42, True, DEFAULT_1_STAGE_HEIGHT, DEFAULT_1_STAGE_WIDTH, progress)
321
+
322
+ return output_video
323
+
324
+ @spaces.GPU(duration=get_duration)
325
+ def generate_video(
326
+ input_image,
327
+ prompt: str,
328
+ duration: float,
329
+ enhance_prompt: bool = True,
330
+ seed: int = 42,
331
+ randomize_seed: bool = True,
332
+ height: int = DEFAULT_1_STAGE_HEIGHT,
333
+ width: int = DEFAULT_1_STAGE_WIDTH,
334
+ progress=gr.Progress(track_tqdm=True),
335
+ ):
336
+ """
337
+ Generate a short cinematic video from a text prompt and optional input image using the LTX-2 distilled pipeline.
338
+ Args:
339
+ input_image: Optional input image for image-to-video. If provided, it is injected at frame 0 to guide motion.
340
+ prompt: Text description of the scene, motion, and cinematic style to generate.
341
+ duration: Desired video length in seconds. Converted to frames using a fixed 24 FPS rate.
342
+ enhance_prompt: Whether to enhance the prompt using the prompt enhancer before encoding.
343
+ seed: Base random seed for reproducibility (ignored if randomize_seed is True).
344
+ randomize_seed: If True, a random seed is generated for each run.
345
+ height: Output video height in pixels.
346
+ width: Output video width in pixels.
347
+ progress: Gradio progress tracker.
348
+ Returns:
349
+ A tuple of:
350
+ - output_path: Path to the generated MP4 video file.
351
+ - seed: The seed used for generation.
352
+ Notes:
353
+ - Uses a fixed frame rate of 24 FPS.
354
+ - Prompt embeddings are generated externally to avoid reloading the text encoder.
355
+ - GPU cache is cleared after generation to reduce VRAM pressure.
356
+ - If an input image is provided, it is temporarily saved to disk for processing.
357
+ """
358
+ try:
359
+ # Randomize seed if checkbox is enabled
360
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
361
+
362
+ # Calculate num_frames from duration (using fixed 24 fps)
363
+ frame_rate = 24.0
364
+ num_frames = int(duration * frame_rate) + 1 # +1 to ensure we meet the duration
365
+
366
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
367
+ output_path = tmpfile.name
368
+
369
+ # Handle image input
370
+ images = []
371
+ temp_image_path = None # Initialize to None
372
+
373
+ images = []
374
+ if input_image is not None:
375
+ images = [(input_image, 0, 1.0)] # input_image is already a path
376
+
377
+ # Prepare image for upload if it exists
378
+ image_input = None
379
+
380
+
381
+ embeddings, final_prompt, status = encode_prompt(
382
+ prompt=prompt,
383
+ enhance_prompt=enhance_prompt,
384
+ input_image=input_image,
385
+ seed=current_seed,
386
+ negative_prompt="",
387
+ )
388
+
389
+ video_context = embeddings["video_context"].to("cuda", non_blocking=True)
390
+ audio_context = embeddings["audio_context"].to("cuda", non_blocking=True)
391
+ print("✓ Embeddings loaded successfully")
392
+
393
+ # free prompt enhancer / encoder temps ASAP
394
+ del embeddings, final_prompt, status
395
+ torch.cuda.empty_cache()
396
+
397
+ # Run inference - progress automatically tracks tqdm from pipeline
398
+ pipeline(
399
+ prompt=prompt,
400
+ output_path=str(output_path),
401
+ seed=current_seed,
402
+ height=height,
403
+ width=width,
404
+ num_frames=num_frames,
405
+ frame_rate=frame_rate,
406
+ images=images,
407
+ tiling_config=TilingConfig.default(),
408
+ video_context=video_context,
409
+ audio_context=audio_context,
410
+ )
411
+ del video_context, audio_context
412
+ torch.cuda.empty_cache()
413
+ print("successful generation")
414
+
415
+ return str(output_path), current_seed
416
+
417
+ except Exception as e:
418
+ import traceback
419
+ error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
420
+ print(error_msg)
421
+ return None, current_seed
422
+
423
+
424
+ def apply_resolution(resolution: str):
425
+ w, h = resolution.split("x")
426
+ return int(w), int(h)
427
+
428
+ def apply_duration(duration: str):
429
+ duration_s = int(duration[:-1])
430
+ return duration_s
431
+
432
+ css = """
433
+ #col-container {
434
+ margin: 0 auto;
435
+ max-width: 1600px;
436
+ }
437
+ #modal-container {
438
+ width: 100vw; /* Take full viewport width */
439
+ height: 100vh; /* Take full viewport height (optional) */
440
+ display: flex;
441
+ justify-content: center; /* Center content horizontally */
442
+ align-items: center; /* Center content vertically if desired */
443
+ }
444
+ #modal-content {
445
+ width: 100%;
446
+ max-width: 700px; /* Limit content width */
447
+ margin: 0 auto;
448
+ border-radius: 8px;
449
+ padding: 1.5rem;
450
+ }
451
+ #step-column {
452
+ padding: 10px;
453
+ border-radius: 8px;
454
+ box-shadow: var(--card-shadow);
455
+ margin: 10px;
456
+ }
457
+ #col-showcase {
458
+ margin: 0 auto;
459
+ max-width: 1100px;
460
+ }
461
+ .button-gradient {
462
+ background: linear-gradient(45deg, rgb(255, 65, 108), rgb(255, 75, 43), rgb(255, 155, 0), rgb(255, 65, 108)) 0% 0% / 400% 400%;
463
+ border: none;
464
+ padding: 14px 28px;
465
+ font-size: 16px;
466
+ font-weight: bold;
467
+ color: white;
468
+ border-radius: 10px;
469
+ cursor: pointer;
470
+ transition: 0.3s ease-in-out;
471
+ animation: 2s linear 0s infinite normal none running gradientAnimation;
472
+ box-shadow: rgba(255, 65, 108, 0.6) 0px 4px 10px;
473
+ }
474
+ .toggle-container {
475
+ display: inline-flex;
476
+ background-color: #ffd6ff; /* light pink background */
477
+ border-radius: 9999px;
478
+ padding: 4px;
479
+ position: relative;
480
+ width: fit-content;
481
+ font-family: sans-serif;
482
+ }
483
+ .toggle-container input[type="radio"] {
484
+ display: none;
485
+ }
486
+ .toggle-container label {
487
+ position: relative;
488
+ z-index: 2;
489
+ flex: 1;
490
+ text-align: center;
491
+ font-weight: 700;
492
+ color: #4b2ab5; /* dark purple text for unselected */
493
+ padding: 6px 22px;
494
+ border-radius: 9999px;
495
+ cursor: pointer;
496
+ transition: color 0.25s ease;
497
+ }
498
+ /* Moving highlight */
499
+ .toggle-highlight {
500
+ position: absolute;
501
+ top: 4px;
502
+ left: 4px;
503
+ width: calc(50% - 4px);
504
+ height: calc(100% - 8px);
505
+ background-color: #4b2ab5; /* dark purple background */
506
+ border-radius: 9999px;
507
+ transition: transform 0.25s ease;
508
+ z-index: 1;
509
+ }
510
+ /* When "True" is checked */
511
+ #true:checked ~ label[for="true"] {
512
+ color: #ffd6ff; /* light pink text */
513
+ }
514
+ /* When "False" is checked */
515
+ #false:checked ~ label[for="false"] {
516
+ color: #ffd6ff; /* light pink text */
517
+ }
518
+ /* Move highlight to right side when False is checked */
519
+ #false:checked ~ .toggle-highlight {
520
+ transform: translateX(100%);
521
+ }
522
+ """
523
+
524
+ css += """
525
+ /* ---- radioanimated ---- */
526
+ .ra-wrap{
527
+ width: fit-content;
528
+ }
529
+ .ra-inner{
530
+ position: relative;
531
+ display: inline-flex;
532
+ align-items: center;
533
+ gap: 0;
534
+ padding: 6px;
535
+ background: #0b0b0b;
536
+ border-radius: 9999px;
537
+ overflow: hidden;
538
+ user-select: none;
539
+ }
540
+ .ra-input{
541
+ display: none;
542
+ }
543
+ .ra-label{
544
+ position: relative;
545
+ z-index: 2;
546
+ padding: 10px 18px;
547
+ font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Arial;
548
+ font-size: 14px;
549
+ font-weight: 600;
550
+ color: rgba(255,255,255,0.7);
551
+ cursor: pointer;
552
+ transition: color 180ms ease;
553
+ white-space: nowrap;
554
+ }
555
+ .ra-highlight{
556
+ position: absolute;
557
+ z-index: 1;
558
+ top: 6px;
559
+ left: 6px;
560
+ height: calc(100% - 12px);
561
+ border-radius: 9999px;
562
+ background: #8bff97; /* green knob */
563
+ transition: transform 200ms ease, width 200ms ease;
564
+ }
565
+ /* selected label becomes darker like your screenshot */
566
+ .ra-input:checked + .ra-label{
567
+ color: rgba(0,0,0,0.75);
568
+ }
569
+ """
570
+
571
+
572
+ with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
573
+ gr.HTML(
574
+ """
575
+ <div style="text-align: center;">
576
+ <p style="font-size:16px; display: inline; margin: 0;">
577
+ <strong>LTX-2 Distilled</strong> DiT-based audio-video foundation model
578
+ </p>
579
+ <a href="https://huggingface.co/Lightricks/LTX-2" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
580
+ [model]
581
+ </a>
582
+ </div>
583
+ <div style="text-align: center;">
584
+ <p style="font-size:16px; display: inline; margin: 0;">
585
+ Using FA3 and Gemma 3 12B 4bit Quantisation for Faster Inference
586
+ </p>
587
+ </div>
588
+ <div style="text-align: center;">
589
+ <strong>HF Space by:</strong>
590
+ <a href="https://huggingface.co/alexnasa" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
591
+ <img src="https://img.shields.io/badge/🤗-Follow Me-green.svg">
592
+ </a>
593
+ </div>
594
+ """
595
+ )
596
+ with gr.Column(elem_id="col-container"):
597
+ with gr.Row():
598
+ with gr.Column(elem_id="step-column"):
599
+
600
+ input_image = gr.Image(
601
+ label="Input Image (Optional)",
602
+ type="filepath", # <-- was "pil"
603
+ height=512
604
+ )
605
+
606
+ prompt = gr.Textbox(
607
+ label="Prompt",
608
+ value="Make this image come alive with cinematic motion, smooth animation",
609
+ lines=3,
610
+ max_lines=3,
611
+ placeholder="Describe the motion and animation you want..."
612
+ )
613
+
614
+ enhance_prompt = gr.Checkbox(
615
+ label="Enhance Prompt",
616
+ value=True,
617
+ visible=False
618
+ )
619
+
620
+ with gr.Accordion("Advanced Settings", open=False, visible=False):
621
+ seed = gr.Slider(
622
+ label="Seed",
623
+ minimum=0,
624
+ maximum=MAX_SEED,
625
+ value=DEFAULT_SEED,
626
+ step=1
627
+ )
628
+
629
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
630
+
631
+
632
+ with gr.Column(elem_id="step-column"):
633
+ output_video = gr.Video(label="Generated Video", autoplay=True, height=512)
634
+
635
+ with gr.Row():
636
+
637
+ with gr.Column():
638
+ radioanimated_duration = RadioAnimated(
639
+ choices=["3s", "5s", "10s"],
640
+ value="3s",
641
+ elem_id="radioanimated_duration"
642
+ )
643
+
644
+ duration = gr.Slider(
645
+ label="Duration (seconds)",
646
+ minimum=1.0,
647
+ maximum=10.0,
648
+ value=3.0,
649
+ step=0.1,
650
+ visible=False
651
+ )
652
+
653
+ with gr.Column():
654
+ radioanimated_resolution = RadioAnimated(
655
+ choices=["768x512", "512x512", "512x768"],
656
+ value=f"{DEFAULT_1_STAGE_WIDTH}x{DEFAULT_1_STAGE_HEIGHT}",
657
+ elem_id="radioanimated_resolution"
658
+ )
659
+
660
+ width = gr.Number(label="Width", value=DEFAULT_1_STAGE_WIDTH, precision=0, visible=False)
661
+ height = gr.Number(label="Height", value=DEFAULT_1_STAGE_HEIGHT, precision=0, visible=False)
662
+
663
+
664
+ generate_btn = gr.Button("🤩 Generate Video", variant="primary", elem_classes="button-gradient")
665
+
666
+
667
+ radioanimated_duration.change(
668
+ fn=apply_duration,
669
+ inputs=radioanimated_duration,
670
+ outputs=[duration],
671
+ api_visibility="private"
672
+ )
673
+ radioanimated_resolution.change(
674
+ fn=apply_resolution,
675
+ inputs=radioanimated_resolution,
676
+ outputs=[width, height],
677
+ api_visibility="private"
678
+ )
679
+
680
+ generate_btn.click(
681
+ fn=generate_video,
682
+ inputs=[
683
+ input_image,
684
+ prompt,
685
+ duration,
686
+ enhance_prompt,
687
+ seed,
688
+ randomize_seed,
689
+ height,
690
+ width,
691
+ ],
692
+ outputs=[output_video,seed]
693
+ )
694
+
695
+ # Add example
696
+ gr.Examples(
697
+ examples=[
698
+ [
699
+ "supergirl.png",
700
+ "A fuzzy puppet superhero character resembling a female puppet with blonde hair and a blue superhero suit stands inside an icy cave made of frozen walls and icicles, she looks panicked and frantic, rapidly turning her head left and right and scanning the cave while waving her arms and shouting angrily and desperately, mouthing the words “where the hell is my dog,” her movements exaggerated and puppet-like with high energy and urgency, suddenly a second puppet dog bursts into frame from the side, jumping up excitedly and tackling her affectionately while licking her face repeatedly, she freezes in surprise and then breaks into relief and laughter as the dog continues licking her, the scene feels chaotic, comedic, and emotional with expressive puppet reactions, cinematic lighting, smooth camera motion, shallow depth of field, and high-quality puppet-style animation"
701
+ ],
702
+ [
703
+ "highland.png",
704
+ "Realistic POV selfie-style video in a snowy, foggy field. Two shaggy Highland cows with long curved horns stand ahead. The camera is handheld and slightly shaky. The woman filming talks nervously and excitedly in a vlog tone: \"Oh my god guys… look how big those horns are… I’m kinda scared.\" The cow on the left walks toward the camera in a cute, bouncy, hopping way, curious and gentle. Snow crunches under its hooves, breath visible in the cold air. The horns look massive from the POV. As the cow gets very close, its wet nose with slight dripping fills part of the frame. She laughs nervously but reaches out and pets the cow. The cow makes deep, soft, interesting mooing and snorting sounds, calm and friendly. Ultra-realistic, natural lighting, immersive audio, documentary-style realism.",
705
+ ],
706
+ [
707
+ "wednesday.png",
708
+ "A cinematic close-up of Wednesday Addams frozen mid-dance on a dark, blue-lit ballroom floor as students move indistinctly behind her, their footsteps and muffled music reduced to a distant, underwater thrum; the audio foregrounds her steady breathing and the faint rustle of fabric as she slowly raises one arm, never breaking eye contact with the camera, then after a deliberately long silence she speaks in a flat, dry, perfectly controlled voice, “I don’t dance… I vibe code,” each word crisp and unemotional, followed by an abrupt cutoff of her voice as the background sound swells slightly, reinforcing the deadpan humor, with precise lip sync, minimal facial movement, stark gothic lighting, and cinematic realism.",
709
+ ],
710
+ [
711
+ "astronaut.png",
712
+ "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a breath-taking, movie-like shot.",
713
+ ]
714
+
715
+ ],
716
+ fn=generate_video_example,
717
+ inputs=[input_image, prompt],
718
+ outputs = [output_video],
719
+ label="Example",
720
+ cache_examples=True,
721
+ )
722
+
723
+
724
+
725
+ if __name__ == "__main__":
726
+ demo.launch(ssr_mode=False, mcp_server=True, css=css)
ltx2_two_stage.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ python ltx2_two_stage.py \
3
+ --image "astronaut.jpg" 0 1.0 \
4
+ --prompt="An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a breath-taking, movie-like shot." \
5
+ --output_path="t2v_2.mp4" \
6
+ --gemma_root="google/gemma-3-12b-it-qat-q4_0-unquantized" \
7
+ --checkpoint_path="rc1/ltx-2-19b-dev-rc1.safetensors" \
8
+ --distilled_lora_path "rc1/ltx-2-19b-distilled-lora-384-rc1.safetensors" \
9
+ --spatial_upsampler_path "rc1/ltx-2-spatial-upscaler-x2-1.0-rc1.safetensors"
10
+
11
+ """
12
+
13
+ from huggingface_hub import hf_hub_download
14
+ from typing import Optional
15
+ from ltx_pipelines import utils
16
+ from ltx_pipelines.constants import DEFAULT_LORA_STRENGTH
17
+ from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
18
+ from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
19
+ from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline
20
+ from ltx_core.tiling import TilingConfig
21
+
22
+
23
+ def get_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None):
24
+ if repo_id is None and filename is None:
25
+ raise ValueError("Please supply at least one of `repo_id` or `filename`")
26
+
27
+ if repo_id is not None:
28
+ if filename is None:
29
+ raise ValueError("If repo_id is specified, filename must also be specified.")
30
+ ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
31
+ else:
32
+ ckpt_path = filename
33
+
34
+ return ckpt_path
35
+
36
+
37
+ def default_2_stage_arg_parser_mod():
38
+ parser = utils.default_2_stage_arg_parser()
39
+ parser.add_argument("--local_files_only", action="store_true")
40
+ parser.add_argument("--checkpoint_id", type=str, default="diffusers-internal-dev/new-ltx-model")
41
+ return parser
42
+
43
+
44
+ def main() -> None:
45
+ parser = default_2_stage_arg_parser_mod()
46
+ args = parser.parse_args()
47
+
48
+ checkpoint_path = get_hub_or_local_checkpoint(args.checkpoint_id, args.checkpoint_path)
49
+ distilled_lora_path = get_hub_or_local_checkpoint(args.checkpoint_id, args.distilled_lora_path)
50
+ spatial_upsampler_path = get_hub_or_local_checkpoint(args.checkpoint_id, args.spatial_upsampler_path)
51
+
52
+ lora_strengths = (args.lora_strength + [DEFAULT_LORA_STRENGTH] * len(args.lora))[: len(args.lora)]
53
+ loras = [
54
+ LoraPathStrengthAndSDOps(lora, strength, LTXV_LORA_COMFY_RENAMING_MAP)
55
+ for lora, strength in zip(args.lora, lora_strengths, strict=True)
56
+ ]
57
+ pipeline = TI2VidTwoStagesPipeline(
58
+ checkpoint_path=checkpoint_path,
59
+ distilled_lora_path=distilled_lora_path,
60
+ distilled_lora_strength=args.distilled_lora_strength,
61
+ spatial_upsampler_path=spatial_upsampler_path,
62
+ gemma_root=args.gemma_root,
63
+ loras=loras,
64
+ fp8transformer=args.enable_fp8,
65
+ local_files_only=args.local_files_only
66
+ )
67
+ pipeline(
68
+ prompt=args.prompt,
69
+ negative_prompt=args.negative_prompt,
70
+ output_path=args.output_path,
71
+ seed=args.seed,
72
+ height=args.height,
73
+ width=args.width,
74
+ num_frames=args.num_frames,
75
+ frame_rate=args.frame_rate,
76
+ num_inference_steps=args.num_inference_steps,
77
+ cfg_guidance_scale=args.cfg_guidance_scale,
78
+ images=args.images,
79
+ tiling_config=TilingConfig.default(),
80
+ )
81
+
82
+
83
+ if __name__ == "__main__":
84
+ main()
packages/ltx-core/README.md ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LTX-Core
2
+
3
+ The foundational library for the LTX-2 Audio-Video generation model. This package contains the raw model definitions, component implementations, and loading logic used by `ltx-pipelines` and `ltx-trainer`.
4
+
5
+ ## 📦 What's Inside?
6
+
7
+ - **`components/`**: Modular diffusion components (Schedulers, Guiders, Noisers, Patchifiers) following standard protocols
8
+ - **`conditioning/`**: Tools for preparing latent states and applying conditioning (image, video, keyframes)
9
+ - **`guidance/`**: Perturbation system for fine-grained control over attention mechanisms
10
+ - **`loader/`**: Utilities for loading weights from `.safetensors`, fusing LoRAs, and managing memory
11
+ - **`model/`**: PyTorch implementations of the LTX-2 Transformer, Video VAE, Audio VAE, Vocoder and Upscaler
12
+ - **`text_encoders/gemma`**: Gemma text encoder implementation with tokenizers, feature extractors, and separate encoders for audio-video and video-only generation
13
+
14
+ ## 🚀 Quick Start
15
+
16
+ `ltx-core` provides the building blocks (models, components, and utilities) needed to construct inference flows. For ready-made inference pipelines use [`ltx-pipelines`](../ltx-pipelines/) or [`ltx-trainer`](../ltx-trainer/) for training.
17
+
18
+ ## 🔧 Installation
19
+
20
+ ```bash
21
+ # From the repository root
22
+ uv sync --frozen
23
+
24
+ # Or install as a package
25
+ pip install -e packages/ltx-core
26
+ ```
27
+
28
+ ## Building Blocks Overview
29
+
30
+ `ltx-core` provides modular components that can be combined to build custom inference flows:
31
+
32
+ ### Core Models
33
+
34
+ - **Transformer** ([`model/transformer/`](src/ltx_core/model/transformer/)): The 48-layer LTX-2 transformer with cross-modal attention for joint audio-video processing. Expects inputs in [`Modality`](src/ltx_core/model/transformer/modality.py) format
35
+ - **Video VAE** ([`model/video_vae/`](src/ltx_core/model/video_vae/)): Encodes/decodes video pixels to/from latent space with temporal and spatial compression
36
+ - **Audio VAE** ([`model/audio_vae/`](src/ltx_core/model/audio_vae/)): Encodes/decodes audio spectrograms to/from latent space
37
+ - **Vocoder** ([`model/audio_vae/`](src/ltx_core/model/audio_vae/)): Neural vocoder that converts mel spectrograms to audio waveforms
38
+ - **Text Encoder** ([`text_encoders/`](src/ltx_core/text_encoders/)): Gemma-based encoder that produces separate embeddings for video and audio conditioning
39
+ - **Spatial Upscaler** ([`model/upsampler/`](src/ltx_core/model/upsampler/)): Upsamples latent representations for higher-resolution generation
40
+
41
+ ### Diffusion Components
42
+
43
+ - **Schedulers** ([`components/schedulers.py`](src/ltx_core/components/schedulers.py)): Noise schedules (LTX2Scheduler, LinearQuadratic, Beta) that control the denoising process
44
+ - **Guiders** ([`components/guiders.py`](src/ltx_core/components/guiders.py)): Guidance strategies (CFG, STG, APG) for controlling generation quality and adherence to prompts
45
+ - **Noisers** ([`components/noisers.py`](src/ltx_core/components/noisers.py)): Add noise to latents according to the diffusion schedule
46
+ - **Patchifiers** ([`components/patchifiers.py`](src/ltx_core/components/patchifiers.py)): Convert between spatial latents `[B, C, F, H, W]` and sequence format `[B, seq_len, dim]` for transformer processing
47
+
48
+ ### Conditioning & Control
49
+
50
+ - **Conditioning** ([`conditioning/`](src/ltx_core/conditioning/)): Tools for preparing and applying various conditioning types (image, video, keyframes)
51
+ - **Guidance** ([`guidance/`](src/ltx_core/guidance/)): Perturbation system for fine-grained control over attention mechanisms (e.g., skipping specific attention layers)
52
+
53
+ ### Utilities
54
+
55
+ - **Loader** ([`loader/`](src/ltx_core/loader/)): Model loading from `.safetensors`, LoRA fusion, weight remapping, and memory management
56
+
57
+ For complete, production-ready pipeline implementations that combine these building blocks, see the [`ltx-pipelines`](../ltx-pipelines/) package.
58
+
59
+ ---
60
+
61
+ # Architecture Overview
62
+
63
+ This section provides a deep dive into the internal architecture of the LTX-2 Audio-Video generation model.
64
+
65
+ ## Table of Contents
66
+
67
+ 1. [High-Level Architecture](#high-level-architecture)
68
+ 2. [The Transformer](#the-transformer)
69
+ 3. [Video VAE](#video-vae)
70
+ 4. [Audio VAE](#audio-vae)
71
+ 5. [Text Encoding (Gemma)](#text-encoding-gemma)
72
+ 6. [Spatial Upscaler](#spatial-upsampler)
73
+ 7. [Data Flow](#data-flow)
74
+
75
+ ---
76
+
77
+ ## High-Level Architecture
78
+
79
+ LTX-2 is a **joint Audio-Video diffusion transformer** that processes both modalities simultaneously in a unified architecture. Unlike traditional models that handle video and audio separately, LTX-2 uses cross-modal attention to enable natural synchronization.
80
+
81
+ ```text
82
+ ┌─────────────────────────────────────────────────────────────┐
83
+ │ INPUT PREPARATION │
84
+ │ │
85
+ │ Video Pixels → Video VAE Encoder → Video Latents │
86
+ │ Audio Waveform → Audio VAE Encoder → Audio Latents │
87
+ │ Text Prompt → Gemma Encoder → Text Embeddings │
88
+ └─────────────────────────────────────────────────────────────┘
89
+
90
+ ┌─────────────────────────────────────────────────────────────┐
91
+ │ LTX-2 TRANSFORMER (48 Blocks) │
92
+ │ │
93
+ │ ┌──────────────┐ ┌──────────────┐ │
94
+ │ │ Video Stream │ │ Audio Stream │ │
95
+ │ │ │ │ │ │
96
+ │ │ Self-Attn │ │ Self-Attn │ │
97
+ │ │ Cross-Attn │ │ Cross-Attn │ │
98
+ │ │ │◄────────────►│ │ │
99
+ │ │ A↔V Cross │ │ A↔V Cross │ │
100
+ │ │ Feed-Forward │ │ Feed-Forward │ │
101
+ │ └──────────────┘ └──────────────┘ │
102
+ └─────────────────────────────────────────────────────────────┘
103
+
104
+ ┌─────────────────────────────────────────────────────────────┐
105
+ │ OUTPUT DECODING │
106
+ │ │
107
+ │ Video Latents → Video VAE Decoder → Video Pixels │
108
+ │ Audio Latents → Audio VAE Decoder → Mel Spectrogram │
109
+ │ Mel Spectrogram → Vocoder → Audio Waveform │
110
+ └─────────────────────────────────────────────────────────────┘
111
+ ```
112
+
113
+ ---
114
+
115
+ ## The Transformer
116
+
117
+ The core of LTX-2 is a 48-layer transformer that processes both video and audio tokens simultaneously.
118
+
119
+ ### Model Structure
120
+
121
+ **Source**: [`src/ltx_core/model/transformer/model.py`](src/ltx_core/model/transformer/model.py)
122
+
123
+ The `LTXModel` class implements the transformer. It supports both video-only and audio-video generation modes. For actual usage, see the [`ltx-pipelines`](../ltx-pipelines/) package which handles model loading and initialization.
124
+
125
+ ### Transformer Block Architecture
126
+
127
+ **Source**: [`src/ltx_core/model/transformer/transformer.py`](src/ltx_core/model/transformer/transformer.py)
128
+
129
+ ```text
130
+ ┌─────────────────────────────────────────────────────────────┐
131
+ │ TRANSFORMER BLOCK │
132
+ │ │
133
+ │ VIDEO PATH: │
134
+ │ Input → RMSNorm → AdaLN → Self-Attn (attn1) │
135
+ │ → RMSNorm → Cross-Attn (attn2, text) │
136
+ │ → RMSNorm → AdaLN → A↔V Cross-Attn │
137
+ │ → RMSNorm → AdaLN → Feed-Forward (ff) → Output │
138
+ │ │
139
+ │ AUDIO PATH: │
140
+ │ Input → RMSNorm → AdaLN → Self-Attn (audio_attn1) │
141
+ │ → RMSNorm → Cross-Attn (audio_attn2, text) │
142
+ │ → RMSNorm → AdaLN → A↔V Cross-Attn │
143
+ │ → RMSNorm → AdaLN → Feed-Forward (audio_ff) │
144
+ │ │
145
+ │ AdaLN (Adaptive Layer Normalization): │
146
+ │ - Uses scale_shift_table (6 params) for video/audio │
147
+ │ - Uses scale_shift_table_a2v_ca (5 params) for A↔V CA │
148
+ │ - Conditioned on per-token timestep embeddings │
149
+ └─────────────────────────────────────────────────────────────┘
150
+ ```
151
+
152
+ ### Perturbations
153
+
154
+ The transformer supports [**perturbations**](src/ltx_core/guidance/perturbations.py) that selectively skip attention operations.
155
+
156
+ Perturbations allow you to disable specific attention mechanisms during inference, which is useful for guidance techniques like STG (Spatio-Temporal Guidance).
157
+
158
+ **Supported Perturbation Types**:
159
+
160
+ - `SKIP_VIDEO_SELF_ATTN`: Skip video self-attention
161
+ - `SKIP_AUDIO_SELF_ATTN`: Skip audio self-attention
162
+ - `SKIP_A2V_CROSS_ATTN`: Skip audio-to-video cross-attention
163
+ - `SKIP_V2A_CROSS_ATTN`: Skip video-to-audio cross-attention
164
+
165
+ Perturbations are used internally by guidance mechanisms like STG (Spatio-Temporal Guidance). For usage examples, see the [`ltx-pipelines`](../ltx-pipelines/) package.
166
+
167
+ ---
168
+
169
+ ## Video VAE
170
+
171
+ The Video VAE ([`src/ltx_core/model/video_vae/`](src/ltx_core/model/video_vae/)) encodes video pixels into latent representations and decodes them back.
172
+
173
+ ### Architecture
174
+
175
+ - **Encoder**: Compresses `[B, 3, F, H, W]` pixels → `[B, 128, F', H/32, W/32]` latents
176
+ - Where `F' = 1 + (F-1)/8` (frame count must satisfy `(F-1) % 8 == 0`)
177
+ - Example: `[B, 3, 33, 512, 512]` → `[B, 128, 5, 16, 16]`
178
+ - **Decoder**: Expands `[B, 128, F, H, W]` latents → `[B, 3, F', H*32, W*32]` pixels
179
+ - Where `F' = 1 + (F-1)*8`
180
+ - Example: `[B, 128, 5, 16, 16]` → `[B, 3, 33, 512, 512]`
181
+
182
+ The Video VAE is used internally by pipelines for encoding video pixels to latents and decoding latents back to pixels. For usage examples, see the [`ltx-pipelines`](../ltx-pipelines/) package.
183
+
184
+ ---
185
+
186
+ ## Audio VAE
187
+
188
+ The Audio VAE ([`src/ltx_core/model/audio_vae/`](src/ltx_core/model/audio_vae/)) processes audio spectrograms.
189
+
190
+ ### Audio VAE Architecture
191
+
192
+ - **Encoder**: Compresses mel spectrogram `[B, mel_bins, T]` → `[B, 8, T/4, 16]` latents
193
+ - Temporal downsampling: 4× (`LATENT_DOWNSAMPLE_FACTOR = 4`)
194
+ - Frequency bins: Fixed 16 mel bins in latent space
195
+ - Latent channels: 8
196
+ - **Decoder**: Expands `[B, 8, T, 16]` latents → mel spectrogram `[B, mel_bins, T*4]`
197
+ - **Vocoder**: Converts mel spectrogram → audio waveform
198
+
199
+ **Downsampling**:
200
+
201
+ - Temporal: 4× (time steps)
202
+ - Frequency: Variable (input mel_bins → fixed 16 in latent space)
203
+
204
+ The Audio VAE is used internally by pipelines for encoding mel spectrograms to latents and decoding latents back to mel spectrograms. The vocoder converts mel spectrograms to audio waveforms. For usage examples, see the [`ltx-pipelines`](../ltx-pipelines/) package.
205
+
206
+ ---
207
+
208
+ ## Text Encoding (Gemma)
209
+
210
+ LTX-2 uses **Gemma** (Google's open LLM) as the text encoder, located in [`src/ltx_core/text_encoders/gemma/`](src/ltx_core/text_encoders/gemma/).
211
+
212
+ ### Text Encoder Architecture
213
+
214
+ - **Tokenizer**: Converts text → token IDs
215
+ - **Gemma Model**: Processes tokens → embeddings
216
+ - **Text Projection**: Uses `PixArtAlphaTextProjection` to project caption embeddings
217
+ - Two-layer MLP with GELU (tanh approximation) or SiLU activation
218
+ - Projects from caption channels (3840) to model dimensions
219
+ - **Feature Extractor**: Extracts video/audio-specific embeddings
220
+ - **Separate Encoders**:
221
+ - `AVEncoder`: For audio-video generation (outputs separate video and audio contexts)
222
+ - `VideoOnlyEncoder`: For video-only generation
223
+
224
+ ### System Prompts
225
+
226
+ System prompts are also used to enhance user's prompts.
227
+
228
+ - **Text-to-Video**: [`gemma_t2v_system_prompt.txt`](src/ltx_core/text_encoders/gemma/encoders/prompts/gemma_t2v_system_prompt.txt)
229
+ - **Image-to-Video**: [`gemma_i2v_system_prompt.txt`](src/ltx_core/text_encoders/gemma/encoders/prompts/gemma_i2v_system_prompt.txt)
230
+
231
+ **Important**: Video and audio receive **different** context embeddings, even from the same prompt. This allows better modality-specific conditioning.
232
+
233
+ **Output Format**:
234
+
235
+ - Video context: `[B, seq_len, 4096]` - Video-specific text embeddings
236
+ - Audio context: `[B, seq_len, 2048]` - Audio-specific text embeddings
237
+
238
+ The text encoder is used internally by pipelines. For usage examples, see the [`ltx-pipelines`](../ltx-pipelines/) package.
239
+
240
+ ---
241
+
242
+ ## Upscaler
243
+
244
+ The Upscaler ([`src/ltx_core/model/upsampler/`](src/ltx_core/model/upsampler/)) upsamples latent representations for higher-resolution output.
245
+
246
+ The spatial upsampler is used internally by two-stage pipelines (e.g., [`TI2VidTwoStagesPipeline`](../ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py), [`ICLoraPipeline`](../ltx-pipelines/src/ltx_pipelines/ic_lora.py)) to upsample low-resolution latents before final VAE decoding. For usage examples, see the [`ltx-pipelines`](../ltx-pipelines/) package.
247
+
248
+ ---
249
+
250
+ ## Data Flow
251
+
252
+ ### Complete Generation Pipeline
253
+
254
+ Here's how all the components work together conceptually ([`src/ltx_core/components/`](src/ltx_core/components/)):
255
+
256
+ **Pipeline Steps**:
257
+
258
+ 1. **Text Encoding**: Text prompt → Gemma encoder → separate video/audio embeddings
259
+ 2. **Latent Initialization**: Initialize noise latents in spatial format `[B, C, F, H, W]`
260
+ 3. **Patchification**: Convert spatial latents to sequence format `[B, seq_len, dim]` for transformer
261
+ 4. **Sigma Schedule**: Generate noise schedule (adapts to token count)
262
+ 5. **Denoising Loop**: Iteratively denoise using transformer predictions
263
+ - Create Modality inputs with per-token timesteps and RoPE positions
264
+ - Forward pass through transformer (conditional and unconditional for CFG)
265
+ - Apply guidance (CFG, STG, etc.)
266
+ - Update latents using diffusion step (Euler, etc.)
267
+ 6. **Unpatchification**: Convert sequence back to spatial format
268
+ 7. **VAE Decoding**: Decode latents to pixel space (with optional upsampling for two-stage)
269
+
270
+ - [`TI2VidTwoStagesPipeline`](../ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py) - Two-stage text-to-video (recommended)
271
+ - [`ICLoraPipeline`](../ltx-pipelines/src/ltx_pipelines/ic_lora.py) - Video-to-video with IC-LoRA control
272
+ - [`DistilledPipeline`](../ltx-pipelines/src/ltx_pipelines/distilled.py) - Fast inference with distilled model
273
+ - [`KeyframeInterpolationPipeline`](../ltx-pipelines/src/ltx_pipelines/keyframe_interpolation.py) - Keyframe-based interpolation
274
+
275
+ See the [ltx-pipelines README](../ltx-pipelines/README.md) for usage examples.
276
+
277
+ ## 🔗 Related Projects
278
+
279
+ - **[ltx-pipelines](../ltx-pipelines/)** - High-level pipeline implementations for text-to-video, image-to-video, and video-to-video
280
+ - **[ltx-trainer](../ltx-trainer/)** - Training and fine-tuning tools
packages/ltx-core/pyproject.toml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "ltx-core"
3
+ version = "1.0.0"
4
+ description = "Core implementation of Lightricks' LTX-2 model"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "torch~=2.7",
9
+ "torchaudio",
10
+ "einops",
11
+ "numpy",
12
+ "transformers",
13
+ "safetensors",
14
+ "accelerate",
15
+ "scipy>=1.14",
16
+ ]
17
+
18
+ [project.optional-dependencies]
19
+ xformers = ["xformers"]
20
+
21
+
22
+ [tool.uv.sources]
23
+ xformers = { index = "pytorch" }
24
+
25
+ [[tool.uv.index]]
26
+ name = "pytorch"
27
+ url = "https://download.pytorch.org/whl/cu129"
28
+ explicit = true
29
+
30
+ [build-system]
31
+ requires = ["uv_build>=0.9.8,<0.10.0"]
32
+ build-backend = "uv_build"
33
+
34
+ [dependency-groups]
35
+ dev = [
36
+ "scikit-image>=0.25.2",
37
+ ]
packages/ltx-core/src/ltx_core/__init__.py ADDED
File without changes
packages/ltx-core/src/ltx_core/components/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Diffusion pipeline components.
3
+ Submodules:
4
+ diffusion_steps - Diffusion stepping algorithms (EulerDiffusionStep)
5
+ guiders - Guidance strategies (CFGGuider, STGGuider, APG variants)
6
+ noisers - Noise samplers (GaussianNoiser)
7
+ patchifiers - Latent patchification (VideoLatentPatchifier, AudioPatchifier)
8
+ protocols - Protocol definitions (Patchifier, etc.)
9
+ schedulers - Sigma schedulers (LTX2Scheduler, LinearQuadraticScheduler)
10
+ """
packages/ltx-core/src/ltx_core/components/diffusion_steps.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ltx_core.components.protocols import DiffusionStepProtocol
4
+ from ltx_core.utils import to_velocity
5
+
6
+
7
+ class EulerDiffusionStep(DiffusionStepProtocol):
8
+ """
9
+ First-order Euler method for diffusion sampling.
10
+ Takes a single step from the current noise level (sigma) to the next by
11
+ computing velocity from the denoised prediction and applying: sample + velocity * dt.
12
+ """
13
+
14
+ def step(
15
+ self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int
16
+ ) -> torch.Tensor:
17
+ sigma = sigmas[step_index]
18
+ sigma_next = sigmas[step_index + 1]
19
+ dt = sigma_next - sigma
20
+ velocity = to_velocity(sample, sigma, denoised_sample)
21
+
22
+ return (sample.to(torch.float32) + velocity.to(torch.float32) * dt).to(sample.dtype)
packages/ltx-core/src/ltx_core/components/guiders.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+ from ltx_core.components.protocols import GuiderProtocol
6
+
7
+
8
+ @dataclass(frozen=True)
9
+ class CFGGuider(GuiderProtocol):
10
+ """
11
+ Classifier-free guidance (CFG) guider.
12
+ Computes the guidance delta as (scale - 1) * (cond - uncond), steering the
13
+ denoising process toward the conditioned prediction.
14
+ Attributes:
15
+ scale: Guidance strength. 1.0 means no guidance, higher values increase
16
+ adherence to the conditioning.
17
+ """
18
+
19
+ scale: float
20
+
21
+ def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
22
+ return (self.scale - 1) * (cond - uncond)
23
+
24
+ def enabled(self) -> bool:
25
+ return self.scale != 1.0
26
+
27
+
28
+ @dataclass(frozen=True)
29
+ class CFGStarRescalingGuider(GuiderProtocol):
30
+ """
31
+ Calculates the CFG delta between conditioned and unconditioned samples.
32
+ To minimize offset in the denoising direction and move mostly along the
33
+ conditioning axis within the distribution, the unconditioned sample is
34
+ rescaled in accordance with the norm of the conditioned sample.
35
+ Attributes:
36
+ scale (float):
37
+ Global guidance strength. A value of 1.0 corresponds to no extra
38
+ guidance beyond the base model prediction. Values > 1.0 increase
39
+ the influence of the conditioned sample relative to the
40
+ unconditioned one.
41
+ """
42
+
43
+ scale: float
44
+
45
+ def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
46
+ rescaled_neg = projection_coef(cond, uncond) * uncond
47
+ return (self.scale - 1) * (cond - rescaled_neg)
48
+
49
+ def enabled(self) -> bool:
50
+ return self.scale != 1.0
51
+
52
+
53
+ @dataclass(frozen=True)
54
+ class STGGuider(GuiderProtocol):
55
+ """
56
+ Calculates the STG delta between conditioned and perturbed denoised samples.
57
+ Perturbed samples are the result of the denoising process with perturbations,
58
+ e.g. attentions acting as passthrough for certain layers and modalities.
59
+ Attributes:
60
+ scale (float):
61
+ Global strength of the STG guidance. A value of 0.0 disables the
62
+ guidance. Larger values increase the correction applied in the
63
+ direction of (pos_denoised - perturbed_denoised).
64
+ """
65
+
66
+ scale: float
67
+
68
+ def delta(self, pos_denoised: torch.Tensor, perturbed_denoised: torch.Tensor) -> torch.Tensor:
69
+ return self.scale * (pos_denoised - perturbed_denoised)
70
+
71
+ def enabled(self) -> bool:
72
+ return self.scale != 0.0
73
+
74
+
75
+ @dataclass(frozen=True)
76
+ class LtxAPGGuider(GuiderProtocol):
77
+ """
78
+ Calculates the APG (adaptive projected guidance) delta between conditioned
79
+ and unconditioned samples.
80
+ To minimize offset in the denoising direction and move mostly along the
81
+ conditioning axis within the distribution, the (cond - uncond) delta is
82
+ decomposed into components parallel and orthogonal to the conditioned
83
+ sample. The `eta` parameter weights the parallel component, while `scale`
84
+ is applied to the orthogonal component. Optionally, a norm threshold can
85
+ be used to suppress guidance when the magnitude of the correction is small.
86
+ Attributes:
87
+ scale (float):
88
+ Strength applied to the component of the guidance that is orthogonal
89
+ to the conditioned sample. Controls how aggressively we move in
90
+ directions that change semantics but stay consistent with the
91
+ conditioning manifold.
92
+ eta (float):
93
+ Weight of the component of the guidance that is parallel to the
94
+ conditioned sample. A value of 1.0 keeps the full parallel
95
+ component; values in [0, 1] attenuate it, and values > 1.0 amplify
96
+ motion along the conditioning direction.
97
+ norm_threshold (float):
98
+ Minimum L2 norm of the guidance delta below which the guidance
99
+ can be reduced or ignored (depending on implementation).
100
+ This is useful for avoiding noisy or unstable updates when the
101
+ guidance signal is very small.
102
+ """
103
+
104
+ scale: float
105
+ eta: float = 1.0
106
+ norm_threshold: float = 0.0
107
+
108
+ def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
109
+ guidance = cond - uncond
110
+ if self.norm_threshold > 0:
111
+ ones = torch.ones_like(guidance)
112
+ guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
113
+ scale_factor = torch.minimum(ones, self.norm_threshold / guidance_norm)
114
+ guidance = guidance * scale_factor
115
+ proj_coeff = projection_coef(guidance, cond)
116
+ g_parallel = proj_coeff * cond
117
+ g_orth = guidance - g_parallel
118
+ g_apg = g_parallel * self.eta + g_orth
119
+
120
+ return g_apg * (self.scale - 1)
121
+
122
+ def enabled(self) -> bool:
123
+ return self.scale != 1.0
124
+
125
+
126
+ @dataclass(frozen=False)
127
+ class LegacyStatefulAPGGuider(GuiderProtocol):
128
+ """
129
+ Calculates the APG (adaptive projected guidance) delta between conditioned
130
+ and unconditioned samples.
131
+ To minimize offset in the denoising direction and move mostly along the
132
+ conditioning axis within the distribution, the (cond - uncond) delta is
133
+ decomposed into components parallel and orthogonal to the conditioned
134
+ sample. The `eta` parameter weights the parallel component, while `scale`
135
+ is applied to the orthogonal component. Optionally, a norm threshold can
136
+ be used to suppress guidance when the magnitude of the correction is small.
137
+ Attributes:
138
+ scale (float):
139
+ Strength applied to the component of the guidance that is orthogonal
140
+ to the conditioned sample. Controls how aggressively we move in
141
+ directions that change semantics but stay consistent with the
142
+ conditioning manifold.
143
+ eta (float):
144
+ Weight of the component of the guidance that is parallel to the
145
+ conditioned sample. A value of 1.0 keeps the full parallel
146
+ component; values in [0, 1] attenuate it, and values > 1.0 amplify
147
+ motion along the conditioning direction.
148
+ norm_threshold (float):
149
+ Minimum L2 norm of the guidance delta below which the guidance
150
+ can be reduced or ignored (depending on implementation).
151
+ This is useful for avoiding noisy or unstable updates when the
152
+ guidance signal is very small.
153
+ momentum (float):
154
+ Exponential moving-average coefficient for accumulating guidance
155
+ over time. running_avg = momentum * running_avg + guidance
156
+ """
157
+
158
+ scale: float
159
+ eta: float
160
+ norm_threshold: float = 5.0
161
+ momentum: float = 0.0
162
+ # it is user's responsibility not to use same APGGuider for several denoisings or different modalities
163
+ # in order not to share accumulated average across different denoisings or modalities
164
+ running_avg: torch.Tensor | None = None
165
+
166
+ def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
167
+ guidance = cond - uncond
168
+ if self.momentum != 0:
169
+ if self.running_avg is None:
170
+ self.running_avg = guidance.clone()
171
+ else:
172
+ self.running_avg = self.momentum * self.running_avg + guidance
173
+ guidance = self.running_avg
174
+
175
+ if self.norm_threshold > 0:
176
+ ones = torch.ones_like(guidance)
177
+ guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
178
+ scale_factor = torch.minimum(ones, self.norm_threshold / guidance_norm)
179
+ guidance = guidance * scale_factor
180
+
181
+ proj_coeff = projection_coef(guidance, cond)
182
+ g_parallel = proj_coeff * cond
183
+ g_orth = guidance - g_parallel
184
+ g_apg = g_parallel * self.eta + g_orth
185
+
186
+ return g_apg * self.scale
187
+
188
+ def enabled(self) -> bool:
189
+ return self.scale != 0.0
190
+
191
+
192
+ def projection_coef(to_project: torch.Tensor, project_onto: torch.Tensor) -> torch.Tensor:
193
+ batch_size = to_project.shape[0]
194
+ positive_flat = to_project.reshape(batch_size, -1)
195
+ negative_flat = project_onto.reshape(batch_size, -1)
196
+ dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
197
+ squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
198
+ return dot_product / squared_norm
packages/ltx-core/src/ltx_core/components/noisers.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import replace
2
+ from typing import Protocol
3
+
4
+ import torch
5
+
6
+ from ltx_core.types import LatentState
7
+
8
+
9
+ class Noiser(Protocol):
10
+ """Protocol for adding noise to a latent state during diffusion."""
11
+
12
+ def __call__(self, latent_state: LatentState, noise_scale: float) -> LatentState: ...
13
+
14
+
15
+ class GaussianNoiser(Noiser):
16
+ """Adds Gaussian noise to a latent state, scaled by the denoise mask."""
17
+
18
+ def __init__(self, generator: torch.Generator):
19
+ super().__init__()
20
+
21
+ self.generator = generator
22
+
23
+ def __call__(self, latent_state: LatentState, noise_scale: float = 1.0) -> LatentState:
24
+ noise = torch.randn(
25
+ *latent_state.latent.shape,
26
+ device=latent_state.latent.device,
27
+ dtype=latent_state.latent.dtype,
28
+ generator=self.generator,
29
+ )
30
+ scaled_mask = latent_state.denoise_mask * noise_scale
31
+ latent = noise * scaled_mask + latent_state.latent * (1 - scaled_mask)
32
+ return replace(
33
+ latent_state,
34
+ latent=latent.to(latent_state.latent.dtype),
35
+ )
packages/ltx-core/src/ltx_core/components/patchifiers.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple
3
+
4
+ import einops
5
+ import torch
6
+
7
+ from ltx_core.components.protocols import Patchifier
8
+ from ltx_core.types import AudioLatentShape, SpatioTemporalScaleFactors, VideoLatentShape
9
+
10
+
11
+ class VideoLatentPatchifier(Patchifier):
12
+ def __init__(self, patch_size: int):
13
+ # Patch sizes for video latents.
14
+ self._patch_size = (
15
+ 1, # temporal dimension
16
+ patch_size, # height dimension
17
+ patch_size, # width dimension
18
+ )
19
+
20
+ @property
21
+ def patch_size(self) -> Tuple[int, int, int]:
22
+ return self._patch_size
23
+
24
+ def get_token_count(self, tgt_shape: VideoLatentShape) -> int:
25
+ return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size)
26
+
27
+ def patchify(
28
+ self,
29
+ latents: torch.Tensor,
30
+ ) -> torch.Tensor:
31
+ latents = einops.rearrange(
32
+ latents,
33
+ "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
34
+ p1=self._patch_size[0],
35
+ p2=self._patch_size[1],
36
+ p3=self._patch_size[2],
37
+ )
38
+
39
+ return latents
40
+
41
+ def unpatchify(
42
+ self,
43
+ latents: torch.Tensor,
44
+ output_shape: VideoLatentShape,
45
+ ) -> torch.Tensor:
46
+ assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier"
47
+
48
+ patch_grid_frames = output_shape.frames // self._patch_size[0]
49
+ patch_grid_height = output_shape.height // self._patch_size[1]
50
+ patch_grid_width = output_shape.width // self._patch_size[2]
51
+
52
+ latents = einops.rearrange(
53
+ latents,
54
+ "b (f h w) (c p q) -> b c f (h p) (w q)",
55
+ f=patch_grid_frames,
56
+ h=patch_grid_height,
57
+ w=patch_grid_width,
58
+ p=self._patch_size[1],
59
+ q=self._patch_size[2],
60
+ )
61
+
62
+ return latents
63
+
64
+ def get_patch_grid_bounds(
65
+ self,
66
+ output_shape: AudioLatentShape | VideoLatentShape,
67
+ device: Optional[torch.device] = None,
68
+ ) -> torch.Tensor:
69
+ """
70
+ Return the per-dimension bounds [inclusive start, exclusive end) for every
71
+ patch produced by `patchify`. The bounds are expressed in the original
72
+ video grid coordinates: frame/time, height, and width.
73
+ The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where:
74
+ - axis 1 (size 3) enumerates (frame/time, height, width) dimensions
75
+ - axis 3 (size 2) stores `[start, end)` indices within each dimension
76
+ Args:
77
+ output_shape: Video grid description containing frames, height, and width.
78
+ device: Device of the latent tensor.
79
+ """
80
+ if not isinstance(output_shape, VideoLatentShape):
81
+ raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates")
82
+
83
+ frames = output_shape.frames
84
+ height = output_shape.height
85
+ width = output_shape.width
86
+ batch_size = output_shape.batch
87
+
88
+ # Validate inputs to ensure positive dimensions
89
+ assert frames > 0, f"frames must be positive, got {frames}"
90
+ assert height > 0, f"height must be positive, got {height}"
91
+ assert width > 0, f"width must be positive, got {width}"
92
+ assert batch_size > 0, f"batch_size must be positive, got {batch_size}"
93
+
94
+ # Generate grid coordinates for each dimension (frame, height, width)
95
+ # We use torch.arange to create the starting coordinates for each patch.
96
+ # indexing='ij' ensures the dimensions are in the order (frame, height, width).
97
+ grid_coords = torch.meshgrid(
98
+ torch.arange(start=0, end=frames, step=self._patch_size[0], device=device),
99
+ torch.arange(start=0, end=height, step=self._patch_size[1], device=device),
100
+ torch.arange(start=0, end=width, step=self._patch_size[2], device=device),
101
+ indexing="ij",
102
+ )
103
+
104
+ # Stack the grid coordinates to create the start coordinates tensor.
105
+ # Shape becomes (3, grid_f, grid_h, grid_w)
106
+ patch_starts = torch.stack(grid_coords, dim=0)
107
+
108
+ # Create a tensor containing the size of a single patch:
109
+ # (frame_patch_size, height_patch_size, width_patch_size).
110
+ # Reshape to (3, 1, 1, 1) to enable broadcasting when adding to the start coordinates.
111
+ patch_size_delta = torch.tensor(
112
+ self._patch_size,
113
+ device=patch_starts.device,
114
+ dtype=patch_starts.dtype,
115
+ ).view(3, 1, 1, 1)
116
+
117
+ # Calculate end coordinates: start + patch_size
118
+ # Shape becomes (3, grid_f, grid_h, grid_w)
119
+ patch_ends = patch_starts + patch_size_delta
120
+
121
+ # Stack start and end coordinates together along the last dimension
122
+ # Shape becomes (3, grid_f, grid_h, grid_w, 2), where the last dimension is [start, end]
123
+ latent_coords = torch.stack((patch_starts, patch_ends), dim=-1)
124
+
125
+ # Broadcast to batch size and flatten all spatial/temporal dimensions into one sequence.
126
+ # Final Shape: (batch_size, 3, num_patches, 2)
127
+ latent_coords = einops.repeat(
128
+ latent_coords,
129
+ "c f h w bounds -> b c (f h w) bounds",
130
+ b=batch_size,
131
+ bounds=2,
132
+ )
133
+
134
+ return latent_coords
135
+
136
+
137
+ def get_pixel_coords(
138
+ latent_coords: torch.Tensor,
139
+ scale_factors: SpatioTemporalScaleFactors,
140
+ causal_fix: bool = False,
141
+ ) -> torch.Tensor:
142
+ """
143
+ Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
144
+ each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
145
+ Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
146
+ Args:
147
+ latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
148
+ scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
149
+ per axis.
150
+ causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
151
+ that treat frame zero differently still yield non-negative timestamps.
152
+ """
153
+ # Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
154
+ broadcast_shape = [1] * latent_coords.ndim
155
+ broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
156
+ scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
157
+
158
+ # Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
159
+ pixel_coords = latent_coords * scale_tensor
160
+
161
+ if causal_fix:
162
+ # VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
163
+ # Shift and clamp to keep the first-frame timestamps causal and non-negative.
164
+ pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
165
+
166
+ return pixel_coords
167
+
168
+
169
+ class AudioPatchifier(Patchifier):
170
+ def __init__(
171
+ self,
172
+ patch_size: int,
173
+ sample_rate: int = 16000,
174
+ hop_length: int = 160,
175
+ audio_latent_downsample_factor: int = 4,
176
+ is_causal: bool = True,
177
+ shift: int = 0,
178
+ ):
179
+ """
180
+ Patchifier tailored for spectrogram/audio latents.
181
+ Args:
182
+ patch_size: Number of mel bins combined into a single patch. This
183
+ controls the resolution along the frequency axis.
184
+ sample_rate: Original waveform sampling rate. Used to map latent
185
+ indices back to seconds so downstream consumers can align audio
186
+ and video cues.
187
+ hop_length: Window hop length used for the spectrogram. Determines
188
+ how many real-time samples separate two consecutive latent frames.
189
+ audio_latent_downsample_factor: Ratio between spectrogram frames and
190
+ latent frames; compensates for additional downsampling inside the
191
+ VAE encoder.
192
+ is_causal: When True, timing is shifted to account for causal
193
+ receptive fields so timestamps do not peek into the future.
194
+ shift: Integer offset applied to the latent indices. Enables
195
+ constructing overlapping windows from the same latent sequence.
196
+ """
197
+ self.hop_length = hop_length
198
+ self.sample_rate = sample_rate
199
+ self.audio_latent_downsample_factor = audio_latent_downsample_factor
200
+ self.is_causal = is_causal
201
+ self.shift = shift
202
+ self._patch_size = (1, patch_size, patch_size)
203
+
204
+ @property
205
+ def patch_size(self) -> Tuple[int, int, int]:
206
+ return self._patch_size
207
+
208
+ def get_token_count(self, tgt_shape: AudioLatentShape) -> int:
209
+ return tgt_shape.frames
210
+
211
+ def _get_audio_latent_time_in_sec(
212
+ self,
213
+ start_latent: int,
214
+ end_latent: int,
215
+ dtype: torch.dtype,
216
+ device: Optional[torch.device] = None,
217
+ ) -> torch.Tensor:
218
+ """
219
+ Converts latent indices into real-time seconds while honoring causal
220
+ offsets and the configured hop length.
221
+ Args:
222
+ start_latent: Inclusive start index inside the latent sequence. This
223
+ sets the first timestamp returned.
224
+ end_latent: Exclusive end index. Determines how many timestamps get
225
+ generated.
226
+ dtype: Floating-point dtype used for the returned tensor, allowing
227
+ callers to control precision.
228
+ device: Target device for the timestamp tensor. When omitted the
229
+ computation occurs on CPU to avoid surprising GPU allocations.
230
+ """
231
+ if device is None:
232
+ device = torch.device("cpu")
233
+
234
+ audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)
235
+
236
+ audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor
237
+
238
+ if self.is_causal:
239
+ # Frame offset for causal alignment.
240
+ # The "+1" ensures the timestamp corresponds to the first sample that is fully available.
241
+ causal_offset = 1
242
+ audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0)
243
+
244
+ return audio_mel_frame * self.hop_length / self.sample_rate
245
+
246
+ def _compute_audio_timings(
247
+ self,
248
+ batch_size: int,
249
+ num_steps: int,
250
+ device: Optional[torch.device] = None,
251
+ ) -> torch.Tensor:
252
+ """
253
+ Builds a `(B, 1, T, 2)` tensor containing timestamps for each latent frame.
254
+ This helper method underpins `get_patch_grid_bounds` for the audio patchifier.
255
+ Args:
256
+ batch_size: Number of sequences to broadcast the timings over.
257
+ num_steps: Number of latent frames (time steps) to convert into timestamps.
258
+ device: Device on which the resulting tensor should reside.
259
+ """
260
+ resolved_device = device
261
+ if resolved_device is None:
262
+ resolved_device = torch.device("cpu")
263
+
264
+ start_timings = self._get_audio_latent_time_in_sec(
265
+ self.shift,
266
+ num_steps + self.shift,
267
+ torch.float32,
268
+ resolved_device,
269
+ )
270
+ start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
271
+
272
+ end_timings = self._get_audio_latent_time_in_sec(
273
+ self.shift + 1,
274
+ num_steps + self.shift + 1,
275
+ torch.float32,
276
+ resolved_device,
277
+ )
278
+ end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
279
+
280
+ return torch.stack([start_timings, end_timings], dim=-1)
281
+
282
+ def patchify(
283
+ self,
284
+ audio_latents: torch.Tensor,
285
+ ) -> torch.Tensor:
286
+ """
287
+ Flattens the audio latent tensor along time. Use `get_patch_grid_bounds`
288
+ to derive timestamps for each latent frame based on the configured hop
289
+ length and downsampling.
290
+ Args:
291
+ audio_latents: Latent tensor to patchify.
292
+ Returns:
293
+ Flattened patch tokens tensor. Use `get_patch_grid_bounds` to compute the
294
+ corresponding timing metadata when needed.
295
+ """
296
+ audio_latents = einops.rearrange(
297
+ audio_latents,
298
+ "b c t f -> b t (c f)",
299
+ )
300
+
301
+ return audio_latents
302
+
303
+ def unpatchify(
304
+ self,
305
+ audio_latents: torch.Tensor,
306
+ output_shape: AudioLatentShape,
307
+ ) -> torch.Tensor:
308
+ """
309
+ Restores the `(B, C, T, F)` spectrogram tensor from flattened patches.
310
+ Use `get_patch_grid_bounds` to recompute the timestamps that describe each
311
+ frame's position in real time.
312
+ Args:
313
+ audio_latents: Latent tensor to unpatchify.
314
+ output_shape: Shape of the unpatched output tensor.
315
+ Returns:
316
+ Unpatched latent tensor. Use `get_patch_grid_bounds` to compute the timing
317
+ metadata associated with the restored latents.
318
+ """
319
+ # audio_latents shape: (batch, time, freq * channels)
320
+ audio_latents = einops.rearrange(
321
+ audio_latents,
322
+ "b t (c f) -> b c t f",
323
+ c=output_shape.channels,
324
+ f=output_shape.mel_bins,
325
+ )
326
+
327
+ return audio_latents
328
+
329
+ def get_patch_grid_bounds(
330
+ self,
331
+ output_shape: AudioLatentShape | VideoLatentShape,
332
+ device: Optional[torch.device] = None,
333
+ ) -> torch.Tensor:
334
+ """
335
+ Return the temporal bounds `[inclusive start, exclusive end)` for every
336
+ patch emitted by `patchify`. For audio this corresponds to timestamps in
337
+ seconds aligned with the original spectrogram grid.
338
+ The returned tensor has shape `[batch_size, 1, time_steps, 2]`, where:
339
+ - axis 1 (size 1) represents the temporal dimension
340
+ - axis 3 (size 2) stores the `[start, end)` timestamps per patch
341
+ Args:
342
+ output_shape: Audio grid specification describing the number of time steps.
343
+ device: Target device for the returned tensor.
344
+ """
345
+ if not isinstance(output_shape, AudioLatentShape):
346
+ raise ValueError("AudioPatchifier expects AudioLatentShape when computing coordinates")
347
+
348
+ return self._compute_audio_timings(output_shape.batch, output_shape.frames, device)
packages/ltx-core/src/ltx_core/components/protocols.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Protocol, Tuple
2
+
3
+ import torch
4
+
5
+ from ltx_core.types import AudioLatentShape, VideoLatentShape
6
+
7
+
8
+ class Patchifier(Protocol):
9
+ """
10
+ Protocol for patchifiers that convert latent tensors into patches and assemble them back.
11
+ """
12
+
13
+ def patchify(
14
+ self,
15
+ latents: torch.Tensor,
16
+ ) -> torch.Tensor:
17
+ ...
18
+ """
19
+ Convert latent tensors into flattened patch tokens.
20
+ Args:
21
+ latents: Latent tensor to patchify.
22
+ Returns:
23
+ Flattened patch tokens tensor.
24
+ """
25
+
26
+ def unpatchify(
27
+ self,
28
+ latents: torch.Tensor,
29
+ output_shape: AudioLatentShape | VideoLatentShape,
30
+ ) -> torch.Tensor:
31
+ """
32
+ Converts latent tensors between spatio-temporal formats and flattened sequence representations.
33
+ Args:
34
+ latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`.
35
+ output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or
36
+ VideoLatentShape.
37
+ Returns:
38
+ Dense latent tensor restored from the flattened representation.
39
+ """
40
+
41
+ @property
42
+ def patch_size(self) -> Tuple[int, int, int]:
43
+ ...
44
+ """
45
+ Returns the patch size as a tuple of (temporal, height, width) dimensions
46
+ """
47
+
48
+ def get_patch_grid_bounds(
49
+ self,
50
+ output_shape: AudioLatentShape | VideoLatentShape,
51
+ device: torch.device | None = None,
52
+ ) -> torch.Tensor:
53
+ ...
54
+ """
55
+ Compute metadata describing where each latent patch resides within the
56
+ grid specified by `output_shape`.
57
+ Args:
58
+ output_shape: Target grid layout for the patches.
59
+ device: Target device for the returned tensor.
60
+ Returns:
61
+ Tensor containing patch coordinate metadata such as spatial or temporal intervals.
62
+ """
63
+
64
+
65
+ class SchedulerProtocol(Protocol):
66
+ """
67
+ Protocol for schedulers that provide a sigmas schedule tensor for a
68
+ given number of steps. Device is cpu.
69
+ """
70
+
71
+ def execute(self, steps: int, **kwargs) -> torch.FloatTensor: ...
72
+
73
+
74
+ class GuiderProtocol(Protocol):
75
+ """
76
+ Protocol for guiders that compute a delta tensor given conditioning inputs.
77
+ The returned delta should be added to the conditional output (cond), enabling
78
+ multiple guiders to be chained together by accumulating their deltas.
79
+ """
80
+
81
+ scale: float
82
+
83
+ def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor: ...
84
+
85
+ def enabled(self) -> bool:
86
+ """
87
+ Returns whether the corresponding perturbation is enabled. E.g. for CFG, this should return False if the scale
88
+ is 1.0.
89
+ """
90
+ ...
91
+
92
+
93
+ class DiffusionStepProtocol(Protocol):
94
+ """
95
+ Protocol for diffusion steps that provide a next sample tensor for a given current sample tensor,
96
+ current denoised sample tensor, and sigmas tensor.
97
+ """
98
+
99
+ def step(
100
+ self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int
101
+ ) -> torch.Tensor: ...
packages/ltx-core/src/ltx_core/components/schedulers.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import lru_cache
3
+
4
+ import numpy
5
+ import scipy
6
+ import torch
7
+
8
+ from ltx_core.components.protocols import SchedulerProtocol
9
+
10
+ BASE_SHIFT_ANCHOR = 1024
11
+ MAX_SHIFT_ANCHOR = 4096
12
+
13
+
14
+ class LTX2Scheduler(SchedulerProtocol):
15
+ """
16
+ Default scheduler for LTX-2 diffusion sampling.
17
+ Generates a sigma schedule with token-count-dependent shifting and optional
18
+ stretching to a terminal value.
19
+ """
20
+
21
+ def execute(
22
+ self,
23
+ steps: int,
24
+ latent: torch.Tensor | None = None,
25
+ max_shift: float = 2.05,
26
+ base_shift: float = 0.95,
27
+ stretch: bool = True,
28
+ terminal: float = 0.1,
29
+ **_kwargs,
30
+ ) -> torch.FloatTensor:
31
+ tokens = math.prod(latent.shape[2:]) if latent is not None else MAX_SHIFT_ANCHOR
32
+ sigmas = torch.linspace(1.0, 0.0, steps + 1)
33
+
34
+ x1 = BASE_SHIFT_ANCHOR
35
+ x2 = MAX_SHIFT_ANCHOR
36
+ mm = (max_shift - base_shift) / (x2 - x1)
37
+ b = base_shift - mm * x1
38
+ sigma_shift = (tokens) * mm + b
39
+
40
+ power = 1
41
+ sigmas = torch.where(
42
+ sigmas != 0,
43
+ math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
44
+ 0,
45
+ )
46
+
47
+ # Stretch sigmas so that its final value matches the given terminal value.
48
+ if stretch:
49
+ non_zero_mask = sigmas != 0
50
+ non_zero_sigmas = sigmas[non_zero_mask]
51
+ one_minus_z = 1.0 - non_zero_sigmas
52
+ scale_factor = one_minus_z[-1] / (1.0 - terminal)
53
+ stretched = 1.0 - (one_minus_z / scale_factor)
54
+ sigmas[non_zero_mask] = stretched
55
+
56
+ return sigmas.to(torch.float32)
57
+
58
+
59
+ class LinearQuadraticScheduler(SchedulerProtocol):
60
+ """
61
+ Scheduler with linear steps followed by quadratic steps.
62
+ Produces a sigma schedule that transitions linearly up to a threshold,
63
+ then follows a quadratic curve for the remaining steps.
64
+ """
65
+
66
+ def execute(
67
+ self, steps: int, threshold_noise: float = 0.025, linear_steps: int | None = None, **_kwargs
68
+ ) -> torch.FloatTensor:
69
+ if steps == 1:
70
+ return torch.FloatTensor([1.0, 0.0])
71
+
72
+ if linear_steps is None:
73
+ linear_steps = steps // 2
74
+ linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
75
+ threshold_noise_step_diff = linear_steps - threshold_noise * steps
76
+ quadratic_steps = steps - linear_steps
77
+ quadratic_sigma_schedule = []
78
+ if quadratic_steps > 0:
79
+ quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
80
+ linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
81
+ const = quadratic_coef * (linear_steps**2)
82
+ quadratic_sigma_schedule = [
83
+ quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, steps)
84
+ ]
85
+ sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
86
+ sigma_schedule = [1.0 - x for x in sigma_schedule]
87
+ return torch.FloatTensor(sigma_schedule)
88
+
89
+
90
+ class BetaScheduler(SchedulerProtocol):
91
+ """
92
+ Scheduler using a beta distribution to sample timesteps.
93
+ Based on: https://arxiv.org/abs/2407.12173
94
+ """
95
+
96
+ shift = 2.37
97
+ timesteps_length = 10000
98
+
99
+ def execute(self, steps: int, alpha: float = 0.6, beta: float = 0.6) -> torch.FloatTensor:
100
+ """
101
+ Execute the beta scheduler.
102
+ Args:
103
+ steps: The number of steps to execute the scheduler for.
104
+ alpha: The alpha parameter for the beta distribution.
105
+ beta: The beta parameter for the beta distribution.
106
+ Warnings:
107
+ The number of steps within `sigmas` theoretically might be less than `steps+1`,
108
+ because of the deduplication of the identical timesteps
109
+ Returns:
110
+ A tensor of sigmas.
111
+ """
112
+ model_sampling_sigmas = _precalculate_model_sampling_sigmas(self.shift, self.timesteps_length)
113
+ total_timesteps = len(model_sampling_sigmas) - 1
114
+ ts = 1 - numpy.linspace(0, 1, steps, endpoint=False)
115
+ ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps).tolist()
116
+ ts = list(dict.fromkeys(ts))
117
+
118
+ sigmas = [float(model_sampling_sigmas[int(t)]) for t in ts] + [0.0]
119
+ return torch.FloatTensor(sigmas)
120
+
121
+
122
+ @lru_cache(maxsize=5)
123
+ def _precalculate_model_sampling_sigmas(shift: float, timesteps_length: int) -> torch.Tensor:
124
+ timesteps = torch.arange(1, timesteps_length + 1, 1) / timesteps_length
125
+ return torch.Tensor([flux_time_shift(shift, 1.0, t) for t in timesteps])
126
+
127
+
128
+ def flux_time_shift(mu: float, sigma: float, t: float) -> float:
129
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
packages/ltx-core/src/ltx_core/conditioning/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Conditioning utilities: latent state, tools, and conditioning types."""
2
+
3
+ from ltx_core.conditioning.exceptions import ConditioningError
4
+ from ltx_core.conditioning.item import ConditioningItem
5
+ from ltx_core.conditioning.types import VideoConditionByKeyframeIndex, VideoConditionByLatentIndex
6
+
7
+ __all__ = [
8
+ "ConditioningError",
9
+ "ConditioningItem",
10
+ "VideoConditionByKeyframeIndex",
11
+ "VideoConditionByLatentIndex",
12
+ ]
packages/ltx-core/src/ltx_core/conditioning/exceptions.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ class ConditioningError(Exception):
2
+ """
3
+ Class for conditioning-related errors.
4
+ """
packages/ltx-core/src/ltx_core/conditioning/item.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Protocol
2
+
3
+ from ltx_core.tools import LatentTools
4
+ from ltx_core.types import LatentState
5
+
6
+
7
+ class ConditioningItem(Protocol):
8
+ """Protocol for conditioning items that modify latent state during diffusion."""
9
+
10
+ def apply_to(self, latent_state: LatentState, latent_tools: LatentTools) -> LatentState:
11
+ """
12
+ Apply the conditioning to the latent state.
13
+ Args:
14
+ latent_state: The latent state to apply the conditioning to. This is state always patchified.
15
+ Returns:
16
+ The latent state after the conditioning has been applied.
17
+ IMPORTANT: If the conditioning needs to add extra tokens to the latent, it should add them to the end of the
18
+ latent.
19
+ """
20
+ ...
packages/ltx-core/src/ltx_core/conditioning/types/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """Conditioning type implementations."""
2
+
3
+ from ltx_core.conditioning.types.keyframe_cond import VideoConditionByKeyframeIndex
4
+ from ltx_core.conditioning.types.latent_cond import VideoConditionByLatentIndex
5
+
6
+ __all__ = [
7
+ "VideoConditionByKeyframeIndex",
8
+ "VideoConditionByLatentIndex",
9
+ ]
packages/ltx-core/src/ltx_core/conditioning/types/keyframe_cond.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ltx_core.components.patchifiers import get_pixel_coords
4
+ from ltx_core.conditioning.item import ConditioningItem
5
+ from ltx_core.tools import VideoLatentTools
6
+ from ltx_core.types import LatentState, VideoLatentShape
7
+
8
+
9
+ class VideoConditionByKeyframeIndex(ConditioningItem):
10
+ """
11
+ Conditions video generation on keyframe latents at a specific frame index.
12
+ Appends keyframe tokens to the latent state with positions offset by frame_idx,
13
+ and sets denoise strength according to the strength parameter.
14
+ """
15
+
16
+ def __init__(self, keyframes: torch.Tensor, frame_idx: int, strength: float):
17
+ self.keyframes = keyframes
18
+ self.frame_idx = frame_idx
19
+ self.strength = strength
20
+
21
+ def apply_to(
22
+ self,
23
+ latent_state: LatentState,
24
+ latent_tools: VideoLatentTools,
25
+ ) -> LatentState:
26
+ tokens = latent_tools.patchifier.patchify(self.keyframes)
27
+ latent_coords = latent_tools.patchifier.get_patch_grid_bounds(
28
+ output_shape=VideoLatentShape.from_torch_shape(self.keyframes.shape),
29
+ device=self.keyframes.device,
30
+ )
31
+ positions = get_pixel_coords(
32
+ latent_coords=latent_coords,
33
+ scale_factors=latent_tools.scale_factors,
34
+ causal_fix=latent_tools.causal_fix if self.frame_idx == 0 else False,
35
+ )
36
+
37
+ positions[:, 0, ...] += self.frame_idx
38
+ positions = positions.to(dtype=torch.float32)
39
+ positions[:, 0, ...] /= latent_tools.fps
40
+
41
+ denoise_mask = torch.full(
42
+ size=(*tokens.shape[:2], 1),
43
+ fill_value=1.0 - self.strength,
44
+ device=self.keyframes.device,
45
+ dtype=self.keyframes.dtype,
46
+ )
47
+
48
+ return LatentState(
49
+ latent=torch.cat([latent_state.latent, tokens], dim=1),
50
+ denoise_mask=torch.cat([latent_state.denoise_mask, denoise_mask], dim=1),
51
+ positions=torch.cat([latent_state.positions, positions], dim=2),
52
+ clean_latent=torch.cat([latent_state.clean_latent, tokens], dim=1),
53
+ )
packages/ltx-core/src/ltx_core/conditioning/types/latent_cond.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ltx_core.conditioning.exceptions import ConditioningError
4
+ from ltx_core.conditioning.item import ConditioningItem
5
+ from ltx_core.tools import LatentTools
6
+ from ltx_core.types import LatentState
7
+
8
+
9
+ class VideoConditionByLatentIndex(ConditioningItem):
10
+ """
11
+ Conditions video generation by injecting latents at a specific latent frame index.
12
+ Replaces tokens in the latent state at positions corresponding to latent_idx,
13
+ and sets denoise strength according to the strength parameter.
14
+ """
15
+
16
+ def __init__(self, latent: torch.Tensor, strength: float, latent_idx: int):
17
+ self.latent = latent
18
+ self.strength = strength
19
+ self.latent_idx = latent_idx
20
+
21
+ def apply_to(self, latent_state: LatentState, latent_tools: LatentTools) -> LatentState:
22
+ cond_batch, cond_channels, _, cond_height, cond_width = self.latent.shape
23
+ tgt_batch, tgt_channels, tgt_frames, tgt_height, tgt_width = latent_tools.target_shape.to_torch_shape()
24
+
25
+ if (cond_batch, cond_channels, cond_height, cond_width) != (tgt_batch, tgt_channels, tgt_height, tgt_width):
26
+ raise ConditioningError(
27
+ f"Can't apply image conditioning item to latent with shape {latent_tools.target_shape}, expected "
28
+ f"shape is ({tgt_batch}, {tgt_channels}, {tgt_frames}, {tgt_height}, {tgt_width}). Make sure "
29
+ "the image and latent have the same spatial shape."
30
+ )
31
+
32
+ tokens = latent_tools.patchifier.patchify(self.latent)
33
+ start_token = latent_tools.patchifier.get_token_count(
34
+ latent_tools.target_shape._replace(frames=self.latent_idx)
35
+ )
36
+ stop_token = start_token + tokens.shape[1]
37
+
38
+ latent_state = latent_state.clone()
39
+
40
+ latent_state.latent[:, start_token:stop_token] = tokens
41
+ latent_state.clean_latent[:, start_token:stop_token] = tokens
42
+ latent_state.denoise_mask[:, start_token:stop_token] = 1.0 - self.strength
43
+
44
+ return latent_state
packages/ltx-core/src/ltx_core/guidance/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Guidance and perturbation utilities for attention manipulation."""
2
+
3
+ from ltx_core.guidance.perturbations import (
4
+ BatchedPerturbationConfig,
5
+ Perturbation,
6
+ PerturbationConfig,
7
+ PerturbationType,
8
+ )
9
+
10
+ __all__ = [
11
+ "BatchedPerturbationConfig",
12
+ "Perturbation",
13
+ "PerturbationConfig",
14
+ "PerturbationType",
15
+ ]
packages/ltx-core/src/ltx_core/guidance/perturbations.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from enum import Enum
3
+
4
+ import torch
5
+ from torch._prims_common import DeviceLikeType
6
+
7
+
8
+ class PerturbationType(Enum):
9
+ """Types of attention perturbations for STG (Spatio-Temporal Guidance)."""
10
+
11
+ SKIP_A2V_CROSS_ATTN = "skip_a2v_cross_attn"
12
+ SKIP_V2A_CROSS_ATTN = "skip_v2a_cross_attn"
13
+ SKIP_VIDEO_SELF_ATTN = "skip_video_self_attn"
14
+ SKIP_AUDIO_SELF_ATTN = "skip_audio_self_attn"
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class Perturbation:
19
+ """A single perturbation specifying which attention type to skip and in which blocks."""
20
+
21
+ type: PerturbationType
22
+ blocks: list[int] | None # None means all blocks
23
+
24
+ def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:
25
+ if self.type != perturbation_type:
26
+ return False
27
+
28
+ if self.blocks is None:
29
+ return True
30
+
31
+ return block in self.blocks
32
+
33
+
34
+ @dataclass(frozen=True)
35
+ class PerturbationConfig:
36
+ """Configuration holding a list of perturbations for a single sample."""
37
+
38
+ perturbations: list[Perturbation] | None
39
+
40
+ def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:
41
+ if self.perturbations is None:
42
+ return False
43
+
44
+ return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
45
+
46
+ @staticmethod
47
+ def empty() -> "PerturbationConfig":
48
+ return PerturbationConfig([])
49
+
50
+
51
+ @dataclass(frozen=True)
52
+ class BatchedPerturbationConfig:
53
+ """Perturbation configurations for a batch, with utilities for generating attention masks."""
54
+
55
+ perturbations: list[PerturbationConfig]
56
+
57
+ def mask(
58
+ self, perturbation_type: PerturbationType, block: int, device: DeviceLikeType, dtype: torch.dtype
59
+ ) -> torch.Tensor:
60
+ mask = torch.ones((len(self.perturbations),), device=device, dtype=dtype)
61
+ for batch_idx, perturbation in enumerate(self.perturbations):
62
+ if perturbation.is_perturbed(perturbation_type, block):
63
+ mask[batch_idx] = 0
64
+
65
+ return mask
66
+
67
+ def mask_like(self, perturbation_type: PerturbationType, block: int, values: torch.Tensor) -> torch.Tensor:
68
+ mask = self.mask(perturbation_type, block, values.device, values.dtype)
69
+ return mask.view(mask.numel(), *([1] * len(values.shape[1:])))
70
+
71
+ def any_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool:
72
+ return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
73
+
74
+ def all_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool:
75
+ return all(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
76
+
77
+ @staticmethod
78
+ def empty(batch_size: int) -> "BatchedPerturbationConfig":
79
+ return BatchedPerturbationConfig([PerturbationConfig.empty() for _ in range(batch_size)])
packages/ltx-core/src/ltx_core/loader/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Loader utilities for model weights, LoRAs, and safetensor operations."""
2
+
3
+ from ltx_core.loader.fuse_loras import apply_loras
4
+ from ltx_core.loader.module_ops import ModuleOps
5
+ from ltx_core.loader.primitives import (
6
+ LoRAAdaptableProtocol,
7
+ LoraPathStrengthAndSDOps,
8
+ LoraStateDictWithStrength,
9
+ ModelBuilderProtocol,
10
+ StateDict,
11
+ StateDictLoader,
12
+ )
13
+ from ltx_core.loader.registry import DummyRegistry, Registry, StateDictRegistry
14
+ from ltx_core.loader.sd_ops import (
15
+ LTXV_LORA_COMFY_RENAMING_MAP,
16
+ ContentMatching,
17
+ ContentReplacement,
18
+ KeyValueOperation,
19
+ KeyValueOperationResult,
20
+ SDKeyValueOperation,
21
+ SDOps,
22
+ )
23
+ from ltx_core.loader.sft_loader import SafetensorsModelStateDictLoader, SafetensorsStateDictLoader
24
+ from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder
25
+
26
+ __all__ = [
27
+ "LTXV_LORA_COMFY_RENAMING_MAP",
28
+ "ContentMatching",
29
+ "ContentReplacement",
30
+ "DummyRegistry",
31
+ "KeyValueOperation",
32
+ "KeyValueOperationResult",
33
+ "LoRAAdaptableProtocol",
34
+ "LoraPathStrengthAndSDOps",
35
+ "LoraStateDictWithStrength",
36
+ "ModelBuilderProtocol",
37
+ "ModuleOps",
38
+ "Registry",
39
+ "SDKeyValueOperation",
40
+ "SDOps",
41
+ "SafetensorsModelStateDictLoader",
42
+ "SafetensorsStateDictLoader",
43
+ "SingleGPUModelBuilder",
44
+ "StateDict",
45
+ "StateDictLoader",
46
+ "StateDictRegistry",
47
+ "apply_loras",
48
+ ]
packages/ltx-core/src/ltx_core/loader/fuse_loras.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+
4
+ from ltx_core.loader.kernels import fused_add_round_kernel
5
+ from ltx_core.loader.primitives import LoraStateDictWithStrength, StateDict
6
+
7
+ BLOCK_SIZE = 1024
8
+
9
+
10
+ def fused_add_round_launch(target_weight: torch.Tensor, original_weight: torch.Tensor, seed: int) -> torch.Tensor:
11
+ if original_weight.dtype == torch.float8_e4m3fn:
12
+ exponent_bits, mantissa_bits, exponent_bias = 4, 3, 7
13
+ elif original_weight.dtype == torch.float8_e5m2:
14
+ exponent_bits, mantissa_bits, exponent_bias = 5, 2, 15 # noqa: F841
15
+ else:
16
+ raise ValueError("Unsupported dtype")
17
+
18
+ if target_weight.dtype != torch.bfloat16:
19
+ raise ValueError("target_weight dtype must be bfloat16")
20
+
21
+ # Calculate grid and block sizes
22
+ n_elements = original_weight.numel()
23
+ grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
24
+
25
+ # Launch kernel
26
+ fused_add_round_kernel[grid](
27
+ original_weight,
28
+ target_weight,
29
+ seed,
30
+ n_elements,
31
+ exponent_bias,
32
+ mantissa_bits,
33
+ BLOCK_SIZE,
34
+ )
35
+ return target_weight
36
+
37
+
38
+ def calculate_weight_float8_(target_weights: torch.Tensor, original_weights: torch.Tensor) -> torch.Tensor:
39
+ result = fused_add_round_launch(target_weights, original_weights, seed=0).to(target_weights.dtype)
40
+ target_weights.copy_(result, non_blocking=True)
41
+ return target_weights
42
+
43
+
44
+ def _prepare_deltas(
45
+ lora_sd_and_strengths: list[LoraStateDictWithStrength], key: str, dtype: torch.dtype, device: torch.device
46
+ ) -> torch.Tensor | None:
47
+ deltas = []
48
+ prefix = key[: -len(".weight")]
49
+ key_a = f"{prefix}.lora_A.weight"
50
+ key_b = f"{prefix}.lora_B.weight"
51
+ for lsd, coef in lora_sd_and_strengths:
52
+ if key_a not in lsd.sd or key_b not in lsd.sd:
53
+ continue
54
+ product = torch.matmul(lsd.sd[key_b] * coef, lsd.sd[key_a])
55
+ deltas.append(product.to(dtype=dtype, device=device))
56
+ if len(deltas) == 0:
57
+ return None
58
+ elif len(deltas) == 1:
59
+ return deltas[0]
60
+ return torch.sum(torch.stack(deltas, dim=0), dim=0)
61
+
62
+
63
+ def apply_loras(
64
+ model_sd: StateDict,
65
+ lora_sd_and_strengths: list[LoraStateDictWithStrength],
66
+ dtype: torch.dtype,
67
+ destination_sd: StateDict | None = None,
68
+ ) -> StateDict:
69
+ sd = {}
70
+ if destination_sd is not None:
71
+ sd = destination_sd.sd
72
+ size = 0
73
+ device = torch.device("meta")
74
+ inner_dtypes = set()
75
+ for key, weight in model_sd.sd.items():
76
+ if weight is None:
77
+ continue
78
+ device = weight.device
79
+ target_dtype = dtype if dtype is not None else weight.dtype
80
+ deltas_dtype = target_dtype if target_dtype not in [torch.float8_e4m3fn, torch.float8_e5m2] else torch.bfloat16
81
+ deltas = _prepare_deltas(lora_sd_and_strengths, key, deltas_dtype, device)
82
+ if deltas is None:
83
+ if key in sd:
84
+ continue
85
+ deltas = weight.clone().to(dtype=target_dtype, device=device)
86
+ elif weight.dtype == torch.float8_e4m3fn:
87
+ if str(device).startswith("cuda"):
88
+ deltas = calculate_weight_float8_(deltas, weight)
89
+ else:
90
+ deltas.add_(weight.to(dtype=deltas.dtype, device=device))
91
+ elif weight.dtype == torch.bfloat16:
92
+ deltas.add_(weight)
93
+ else:
94
+ raise ValueError(f"Unsupported dtype: {weight.dtype}")
95
+ sd[key] = deltas.to(dtype=target_dtype)
96
+ inner_dtypes.add(target_dtype)
97
+ size += deltas.nbytes
98
+ if destination_sd is not None:
99
+ return destination_sd
100
+ return StateDict(sd, device, size, inner_dtypes)
packages/ltx-core/src/ltx_core/loader/kernels.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ruff: noqa: ANN001, ANN201, ERA001, N803, N806
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def fused_add_round_kernel(
8
+ x_ptr,
9
+ output_ptr, # contents will be added to the output
10
+ seed,
11
+ n_elements,
12
+ EXPONENT_BIAS,
13
+ MANTISSA_BITS,
14
+ BLOCK_SIZE: tl.constexpr,
15
+ ):
16
+ """
17
+ A kernel to upcast 8bit quantized weights to bfloat16 with stochastic rounding
18
+ and add them to bfloat16 output weights. Might be used to upcast original model weights
19
+ and to further add them to precalculated deltas coming from LoRAs.
20
+ """
21
+ # Get program ID and compute offsets
22
+ pid = tl.program_id(axis=0)
23
+ block_start = pid * BLOCK_SIZE
24
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
25
+ mask = offsets < n_elements
26
+
27
+ # Load data
28
+ x = tl.load(x_ptr + offsets, mask=mask)
29
+ rand_vals = tl.rand(seed, offsets) - 0.5
30
+
31
+ x = tl.cast(x, tl.float16)
32
+ delta = tl.load(output_ptr + offsets, mask=mask)
33
+ delta = tl.cast(delta, tl.float16)
34
+ x = x + delta
35
+
36
+ x_bits = tl.cast(x, tl.int16, bitcast=True)
37
+
38
+ # Calculate the exponent. Unbiased fp16 exponent is ((x_bits & 0x7C00) >> 10) - 15 for
39
+ # normal numbers and -14 for subnormals.
40
+ fp16_exponent_bits = (x_bits & 0x7C00) >> 10
41
+ fp16_normals = fp16_exponent_bits > 0
42
+ fp16_exponent = tl.where(fp16_normals, fp16_exponent_bits - 15, -14)
43
+
44
+ # Add the target dtype's exponent bias and clamp to the target dtype's exponent range.
45
+ exponent = fp16_exponent + EXPONENT_BIAS
46
+ MAX_EXPONENT = 2 * EXPONENT_BIAS + 1
47
+ exponent = tl.where(exponent > MAX_EXPONENT, MAX_EXPONENT, exponent)
48
+ exponent = tl.where(exponent < 0, 0, exponent)
49
+
50
+ # Normal ULP exponent, expressed as an fp16 exponent field:
51
+ # (exponent - EXPONENT_BIAS - MANTISSA_BITS) + 15
52
+ # Simplifies to: fp16_exponent - MANTISSA_BITS + 15
53
+ # See https://en.wikipedia.org/wiki/Unit_in_the_last_place
54
+ eps_exp = tl.maximum(0, tl.minimum(31, exponent - EXPONENT_BIAS - MANTISSA_BITS + 15))
55
+
56
+ # Calculate epsilon in the target dtype
57
+ eps_normal = tl.cast(tl.cast(eps_exp << 10, tl.int16), tl.float16, bitcast=True)
58
+
59
+ # Subnormal ULP: 2^(1 - EXPONENT_BIAS - MANTISSA_BITS) ->
60
+ # fp16 exponent bits: (1 - EXPONENT_BIAS - MANTISSA_BITS) + 15 =
61
+ # 16 - EXPONENT_BIAS - MANTISSA_BITS
62
+ eps_subnormal = tl.cast(tl.cast((16 - EXPONENT_BIAS - MANTISSA_BITS) << 10, tl.int16), tl.float16, bitcast=True)
63
+ eps = tl.where(exponent > 0, eps_normal, eps_subnormal)
64
+
65
+ # Apply zero mask to epsilon
66
+ eps = tl.where(x == 0, 0.0, eps)
67
+
68
+ # Apply stochastic rounding
69
+ output = tl.cast(x + rand_vals * eps, tl.bfloat16)
70
+
71
+ # Store the result
72
+ tl.store(output_ptr + offsets, output, mask=mask)
packages/ltx-core/src/ltx_core/loader/module_ops.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, NamedTuple
2
+
3
+ import torch
4
+
5
+
6
+ class ModuleOps(NamedTuple):
7
+ """
8
+ Defines a named operation for matching and mutating PyTorch modules.
9
+ Used to selectively transform modules in a model (e.g., replacing layers with quantized versions).
10
+ """
11
+
12
+ name: str
13
+ matcher: Callable[[torch.nn.Module], bool]
14
+ mutator: Callable[[torch.nn.Module], torch.nn.Module]
packages/ltx-core/src/ltx_core/loader/primitives.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import NamedTuple, Protocol
3
+
4
+ import torch
5
+
6
+ from ltx_core.loader.module_ops import ModuleOps
7
+ from ltx_core.loader.sd_ops import SDOps
8
+ from ltx_core.model.model_protocol import ModelType
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class StateDict:
13
+ """
14
+ Immutable container for a PyTorch state dictionary.
15
+ Contains:
16
+ - sd: Dictionary of tensors (weights, buffers, etc.)
17
+ - device: Device where tensors are stored
18
+ - size: Total memory footprint in bytes
19
+ - dtype: Set of tensor dtypes present
20
+ """
21
+
22
+ sd: dict
23
+ device: torch.device
24
+ size: int
25
+ dtype: set[torch.dtype]
26
+
27
+ def footprint(self) -> tuple[int, torch.device]:
28
+ return self.size, self.device
29
+
30
+
31
+ class StateDictLoader(Protocol):
32
+ """
33
+ Protocol for loading state dictionaries from various sources.
34
+ Implementations must provide:
35
+ - metadata: Extract model metadata from a single path
36
+ - load: Load state dict from path(s) and apply SDOps transformations
37
+ """
38
+
39
+ def metadata(self, path: str) -> dict:
40
+ """
41
+ Load metadata from path
42
+ """
43
+
44
+ def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict:
45
+ """
46
+ Load state dict from path or paths (for sharded model storage) and apply sd_ops
47
+ """
48
+
49
+
50
+ class ModelBuilderProtocol(Protocol[ModelType]):
51
+ """
52
+ Protocol for building PyTorch models from configuration dictionaries.
53
+ Implementations must provide:
54
+ - meta_model: Create a model from configuration dictionary and apply module operations
55
+ - build: Create and initialize a model from state dictionary and apply dtype transformations
56
+ """
57
+
58
+ def meta_model(self, config: dict, module_ops: list[ModuleOps] | None = None) -> ModelType:
59
+ """
60
+ Create a model on the meta device from a configuration dictionary.
61
+ This decouples model creation from weight loading, allowing the model
62
+ architecture to be instantiated without allocating memory for parameters.
63
+ Args:
64
+ config: Model configuration dictionary.
65
+ module_ops: Optional list of module operations to apply (e.g., quantization).
66
+ Returns:
67
+ Model instance on meta device (no actual memory allocated for parameters).
68
+ """
69
+ ...
70
+
71
+ def build(self, dtype: torch.dtype | None = None) -> ModelType:
72
+ """
73
+ Build the model
74
+ Args:
75
+ dtype: Target dtype for the model, if None, uses the dtype of the model_path model
76
+ Returns:
77
+ Model instance
78
+ """
79
+ ...
80
+
81
+
82
+ class LoRAAdaptableProtocol(Protocol):
83
+ """
84
+ Protocol for models that can be adapted with LoRAs.
85
+ Implementations must provide:
86
+ - lora: Add a LoRA to the model
87
+ """
88
+
89
+ def lora(self, lora_path: str, strength: float) -> "LoRAAdaptableProtocol":
90
+ pass
91
+
92
+
93
+ class LoraPathStrengthAndSDOps(NamedTuple):
94
+ """
95
+ Tuple containing a LoRA path, strength, and SDOps for applying to the LoRA state dict.
96
+ """
97
+
98
+ path: str
99
+ strength: float
100
+ sd_ops: SDOps
101
+
102
+
103
+ class LoraStateDictWithStrength(NamedTuple):
104
+ """
105
+ Tuple containing a LoRA state dict and strength for applying to the model.
106
+ """
107
+
108
+ state_dict: StateDict
109
+ strength: float
packages/ltx-core/src/ltx_core/loader/registry.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import threading
3
+ from dataclasses import dataclass, field
4
+ from pathlib import Path
5
+ from typing import Protocol
6
+
7
+ from ltx_core.loader.primitives import StateDict
8
+ from ltx_core.loader.sd_ops import SDOps
9
+
10
+
11
+ class Registry(Protocol):
12
+ """
13
+ Protocol for managing state dictionaries in a registry.
14
+ It is used to store state dictionaries and reuse them later without loading them again.
15
+ Implementations must provide:
16
+ - add: Add a state dictionary to the registry
17
+ - pop: Remove a state dictionary from the registry
18
+ - get: Retrieve a state dictionary from the registry
19
+ - clear: Clear all state dictionaries from the registry
20
+ """
21
+
22
+ def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None: ...
23
+
24
+ def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ...
25
+
26
+ def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ...
27
+
28
+ def clear(self) -> None: ...
29
+
30
+
31
+ class DummyRegistry(Registry):
32
+ """
33
+ Dummy registry that does not store state dictionaries.
34
+ """
35
+
36
+ def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None:
37
+ pass
38
+
39
+ def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
40
+ pass
41
+
42
+ def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
43
+ pass
44
+
45
+ def clear(self) -> None:
46
+ pass
47
+
48
+
49
+ @dataclass
50
+ class StateDictRegistry(Registry):
51
+ """
52
+ Registry that stores state dictionaries in a dictionary.
53
+ """
54
+
55
+ _state_dicts: dict[str, StateDict] = field(default_factory=dict)
56
+ _lock: threading.Lock = field(default_factory=threading.Lock)
57
+
58
+ def _generate_id(self, paths: list[str], sd_ops: SDOps) -> str:
59
+ m = hashlib.sha256()
60
+ parts = [str(Path(p).resolve()) for p in paths]
61
+ if sd_ops is not None:
62
+ parts.append(sd_ops.name)
63
+ m.update("\0".join(parts).encode("utf-8"))
64
+ return m.hexdigest()
65
+
66
+ def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> str:
67
+ sd_id = self._generate_id(paths, sd_ops)
68
+ with self._lock:
69
+ if sd_id in self._state_dicts:
70
+ raise ValueError(f"State dict retrieved from {paths} with {sd_ops} already added, check with get first")
71
+ self._state_dicts[sd_id] = state_dict
72
+ return sd_id
73
+
74
+ def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
75
+ with self._lock:
76
+ return self._state_dicts.pop(self._generate_id(paths, sd_ops), None)
77
+
78
+ def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
79
+ with self._lock:
80
+ return self._state_dicts.get(self._generate_id(paths, sd_ops), None)
81
+
82
+ def clear(self) -> None:
83
+ with self._lock:
84
+ self._state_dicts.clear()
packages/ltx-core/src/ltx_core/loader/sd_ops.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, replace
2
+ from typing import NamedTuple, Protocol
3
+
4
+ import torch
5
+
6
+
7
+ @dataclass(frozen=True, slots=True)
8
+ class ContentReplacement:
9
+ """
10
+ Represents a content replacement operation.
11
+ Used to replace a specific content with a replacement in a state dict key.
12
+ """
13
+
14
+ content: str
15
+ replacement: str
16
+
17
+
18
+ @dataclass(frozen=True, slots=True)
19
+ class ContentMatching:
20
+ """
21
+ Represents a content matching operation.
22
+ Used to match a specific prefix and suffix in a state dict key.
23
+ """
24
+
25
+ prefix: str = ""
26
+ suffix: str = ""
27
+
28
+
29
+ class KeyValueOperationResult(NamedTuple):
30
+ """
31
+ Represents the result of a key-value operation.
32
+ Contains the new key and value after the operation has been applied.
33
+ """
34
+
35
+ new_key: str
36
+ new_value: torch.Tensor
37
+
38
+
39
+ class KeyValueOperation(Protocol):
40
+ """
41
+ Protocol for key-value operations.
42
+ Used to apply operations to a specific key and value in a state dict.
43
+ """
44
+
45
+ def __call__(self, tensor_key: str, tensor_value: torch.Tensor) -> list[KeyValueOperationResult]: ...
46
+
47
+
48
+ @dataclass(frozen=True, slots=True)
49
+ class SDKeyValueOperation:
50
+ """
51
+ Represents a key-value operation.
52
+ Used to apply operations to a specific key and value in a state dict.
53
+ """
54
+
55
+ key_matcher: ContentMatching
56
+ kv_operation: KeyValueOperation
57
+
58
+
59
+ @dataclass(frozen=True, slots=True)
60
+ class SDOps:
61
+ """Immutable class representing state dict key operations."""
62
+
63
+ name: str
64
+ mapping: tuple[
65
+ ContentReplacement | ContentMatching | SDKeyValueOperation, ...
66
+ ] = () # Immutable tuple of (key, value) pairs
67
+
68
+ def with_replacement(self, content: str, replacement: str) -> "SDOps":
69
+ """Create a new SDOps instance with the specified replacement added to the mapping."""
70
+
71
+ new_mapping = (*self.mapping, ContentReplacement(content, replacement))
72
+ return replace(self, mapping=new_mapping)
73
+
74
+ def with_matching(self, prefix: str = "", suffix: str = "") -> "SDOps":
75
+ """Create a new SDOps instance with the specified prefix and suffix matching added to the mapping."""
76
+
77
+ new_mapping = (*self.mapping, ContentMatching(prefix, suffix))
78
+ return replace(self, mapping=new_mapping)
79
+
80
+ def with_kv_operation(
81
+ self,
82
+ operation: KeyValueOperation,
83
+ key_prefix: str = "",
84
+ key_suffix: str = "",
85
+ ) -> "SDOps":
86
+ """Create a new SDOps instance with the specified value operation added to the mapping."""
87
+ key_matcher = ContentMatching(key_prefix, key_suffix)
88
+ sd_kv_operation = SDKeyValueOperation(key_matcher, operation)
89
+ new_mapping = (*self.mapping, sd_kv_operation)
90
+ return replace(self, mapping=new_mapping)
91
+
92
+ def apply_to_key(self, key: str) -> str | None:
93
+ """Apply the mapping to the given name."""
94
+ matchers = [content for content in self.mapping if isinstance(content, ContentMatching)]
95
+ valid = any(key.startswith(f.prefix) and key.endswith(f.suffix) for f in matchers)
96
+ if not valid:
97
+ return None
98
+
99
+ for replacement in self.mapping:
100
+ if not isinstance(replacement, ContentReplacement):
101
+ continue
102
+ if replacement.content in key:
103
+ key = key.replace(replacement.content, replacement.replacement)
104
+ return key
105
+
106
+ def apply_to_key_value(self, key: str, value: torch.Tensor) -> list[KeyValueOperationResult]:
107
+ """Apply the value operation to the given name and associated value."""
108
+ for operation in self.mapping:
109
+ if not isinstance(operation, SDKeyValueOperation):
110
+ continue
111
+ if key.startswith(operation.key_matcher.prefix) and key.endswith(operation.key_matcher.suffix):
112
+ return operation.kv_operation(key, value)
113
+ return [KeyValueOperationResult(key, value)]
114
+
115
+
116
+ # Predefined SDOps instances
117
+ LTXV_LORA_COMFY_RENAMING_MAP = (
118
+ SDOps("LTXV_LORA_COMFY_PREFIX_MAP").with_matching().with_replacement("diffusion_model.", "")
119
+ )
120
+
121
+ LTXV_LORA_COMFY_TARGET_MAP = (
122
+ SDOps("LTXV_LORA_COMFY_TARGET_MAP")
123
+ .with_matching()
124
+ .with_replacement("diffusion_model.", "")
125
+ .with_replacement(".lora_A.weight", ".weight")
126
+ .with_replacement(".lora_B.weight", ".weight")
127
+ )
packages/ltx-core/src/ltx_core/loader/sft_loader.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import safetensors
4
+ import torch
5
+
6
+ from ltx_core.loader.primitives import StateDict, StateDictLoader
7
+ from ltx_core.loader.sd_ops import SDOps
8
+
9
+
10
+ class SafetensorsStateDictLoader(StateDictLoader):
11
+ """
12
+ Loads weights from safetensors files without metadata support.
13
+ Use this for loading raw weight files. For model files that include
14
+ configuration metadata, use SafetensorsModelStateDictLoader instead.
15
+ """
16
+
17
+ def metadata(self, path: str) -> dict:
18
+ raise NotImplementedError("Not implemented")
19
+
20
+ def load(self, path: str | list[str], sd_ops: SDOps, device: torch.device | None = None) -> StateDict:
21
+ """
22
+ Load state dict from path or paths (for sharded model storage) and apply sd_ops
23
+ """
24
+ sd = {}
25
+ size = 0
26
+ dtype = set()
27
+ device = device or torch.device("cpu")
28
+ model_paths = path if isinstance(path, list) else [path]
29
+ for shard_path in model_paths:
30
+ with safetensors.safe_open(shard_path, framework="pt", device=str(device)) as f:
31
+ safetensor_keys = f.keys()
32
+ for name in safetensor_keys:
33
+ expected_name = name if sd_ops is None else sd_ops.apply_to_key(name)
34
+ if expected_name is None:
35
+ continue
36
+ value = f.get_tensor(name).to(device=device, non_blocking=True, copy=False)
37
+ key_value_pairs = ((expected_name, value),)
38
+ if sd_ops is not None:
39
+ key_value_pairs = sd_ops.apply_to_key_value(expected_name, value)
40
+ for key, value in key_value_pairs:
41
+ size += value.nbytes
42
+ dtype.add(value.dtype)
43
+ sd[key] = value
44
+
45
+ return StateDict(sd=sd, device=device, size=size, dtype=dtype)
46
+
47
+
48
+ class SafetensorsModelStateDictLoader(StateDictLoader):
49
+ """
50
+ Loads weights and configuration metadata from safetensors model files.
51
+ Unlike SafetensorsStateDictLoader, this loader can read model configuration
52
+ from the safetensors file metadata via the metadata() method.
53
+ """
54
+
55
+ def __init__(self, weight_loader: SafetensorsStateDictLoader | None = None):
56
+ self.weight_loader = weight_loader if weight_loader is not None else SafetensorsStateDictLoader()
57
+
58
+ def metadata(self, path: str) -> dict:
59
+ with safetensors.safe_open(path, framework="pt") as f:
60
+ return json.loads(f.metadata()["config"])
61
+
62
+ def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict:
63
+ return self.weight_loader.load(path, sd_ops, device)
packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass, field, replace
3
+ from typing import Generic
4
+
5
+ import torch
6
+
7
+ from ltx_core.loader.fuse_loras import apply_loras
8
+ from ltx_core.loader.module_ops import ModuleOps
9
+ from ltx_core.loader.primitives import (
10
+ LoRAAdaptableProtocol,
11
+ LoraPathStrengthAndSDOps,
12
+ LoraStateDictWithStrength,
13
+ ModelBuilderProtocol,
14
+ StateDict,
15
+ StateDictLoader,
16
+ )
17
+ from ltx_core.loader.registry import DummyRegistry, Registry
18
+ from ltx_core.loader.sd_ops import SDOps
19
+ from ltx_core.loader.sft_loader import SafetensorsModelStateDictLoader
20
+ from ltx_core.model.model_protocol import ModelConfigurator, ModelType
21
+
22
+ logger: logging.Logger = logging.getLogger(__name__)
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class SingleGPUModelBuilder(Generic[ModelType], ModelBuilderProtocol[ModelType], LoRAAdaptableProtocol):
27
+ """
28
+ Builder for PyTorch models residing on a single GPU.
29
+ """
30
+
31
+ model_class_configurator: type[ModelConfigurator[ModelType]]
32
+ model_path: str | tuple[str, ...]
33
+ model_sd_ops: SDOps | None = None
34
+ module_ops: tuple[ModuleOps, ...] = field(default_factory=tuple)
35
+ loras: tuple[LoraPathStrengthAndSDOps, ...] = field(default_factory=tuple)
36
+ model_loader: StateDictLoader = field(default_factory=SafetensorsModelStateDictLoader)
37
+ registry: Registry = field(default_factory=DummyRegistry)
38
+
39
+ def lora(self, lora_path: str, strength: float = 1.0, sd_ops: SDOps | None = None) -> "SingleGPUModelBuilder":
40
+ return replace(self, loras=(*self.loras, LoraPathStrengthAndSDOps(lora_path, strength, sd_ops)))
41
+
42
+ def model_config(self) -> dict:
43
+ first_shard_path = self.model_path[0] if isinstance(self.model_path, tuple) else self.model_path
44
+ return self.model_loader.metadata(first_shard_path)
45
+
46
+ def meta_model(self, config: dict, module_ops: tuple[ModuleOps, ...]) -> ModelType:
47
+ with torch.device("meta"):
48
+ model = self.model_class_configurator.from_config(config)
49
+ for module_op in module_ops:
50
+ if module_op.matcher(model):
51
+ model = module_op.mutator(model)
52
+ return model
53
+
54
+ def load_sd(
55
+ self, paths: list[str], registry: Registry, device: torch.device | None, sd_ops: SDOps | None = None
56
+ ) -> StateDict:
57
+ state_dict = registry.get(paths, sd_ops)
58
+ if state_dict is None:
59
+ state_dict = self.model_loader.load(paths, sd_ops=sd_ops, device=device)
60
+ registry.add(paths, sd_ops=sd_ops, state_dict=state_dict)
61
+ return state_dict
62
+
63
+ def _return_model(self, meta_model: ModelType, device: torch.device) -> ModelType:
64
+ uninitialized_params = [name for name, param in meta_model.named_parameters() if str(param.device) == "meta"]
65
+ uninitialized_buffers = [name for name, buffer in meta_model.named_buffers() if str(buffer.device) == "meta"]
66
+ if uninitialized_params or uninitialized_buffers:
67
+ logger.warning(f"Uninitialized parameters or buffers: {uninitialized_params + uninitialized_buffers}")
68
+ return meta_model
69
+ retval = meta_model.to(device)
70
+ return retval
71
+
72
+ def build(self, device: torch.device | None = None, dtype: torch.dtype | None = None) -> ModelType:
73
+ device = torch.device("cuda") if device is None else device
74
+ config = self.model_config()
75
+ meta_model = self.meta_model(config, self.module_ops)
76
+ model_paths = self.model_path if isinstance(self.model_path, tuple) else [self.model_path]
77
+ model_state_dict = self.load_sd(model_paths, sd_ops=self.model_sd_ops, registry=self.registry, device=device)
78
+
79
+ lora_strengths = [lora.strength for lora in self.loras]
80
+ if not lora_strengths or (min(lora_strengths) == 0 and max(lora_strengths) == 0):
81
+ sd = model_state_dict.sd
82
+ if dtype is not None:
83
+ sd = {key: value.to(dtype=dtype) for key, value in model_state_dict.sd.items()}
84
+ meta_model.load_state_dict(sd, strict=False, assign=True)
85
+ return self._return_model(meta_model, device)
86
+
87
+ lora_state_dicts = [
88
+ self.load_sd([lora.path], sd_ops=lora.sd_ops, registry=self.registry, device=device) for lora in self.loras
89
+ ]
90
+ lora_sd_and_strengths = [
91
+ LoraStateDictWithStrength(sd, strength)
92
+ for sd, strength in zip(lora_state_dicts, lora_strengths, strict=True)
93
+ ]
94
+ final_sd = apply_loras(
95
+ model_sd=model_state_dict,
96
+ lora_sd_and_strengths=lora_sd_and_strengths,
97
+ dtype=dtype,
98
+ destination_sd=model_state_dict if isinstance(self.registry, DummyRegistry) else None,
99
+ )
100
+ meta_model.load_state_dict(final_sd.sd, strict=False, assign=True)
101
+ return self._return_model(meta_model, device)
packages/ltx-core/src/ltx_core/model/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """Model definitions for LTX-2."""
2
+
3
+ from ltx_core.model.model_protocol import ModelConfigurator, ModelType
4
+
5
+ __all__ = [
6
+ "ModelConfigurator",
7
+ "ModelType",
8
+ ]
packages/ltx-core/src/ltx_core/model/audio_vae/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio VAE model components."""
2
+
3
+ from ltx_core.model.audio_vae.audio_vae import AudioDecoder, AudioEncoder, decode_audio
4
+ from ltx_core.model.audio_vae.model_configurator import (
5
+ AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
6
+ AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER,
7
+ VOCODER_COMFY_KEYS_FILTER,
8
+ AudioDecoderConfigurator,
9
+ AudioEncoderConfigurator,
10
+ VocoderConfigurator,
11
+ )
12
+ from ltx_core.model.audio_vae.ops import AudioProcessor
13
+ from ltx_core.model.audio_vae.vocoder import Vocoder
14
+
15
+ __all__ = [
16
+ "AUDIO_VAE_DECODER_COMFY_KEYS_FILTER",
17
+ "AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER",
18
+ "VOCODER_COMFY_KEYS_FILTER",
19
+ "AudioDecoder",
20
+ "AudioDecoderConfigurator",
21
+ "AudioEncoder",
22
+ "AudioEncoderConfigurator",
23
+ "AudioProcessor",
24
+ "Vocoder",
25
+ "VocoderConfigurator",
26
+ "decode_audio",
27
+ ]
packages/ltx-core/src/ltx_core/model/audio_vae/attention.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ import torch
4
+
5
+ from ltx_core.model.common.normalization import NormType, build_normalization_layer
6
+
7
+
8
+ class AttentionType(Enum):
9
+ """Enum for specifying the attention mechanism type."""
10
+
11
+ VANILLA = "vanilla"
12
+ LINEAR = "linear"
13
+ NONE = "none"
14
+
15
+
16
+ class AttnBlock(torch.nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_channels: int,
20
+ norm_type: NormType = NormType.GROUP,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.in_channels = in_channels
24
+
25
+ self.norm = build_normalization_layer(in_channels, normtype=norm_type)
26
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
27
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
28
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
29
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
30
+
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
32
+ h_ = x
33
+ h_ = self.norm(h_)
34
+ q = self.q(h_)
35
+ k = self.k(h_)
36
+ v = self.v(h_)
37
+
38
+ # compute attention
39
+ b, c, h, w = q.shape
40
+ q = q.reshape(b, c, h * w).contiguous()
41
+ q = q.permute(0, 2, 1).contiguous() # b,hw,c
42
+ k = k.reshape(b, c, h * w).contiguous() # b,c,hw
43
+ w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
44
+ w_ = w_ * (int(c) ** (-0.5))
45
+ w_ = torch.nn.functional.softmax(w_, dim=2)
46
+
47
+ # attend to values
48
+ v = v.reshape(b, c, h * w).contiguous()
49
+ w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
50
+ h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
51
+ h_ = h_.reshape(b, c, h, w).contiguous()
52
+
53
+ h_ = self.proj_out(h_)
54
+
55
+ return x + h_
56
+
57
+
58
+ def make_attn(
59
+ in_channels: int,
60
+ attn_type: AttentionType = AttentionType.VANILLA,
61
+ norm_type: NormType = NormType.GROUP,
62
+ ) -> torch.nn.Module:
63
+ match attn_type:
64
+ case AttentionType.VANILLA:
65
+ return AttnBlock(in_channels, norm_type=norm_type)
66
+ case AttentionType.NONE:
67
+ return torch.nn.Identity()
68
+ case AttentionType.LINEAR:
69
+ raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
70
+ case _:
71
+ raise ValueError(f"Unknown attention type: {attn_type}")
packages/ltx-core/src/ltx_core/model/audio_vae/audio_vae.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Set, Tuple
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from ltx_core.components.patchifiers import AudioPatchifier
7
+ from ltx_core.model.audio_vae.attention import AttentionType, make_attn
8
+ from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
9
+ from ltx_core.model.audio_vae.causality_axis import CausalityAxis
10
+ from ltx_core.model.audio_vae.downsample import build_downsampling_path
11
+ from ltx_core.model.audio_vae.ops import PerChannelStatistics
12
+ from ltx_core.model.audio_vae.resnet import ResnetBlock
13
+ from ltx_core.model.audio_vae.upsample import build_upsampling_path
14
+ from ltx_core.model.audio_vae.vocoder import Vocoder
15
+ from ltx_core.model.common.normalization import NormType, build_normalization_layer
16
+ from ltx_core.types import AudioLatentShape
17
+
18
+ LATENT_DOWNSAMPLE_FACTOR = 4
19
+
20
+
21
+ def build_mid_block(
22
+ channels: int,
23
+ temb_channels: int,
24
+ dropout: float,
25
+ norm_type: NormType,
26
+ causality_axis: CausalityAxis,
27
+ attn_type: AttentionType,
28
+ add_attention: bool,
29
+ ) -> torch.nn.Module:
30
+ """Build the middle block with two ResNet blocks and optional attention."""
31
+ mid = torch.nn.Module()
32
+ mid.block_1 = ResnetBlock(
33
+ in_channels=channels,
34
+ out_channels=channels,
35
+ temb_channels=temb_channels,
36
+ dropout=dropout,
37
+ norm_type=norm_type,
38
+ causality_axis=causality_axis,
39
+ )
40
+ mid.attn_1 = make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else torch.nn.Identity()
41
+ mid.block_2 = ResnetBlock(
42
+ in_channels=channels,
43
+ out_channels=channels,
44
+ temb_channels=temb_channels,
45
+ dropout=dropout,
46
+ norm_type=norm_type,
47
+ causality_axis=causality_axis,
48
+ )
49
+ return mid
50
+
51
+
52
+ def run_mid_block(mid: torch.nn.Module, features: torch.Tensor) -> torch.Tensor:
53
+ """Run features through the middle block."""
54
+ features = mid.block_1(features, temb=None)
55
+ features = mid.attn_1(features)
56
+ return mid.block_2(features, temb=None)
57
+
58
+
59
+ class AudioEncoder(torch.nn.Module):
60
+ """
61
+ Encoder that compresses audio spectrograms into latent representations.
62
+ The encoder uses a series of downsampling blocks with residual connections,
63
+ attention mechanisms, and configurable causal convolutions.
64
+ """
65
+
66
+ def __init__( # noqa: PLR0913
67
+ self,
68
+ *,
69
+ ch: int,
70
+ ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
71
+ num_res_blocks: int,
72
+ attn_resolutions: Set[int],
73
+ dropout: float = 0.0,
74
+ resamp_with_conv: bool = True,
75
+ in_channels: int,
76
+ resolution: int,
77
+ z_channels: int,
78
+ double_z: bool = True,
79
+ attn_type: AttentionType = AttentionType.VANILLA,
80
+ mid_block_add_attention: bool = True,
81
+ norm_type: NormType = NormType.GROUP,
82
+ causality_axis: CausalityAxis = CausalityAxis.WIDTH,
83
+ sample_rate: int = 16000,
84
+ mel_hop_length: int = 160,
85
+ n_fft: int = 1024,
86
+ is_causal: bool = True,
87
+ mel_bins: int = 64,
88
+ **_ignore_kwargs,
89
+ ) -> None:
90
+ """
91
+ Initialize the Encoder.
92
+ Args:
93
+ Arguments are configuration parameters, loaded from the audio VAE checkpoint config
94
+ (audio_vae.model.params.ddconfig):
95
+ ch: Base number of feature channels used in the first convolution layer.
96
+ ch_mult: Multiplicative factors for the number of channels at each resolution level.
97
+ num_res_blocks: Number of residual blocks to use at each resolution level.
98
+ attn_resolutions: Spatial resolutions (e.g., in time/frequency) at which to apply attention.
99
+ resolution: Input spatial resolution of the spectrogram (height, width).
100
+ z_channels: Number of channels in the latent representation.
101
+ norm_type: Normalization layer type to use within the network (e.g., group, batch).
102
+ causality_axis: Axis along which convolutions should be causal (e.g., time axis).
103
+ sample_rate: Audio sample rate in Hz for the input signals.
104
+ mel_hop_length: Hop length used when computing the mel spectrogram.
105
+ n_fft: FFT size used to compute the spectrogram.
106
+ mel_bins: Number of mel-frequency bins in the input spectrogram.
107
+ in_channels: Number of channels in the input spectrogram tensor.
108
+ double_z: If True, predict both mean and log-variance (doubling latent channels).
109
+ is_causal: If True, use causal convolutions suitable for streaming setups.
110
+ dropout: Dropout probability used in residual and mid blocks.
111
+ attn_type: Type of attention mechanism to use in attention blocks.
112
+ resamp_with_conv: If True, perform resolution changes using strided convolutions.
113
+ mid_block_add_attention: If True, add an attention block in the mid-level of the encoder.
114
+ """
115
+ super().__init__()
116
+
117
+ self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
118
+ self.sample_rate = sample_rate
119
+ self.mel_hop_length = mel_hop_length
120
+ self.n_fft = n_fft
121
+ self.is_causal = is_causal
122
+ self.mel_bins = mel_bins
123
+
124
+ self.patchifier = AudioPatchifier(
125
+ patch_size=1,
126
+ audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
127
+ sample_rate=sample_rate,
128
+ hop_length=mel_hop_length,
129
+ is_causal=is_causal,
130
+ )
131
+
132
+ self.ch = ch
133
+ self.temb_ch = 0
134
+ self.num_resolutions = len(ch_mult)
135
+ self.num_res_blocks = num_res_blocks
136
+ self.resolution = resolution
137
+ self.in_channels = in_channels
138
+ self.z_channels = z_channels
139
+ self.double_z = double_z
140
+ self.norm_type = norm_type
141
+ self.causality_axis = causality_axis
142
+ self.attn_type = attn_type
143
+
144
+ # downsampling
145
+ self.conv_in = make_conv2d(
146
+ in_channels,
147
+ self.ch,
148
+ kernel_size=3,
149
+ stride=1,
150
+ causality_axis=self.causality_axis,
151
+ )
152
+
153
+ self.non_linearity = torch.nn.SiLU()
154
+
155
+ self.down, block_in = build_downsampling_path(
156
+ ch=ch,
157
+ ch_mult=ch_mult,
158
+ num_resolutions=self.num_resolutions,
159
+ num_res_blocks=num_res_blocks,
160
+ resolution=resolution,
161
+ temb_channels=self.temb_ch,
162
+ dropout=dropout,
163
+ norm_type=self.norm_type,
164
+ causality_axis=self.causality_axis,
165
+ attn_type=self.attn_type,
166
+ attn_resolutions=attn_resolutions,
167
+ resamp_with_conv=resamp_with_conv,
168
+ )
169
+
170
+ self.mid = build_mid_block(
171
+ channels=block_in,
172
+ temb_channels=self.temb_ch,
173
+ dropout=dropout,
174
+ norm_type=self.norm_type,
175
+ causality_axis=self.causality_axis,
176
+ attn_type=self.attn_type,
177
+ add_attention=mid_block_add_attention,
178
+ )
179
+
180
+ self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
181
+ self.conv_out = make_conv2d(
182
+ block_in,
183
+ 2 * z_channels if double_z else z_channels,
184
+ kernel_size=3,
185
+ stride=1,
186
+ causality_axis=self.causality_axis,
187
+ )
188
+
189
+ def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:
190
+ """
191
+ Encode audio spectrogram into latent representations.
192
+ Args:
193
+ spectrogram: Input spectrogram of shape (batch, channels, time, frequency)
194
+ Returns:
195
+ Encoded latent representation of shape (batch, channels, frames, mel_bins)
196
+ """
197
+ h = self.conv_in(spectrogram)
198
+ h = self._run_downsampling_path(h)
199
+ h = run_mid_block(self.mid, h)
200
+ h = self._finalize_output(h)
201
+
202
+ return self._normalize_latents(h)
203
+
204
+ def _run_downsampling_path(self, h: torch.Tensor) -> torch.Tensor:
205
+ for level in range(self.num_resolutions):
206
+ stage = self.down[level]
207
+ for block_idx in range(self.num_res_blocks):
208
+ h = stage.block[block_idx](h, temb=None)
209
+ if stage.attn:
210
+ h = stage.attn[block_idx](h)
211
+
212
+ if level != self.num_resolutions - 1:
213
+ h = stage.downsample(h)
214
+
215
+ return h
216
+
217
+ def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:
218
+ h = self.norm_out(h)
219
+ h = self.non_linearity(h)
220
+ return self.conv_out(h)
221
+
222
+ def _normalize_latents(self, latent_output: torch.Tensor) -> torch.Tensor:
223
+ """
224
+ Normalize encoder latents using per-channel statistics.
225
+ When the encoder is configured with ``double_z=True``, the final
226
+ convolution produces twice the number of latent channels, typically
227
+ interpreted as two concatenated tensors along the channel dimension
228
+ (e.g., mean and variance or other auxiliary parameters).
229
+ This method intentionally uses only the first half of the channels
230
+ (the "mean" component) as input to the patchifier and normalization
231
+ logic. The remaining channels are left unchanged by this method and
232
+ are expected to be consumed elsewhere in the VAE pipeline.
233
+ If ``double_z=False``, the encoder output already contains only the
234
+ mean latents and the chunking operation simply returns that tensor.
235
+ """
236
+ means = torch.chunk(latent_output, 2, dim=1)[0]
237
+ latent_shape = AudioLatentShape(
238
+ batch=means.shape[0],
239
+ channels=means.shape[1],
240
+ frames=means.shape[2],
241
+ mel_bins=means.shape[3],
242
+ )
243
+ latent_patched = self.patchifier.patchify(means)
244
+ latent_normalized = self.per_channel_statistics.normalize(latent_patched)
245
+ return self.patchifier.unpatchify(latent_normalized, latent_shape)
246
+
247
+
248
+ class AudioDecoder(torch.nn.Module):
249
+ """
250
+ Symmetric decoder that reconstructs audio spectrograms from latent features.
251
+ The decoder mirrors the encoder structure with configurable channel multipliers,
252
+ attention resolutions, and causal convolutions.
253
+ """
254
+
255
+ def __init__( # noqa: PLR0913
256
+ self,
257
+ *,
258
+ ch: int,
259
+ out_ch: int,
260
+ ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
261
+ num_res_blocks: int,
262
+ attn_resolutions: Set[int],
263
+ resolution: int,
264
+ z_channels: int,
265
+ norm_type: NormType = NormType.GROUP,
266
+ causality_axis: CausalityAxis = CausalityAxis.WIDTH,
267
+ dropout: float = 0.0,
268
+ mid_block_add_attention: bool = True,
269
+ sample_rate: int = 16000,
270
+ mel_hop_length: int = 160,
271
+ is_causal: bool = True,
272
+ mel_bins: int | None = None,
273
+ ) -> None:
274
+ """
275
+ Initialize the Decoder.
276
+ Args:
277
+ Arguments are configuration parameters, loaded from the audio VAE checkpoint config
278
+ (audio_vae.model.params.ddconfig):
279
+ - ch, out_ch, ch_mult, num_res_blocks, attn_resolutions
280
+ - resolution, z_channels
281
+ - norm_type, causality_axis
282
+ """
283
+ super().__init__()
284
+
285
+ # Internal behavioural defaults that are not driven by the checkpoint.
286
+ resamp_with_conv = True
287
+ attn_type = AttentionType.VANILLA
288
+
289
+ # Per-channel statistics for denormalizing latents
290
+ self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
291
+ self.sample_rate = sample_rate
292
+ self.mel_hop_length = mel_hop_length
293
+ self.is_causal = is_causal
294
+ self.mel_bins = mel_bins
295
+ self.patchifier = AudioPatchifier(
296
+ patch_size=1,
297
+ audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
298
+ sample_rate=sample_rate,
299
+ hop_length=mel_hop_length,
300
+ is_causal=is_causal,
301
+ )
302
+
303
+ self.ch = ch
304
+ self.temb_ch = 0
305
+ self.num_resolutions = len(ch_mult)
306
+ self.num_res_blocks = num_res_blocks
307
+ self.resolution = resolution
308
+ self.out_ch = out_ch
309
+ self.give_pre_end = False
310
+ self.tanh_out = False
311
+ self.norm_type = norm_type
312
+ self.z_channels = z_channels
313
+ self.channel_multipliers = ch_mult
314
+ self.attn_resolutions = attn_resolutions
315
+ self.causality_axis = causality_axis
316
+ self.attn_type = attn_type
317
+
318
+ base_block_channels = ch * self.channel_multipliers[-1]
319
+ base_resolution = resolution // (2 ** (self.num_resolutions - 1))
320
+ self.z_shape = (1, z_channels, base_resolution, base_resolution)
321
+
322
+ self.conv_in = make_conv2d(
323
+ z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
324
+ )
325
+ self.non_linearity = torch.nn.SiLU()
326
+ self.mid = build_mid_block(
327
+ channels=base_block_channels,
328
+ temb_channels=self.temb_ch,
329
+ dropout=dropout,
330
+ norm_type=self.norm_type,
331
+ causality_axis=self.causality_axis,
332
+ attn_type=self.attn_type,
333
+ add_attention=mid_block_add_attention,
334
+ )
335
+ self.up, final_block_channels = build_upsampling_path(
336
+ ch=ch,
337
+ ch_mult=ch_mult,
338
+ num_resolutions=self.num_resolutions,
339
+ num_res_blocks=num_res_blocks,
340
+ resolution=resolution,
341
+ temb_channels=self.temb_ch,
342
+ dropout=dropout,
343
+ norm_type=self.norm_type,
344
+ causality_axis=self.causality_axis,
345
+ attn_type=self.attn_type,
346
+ attn_resolutions=attn_resolutions,
347
+ resamp_with_conv=resamp_with_conv,
348
+ initial_block_channels=base_block_channels,
349
+ )
350
+
351
+ self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
352
+ self.conv_out = make_conv2d(
353
+ final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
354
+ )
355
+
356
+ def forward(self, sample: torch.Tensor) -> torch.Tensor:
357
+ """
358
+ Decode latent features back to audio spectrograms.
359
+ Args:
360
+ sample: Encoded latent representation of shape (batch, channels, frames, mel_bins)
361
+ Returns:
362
+ Reconstructed audio spectrogram of shape (batch, channels, time, frequency)
363
+ """
364
+ sample, target_shape = self._denormalize_latents(sample)
365
+
366
+ h = self.conv_in(sample)
367
+ h = run_mid_block(self.mid, h)
368
+ h = self._run_upsampling_path(h)
369
+ h = self._finalize_output(h)
370
+
371
+ return self._adjust_output_shape(h, target_shape)
372
+
373
+ def _denormalize_latents(self, sample: torch.Tensor) -> tuple[torch.Tensor, AudioLatentShape]:
374
+ latent_shape = AudioLatentShape(
375
+ batch=sample.shape[0],
376
+ channels=sample.shape[1],
377
+ frames=sample.shape[2],
378
+ mel_bins=sample.shape[3],
379
+ )
380
+
381
+ sample_patched = self.patchifier.patchify(sample)
382
+ sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched)
383
+ sample = self.patchifier.unpatchify(sample_denormalized, latent_shape)
384
+
385
+ target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR
386
+ if self.causality_axis != CausalityAxis.NONE:
387
+ target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
388
+
389
+ target_shape = AudioLatentShape(
390
+ batch=latent_shape.batch,
391
+ channels=self.out_ch,
392
+ frames=target_frames,
393
+ mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
394
+ )
395
+
396
+ return sample, target_shape
397
+
398
+ def _adjust_output_shape(
399
+ self,
400
+ decoded_output: torch.Tensor,
401
+ target_shape: AudioLatentShape,
402
+ ) -> torch.Tensor:
403
+ """
404
+ Adjust output shape to match target dimensions for variable-length audio.
405
+ This function handles the common case where decoded audio spectrograms need to be
406
+ resized to match a specific target shape.
407
+ Args:
408
+ decoded_output: Tensor of shape (batch, channels, time, frequency)
409
+ target_shape: AudioLatentShape describing (batch, channels, time, mel bins)
410
+ Returns:
411
+ Tensor adjusted to match target_shape exactly
412
+ """
413
+ # Current output shape: (batch, channels, time, frequency)
414
+ _, _, current_time, current_freq = decoded_output.shape
415
+ target_channels = target_shape.channels
416
+ target_time = target_shape.frames
417
+ target_freq = target_shape.mel_bins
418
+
419
+ # Step 1: Crop first to avoid exceeding target dimensions
420
+ decoded_output = decoded_output[
421
+ :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
422
+ ]
423
+
424
+ # Step 2: Calculate padding needed for time and frequency dimensions
425
+ time_padding_needed = target_time - decoded_output.shape[2]
426
+ freq_padding_needed = target_freq - decoded_output.shape[3]
427
+
428
+ # Step 3: Apply padding if needed
429
+ if time_padding_needed > 0 or freq_padding_needed > 0:
430
+ # PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)
431
+ # For audio: pad_left/right = frequency, pad_top/bottom = time
432
+ padding = (
433
+ 0,
434
+ max(freq_padding_needed, 0), # frequency padding (left, right)
435
+ 0,
436
+ max(time_padding_needed, 0), # time padding (top, bottom)
437
+ )
438
+ decoded_output = F.pad(decoded_output, padding)
439
+
440
+ # Step 4: Final safety crop to ensure exact target shape
441
+ decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
442
+
443
+ return decoded_output
444
+
445
+ def _run_upsampling_path(self, h: torch.Tensor) -> torch.Tensor:
446
+ for level in reversed(range(self.num_resolutions)):
447
+ stage = self.up[level]
448
+ for block_idx, block in enumerate(stage.block):
449
+ h = block(h, temb=None)
450
+ if stage.attn:
451
+ h = stage.attn[block_idx](h)
452
+
453
+ if level != 0 and hasattr(stage, "upsample"):
454
+ h = stage.upsample(h)
455
+
456
+ return h
457
+
458
+ def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:
459
+ if self.give_pre_end:
460
+ return h
461
+
462
+ h = self.norm_out(h)
463
+ h = self.non_linearity(h)
464
+ h = self.conv_out(h)
465
+ return torch.tanh(h) if self.tanh_out else h
466
+
467
+
468
+ def decode_audio(latent: torch.Tensor, audio_decoder: "AudioDecoder", vocoder: "Vocoder") -> torch.Tensor:
469
+ """
470
+ Decode an audio latent representation using the provided audio decoder and vocoder.
471
+ Args:
472
+ latent: Input audio latent tensor.
473
+ audio_decoder: Model to decode the latent to waveform features.
474
+ vocoder: Model to convert decoded features to audio waveform.
475
+ Returns:
476
+ Decoded audio as a float tensor.
477
+ """
478
+ decoded_audio = audio_decoder(latent)
479
+ decoded_audio = vocoder(decoded_audio).squeeze(0).float()
480
+ return decoded_audio
packages/ltx-core/src/ltx_core/model/audio_vae/causal_conv_2d.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from ltx_core.model.audio_vae.causality_axis import CausalityAxis
5
+
6
+
7
+ class CausalConv2d(torch.nn.Module):
8
+ """
9
+ A causal 2D convolution.
10
+ This layer ensures that the output at time `t` only depends on inputs
11
+ at time `t` and earlier. It achieves this by applying asymmetric padding
12
+ to the time dimension (width) before the convolution.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ in_channels: int,
18
+ out_channels: int,
19
+ kernel_size: int | tuple[int, int],
20
+ stride: int = 1,
21
+ dilation: int | tuple[int, int] = 1,
22
+ groups: int = 1,
23
+ bias: bool = True,
24
+ causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
25
+ ) -> None:
26
+ super().__init__()
27
+
28
+ self.causality_axis = causality_axis
29
+
30
+ # Ensure kernel_size and dilation are tuples
31
+ kernel_size = torch.nn.modules.utils._pair(kernel_size)
32
+ dilation = torch.nn.modules.utils._pair(dilation)
33
+
34
+ # Calculate padding dimensions
35
+ pad_h = (kernel_size[0] - 1) * dilation[0]
36
+ pad_w = (kernel_size[1] - 1) * dilation[1]
37
+
38
+ # The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom)
39
+ match self.causality_axis:
40
+ case CausalityAxis.NONE:
41
+ self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
42
+ case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY:
43
+ self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
44
+ case CausalityAxis.HEIGHT:
45
+ self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
46
+ case _:
47
+ raise ValueError(f"Invalid causality_axis: {causality_axis}")
48
+
49
+ # The internal convolution layer uses no padding, as we handle it manually
50
+ self.conv = torch.nn.Conv2d(
51
+ in_channels,
52
+ out_channels,
53
+ kernel_size,
54
+ stride=stride,
55
+ padding=0,
56
+ dilation=dilation,
57
+ groups=groups,
58
+ bias=bias,
59
+ )
60
+
61
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
62
+ # Apply causal padding before convolution
63
+ x = F.pad(x, self.padding)
64
+ return self.conv(x)
65
+
66
+
67
+ def make_conv2d(
68
+ in_channels: int,
69
+ out_channels: int,
70
+ kernel_size: int | tuple[int, int],
71
+ stride: int = 1,
72
+ padding: tuple[int, int, int, int] | None = None,
73
+ dilation: int = 1,
74
+ groups: int = 1,
75
+ bias: bool = True,
76
+ causality_axis: CausalityAxis | None = None,
77
+ ) -> torch.nn.Module:
78
+ """
79
+ Create a 2D convolution layer that can be either causal or non-causal.
80
+ Args:
81
+ in_channels: Number of input channels
82
+ out_channels: Number of output channels
83
+ kernel_size: Size of the convolution kernel
84
+ stride: Convolution stride
85
+ padding: Padding (if None, will be calculated based on causal flag)
86
+ dilation: Dilation rate
87
+ groups: Number of groups for grouped convolution
88
+ bias: Whether to use bias
89
+ causality_axis: Dimension along which to apply causality.
90
+ Returns:
91
+ Either a regular Conv2d or CausalConv2d layer
92
+ """
93
+ if causality_axis is not None:
94
+ # For causal convolution, padding is handled internally by CausalConv2d
95
+ return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis)
96
+ else:
97
+ # For non-causal convolution, use symmetric padding if not specified
98
+ if padding is None:
99
+ padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size)
100
+
101
+ return torch.nn.Conv2d(
102
+ in_channels,
103
+ out_channels,
104
+ kernel_size,
105
+ stride,
106
+ padding,
107
+ dilation,
108
+ groups,
109
+ bias,
110
+ )
packages/ltx-core/src/ltx_core/model/audio_vae/causality_axis.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+
4
+ class CausalityAxis(Enum):
5
+ """Enum for specifying the causality axis in causal convolutions."""
6
+
7
+ NONE = None
8
+ WIDTH = "width"
9
+ HEIGHT = "height"
10
+ WIDTH_COMPATIBILITY = "width-compatibility"
packages/ltx-core/src/ltx_core/model/audio_vae/downsample.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Set, Tuple
2
+
3
+ import torch
4
+
5
+ from ltx_core.model.audio_vae.attention import AttentionType, make_attn
6
+ from ltx_core.model.audio_vae.causality_axis import CausalityAxis
7
+ from ltx_core.model.audio_vae.resnet import ResnetBlock
8
+ from ltx_core.model.common.normalization import NormType
9
+
10
+
11
+ class Downsample(torch.nn.Module):
12
+ """
13
+ A downsampling layer that can use either a strided convolution
14
+ or average pooling. Supports standard and causal padding for the
15
+ convolutional mode.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ in_channels: int,
21
+ with_conv: bool,
22
+ causality_axis: CausalityAxis = CausalityAxis.WIDTH,
23
+ ) -> None:
24
+ super().__init__()
25
+ self.with_conv = with_conv
26
+ self.causality_axis = causality_axis
27
+
28
+ if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
29
+ raise ValueError("causality is only supported when `with_conv=True`.")
30
+
31
+ if self.with_conv:
32
+ # Do time downsampling here
33
+ # no asymmetric padding in torch conv, must do it ourselves
34
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ if self.with_conv:
38
+ # Padding tuple is in the order: (left, right, top, bottom).
39
+ match self.causality_axis:
40
+ case CausalityAxis.NONE:
41
+ pad = (0, 1, 0, 1)
42
+ case CausalityAxis.WIDTH:
43
+ pad = (2, 0, 0, 1)
44
+ case CausalityAxis.HEIGHT:
45
+ pad = (0, 1, 2, 0)
46
+ case CausalityAxis.WIDTH_COMPATIBILITY:
47
+ pad = (1, 0, 0, 1)
48
+ case _:
49
+ raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
50
+
51
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
52
+ x = self.conv(x)
53
+ else:
54
+ # This branch is only taken if with_conv=False, which implies causality_axis is NONE.
55
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
56
+
57
+ return x
58
+
59
+
60
+ def build_downsampling_path( # noqa: PLR0913
61
+ *,
62
+ ch: int,
63
+ ch_mult: Tuple[int, ...],
64
+ num_resolutions: int,
65
+ num_res_blocks: int,
66
+ resolution: int,
67
+ temb_channels: int,
68
+ dropout: float,
69
+ norm_type: NormType,
70
+ causality_axis: CausalityAxis,
71
+ attn_type: AttentionType,
72
+ attn_resolutions: Set[int],
73
+ resamp_with_conv: bool,
74
+ ) -> tuple[torch.nn.ModuleList, int]:
75
+ """Build the downsampling path with residual blocks, attention, and downsampling layers."""
76
+ down_modules = torch.nn.ModuleList()
77
+ curr_res = resolution
78
+ in_ch_mult = (1, *tuple(ch_mult))
79
+ block_in = ch
80
+
81
+ for i_level in range(num_resolutions):
82
+ block = torch.nn.ModuleList()
83
+ attn = torch.nn.ModuleList()
84
+ block_in = ch * in_ch_mult[i_level]
85
+ block_out = ch * ch_mult[i_level]
86
+
87
+ for _ in range(num_res_blocks):
88
+ block.append(
89
+ ResnetBlock(
90
+ in_channels=block_in,
91
+ out_channels=block_out,
92
+ temb_channels=temb_channels,
93
+ dropout=dropout,
94
+ norm_type=norm_type,
95
+ causality_axis=causality_axis,
96
+ )
97
+ )
98
+ block_in = block_out
99
+ if curr_res in attn_resolutions:
100
+ attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type))
101
+
102
+ down = torch.nn.Module()
103
+ down.block = block
104
+ down.attn = attn
105
+ if i_level != num_resolutions - 1:
106
+ down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
107
+ curr_res = curr_res // 2
108
+ down_modules.append(down)
109
+
110
+ return down_modules, block_in
packages/ltx-core/src/ltx_core/model/audio_vae/model_configurator.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ltx_core.loader.sd_ops import SDOps
2
+ from ltx_core.model.audio_vae.attention import AttentionType
3
+ from ltx_core.model.audio_vae.audio_vae import AudioDecoder, AudioEncoder
4
+ from ltx_core.model.audio_vae.causality_axis import CausalityAxis
5
+ from ltx_core.model.audio_vae.vocoder import Vocoder
6
+ from ltx_core.model.common.normalization import NormType
7
+ from ltx_core.model.model_protocol import ModelConfigurator
8
+
9
+
10
+ class VocoderConfigurator(ModelConfigurator[Vocoder]):
11
+ @classmethod
12
+ def from_config(cls: type[Vocoder], config: dict) -> Vocoder:
13
+ config = config.get("vocoder", {})
14
+ return Vocoder(
15
+ resblock_kernel_sizes=config.get("resblock_kernel_sizes", [3, 7, 11]),
16
+ upsample_rates=config.get("upsample_rates", [6, 5, 2, 2, 2]),
17
+ upsample_kernel_sizes=config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4]),
18
+ resblock_dilation_sizes=config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]]),
19
+ upsample_initial_channel=config.get("upsample_initial_channel", 1024),
20
+ stereo=config.get("stereo", True),
21
+ resblock=config.get("resblock", "1"),
22
+ output_sample_rate=config.get("output_sample_rate", 24000),
23
+ )
24
+
25
+
26
+ VOCODER_COMFY_KEYS_FILTER = (
27
+ SDOps("VOCODER_COMFY_KEYS_FILTER").with_matching(prefix="vocoder.").with_replacement("vocoder.", "")
28
+ )
29
+
30
+
31
+ class AudioDecoderConfigurator(ModelConfigurator[AudioDecoder]):
32
+ @classmethod
33
+ def from_config(cls: type[AudioDecoder], config: dict) -> AudioDecoder:
34
+ audio_vae_cfg = config.get("audio_vae", {})
35
+ model_cfg = audio_vae_cfg.get("model", {})
36
+ model_params = model_cfg.get("params", {})
37
+ ddconfig = model_params.get("ddconfig", {})
38
+ preprocessing_cfg = audio_vae_cfg.get("preprocessing", {})
39
+ stft_cfg = preprocessing_cfg.get("stft", {})
40
+ mel_cfg = preprocessing_cfg.get("mel", {})
41
+ variables_cfg = audio_vae_cfg.get("variables", {})
42
+
43
+ sample_rate = model_params.get("sampling_rate", 16000)
44
+ mel_hop_length = stft_cfg.get("hop_length", 160)
45
+ is_causal = stft_cfg.get("causal", True)
46
+ mel_bins = ddconfig.get("mel_bins") or mel_cfg.get("n_mel_channels") or variables_cfg.get("mel_bins")
47
+
48
+ return AudioDecoder(
49
+ ch=ddconfig.get("ch", 128),
50
+ out_ch=ddconfig.get("out_ch", 2),
51
+ ch_mult=tuple(ddconfig.get("ch_mult", (1, 2, 4))),
52
+ num_res_blocks=ddconfig.get("num_res_blocks", 2),
53
+ attn_resolutions=ddconfig.get("attn_resolutions", {8, 16, 32}),
54
+ resolution=ddconfig.get("resolution", 256),
55
+ z_channels=ddconfig.get("z_channels", 8),
56
+ norm_type=NormType(ddconfig.get("norm_type", "pixel")),
57
+ causality_axis=CausalityAxis(ddconfig.get("causality_axis", "height")),
58
+ dropout=ddconfig.get("dropout", 0.0),
59
+ mid_block_add_attention=ddconfig.get("mid_block_add_attention", True),
60
+ sample_rate=sample_rate,
61
+ mel_hop_length=mel_hop_length,
62
+ is_causal=is_causal,
63
+ mel_bins=mel_bins,
64
+ )
65
+
66
+
67
+ class AudioEncoderConfigurator(ModelConfigurator[AudioEncoder]):
68
+ @classmethod
69
+ def from_config(cls: type[AudioEncoder], config: dict) -> AudioEncoder:
70
+ audio_vae_cfg = config.get("audio_vae", {})
71
+ model_cfg = audio_vae_cfg.get("model", {})
72
+ model_params = model_cfg.get("params", {})
73
+ ddconfig = model_params.get("ddconfig", {})
74
+ preprocessing_cfg = audio_vae_cfg.get("preprocessing", {})
75
+ stft_cfg = preprocessing_cfg.get("stft", {})
76
+ mel_cfg = preprocessing_cfg.get("mel", {})
77
+ variables_cfg = audio_vae_cfg.get("variables", {})
78
+
79
+ sample_rate = model_params.get("sampling_rate", 16000)
80
+ mel_hop_length = stft_cfg.get("hop_length", 160)
81
+ n_fft = stft_cfg.get("filter_length", 1024)
82
+ is_causal = stft_cfg.get("causal", True)
83
+ mel_bins = ddconfig.get("mel_bins") or mel_cfg.get("n_mel_channels") or variables_cfg.get("mel_bins")
84
+
85
+ return AudioEncoder(
86
+ ch=ddconfig.get("ch", 128),
87
+ ch_mult=tuple(ddconfig.get("ch_mult", (1, 2, 4))),
88
+ num_res_blocks=ddconfig.get("num_res_blocks", 2),
89
+ attn_resolutions=ddconfig.get("attn_resolutions", {8, 16, 32}),
90
+ resolution=ddconfig.get("resolution", 256),
91
+ z_channels=ddconfig.get("z_channels", 8),
92
+ double_z=ddconfig.get("double_z", True),
93
+ dropout=ddconfig.get("dropout", 0.0),
94
+ resamp_with_conv=ddconfig.get("resamp_with_conv", True),
95
+ in_channels=ddconfig.get("in_channels", 2),
96
+ attn_type=AttentionType(ddconfig.get("attn_type", "vanilla")),
97
+ mid_block_add_attention=ddconfig.get("mid_block_add_attention", True),
98
+ norm_type=NormType(ddconfig.get("norm_type", "pixel")),
99
+ causality_axis=CausalityAxis(ddconfig.get("causality_axis", "height")),
100
+ sample_rate=sample_rate,
101
+ mel_hop_length=mel_hop_length,
102
+ n_fft=n_fft,
103
+ is_causal=is_causal,
104
+ mel_bins=mel_bins,
105
+ )
106
+
107
+
108
+ AUDIO_VAE_DECODER_COMFY_KEYS_FILTER = (
109
+ SDOps("AUDIO_VAE_DECODER_COMFY_KEYS_FILTER")
110
+ .with_matching(prefix="audio_vae.decoder.")
111
+ .with_matching(prefix="audio_vae.per_channel_statistics.")
112
+ .with_replacement("audio_vae.decoder.", "")
113
+ .with_replacement("audio_vae.per_channel_statistics.", "per_channel_statistics.")
114
+ )
115
+
116
+
117
+ AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER = (
118
+ SDOps("AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER")
119
+ .with_matching(prefix="audio_vae.encoder.")
120
+ .with_matching(prefix="audio_vae.per_channel_statistics.")
121
+ .with_replacement("audio_vae.encoder.", "")
122
+ .with_replacement("audio_vae.per_channel_statistics.", "per_channel_statistics.")
123
+ )
packages/ltx-core/src/ltx_core/model/audio_vae/ops.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from torch import nn
4
+
5
+
6
+ class AudioProcessor(nn.Module):
7
+ """Converts audio waveforms to log-mel spectrograms with optional resampling."""
8
+
9
+ def __init__(
10
+ self,
11
+ sample_rate: int,
12
+ mel_bins: int,
13
+ mel_hop_length: int,
14
+ n_fft: int,
15
+ ) -> None:
16
+ super().__init__()
17
+ self.sample_rate = sample_rate
18
+ self.mel_transform = torchaudio.transforms.MelSpectrogram(
19
+ sample_rate=sample_rate,
20
+ n_fft=n_fft,
21
+ win_length=n_fft,
22
+ hop_length=mel_hop_length,
23
+ f_min=0.0,
24
+ f_max=sample_rate / 2.0,
25
+ n_mels=mel_bins,
26
+ window_fn=torch.hann_window,
27
+ center=True,
28
+ pad_mode="reflect",
29
+ power=1.0,
30
+ mel_scale="slaney",
31
+ norm="slaney",
32
+ )
33
+
34
+ def resample_waveform(
35
+ self,
36
+ waveform: torch.Tensor,
37
+ source_rate: int,
38
+ target_rate: int,
39
+ ) -> torch.Tensor:
40
+ """Resample waveform to target sample rate if needed."""
41
+ if source_rate == target_rate:
42
+ return waveform
43
+ resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
44
+ return resampled.to(device=waveform.device, dtype=waveform.dtype)
45
+
46
+ def waveform_to_mel(
47
+ self,
48
+ waveform: torch.Tensor,
49
+ waveform_sample_rate: int,
50
+ ) -> torch.Tensor:
51
+ """Convert waveform to log-mel spectrogram [batch, channels, time, n_mels]."""
52
+ waveform = self.resample_waveform(waveform, waveform_sample_rate, self.sample_rate)
53
+
54
+ mel = self.mel_transform(waveform)
55
+ mel = torch.log(torch.clamp(mel, min=1e-5))
56
+
57
+ mel = mel.to(device=waveform.device, dtype=waveform.dtype)
58
+ return mel.permute(0, 1, 3, 2).contiguous()
59
+
60
+
61
+ class PerChannelStatistics(nn.Module):
62
+ """
63
+ Per-channel statistics for normalizing and denormalizing the latent representation.
64
+ This statics is computed over the entire dataset and stored in model's checkpoint under AudioVAE state_dict.
65
+ """
66
+
67
+ def __init__(self, latent_channels: int = 128) -> None:
68
+ super().__init__()
69
+ self.register_buffer("std-of-means", torch.empty(latent_channels))
70
+ self.register_buffer("mean-of-means", torch.empty(latent_channels))
71
+
72
+ def un_normalize(self, x: torch.Tensor) -> torch.Tensor:
73
+ return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x)
74
+
75
+ def normalize(self, x: torch.Tensor) -> torch.Tensor:
76
+ return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x)
packages/ltx-core/src/ltx_core/model/audio_vae/resnet.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
+
5
+ from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
6
+ from ltx_core.model.audio_vae.causality_axis import CausalityAxis
7
+ from ltx_core.model.common.normalization import NormType, build_normalization_layer
8
+
9
+ LRELU_SLOPE = 0.1
10
+
11
+
12
+ class ResBlock1(torch.nn.Module):
13
+ def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int, int] = (1, 3, 5)):
14
+ super(ResBlock1, self).__init__()
15
+ self.convs1 = torch.nn.ModuleList(
16
+ [
17
+ torch.nn.Conv1d(
18
+ channels,
19
+ channels,
20
+ kernel_size,
21
+ 1,
22
+ dilation=dilation[0],
23
+ padding="same",
24
+ ),
25
+ torch.nn.Conv1d(
26
+ channels,
27
+ channels,
28
+ kernel_size,
29
+ 1,
30
+ dilation=dilation[1],
31
+ padding="same",
32
+ ),
33
+ torch.nn.Conv1d(
34
+ channels,
35
+ channels,
36
+ kernel_size,
37
+ 1,
38
+ dilation=dilation[2],
39
+ padding="same",
40
+ ),
41
+ ]
42
+ )
43
+
44
+ self.convs2 = torch.nn.ModuleList(
45
+ [
46
+ torch.nn.Conv1d(
47
+ channels,
48
+ channels,
49
+ kernel_size,
50
+ 1,
51
+ dilation=1,
52
+ padding="same",
53
+ ),
54
+ torch.nn.Conv1d(
55
+ channels,
56
+ channels,
57
+ kernel_size,
58
+ 1,
59
+ dilation=1,
60
+ padding="same",
61
+ ),
62
+ torch.nn.Conv1d(
63
+ channels,
64
+ channels,
65
+ kernel_size,
66
+ 1,
67
+ dilation=1,
68
+ padding="same",
69
+ ),
70
+ ]
71
+ )
72
+
73
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
74
+ for conv1, conv2 in zip(self.convs1, self.convs2, strict=True):
75
+ xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
76
+ xt = conv1(xt)
77
+ xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE)
78
+ xt = conv2(xt)
79
+ x = xt + x
80
+ return x
81
+
82
+
83
+ class ResBlock2(torch.nn.Module):
84
+ def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int] = (1, 3)):
85
+ super(ResBlock2, self).__init__()
86
+ self.convs = torch.nn.ModuleList(
87
+ [
88
+ torch.nn.Conv1d(
89
+ channels,
90
+ channels,
91
+ kernel_size,
92
+ 1,
93
+ dilation=dilation[0],
94
+ padding="same",
95
+ ),
96
+ torch.nn.Conv1d(
97
+ channels,
98
+ channels,
99
+ kernel_size,
100
+ 1,
101
+ dilation=dilation[1],
102
+ padding="same",
103
+ ),
104
+ ]
105
+ )
106
+
107
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
108
+ for conv in self.convs:
109
+ xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
110
+ xt = conv(xt)
111
+ x = xt + x
112
+ return x
113
+
114
+
115
+ class ResnetBlock(torch.nn.Module):
116
+ def __init__(
117
+ self,
118
+ *,
119
+ in_channels: int,
120
+ out_channels: int | None = None,
121
+ conv_shortcut: bool = False,
122
+ dropout: float = 0.0,
123
+ temb_channels: int = 512,
124
+ norm_type: NormType = NormType.GROUP,
125
+ causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
126
+ ) -> None:
127
+ super().__init__()
128
+ self.causality_axis = causality_axis
129
+
130
+ if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP:
131
+ raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
132
+ self.in_channels = in_channels
133
+ out_channels = in_channels if out_channels is None else out_channels
134
+ self.out_channels = out_channels
135
+ self.use_conv_shortcut = conv_shortcut
136
+
137
+ self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
138
+ self.non_linearity = torch.nn.SiLU()
139
+ self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
140
+ if temb_channels > 0:
141
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
142
+ self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
143
+ self.dropout = torch.nn.Dropout(dropout)
144
+ self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
145
+ if self.in_channels != self.out_channels:
146
+ if self.use_conv_shortcut:
147
+ self.conv_shortcut = make_conv2d(
148
+ in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
149
+ )
150
+ else:
151
+ self.nin_shortcut = make_conv2d(
152
+ in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
153
+ )
154
+
155
+ def forward(
156
+ self,
157
+ x: torch.Tensor,
158
+ temb: torch.Tensor | None = None,
159
+ ) -> torch.Tensor:
160
+ h = x
161
+ h = self.norm1(h)
162
+ h = self.non_linearity(h)
163
+ h = self.conv1(h)
164
+
165
+ if temb is not None:
166
+ h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
167
+
168
+ h = self.norm2(h)
169
+ h = self.non_linearity(h)
170
+ h = self.dropout(h)
171
+ h = self.conv2(h)
172
+
173
+ if self.in_channels != self.out_channels:
174
+ x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)
175
+
176
+ return x + h
packages/ltx-core/src/ltx_core/model/audio_vae/upsample.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Set, Tuple
2
+
3
+ import torch
4
+
5
+ from ltx_core.model.audio_vae.attention import AttentionType, make_attn
6
+ from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
7
+ from ltx_core.model.audio_vae.causality_axis import CausalityAxis
8
+ from ltx_core.model.audio_vae.resnet import ResnetBlock
9
+ from ltx_core.model.common.normalization import NormType
10
+
11
+
12
+ class Upsample(torch.nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_channels: int,
16
+ with_conv: bool,
17
+ causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
18
+ ) -> None:
19
+ super().__init__()
20
+ self.with_conv = with_conv
21
+ self.causality_axis = causality_axis
22
+ if self.with_conv:
23
+ self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
27
+ if self.with_conv:
28
+ x = self.conv(x)
29
+ # Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.
30
+ # For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].
31
+ # The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],
32
+ # So the output elements rely on the following windows:
33
+ # 0: [-,-,0]
34
+ # 1: [-,0,0]
35
+ # 2: [0,0,1]
36
+ # 3: [0,1,1]
37
+ # 4: [1,1,2]
38
+ # 5: [1,2,2]
39
+ # Notice that the first and second elements in the output rely only on the first element in the input,
40
+ # while all other elements rely on two elements in the input.
41
+ # So we can drop the first element to undo the padding (rather than the last element).
42
+ # This is a no-op for non-causal convolutions.
43
+ match self.causality_axis:
44
+ case CausalityAxis.NONE:
45
+ pass # x remains unchanged
46
+ case CausalityAxis.HEIGHT:
47
+ x = x[:, :, 1:, :]
48
+ case CausalityAxis.WIDTH:
49
+ x = x[:, :, :, 1:]
50
+ case CausalityAxis.WIDTH_COMPATIBILITY:
51
+ pass # x remains unchanged
52
+ case _:
53
+ raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
54
+
55
+ return x
56
+
57
+
58
+ def build_upsampling_path( # noqa: PLR0913
59
+ *,
60
+ ch: int,
61
+ ch_mult: Tuple[int, ...],
62
+ num_resolutions: int,
63
+ num_res_blocks: int,
64
+ resolution: int,
65
+ temb_channels: int,
66
+ dropout: float,
67
+ norm_type: NormType,
68
+ causality_axis: CausalityAxis,
69
+ attn_type: AttentionType,
70
+ attn_resolutions: Set[int],
71
+ resamp_with_conv: bool,
72
+ initial_block_channels: int,
73
+ ) -> tuple[torch.nn.ModuleList, int]:
74
+ """Build the upsampling path with residual blocks, attention, and upsampling layers."""
75
+ up_modules = torch.nn.ModuleList()
76
+ block_in = initial_block_channels
77
+ curr_res = resolution // (2 ** (num_resolutions - 1))
78
+
79
+ for level in reversed(range(num_resolutions)):
80
+ stage = torch.nn.Module()
81
+ stage.block = torch.nn.ModuleList()
82
+ stage.attn = torch.nn.ModuleList()
83
+ block_out = ch * ch_mult[level]
84
+
85
+ for _ in range(num_res_blocks + 1):
86
+ stage.block.append(
87
+ ResnetBlock(
88
+ in_channels=block_in,
89
+ out_channels=block_out,
90
+ temb_channels=temb_channels,
91
+ dropout=dropout,
92
+ norm_type=norm_type,
93
+ causality_axis=causality_axis,
94
+ )
95
+ )
96
+ block_in = block_out
97
+ if curr_res in attn_resolutions:
98
+ stage.attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type))
99
+
100
+ if level != 0:
101
+ stage.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
102
+ curr_res *= 2
103
+
104
+ up_modules.insert(0, stage)
105
+
106
+ return up_modules, block_in
packages/ltx-core/src/ltx_core/model/audio_vae/vocoder.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+
4
+ import einops
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+ from ltx_core.model.audio_vae.resnet import LRELU_SLOPE, ResBlock1, ResBlock2
10
+
11
+
12
+ class Vocoder(torch.nn.Module):
13
+ """
14
+ Vocoder model for synthesizing audio from Mel spectrograms.
15
+ Args:
16
+ resblock_kernel_sizes: List of kernel sizes for the residual blocks.
17
+ This value is read from the checkpoint at `config.vocoder.resblock_kernel_sizes`.
18
+ upsample_rates: List of upsampling rates.
19
+ This value is read from the checkpoint at `config.vocoder.upsample_rates`.
20
+ upsample_kernel_sizes: List of kernel sizes for the upsampling layers.
21
+ This value is read from the checkpoint at `config.vocoder.upsample_kernel_sizes`.
22
+ resblock_dilation_sizes: List of dilation sizes for the residual blocks.
23
+ This value is read from the checkpoint at `config.vocoder.resblock_dilation_sizes`.
24
+ upsample_initial_channel: Initial number of channels for the upsampling layers.
25
+ This value is read from the checkpoint at `config.vocoder.upsample_initial_channel`.
26
+ stereo: Whether to use stereo output.
27
+ This value is read from the checkpoint at `config.vocoder.stereo`.
28
+ resblock: Type of residual block to use.
29
+ This value is read from the checkpoint at `config.vocoder.resblock`.
30
+ output_sample_rate: Waveform sample rate.
31
+ This value is read from the checkpoint at `config.vocoder.output_sample_rate`.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ resblock_kernel_sizes: List[int] | None = None,
37
+ upsample_rates: List[int] | None = None,
38
+ upsample_kernel_sizes: List[int] | None = None,
39
+ resblock_dilation_sizes: List[List[int]] | None = None,
40
+ upsample_initial_channel: int = 1024,
41
+ stereo: bool = True,
42
+ resblock: str = "1",
43
+ output_sample_rate: int = 24000,
44
+ ):
45
+ super().__init__()
46
+
47
+ # Initialize default values if not provided. Note that mutable default values are not supported.
48
+ if resblock_kernel_sizes is None:
49
+ resblock_kernel_sizes = [3, 7, 11]
50
+ if upsample_rates is None:
51
+ upsample_rates = [6, 5, 2, 2, 2]
52
+ if upsample_kernel_sizes is None:
53
+ upsample_kernel_sizes = [16, 15, 8, 4, 4]
54
+ if resblock_dilation_sizes is None:
55
+ resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
56
+
57
+ self.output_sample_rate = output_sample_rate
58
+ self.num_kernels = len(resblock_kernel_sizes)
59
+ self.num_upsamples = len(upsample_rates)
60
+ in_channels = 128 if stereo else 64
61
+ self.conv_pre = nn.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
62
+ resblock_class = ResBlock1 if resblock == "1" else ResBlock2
63
+
64
+ self.ups = nn.ModuleList()
65
+ for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True)):
66
+ self.ups.append(
67
+ nn.ConvTranspose1d(
68
+ upsample_initial_channel // (2**i),
69
+ upsample_initial_channel // (2 ** (i + 1)),
70
+ kernel_size,
71
+ stride,
72
+ padding=(kernel_size - stride) // 2,
73
+ )
74
+ )
75
+
76
+ self.resblocks = nn.ModuleList()
77
+ for i, _ in enumerate(self.ups):
78
+ ch = upsample_initial_channel // (2 ** (i + 1))
79
+ for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes, strict=True):
80
+ self.resblocks.append(resblock_class(ch, kernel_size, dilations))
81
+
82
+ out_channels = 2 if stereo else 1
83
+ final_channels = upsample_initial_channel // (2**self.num_upsamples)
84
+ self.conv_post = nn.Conv1d(final_channels, out_channels, 7, 1, padding=3)
85
+
86
+ self.upsample_factor = math.prod(layer.stride[0] for layer in self.ups)
87
+
88
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
89
+ """
90
+ Forward pass of the vocoder.
91
+ Args:
92
+ x: Input Mel spectrogram tensor. Can be either:
93
+ - 3D: (batch_size, time, mel_bins) for mono
94
+ - 4D: (batch_size, 2, time, mel_bins) for stereo
95
+ Returns:
96
+ Audio waveform tensor of shape (batch_size, out_channels, audio_length)
97
+ """
98
+ x = x.transpose(2, 3) # (batch, channels, time, mel_bins) -> (batch, channels, mel_bins, time)
99
+
100
+ if x.dim() == 4: # stereo
101
+ assert x.shape[1] == 2, "Input must have 2 channels for stereo"
102
+ x = einops.rearrange(x, "b s c t -> b (s c) t")
103
+
104
+ x = self.conv_pre(x)
105
+
106
+ for i in range(self.num_upsamples):
107
+ x = F.leaky_relu(x, LRELU_SLOPE)
108
+ x = self.ups[i](x)
109
+ start = i * self.num_kernels
110
+ end = start + self.num_kernels
111
+
112
+ # Evaluate all resblocks with the same input tensor so they can run
113
+ # independently (and thus in parallel on accelerator hardware) before
114
+ # aggregating their outputs via mean.
115
+ block_outputs = torch.stack(
116
+ [self.resblocks[idx](x) for idx in range(start, end)],
117
+ dim=0,
118
+ )
119
+
120
+ x = block_outputs.mean(dim=0)
121
+
122
+ x = self.conv_post(F.leaky_relu(x))
123
+ return torch.tanh(x)
packages/ltx-core/src/ltx_core/model/common/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """Common model utilities."""
2
+
3
+ from ltx_core.model.common.normalization import NormType, PixelNorm, build_normalization_layer
4
+
5
+ __all__ = [
6
+ "NormType",
7
+ "PixelNorm",
8
+ "build_normalization_layer",
9
+ ]
packages/ltx-core/src/ltx_core/model/common/normalization.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ class NormType(Enum):
8
+ """Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm)."""
9
+
10
+ GROUP = "group"
11
+ PIXEL = "pixel"
12
+
13
+
14
+ class PixelNorm(nn.Module):
15
+ """
16
+ Per-pixel (per-location) RMS normalization layer.
17
+ For each element along the chosen dimension, this layer normalizes the tensor
18
+ by the root-mean-square of its values across that dimension:
19
+ y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
20
+ """
21
+
22
+ def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
23
+ """
24
+ Args:
25
+ dim: Dimension along which to compute the RMS (typically channels).
26
+ eps: Small constant added for numerical stability.
27
+ """
28
+ super().__init__()
29
+ self.dim = dim
30
+ self.eps = eps
31
+
32
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
33
+ """
34
+ Apply RMS normalization along the configured dimension.
35
+ """
36
+ # Compute mean of squared values along `dim`, keep dimensions for broadcasting.
37
+ mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
38
+ # Normalize by the root-mean-square (RMS).
39
+ rms = torch.sqrt(mean_sq + self.eps)
40
+ return x / rms
41
+
42
+
43
+ def build_normalization_layer(
44
+ in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP
45
+ ) -> nn.Module:
46
+ """
47
+ Create a normalization layer based on the normalization type.
48
+ Args:
49
+ in_channels: Number of input channels
50
+ num_groups: Number of groups for group normalization
51
+ normtype: Type of normalization: "group" or "pixel"
52
+ Returns:
53
+ A normalization layer
54
+ """
55
+ if normtype == NormType.GROUP:
56
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
57
+ if normtype == NormType.PIXEL:
58
+ return PixelNorm(dim=1, eps=1e-6)
59
+ raise ValueError(f"Invalid normalization type: {normtype}")
packages/ltx-core/src/ltx_core/model/model_protocol.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Protocol, TypeVar
2
+
3
+ ModelType = TypeVar("ModelType")
4
+
5
+
6
+ class ModelConfigurator(Protocol[ModelType]):
7
+ """Protocol for model loader classes that instantiates models from a configuration dictionary."""
8
+
9
+ @classmethod
10
+ def from_config(cls, config: dict) -> ModelType: ...
packages/ltx-core/src/ltx_core/model/transformer/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Transformer model components."""
2
+
3
+ from ltx_core.model.transformer.modality import Modality
4
+ from ltx_core.model.transformer.model import LTXModel, X0Model
5
+ from ltx_core.model.transformer.model_configurator import (
6
+ LTXV_MODEL_COMFY_RENAMING_MAP,
7
+ LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP,
8
+ UPCAST_DURING_INFERENCE,
9
+ LTXModelConfigurator,
10
+ LTXVideoOnlyModelConfigurator,
11
+ UpcastWithStochasticRounding,
12
+ )
13
+
14
+ __all__ = [
15
+ "LTXV_MODEL_COMFY_RENAMING_MAP",
16
+ "LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP",
17
+ "UPCAST_DURING_INFERENCE",
18
+ "LTXModel",
19
+ "LTXModelConfigurator",
20
+ "LTXVideoOnlyModelConfigurator",
21
+ "Modality",
22
+ "UpcastWithStochasticRounding",
23
+ "X0Model",
24
+ ]
packages/ltx-core/src/ltx_core/model/transformer/adaln.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+
5
+ from ltx_core.model.transformer.timestep_embedding import PixArtAlphaCombinedTimestepSizeEmbeddings
6
+
7
+
8
+ class AdaLayerNormSingle(torch.nn.Module):
9
+ r"""
10
+ Norm layer adaptive layer norm single (adaLN-single).
11
+ As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
12
+ Parameters:
13
+ embedding_dim (`int`): The size of each embedding vector.
14
+ use_additional_conditions (`bool`): To use additional conditions for normalization or not.
15
+ """
16
+
17
+ def __init__(self, embedding_dim: int, embedding_coefficient: int = 6):
18
+ super().__init__()
19
+
20
+ self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
21
+ embedding_dim,
22
+ size_emb_dim=embedding_dim // 3,
23
+ )
24
+
25
+ self.silu = torch.nn.SiLU()
26
+ self.linear = torch.nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True)
27
+
28
+ def forward(
29
+ self,
30
+ timestep: torch.Tensor,
31
+ hidden_dtype: Optional[torch.dtype] = None,
32
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
33
+ embedded_timestep = self.emb(timestep, hidden_dtype=hidden_dtype)
34
+ return self.linear(self.silu(embedded_timestep)), embedded_timestep
packages/ltx-core/src/ltx_core/model/transformer/attention.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from typing import Protocol
3
+
4
+ import torch
5
+
6
+ from ltx_core.model.transformer.rope import LTXRopeType, apply_rotary_emb
7
+
8
+ memory_efficient_attention = None
9
+ flash_attn_interface = None
10
+ try:
11
+ from xformers.ops import memory_efficient_attention
12
+ except ImportError:
13
+ memory_efficient_attention = None
14
+
15
+ import flash_attn_interface
16
+
17
+ class AttentionCallable(Protocol):
18
+ def __call__(
19
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None
20
+ ) -> torch.Tensor: ...
21
+
22
+
23
+ class PytorchAttention(AttentionCallable):
24
+ def __call__(
25
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None
26
+ ) -> torch.Tensor:
27
+ b, _, dim_head = q.shape
28
+ dim_head //= heads
29
+ q, k, v = (t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v))
30
+
31
+ if mask is not None:
32
+ # add a batch dimension if there isn't already one
33
+ if mask.ndim == 2:
34
+ mask = mask.unsqueeze(0)
35
+ # add a heads dimension if there isn't already one
36
+ if mask.ndim == 3:
37
+ mask = mask.unsqueeze(1)
38
+
39
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
40
+ out = out.transpose(1, 2).reshape(b, -1, heads * dim_head)
41
+ return out
42
+
43
+
44
+ class XFormersAttention(AttentionCallable):
45
+ def __call__(
46
+ self,
47
+ q: torch.Tensor,
48
+ k: torch.Tensor,
49
+ v: torch.Tensor,
50
+ heads: int,
51
+ mask: torch.Tensor | None = None,
52
+ ) -> torch.Tensor:
53
+ if memory_efficient_attention is None:
54
+ raise RuntimeError("XFormersAttention was selected but `xformers` is not installed.")
55
+
56
+ b, _, dim_head = q.shape
57
+ dim_head //= heads
58
+
59
+ # xformers expects [B, M, H, K]
60
+ q, k, v = (t.view(b, -1, heads, dim_head) for t in (q, k, v))
61
+
62
+ if mask is not None:
63
+ # add a singleton batch dimension
64
+ if mask.ndim == 2:
65
+ mask = mask.unsqueeze(0)
66
+ # add a singleton heads dimension
67
+ if mask.ndim == 3:
68
+ mask = mask.unsqueeze(1)
69
+ # pad to a multiple of 8
70
+ pad = 8 - mask.shape[-1] % 8
71
+ # the xformers docs says that it's allowed to have a mask of shape (1, Nq, Nk)
72
+ # but when using separated heads, the shape has to be (B, H, Nq, Nk)
73
+ # in flux, this matrix ends up being over 1GB
74
+ # here, we create a mask with the same batch/head size as the input mask (potentially singleton or full)
75
+ mask_out = torch.empty(
76
+ [mask.shape[0], mask.shape[1], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device
77
+ )
78
+
79
+ mask_out[..., : mask.shape[-1]] = mask
80
+ # doesn't this remove the padding again??
81
+ mask = mask_out[..., : mask.shape[-1]]
82
+ mask = mask.expand(b, heads, -1, -1)
83
+
84
+ out = memory_efficient_attention(q.to(v.dtype), k.to(v.dtype), v, attn_bias=mask, p=0.0)
85
+ out = out.reshape(b, -1, heads * dim_head)
86
+ return out
87
+
88
+
89
+ class FlashAttention3(AttentionCallable):
90
+ def __call__(
91
+ self,
92
+ q: torch.Tensor,
93
+ k: torch.Tensor,
94
+ v: torch.Tensor,
95
+ heads: int,
96
+ mask: torch.Tensor | None = None,
97
+ ) -> torch.Tensor:
98
+ if flash_attn_interface is None:
99
+ raise RuntimeError("FlashAttention3 was selected but `FlashAttention3` is not installed.")
100
+
101
+ b, _, dim_head = q.shape
102
+ dim_head //= heads
103
+
104
+ q, k, v = (t.view(b, -1, heads, dim_head) for t in (q, k, v))
105
+
106
+ if mask is not None:
107
+ raise NotImplementedError("Mask is not supported for FlashAttention3")
108
+
109
+ out = flash_attn_interface.flash_attn_func(q.to(v.dtype), k.to(v.dtype), v)
110
+ out = out.reshape(b, -1, heads * dim_head)
111
+ return out
112
+
113
+
114
+ class AttentionFunction(Enum):
115
+ PYTORCH = "pytorch"
116
+ XFORMERS = "xformers"
117
+ FLASH_ATTENTION_3 = "flash_attention_3"
118
+ DEFAULT = "default"
119
+
120
+ def __call__(
121
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None
122
+ ) -> torch.Tensor:
123
+ if mask is None:
124
+ return FlashAttention3()(q, k, v, heads, mask)
125
+ else:
126
+ return (
127
+ XFormersAttention()(q, k, v, heads, mask)
128
+ if memory_efficient_attention is not None
129
+ else PytorchAttention()(q, k, v, heads, mask)
130
+ )
131
+
132
+
133
+ class Attention(torch.nn.Module):
134
+ def __init__(
135
+ self,
136
+ query_dim: int,
137
+ context_dim: int | None = None,
138
+ heads: int = 8,
139
+ dim_head: int = 64,
140
+ norm_eps: float = 1e-6,
141
+ rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
142
+ attention_function: AttentionCallable | AttentionFunction = AttentionFunction.DEFAULT,
143
+ ) -> None:
144
+ super().__init__()
145
+ self.rope_type = rope_type
146
+ self.attention_function = attention_function
147
+
148
+ inner_dim = dim_head * heads
149
+ context_dim = query_dim if context_dim is None else context_dim
150
+
151
+ self.heads = heads
152
+ self.dim_head = dim_head
153
+
154
+ self.q_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps)
155
+ self.k_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps)
156
+
157
+ self.to_q = torch.nn.Linear(query_dim, inner_dim, bias=True)
158
+ self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True)
159
+ self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True)
160
+
161
+ self.to_out = torch.nn.Sequential(torch.nn.Linear(inner_dim, query_dim, bias=True), torch.nn.Identity())
162
+
163
+ def forward(
164
+ self,
165
+ x: torch.Tensor,
166
+ context: torch.Tensor | None = None,
167
+ mask: torch.Tensor | None = None,
168
+ pe: torch.Tensor | None = None,
169
+ k_pe: torch.Tensor | None = None,
170
+ ) -> torch.Tensor:
171
+ q = self.to_q(x)
172
+ context = x if context is None else context
173
+ k = self.to_k(context)
174
+ v = self.to_v(context)
175
+
176
+ q = self.q_norm(q)
177
+ k = self.k_norm(k)
178
+
179
+ if pe is not None:
180
+ q = apply_rotary_emb(q, pe, self.rope_type)
181
+ k = apply_rotary_emb(k, pe if k_pe is None else k_pe, self.rope_type)
182
+
183
+ # attention_function can be an enum *or* a custom callable
184
+ out = self.attention_function(q, k, v, self.heads, mask)
185
+ return self.to_out(out)
packages/ltx-core/src/ltx_core/model/transformer/feed_forward.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ltx_core.model.transformer.gelu_approx import GELUApprox
4
+
5
+
6
+ class FeedForward(torch.nn.Module):
7
+ def __init__(self, dim: int, dim_out: int, mult: int = 4) -> None:
8
+ super().__init__()
9
+ inner_dim = int(dim * mult)
10
+ project_in = GELUApprox(dim, inner_dim)
11
+
12
+ self.net = torch.nn.Sequential(project_in, torch.nn.Identity(), torch.nn.Linear(inner_dim, dim_out))
13
+
14
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
15
+ return self.net(x)
packages/ltx-core/src/ltx_core/model/transformer/gelu_approx.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class GELUApprox(torch.nn.Module):
5
+ def __init__(self, dim_in: int, dim_out: int) -> None:
6
+ super().__init__()
7
+ self.proj = torch.nn.Linear(dim_in, dim_out)
8
+
9
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
10
+ return torch.nn.functional.gelu(self.proj(x), approximate="tanh")
packages/ltx-core/src/ltx_core/model/transformer/modality.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+
6
+ @dataclass(frozen=True)
7
+ class Modality:
8
+ """
9
+ Input data for a single modality (video or audio) in the transformer.
10
+ Bundles the latent tokens, timestep embeddings, positional information,
11
+ and text conditioning context for processing by the diffusion transformer.
12
+ """
13
+
14
+ latent: (
15
+ torch.Tensor
16
+ ) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
17
+ timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps
18
+ positions: (
19
+ torch.Tensor
20
+ ) # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens
21
+ context: torch.Tensor
22
+ enabled: bool = True
23
+ context_mask: torch.Tensor | None = None