linoy commited on
Commit
ebfc6b3
·
0 Parent(s):

inital commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +301 -0
  2. packages/ltx-core/README.md +1 -0
  3. packages/ltx-core/pyproject.toml +38 -0
  4. packages/ltx-core/src/ltx_core/__init__.py +0 -0
  5. packages/ltx-core/src/ltx_core/__pycache__/__init__.cpython-310.pyc +0 -0
  6. packages/ltx-core/src/ltx_core/__pycache__/tiling.cpython-310.pyc +0 -0
  7. packages/ltx-core/src/ltx_core/__pycache__/utils.cpython-310.pyc +0 -0
  8. packages/ltx-core/src/ltx_core/guidance/__init__.py +0 -0
  9. packages/ltx-core/src/ltx_core/guidance/__pycache__/__init__.cpython-310.pyc +0 -0
  10. packages/ltx-core/src/ltx_core/guidance/__pycache__/perturbations.cpython-310.pyc +0 -0
  11. packages/ltx-core/src/ltx_core/guidance/perturbations.py +74 -0
  12. packages/ltx-core/src/ltx_core/legacy_tiling.py +258 -0
  13. packages/ltx-core/src/ltx_core/loader/.ipynb_checkpoints/sd_ops-checkpoint.py +107 -0
  14. packages/ltx-core/src/ltx_core/loader/__init__.py +0 -0
  15. packages/ltx-core/src/ltx_core/loader/__pycache__/__init__.cpython-310.pyc +0 -0
  16. packages/ltx-core/src/ltx_core/loader/__pycache__/fuse_loras.cpython-310.pyc +0 -0
  17. packages/ltx-core/src/ltx_core/loader/__pycache__/kernels.cpython-310.pyc +0 -0
  18. packages/ltx-core/src/ltx_core/loader/__pycache__/module_ops.cpython-310.pyc +0 -0
  19. packages/ltx-core/src/ltx_core/loader/__pycache__/primitives.cpython-310.pyc +0 -0
  20. packages/ltx-core/src/ltx_core/loader/__pycache__/registry.cpython-310.pyc +0 -0
  21. packages/ltx-core/src/ltx_core/loader/__pycache__/sd_ops.cpython-310.pyc +0 -0
  22. packages/ltx-core/src/ltx_core/loader/__pycache__/sft_loader.cpython-310.pyc +0 -0
  23. packages/ltx-core/src/ltx_core/loader/__pycache__/single_gpu_model_builder.cpython-310.pyc +0 -0
  24. packages/ltx-core/src/ltx_core/loader/fuse_loras.py +102 -0
  25. packages/ltx-core/src/ltx_core/loader/kernels.py +74 -0
  26. packages/ltx-core/src/ltx_core/loader/module_ops.py +11 -0
  27. packages/ltx-core/src/ltx_core/loader/primitives.py +63 -0
  28. packages/ltx-core/src/ltx_core/loader/registry.py +68 -0
  29. packages/ltx-core/src/ltx_core/loader/sd_ops.py +107 -0
  30. packages/ltx-core/src/ltx_core/loader/sft_loader.py +53 -0
  31. packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py +99 -0
  32. packages/ltx-core/src/ltx_core/model/.ipynb_checkpoints/model_ledger-checkpoint.py +253 -0
  33. packages/ltx-core/src/ltx_core/model/__init__.py +0 -0
  34. packages/ltx-core/src/ltx_core/model/__pycache__/__init__.cpython-310.pyc +0 -0
  35. packages/ltx-core/src/ltx_core/model/__pycache__/model_ledger.cpython-310.pyc +0 -0
  36. packages/ltx-core/src/ltx_core/model/__pycache__/model_protocol.cpython-310.pyc +0 -0
  37. packages/ltx-core/src/ltx_core/model/audio_vae/__init__.py +0 -0
  38. packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/__init__.cpython-310.pyc +0 -0
  39. packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/attention.cpython-310.pyc +0 -0
  40. packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/audio_vae.cpython-310.pyc +0 -0
  41. packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/causal_conv_2d.cpython-310.pyc +0 -0
  42. packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/causality_axis.cpython-310.pyc +0 -0
  43. packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/downsample.cpython-310.pyc +0 -0
  44. packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/model_configurator.cpython-310.pyc +0 -0
  45. packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/ops.cpython-310.pyc +0 -0
  46. packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/resnet.cpython-310.pyc +0 -0
  47. packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/upsample.cpython-310.pyc +0 -0
  48. packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/vocoder.cpython-310.pyc +0 -0
  49. packages/ltx-core/src/ltx_core/model/audio_vae/attention.py +71 -0
  50. packages/ltx-core/src/ltx_core/model/audio_vae/audio_vae.py +483 -0
app.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple Gradio app for LTX-2 inference based on ltx2_two_stage.py example
3
+ """
4
+
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ # Add packages to Python path
9
+ current_dir = Path(__file__).parent
10
+ sys.path.insert(0, str(current_dir / "packages" / "ltx-pipelines" / "src"))
11
+ sys.path.insert(0, str(current_dir / "packages" / "ltx-core" / "src"))
12
+
13
+ import gradio as gr
14
+ from typing import Optional
15
+ from huggingface_hub import hf_hub_download
16
+ from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline
17
+ from ltx_core.tiling import TilingConfig
18
+ from ltx_pipelines.constants import (
19
+ DEFAULT_SEED,
20
+ DEFAULT_HEIGHT,
21
+ DEFAULT_WIDTH,
22
+ DEFAULT_NUM_FRAMES,
23
+ DEFAULT_FRAME_RATE,
24
+ DEFAULT_NUM_INFERENCE_STEPS,
25
+ DEFAULT_CFG_GUIDANCE_SCALE,
26
+ DEFAULT_LORA_STRENGTH,
27
+ )
28
+
29
+ # Custom negative prompt
30
+ DEFAULT_NEGATIVE_PROMPT = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static"
31
+
32
+ # Default prompt from docstring example
33
+ DEFAULT_PROMPT = "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a breath-taking, movie-like shot."
34
+
35
+ # HuggingFace Hub defaults
36
+ DEFAULT_REPO_ID = "LTX-Colab/LTX-Video-Preview"
37
+ DEFAULT_GEMMA_REPO_ID = "google/gemma-3-12b-it-qat-q4_0-unquantized"
38
+ DEFAULT_CHECKPOINT_FILENAME = "ltx-2-19b-dev-rc1.safetensors"
39
+ DEFAULT_DISTILLED_LORA_FILENAME = "ltx-2-19b-distilled-lora-384-rc1.safetensors"
40
+ DEFAULT_SPATIAL_UPSAMPLER_FILENAME = "ltx-2-spatial-upscaler-x2-1.0-rc1.safetensors"
41
+
42
+ def get_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None):
43
+ """Download from HuggingFace Hub or use local checkpoint."""
44
+ if repo_id is None and filename is None:
45
+ raise ValueError("Please supply at least one of `repo_id` or `filename`")
46
+
47
+ if repo_id is not None:
48
+ if filename is None:
49
+ raise ValueError("If repo_id is specified, filename must also be specified.")
50
+ print(f"Downloading {filename} from {repo_id}...")
51
+ ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
52
+ print(f"Downloaded to {ckpt_path}")
53
+ else:
54
+ ckpt_path = filename
55
+
56
+ return ckpt_path
57
+
58
+
59
+ # Initialize pipeline at startup
60
+ print("=" * 80)
61
+ print("Loading LTX-2 2-stage pipeline...")
62
+ print("=" * 80)
63
+
64
+ checkpoint_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_CHECKPOINT_FILENAME)
65
+ distilled_lora_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_DISTILLED_LORA_FILENAME)
66
+ spatial_upsampler_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_SPATIAL_UPSAMPLER_FILENAME)
67
+
68
+ print(f"Initializing pipeline with:")
69
+ print(f" checkpoint_path={checkpoint_path}")
70
+ print(f" distilled_lora_path={distilled_lora_path}")
71
+ print(f" spatial_upsampler_path={spatial_upsampler_path}")
72
+ print(f" gemma_root={DEFAULT_GEMMA_REPO_ID}")
73
+
74
+ pipeline = TI2VidTwoStagesPipeline(
75
+ checkpoint_path=checkpoint_path,
76
+ distilled_lora_path=distilled_lora_path,
77
+ distilled_lora_strength=DEFAULT_LORA_STRENGTH,
78
+ spatial_upsampler_path=spatial_upsampler_path,
79
+ gemma_root=DEFAULT_GEMMA_REPO_ID,
80
+ loras=[],
81
+ fp8transformer=False,
82
+ local_files_only=False
83
+ )
84
+
85
+ print("=" * 80)
86
+ print("Warming up pipeline (loading Gemma text encoder)...")
87
+ print("=" * 80)
88
+
89
+ # Do a dummy warmup to load all models including Gemma
90
+ import tempfile
91
+ import os
92
+ warmup_output = tempfile.mktemp(suffix=".mp4")
93
+ try:
94
+ pipeline(
95
+ prompt="warmup",
96
+ negative_prompt="",
97
+ output_path=warmup_output,
98
+ seed=42,
99
+ height=256,
100
+ width=256,
101
+ num_frames=9,
102
+ frame_rate=8,
103
+ num_inference_steps=1,
104
+ cfg_guidance_scale=1.0,
105
+ images=[],
106
+ tiling_config=TilingConfig.default(),
107
+ )
108
+ # Clean up warmup output
109
+ if os.path.exists(warmup_output):
110
+ os.remove(warmup_output)
111
+ except Exception as e:
112
+ print(f"Warmup completed with note: {e}")
113
+
114
+ print("=" * 80)
115
+ print("Pipeline fully loaded and ready!")
116
+ print("=" * 80)
117
+
118
+
119
+ def generate_video(
120
+ input_image,
121
+ prompt: str,
122
+ duration: float,
123
+ negative_prompt: str,
124
+ seed: int,
125
+ randomize_seed: bool,
126
+ num_inference_steps: int,
127
+ cfg_guidance_scale: float,
128
+ height: int,
129
+ width: int,
130
+ progress=gr.Progress()
131
+ ):
132
+ """Generate a video based on the given parameters."""
133
+ try:
134
+ # Randomize seed if checkbox is enabled
135
+ if randomize_seed:
136
+ import random
137
+ seed = random.randint(0, 1000000)
138
+
139
+ # Calculate num_frames from duration (using fixed 24 fps)
140
+ frame_rate = 24.0
141
+ num_frames = int(duration * frame_rate) + 1 # +1 to ensure we meet the duration
142
+
143
+ # Create output directory if it doesn't exist
144
+ output_dir = Path("outputs")
145
+ output_dir.mkdir(exist_ok=True)
146
+ output_path = output_dir / f"video_{seed}.mp4"
147
+
148
+ # Handle image input
149
+ images = []
150
+ if input_image is not None:
151
+ # Save uploaded image temporarily
152
+ temp_image_path = output_dir / f"temp_input_{seed}.jpg"
153
+ if hasattr(input_image, 'save'):
154
+ input_image.save(temp_image_path)
155
+ else:
156
+ # If it's a file path already
157
+ temp_image_path = input_image
158
+ # Format: (image_path, frame_idx, strength)
159
+ images = [(str(temp_image_path), 0, 1.0)]
160
+
161
+ # Run inference
162
+ progress(0, desc="Generating video (2-stage)...")
163
+ pipeline(
164
+ prompt=prompt,
165
+ negative_prompt=negative_prompt,
166
+ output_path=str(output_path),
167
+ seed=seed,
168
+ height=height,
169
+ width=width,
170
+ num_frames=num_frames,
171
+ frame_rate=frame_rate,
172
+ num_inference_steps=num_inference_steps,
173
+ cfg_guidance_scale=cfg_guidance_scale,
174
+ images=images,
175
+ tiling_config=TilingConfig.default(),
176
+ )
177
+
178
+ progress(1.0, desc="Done!")
179
+ return str(output_path)
180
+
181
+ except Exception as e:
182
+ import traceback
183
+ error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
184
+ print(error_msg)
185
+ return None
186
+
187
+
188
+ # Create Gradio interface
189
+ with gr.Blocks(title="LTX-2 Image-to-Video") as demo:
190
+ gr.Markdown("# LTX-2 Image-to-Video Generation")
191
+ gr.Markdown("Transform images into videos using the LTX-2 2-stage pipeline")
192
+
193
+ with gr.Row():
194
+ with gr.Column():
195
+ input_image = gr.Image(
196
+ label="Input Image",
197
+ type="pil",
198
+ sources=["upload"]
199
+ )
200
+
201
+ prompt = gr.Textbox(
202
+ label="Prompt",
203
+ value="Make this image come alive with cinematic motion, smooth animation",
204
+ lines=3,
205
+ placeholder="Describe the motion and animation you want..."
206
+ )
207
+
208
+ duration = gr.Slider(
209
+ label="Duration (seconds)",
210
+ minimum=1.0,
211
+ maximum=10.0,
212
+ value=5.0,
213
+ step=0.1
214
+ )
215
+
216
+ generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
217
+
218
+ with gr.Accordion("Advanced Settings", open=False):
219
+ negative_prompt = gr.Textbox(
220
+ label="Negative Prompt",
221
+ value=DEFAULT_NEGATIVE_PROMPT,
222
+ lines=2
223
+ )
224
+
225
+ seed = gr.Slider(
226
+ label="Seed",
227
+ minimum=0,
228
+ maximum=1000000,
229
+ value=DEFAULT_SEED,
230
+ step=1
231
+ )
232
+
233
+ randomize_seed = gr.Checkbox(
234
+ label="Randomize Seed",
235
+ value=True
236
+ )
237
+
238
+ num_inference_steps = gr.Slider(
239
+ label="Inference Steps",
240
+ minimum=1,
241
+ maximum=100,
242
+ value=DEFAULT_NUM_INFERENCE_STEPS,
243
+ step=1
244
+ )
245
+
246
+ cfg_guidance_scale = gr.Slider(
247
+ label="CFG Guidance Scale",
248
+ minimum=1.0,
249
+ maximum=10.0,
250
+ value=DEFAULT_CFG_GUIDANCE_SCALE,
251
+ step=0.1
252
+ )
253
+
254
+ with gr.Row():
255
+ width = gr.Number(
256
+ label="Width",
257
+ value=DEFAULT_WIDTH,
258
+ precision=0
259
+ )
260
+ height = gr.Number(
261
+ label="Height",
262
+ value=DEFAULT_HEIGHT,
263
+ precision=0
264
+ )
265
+
266
+ with gr.Column():
267
+ output_video = gr.Video(label="Generated Video", autoplay=True)
268
+
269
+ generate_btn.click(
270
+ fn=generate_video,
271
+ inputs=[
272
+ input_image,
273
+ prompt,
274
+ duration,
275
+ negative_prompt,
276
+ seed,
277
+ randomize_seed,
278
+ num_inference_steps,
279
+ cfg_guidance_scale,
280
+ height,
281
+ width,
282
+ ],
283
+ outputs=output_video
284
+ )
285
+
286
+ # Add example
287
+ gr.Examples(
288
+ examples=[
289
+ [
290
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg",
291
+ "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a breath-taking, movie-like shot.",
292
+ 5.0,
293
+ ]
294
+ ],
295
+ inputs=[input_image, prompt, duration],
296
+ label="Example"
297
+ )
298
+
299
+
300
+ if __name__ == "__main__":
301
+ demo.launch(share=True)
packages/ltx-core/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # LTX-2 Core
packages/ltx-core/pyproject.toml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "ltx-core"
3
+ version = "0.1.0"
4
+ description = "Core implementation of Lightricks' LTX-2 model"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "torch~=2.7",
9
+ "torchaudio",
10
+ "einops",
11
+ "numpy",
12
+ "transformers",
13
+ "safetensors",
14
+ "accelerate",
15
+ "scipy>=1.14",
16
+ ]
17
+
18
+ [project.optional-dependencies]
19
+ flashpack = ["flashpack==0.1.2"]
20
+ xformers = ["xformers"]
21
+
22
+
23
+ [tool.uv.sources]
24
+ xformers = { index = "pytorch" }
25
+
26
+ [[tool.uv.index]]
27
+ name = "pytorch"
28
+ url = "https://download.pytorch.org/whl/cu129"
29
+ explicit = true
30
+
31
+ [build-system]
32
+ requires = ["uv_build>=0.9.8,<0.10.0"]
33
+ build-backend = "uv_build"
34
+
35
+ [dependency-groups]
36
+ dev = [
37
+ "scikit-image>=0.25.2",
38
+ ]
packages/ltx-core/src/ltx_core/__init__.py ADDED
File without changes
packages/ltx-core/src/ltx_core/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (176 Bytes). View file
 
packages/ltx-core/src/ltx_core/__pycache__/tiling.cpython-310.pyc ADDED
Binary file (10.7 kB). View file
 
packages/ltx-core/src/ltx_core/__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.85 kB). View file
 
packages/ltx-core/src/ltx_core/guidance/__init__.py ADDED
File without changes
packages/ltx-core/src/ltx_core/guidance/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (185 Bytes). View file
 
packages/ltx-core/src/ltx_core/guidance/__pycache__/perturbations.cpython-310.pyc ADDED
Binary file (3.82 kB). View file
 
packages/ltx-core/src/ltx_core/guidance/perturbations.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Lightricks. All rights reserved.
2
+ # Created by Andrew Kvochko
3
+
4
+ from dataclasses import dataclass
5
+ from enum import Enum
6
+
7
+ import torch
8
+ from torch._prims_common import DeviceLikeType
9
+
10
+
11
+ class PerturbationType(Enum):
12
+ SKIP_A2V_CROSS_ATTN = "skip_a2v_cross_attn"
13
+ SKIP_V2A_CROSS_ATTN = "skip_v2a_cross_attn"
14
+ SKIP_VIDEO_SELF_ATTN = "skip_video_self_attn"
15
+ SKIP_AUDIO_SELF_ATTN = "skip_audio_self_attn"
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class Perturbation:
20
+ type: PerturbationType
21
+ blocks: list[int] | None # None means all blocks
22
+
23
+ def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:
24
+ if self.type != perturbation_type:
25
+ return False
26
+
27
+ if self.blocks is None:
28
+ return True
29
+
30
+ return block in self.blocks
31
+
32
+
33
+ @dataclass(frozen=True)
34
+ class PerturbationConfig:
35
+ perturbations: list[Perturbation] | None
36
+
37
+ def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:
38
+ if self.perturbations is None:
39
+ return False
40
+
41
+ return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
42
+
43
+ @staticmethod
44
+ def empty() -> "PerturbationConfig":
45
+ return PerturbationConfig([])
46
+
47
+
48
+ @dataclass(frozen=True)
49
+ class BatchedPerturbationConfig:
50
+ perturbations: list[PerturbationConfig]
51
+
52
+ def mask(
53
+ self, perturbation_type: PerturbationType, block: int, device: DeviceLikeType, dtype: torch.dtype
54
+ ) -> torch.Tensor:
55
+ mask = torch.ones((len(self.perturbations),), device=device, dtype=dtype)
56
+ for batch_idx, perturbation in enumerate(self.perturbations):
57
+ if perturbation.is_perturbed(perturbation_type, block):
58
+ mask[batch_idx] = 0
59
+
60
+ return mask
61
+
62
+ def mask_like(self, perturbation_type: PerturbationType, block: int, values: torch.Tensor) -> torch.Tensor:
63
+ mask = self.mask(perturbation_type, block, values.device, values.dtype)
64
+ return mask.view(mask.numel(), *([1] * len(values.shape[1:])))
65
+
66
+ def any_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool:
67
+ return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
68
+
69
+ def all_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool:
70
+ return all(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
71
+
72
+ @staticmethod
73
+ def empty(batch_size: int) -> "BatchedPerturbationConfig":
74
+ return BatchedPerturbationConfig([PerturbationConfig.empty() for _ in range(batch_size)])
packages/ltx-core/src/ltx_core/legacy_tiling.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections.abc import Generator
3
+
4
+ import torch
5
+
6
+ from ltx_core.model.video_vae.video_vae import Decoder
7
+
8
+
9
+ def compute_chunk_boundaries(
10
+ chunk_start: int,
11
+ temporal_tile_length: int,
12
+ temporal_overlap: int,
13
+ total_latent_frames: int,
14
+ ) -> tuple[int, int]:
15
+ """Compute chunk boundaries for temporal tiling.
16
+
17
+ Args:
18
+ chunk_start: Starting frame index for the current chunk
19
+ temporal_tile_length: Length of each temporal tile
20
+ temporal_overlap: Number of frames to overlap between chunks
21
+ total_latent_frames: Total number of latent frames
22
+
23
+ Returns:
24
+ Tuple of (overlap_start, chunk_end)
25
+ """
26
+ if chunk_start == 0:
27
+ # First chunk: no overlap needed
28
+ chunk_end = min(chunk_start + temporal_tile_length, total_latent_frames)
29
+ overlap_start = chunk_start
30
+ else:
31
+ # Subsequent chunks: include overlap from previous chunk
32
+ # -1 because we need one extra frame to overlap, which is decoded to a single frame
33
+ # never overlap with the first latent frame
34
+ overlap_start = max(1, chunk_start - temporal_overlap - 1)
35
+ extra_frames = chunk_start - overlap_start
36
+ chunk_end = min(
37
+ chunk_start + temporal_tile_length - extra_frames,
38
+ total_latent_frames,
39
+ )
40
+
41
+ return overlap_start, chunk_end
42
+
43
+
44
+ def spatial_decode( # noqa
45
+ decoder: Decoder,
46
+ samples: torch.Tensor,
47
+ horizontal_tiles: int,
48
+ vertical_tiles: int,
49
+ overlap: int,
50
+ last_frame_fix: bool,
51
+ scale_factors: tuple[float, float, float],
52
+ timestep: float,
53
+ generator: torch.Generator,
54
+ ) -> torch.Tensor:
55
+ if last_frame_fix:
56
+ # Repeat the last frame along dimension 2 (frames)
57
+ # samples shape - [batch, channels, frames, height, width]
58
+ last_frame = samples[:, :, -1:, :, :]
59
+ samples = torch.cat([samples, last_frame], dim=2)
60
+
61
+ batch, _, frames, height, width = samples.shape
62
+ time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
63
+ image_frames = 1 + (frames - 1) * time_scale_factor
64
+
65
+ # Calculate output image dimensions
66
+ output_height = height * height_scale_factor
67
+ output_width = width * width_scale_factor
68
+
69
+ # Calculate tile sizes with overlap
70
+ base_tile_height = (height + (vertical_tiles - 1) * overlap) // vertical_tiles
71
+ base_tile_width = (width + (horizontal_tiles - 1) * overlap) // horizontal_tiles
72
+
73
+ # Initialize output tensor and weight tensor
74
+ # VAE decode returns images in format [batch, height, width, channels]
75
+ output = None
76
+ weights = None
77
+
78
+ target_device = samples.device
79
+ target_dtype = samples.dtype
80
+
81
+ output = torch.zeros(
82
+ (
83
+ batch,
84
+ 3,
85
+ image_frames,
86
+ output_height,
87
+ output_width,
88
+ ),
89
+ device=target_device,
90
+ dtype=target_dtype,
91
+ )
92
+ weights = torch.zeros(
93
+ (batch, 1, image_frames, output_height, output_width),
94
+ device=target_device,
95
+ dtype=target_dtype,
96
+ )
97
+
98
+ # Process each tile
99
+ for v in range(vertical_tiles):
100
+ for h in range(horizontal_tiles):
101
+ # Calculate tile boundaries
102
+ h_start = h * (base_tile_width - overlap)
103
+ v_start = v * (base_tile_height - overlap)
104
+
105
+ # Adjust end positions for edge tiles
106
+ h_end = min(h_start + base_tile_width, width) if h < horizontal_tiles - 1 else width
107
+ v_end = min(v_start + base_tile_height, height) if v < vertical_tiles - 1 else height
108
+
109
+ # Calculate actual tile dimensions
110
+ tile_height = v_end - v_start
111
+ tile_width = h_end - h_start
112
+
113
+ logging.info(f"Processing VAE decode tile at row {v}, col {h}:")
114
+ logging.info(f" Position: ({v_start}:{v_end}, {h_start}:{h_end})")
115
+ logging.info(f" Size: {tile_height}x{tile_width}")
116
+
117
+ # Extract tile
118
+ tile = samples[:, :, :, v_start:v_end, h_start:h_end]
119
+
120
+ # Decode the tile
121
+ decoded_tile = decoder.decode(tile, timestep, generator)
122
+
123
+ # Calculate output tile boundaries
124
+ out_h_start = v_start * height_scale_factor
125
+ out_h_end = v_end * height_scale_factor
126
+ out_w_start = h_start * width_scale_factor
127
+ out_w_end = h_end * width_scale_factor
128
+
129
+ # Create weight mask for this tile
130
+ tile_out_height = out_h_end - out_h_start
131
+ tile_out_width = out_w_end - out_w_start
132
+ tile_weights = torch.ones(
133
+ (batch, 1, image_frames, tile_out_height, tile_out_width),
134
+ device=decoded_tile.device,
135
+ dtype=decoded_tile.dtype,
136
+ )
137
+
138
+ # Calculate overlap regions in output space
139
+ overlap_out_h = overlap * height_scale_factor
140
+ overlap_out_w = overlap * width_scale_factor
141
+
142
+ # Apply horizontal blending weights
143
+ if h > 0: # Left overlap
144
+ h_blend = torch.linspace(0, 1, overlap_out_w, device=decoded_tile.device)
145
+ tile_weights[:, :, :, :, :overlap_out_w] *= h_blend
146
+ if h < horizontal_tiles - 1: # Right overlap
147
+ h_blend = torch.linspace(1, 0, overlap_out_w, device=decoded_tile.device)
148
+ tile_weights[:, :, :, :, -overlap_out_w:] *= h_blend
149
+
150
+ # Apply vertical blending weights
151
+ if v > 0: # Top overlap
152
+ v_blend = torch.linspace(0, 1, overlap_out_h, device=decoded_tile.device)
153
+ tile_weights[:, :, :, :overlap_out_h, :] *= v_blend.view(1, 1, 1, -1, 1)
154
+ if v < vertical_tiles - 1: # Bottom overlap
155
+ v_blend = torch.linspace(1, 0, overlap_out_h, device=decoded_tile.device)
156
+ tile_weights[:, :, :, -overlap_out_h:, :] *= v_blend.view(1, 1, 1, -1, 1)
157
+
158
+ # Add weighted tile to output
159
+ output[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += (decoded_tile * tile_weights).to(
160
+ target_device, target_dtype
161
+ )
162
+
163
+ # Add weights to weight tensor
164
+ weights[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += tile_weights.to(
165
+ target_device, target_dtype
166
+ )
167
+
168
+ # Normalize by weights
169
+ output /= weights + 1e-8
170
+ # LT_INTERNAL: changed from output[:-time_scale_factor, :, :]!
171
+ if last_frame_fix:
172
+ output = output[:, :, :-time_scale_factor, :, :]
173
+
174
+ return output
175
+
176
+
177
+ def decode_spatial_temporal(
178
+ decoder: Decoder,
179
+ samples: torch.ensor,
180
+ timestep: float,
181
+ generator: torch.Generator,
182
+ scale_factors: tuple[float, float, float],
183
+ spatial_tiles: int = 4,
184
+ spatial_overlap: int = 1,
185
+ temporal_tile_length: int = 16,
186
+ temporal_overlap: int = 1,
187
+ last_frame_fix: bool = False,
188
+ ) -> Generator[torch.Tensor, None, None]:
189
+ if temporal_tile_length < temporal_overlap + 1:
190
+ raise ValueError("Temporal tile length must be greater than temporal overlap + 1")
191
+
192
+ _, _, frames, _, _ = samples.shape
193
+ time_scale_factor, _, _ = scale_factors
194
+
195
+ # Process temporal chunks similar to reference function
196
+ total_latent_frames = frames
197
+ chunk_start = 0
198
+
199
+ previous_tile = None
200
+ while chunk_start < total_latent_frames:
201
+ # Calculate chunk boundaries
202
+ overlap_start, chunk_end = compute_chunk_boundaries(
203
+ chunk_start, temporal_tile_length, temporal_overlap, total_latent_frames
204
+ )
205
+
206
+ # units are latent frames
207
+ chunk_frames = chunk_end - overlap_start
208
+ logging.info(f"Processing temporal chunk: {overlap_start}:{chunk_end} ({chunk_frames} latent frames)")
209
+
210
+ # Extract tile
211
+ tile = samples[:, :, overlap_start:chunk_end]
212
+
213
+ # Decode the tile
214
+ decoded_tile = spatial_decode(
215
+ decoder,
216
+ tile,
217
+ spatial_tiles,
218
+ spatial_tiles,
219
+ spatial_overlap,
220
+ last_frame_fix,
221
+ scale_factors,
222
+ timestep,
223
+ generator,
224
+ )
225
+
226
+ if previous_tile is None:
227
+ previous_tile = decoded_tile
228
+ else:
229
+ # Drop first frame if needed (overlap)
230
+ if decoded_tile.shape[2] == 1:
231
+ raise ValueError("Dropping first frame but tile has only 1 frame")
232
+ decoded_tile = decoded_tile[:, :, 1:] # Drop first frame
233
+
234
+ # Create weight mask for this tile
235
+ # -1 is for dropped frame above
236
+ overlap_frames = temporal_overlap * time_scale_factor
237
+ frame_weights = torch.linspace(
238
+ 0,
239
+ 1,
240
+ overlap_frames + 2,
241
+ device=decoded_tile.device,
242
+ dtype=decoded_tile.dtype,
243
+ )[1:-1]
244
+ tile_weights = frame_weights.view(1, 1, -1, 1, 1)
245
+
246
+ previous_tile[:, :, -overlap_frames:] = (
247
+ previous_tile[:, :, -overlap_frames:] * (1 - tile_weights)
248
+ + decoded_tile[:, :, :overlap_frames] * tile_weights
249
+ )
250
+ resulting_tile = previous_tile[:, :, :-overlap_frames]
251
+ decoded_tile[:, :, :overlap_frames] = previous_tile[:, :, -overlap_frames:]
252
+ yield resulting_tile
253
+ previous_tile = decoded_tile
254
+
255
+ # Move to next chunk
256
+ chunk_start = chunk_end
257
+
258
+ yield decoded_tile
packages/ltx-core/src/ltx_core/loader/.ipynb_checkpoints/sd_ops-checkpoint.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Lightricks. All rights reserved.
2
+ # Created by Alexey Kravtsov
3
+
4
+ from dataclasses import dataclass, replace
5
+ #from typing import NamedTuple, Protocol, Self
6
+ from typing import NamedTuple, Protocol
7
+ from typing_extensions import Self
8
+
9
+ import torch
10
+
11
+
12
+ @dataclass(frozen=True, slots=True)
13
+ class ContentReplacement:
14
+ content: str
15
+ replacement: str
16
+
17
+
18
+ @dataclass(frozen=True, slots=True)
19
+ class ContentMatching:
20
+ prefix: str = ""
21
+ suffix: str = ""
22
+
23
+
24
+ class KeyValueOperationResult(NamedTuple):
25
+ new_key: str
26
+ new_value: torch.Tensor
27
+
28
+
29
+ class KeyValueOperation(Protocol):
30
+ def __call__(self, tensor_key: str, tensor_value: torch.Tensor) -> list[KeyValueOperationResult]: ...
31
+
32
+
33
+ @dataclass(frozen=True, slots=True)
34
+ class SDKeyValueOperation:
35
+ key_matcher: ContentMatching
36
+ kv_operation: KeyValueOperation
37
+
38
+
39
+ @dataclass(frozen=True, slots=True)
40
+ class SDOps:
41
+ """Immutable class representing state dict key operations."""
42
+
43
+ name: str
44
+ mapping: tuple[
45
+ ContentReplacement | ContentMatching | SDKeyValueOperation, ...
46
+ ] = () # Immutable tuple of (key, value) pairs
47
+
48
+ def with_replacement(self, content: str, replacement: str) -> Self:
49
+ """Create a new SDOps instance with the specified replacement added to the mapping."""
50
+
51
+ new_mapping = (*self.mapping, ContentReplacement(content, replacement))
52
+ return replace(self, mapping=new_mapping)
53
+
54
+ def with_matching(self, prefix: str = "", suffix: str = "") -> Self:
55
+ """Create a new SDOps instance with the specified prefix and suffix matching added to the mapping."""
56
+
57
+ new_mapping = (*self.mapping, ContentMatching(prefix, suffix))
58
+ return replace(self, mapping=new_mapping)
59
+
60
+ def with_kv_operation(
61
+ self,
62
+ operation: KeyValueOperation,
63
+ key_prefix: str = "",
64
+ key_suffix: str = "",
65
+ ) -> Self:
66
+ """Create a new SDOps instance with the specified value operation added to the mapping."""
67
+ key_matcher = ContentMatching(key_prefix, key_suffix)
68
+ sd_kv_operation = SDKeyValueOperation(key_matcher, operation)
69
+ new_mapping = (*self.mapping, sd_kv_operation)
70
+ return replace(self, mapping=new_mapping)
71
+
72
+ def apply_to_key(self, key: str) -> str | None:
73
+ """Apply the mapping to the given name."""
74
+ matchers = [content for content in self.mapping if isinstance(content, ContentMatching)]
75
+ valid = any(key.startswith(f.prefix) and key.endswith(f.suffix) for f in matchers)
76
+ if not valid:
77
+ return None
78
+
79
+ for replacement in self.mapping:
80
+ if not isinstance(replacement, ContentReplacement):
81
+ continue
82
+ if replacement.content in key:
83
+ key = key.replace(replacement.content, replacement.replacement)
84
+ return key
85
+
86
+ def apply_to_key_value(self, key: str, value: torch.Tensor) -> list[KeyValueOperationResult]:
87
+ """Apply the value operation to the given name and associated value."""
88
+ for operation in self.mapping:
89
+ if not isinstance(operation, SDKeyValueOperation):
90
+ continue
91
+ if key.startswith(operation.key_matcher.prefix) and key.endswith(operation.key_matcher.suffix):
92
+ return operation.kv_operation(key, value)
93
+ return [KeyValueOperationResult(key, value)]
94
+
95
+
96
+ # Predefined SDOps instances
97
+ LTXV_LORA_COMFY_RENAMING_MAP = (
98
+ SDOps("LTXV_LORA_COMFY_PREFIX_MAP").with_matching().with_replacement("diffusion_model.", "")
99
+ )
100
+
101
+ LTXV_LORA_COMFY_TARGET_MAP = (
102
+ SDOps("LTXV_LORA_COMFY_TARGET_MAP")
103
+ .with_matching()
104
+ .with_replacement("diffusion_model.", "")
105
+ .with_replacement(".lora_A.weight", ".weight")
106
+ .with_replacement(".lora_B.weight", ".weight")
107
+ )
packages/ltx-core/src/ltx_core/loader/__init__.py ADDED
File without changes
packages/ltx-core/src/ltx_core/loader/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (183 Bytes). View file
 
packages/ltx-core/src/ltx_core/loader/__pycache__/fuse_loras.cpython-310.pyc ADDED
Binary file (2.75 kB). View file
 
packages/ltx-core/src/ltx_core/loader/__pycache__/kernels.cpython-310.pyc ADDED
Binary file (1.58 kB). View file
 
packages/ltx-core/src/ltx_core/loader/__pycache__/module_ops.cpython-310.pyc ADDED
Binary file (558 Bytes). View file
 
packages/ltx-core/src/ltx_core/loader/__pycache__/primitives.cpython-310.pyc ADDED
Binary file (3.09 kB). View file
 
packages/ltx-core/src/ltx_core/loader/__pycache__/registry.cpython-310.pyc ADDED
Binary file (3.53 kB). View file
 
packages/ltx-core/src/ltx_core/loader/__pycache__/sd_ops.cpython-310.pyc ADDED
Binary file (4.36 kB). View file
 
packages/ltx-core/src/ltx_core/loader/__pycache__/sft_loader.cpython-310.pyc ADDED
Binary file (2.65 kB). View file
 
packages/ltx-core/src/ltx_core/loader/__pycache__/single_gpu_model_builder.cpython-310.pyc ADDED
Binary file (5.04 kB). View file
 
packages/ltx-core/src/ltx_core/loader/fuse_loras.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Lightricks. All rights reserved.
2
+ # Created by Alexey Kravtsov
3
+ import torch
4
+ import triton
5
+
6
+ from ltx_core.loader.kernels import fused_add_round_kernel
7
+ from ltx_core.loader.primitives import LoraStateDictWithStrength, StateDict
8
+
9
+ BLOCK_SIZE = 1024
10
+
11
+
12
+ def fused_add_round_launch(target_weight: torch.Tensor, original_weight: torch.Tensor, seed: int) -> torch.Tensor:
13
+ if original_weight.dtype == torch.float8_e4m3fn:
14
+ exponent_bits, mantissa_bits, exponent_bias = 4, 3, 7
15
+ elif original_weight.dtype == torch.float8_e5m2:
16
+ exponent_bits, mantissa_bits, exponent_bias = 5, 2, 15 # noqa: F841
17
+ else:
18
+ raise ValueError("Unsupported dtype")
19
+
20
+ if target_weight.dtype != torch.bfloat16:
21
+ raise ValueError("target_weight dtype must be bfloat16")
22
+
23
+ # Calculate grid and block sizes
24
+ n_elements = original_weight.numel()
25
+ grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
26
+
27
+ # Launch kernel
28
+ fused_add_round_kernel[grid](
29
+ original_weight,
30
+ target_weight,
31
+ seed,
32
+ n_elements,
33
+ exponent_bias,
34
+ mantissa_bits,
35
+ BLOCK_SIZE,
36
+ )
37
+ return target_weight
38
+
39
+
40
+ def calculate_weight_float8_(target_weights: torch.Tensor, original_weights: torch.Tensor) -> torch.Tensor:
41
+ result = fused_add_round_launch(target_weights, original_weights, seed=0).to(target_weights.dtype)
42
+ target_weights.copy_(result, non_blocking=True)
43
+ return target_weights
44
+
45
+
46
+ def _prepare_deltas(
47
+ lora_sd_and_strengths: list[LoraStateDictWithStrength], key: str, dtype: torch.dtype, device: torch.device
48
+ ) -> torch.Tensor | None:
49
+ deltas = []
50
+ prefix = key[: -len(".weight")]
51
+ key_a = f"{prefix}.lora_A.weight"
52
+ key_b = f"{prefix}.lora_B.weight"
53
+ for lsd, coef in lora_sd_and_strengths:
54
+ if key_a not in lsd.sd or key_b not in lsd.sd:
55
+ continue
56
+ product = torch.matmul(lsd.sd[key_b] * coef, lsd.sd[key_a])
57
+ deltas.append(product.to(dtype=dtype, device=device))
58
+ if len(deltas) == 0:
59
+ return None
60
+ elif len(deltas) == 1:
61
+ return deltas[0]
62
+ return torch.sum(torch.stack(deltas, dim=0), dim=0)
63
+
64
+
65
+ def apply_loras(
66
+ model_sd: StateDict,
67
+ lora_sd_and_strengths: list[LoraStateDictWithStrength],
68
+ dtype: torch.dtype,
69
+ destination_sd: StateDict | None = None,
70
+ ) -> StateDict:
71
+ sd = {}
72
+ if destination_sd is not None:
73
+ sd = destination_sd.sd
74
+ size = 0
75
+ device = torch.device("meta")
76
+ inner_dtypes = set()
77
+ for key, weight in model_sd.sd.items():
78
+ if weight is None:
79
+ continue
80
+ device = weight.device
81
+ target_dtype = dtype if dtype is not None else weight.dtype
82
+ deltas_dtype = target_dtype if target_dtype not in [torch.float8_e4m3fn, torch.float8_e5m2] else torch.bfloat16
83
+ deltas = _prepare_deltas(lora_sd_and_strengths, key, deltas_dtype, device)
84
+ if deltas is None:
85
+ if key in sd:
86
+ continue
87
+ deltas = weight.clone().to(dtype=target_dtype, device=device)
88
+ elif weight.dtype == torch.float8_e4m3fn:
89
+ if str(device).startswith("cuda"):
90
+ deltas = calculate_weight_float8_(deltas, weight)
91
+ else:
92
+ deltas.add_(weight.to(dtype=deltas.dtype, device=device))
93
+ elif weight.dtype == torch.bfloat16:
94
+ deltas.add_(weight)
95
+ else:
96
+ raise ValueError(f"Unsupported dtype: {weight.dtype}")
97
+ sd[key] = deltas.to(dtype=target_dtype)
98
+ inner_dtypes.add(target_dtype)
99
+ size += deltas.nbytes
100
+ if destination_sd is not None:
101
+ return destination_sd
102
+ return StateDict(sd, device, size, inner_dtypes)
packages/ltx-core/src/ltx_core/loader/kernels.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ruff: noqa: ANN001, ANN201, ERA001, N803, N806
2
+ # Copyright (c) 2025 Lightricks. All rights reserved.
3
+ # Created by Alexey Kravtsov
4
+ import triton
5
+ import triton.language as tl
6
+
7
+
8
+ @triton.jit
9
+ def fused_add_round_kernel(
10
+ x_ptr,
11
+ output_ptr, # contents will be added to the output
12
+ seed,
13
+ n_elements,
14
+ EXPONENT_BIAS,
15
+ MANTISSA_BITS,
16
+ BLOCK_SIZE: tl.constexpr,
17
+ ):
18
+ """
19
+ A kernel to upcast 8bit quantized weights to bfloat16 with stochastic rounding
20
+ and add them to bfloat16 output weights. Might be used to upcast original model weights
21
+ and to further add them to precalculated deltas coming from LoRAs.
22
+ """
23
+ # Get program ID and compute offsets
24
+ pid = tl.program_id(axis=0)
25
+ block_start = pid * BLOCK_SIZE
26
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
27
+ mask = offsets < n_elements
28
+
29
+ # Load data
30
+ x = tl.load(x_ptr + offsets, mask=mask)
31
+ rand_vals = tl.rand(seed, offsets) - 0.5
32
+
33
+ x = tl.cast(x, tl.float16)
34
+ delta = tl.load(output_ptr + offsets, mask=mask)
35
+ delta = tl.cast(delta, tl.float16)
36
+ x = x + delta
37
+
38
+ x_bits = tl.cast(x, tl.int16, bitcast=True)
39
+
40
+ # Calculate the exponent. Unbiased fp16 exponent is ((x_bits & 0x7C00) >> 10) - 15 for
41
+ # normal numbers and -14 for subnormals.
42
+ fp16_exponent_bits = (x_bits & 0x7C00) >> 10
43
+ fp16_normals = fp16_exponent_bits > 0
44
+ fp16_exponent = tl.where(fp16_normals, fp16_exponent_bits - 15, -14)
45
+
46
+ # Add the target dtype's exponent bias and clamp to the target dtype's exponent range.
47
+ exponent = fp16_exponent + EXPONENT_BIAS
48
+ MAX_EXPONENT = 2 * EXPONENT_BIAS + 1
49
+ exponent = tl.where(exponent > MAX_EXPONENT, MAX_EXPONENT, exponent)
50
+ exponent = tl.where(exponent < 0, 0, exponent)
51
+
52
+ # Normal ULP exponent, expressed as an fp16 exponent field:
53
+ # (exponent - EXPONENT_BIAS - MANTISSA_BITS) + 15
54
+ # Simplifies to: fp16_exponent - MANTISSA_BITS + 15
55
+ # See https://en.wikipedia.org/wiki/Unit_in_the_last_place
56
+ eps_exp = tl.maximum(0, tl.minimum(31, exponent - EXPONENT_BIAS - MANTISSA_BITS + 15))
57
+
58
+ # Calculate epsilon in the target dtype
59
+ eps_normal = tl.cast(tl.cast(eps_exp << 10, tl.int16), tl.float16, bitcast=True)
60
+
61
+ # Subnormal ULP: 2^(1 - EXPONENT_BIAS - MANTISSA_BITS) ->
62
+ # fp16 exponent bits: (1 - EXPONENT_BIAS - MANTISSA_BITS) + 15 =
63
+ # 16 - EXPONENT_BIAS - MANTISSA_BITS
64
+ eps_subnormal = tl.cast(tl.cast((16 - EXPONENT_BIAS - MANTISSA_BITS) << 10, tl.int16), tl.float16, bitcast=True)
65
+ eps = tl.where(exponent > 0, eps_normal, eps_subnormal)
66
+
67
+ # Apply zero mask to epsilon
68
+ eps = tl.where(x == 0, 0.0, eps)
69
+
70
+ # Apply stochastic rounding
71
+ output = tl.cast(x + rand_vals * eps, tl.bfloat16)
72
+
73
+ # Store the result
74
+ tl.store(output_ptr + offsets, output, mask=mask)
packages/ltx-core/src/ltx_core/loader/module_ops.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Lightricks. All rights reserved.
2
+ # Created by Alexey Kravtsov
3
+ from typing import Callable, NamedTuple
4
+
5
+ import torch
6
+
7
+
8
+ class ModuleOps(NamedTuple):
9
+ name: str
10
+ matcher: Callable[[torch.nn.Module], bool]
11
+ mutator: Callable[[torch.nn.Module], torch.nn.Module]
packages/ltx-core/src/ltx_core/loader/primitives.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Lightricks. All rights reserved.
2
+ # Created by Alexey Kravtsov
3
+ from dataclasses import dataclass
4
+ from typing import NamedTuple, Protocol
5
+
6
+ import torch
7
+
8
+ from ltx_core.loader.module_ops import ModuleOps
9
+ from ltx_core.loader.sd_ops import SDOps
10
+ from ltx_core.model.model_protocol import ModelType
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class StateDict:
15
+ sd: dict
16
+ device: torch.device
17
+ size: int
18
+ dtype: set[torch.dtype]
19
+
20
+ def footprint(self) -> tuple[int, torch.device]:
21
+ return self.size, self.device
22
+
23
+
24
+ class StateDictLoader(Protocol):
25
+ def metadata(self, path: str) -> dict:
26
+ """
27
+ Load metadata from path
28
+ """
29
+
30
+ def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict:
31
+ """
32
+ Load state dict from path or paths (for sharded model storage) and apply sd_ops
33
+ """
34
+
35
+
36
+ class ModelBuilderProtocol(Protocol[ModelType]):
37
+ def meta_model(self, config: dict, module_ops: list[ModuleOps] | None = None) -> ModelType: ...
38
+
39
+ def build(self, dtype: torch.dtype | None = None) -> ModelType:
40
+ """
41
+ Build the model
42
+ Args:
43
+ dtype: Target dtype for the model, if None, uses the dtype of the model_path model
44
+ Returns:
45
+ Model instance
46
+ """
47
+ ...
48
+
49
+
50
+ class LoRAAdaptableProtocol(Protocol):
51
+ def lora(self, lora_path: str, strength: float) -> "LoRAAdaptableProtocol":
52
+ pass
53
+
54
+
55
+ class LoraPathStrengthAndSDOps(NamedTuple):
56
+ path: str
57
+ strength: float
58
+ sd_ops: SDOps
59
+
60
+
61
+ class LoraStateDictWithStrength(NamedTuple):
62
+ state_dict: StateDict
63
+ strength: float
packages/ltx-core/src/ltx_core/loader/registry.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Lightricks. All rights reserved.
2
+ # Created by Alexey Kravtsov
3
+ import hashlib
4
+ import threading
5
+ from dataclasses import dataclass, field
6
+ from pathlib import Path
7
+ from typing import Protocol
8
+
9
+ from ltx_core.loader.primitives import StateDict
10
+ from ltx_core.loader.sd_ops import SDOps
11
+
12
+
13
+ class Registry(Protocol):
14
+ def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None: ...
15
+
16
+ def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ...
17
+
18
+ def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ...
19
+
20
+ def clear(self) -> None: ...
21
+
22
+
23
+ class DummyRegistry(Registry):
24
+ def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None:
25
+ pass
26
+
27
+ def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
28
+ pass
29
+
30
+ def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
31
+ pass
32
+
33
+ def clear(self) -> None:
34
+ pass
35
+
36
+
37
+ @dataclass
38
+ class StateDictRegistry(Registry):
39
+ _state_dicts: dict[str, StateDict] = field(default_factory=dict)
40
+ _lock: threading.Lock = field(default_factory=threading.Lock)
41
+
42
+ def _generate_id(self, paths: list[str], sd_ops: SDOps) -> str:
43
+ m = hashlib.sha256()
44
+ parts = [str(Path(p).resolve()) for p in paths]
45
+ if sd_ops is not None:
46
+ parts.append(sd_ops.name)
47
+ m.update("\0".join(parts).encode("utf-8"))
48
+ return m.hexdigest()
49
+
50
+ def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> str:
51
+ sd_id = self._generate_id(paths, sd_ops)
52
+ with self._lock:
53
+ if sd_id in self._state_dicts:
54
+ raise ValueError(f"State dict retrieved from {paths} with {sd_ops} already added, check with get first")
55
+ self._state_dicts[sd_id] = state_dict
56
+ return sd_id
57
+
58
+ def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
59
+ with self._lock:
60
+ return self._state_dicts.pop(self._generate_id(paths, sd_ops), None)
61
+
62
+ def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
63
+ with self._lock:
64
+ return self._state_dicts.get(self._generate_id(paths, sd_ops), None)
65
+
66
+ def clear(self) -> None:
67
+ with self._lock:
68
+ self._state_dicts.clear()
packages/ltx-core/src/ltx_core/loader/sd_ops.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Lightricks. All rights reserved.
2
+ # Created by Alexey Kravtsov
3
+
4
+ from dataclasses import dataclass, replace
5
+ #from typing import NamedTuple, Protocol, Self
6
+ from typing import NamedTuple, Protocol
7
+ from typing_extensions import Self
8
+
9
+ import torch
10
+
11
+
12
+ @dataclass(frozen=True, slots=True)
13
+ class ContentReplacement:
14
+ content: str
15
+ replacement: str
16
+
17
+
18
+ @dataclass(frozen=True, slots=True)
19
+ class ContentMatching:
20
+ prefix: str = ""
21
+ suffix: str = ""
22
+
23
+
24
+ class KeyValueOperationResult(NamedTuple):
25
+ new_key: str
26
+ new_value: torch.Tensor
27
+
28
+
29
+ class KeyValueOperation(Protocol):
30
+ def __call__(self, tensor_key: str, tensor_value: torch.Tensor) -> list[KeyValueOperationResult]: ...
31
+
32
+
33
+ @dataclass(frozen=True, slots=True)
34
+ class SDKeyValueOperation:
35
+ key_matcher: ContentMatching
36
+ kv_operation: KeyValueOperation
37
+
38
+
39
+ @dataclass(frozen=True, slots=True)
40
+ class SDOps:
41
+ """Immutable class representing state dict key operations."""
42
+
43
+ name: str
44
+ mapping: tuple[
45
+ ContentReplacement | ContentMatching | SDKeyValueOperation, ...
46
+ ] = () # Immutable tuple of (key, value) pairs
47
+
48
+ def with_replacement(self, content: str, replacement: str) -> Self:
49
+ """Create a new SDOps instance with the specified replacement added to the mapping."""
50
+
51
+ new_mapping = (*self.mapping, ContentReplacement(content, replacement))
52
+ return replace(self, mapping=new_mapping)
53
+
54
+ def with_matching(self, prefix: str = "", suffix: str = "") -> Self:
55
+ """Create a new SDOps instance with the specified prefix and suffix matching added to the mapping."""
56
+
57
+ new_mapping = (*self.mapping, ContentMatching(prefix, suffix))
58
+ return replace(self, mapping=new_mapping)
59
+
60
+ def with_kv_operation(
61
+ self,
62
+ operation: KeyValueOperation,
63
+ key_prefix: str = "",
64
+ key_suffix: str = "",
65
+ ) -> Self:
66
+ """Create a new SDOps instance with the specified value operation added to the mapping."""
67
+ key_matcher = ContentMatching(key_prefix, key_suffix)
68
+ sd_kv_operation = SDKeyValueOperation(key_matcher, operation)
69
+ new_mapping = (*self.mapping, sd_kv_operation)
70
+ return replace(self, mapping=new_mapping)
71
+
72
+ def apply_to_key(self, key: str) -> str | None:
73
+ """Apply the mapping to the given name."""
74
+ matchers = [content for content in self.mapping if isinstance(content, ContentMatching)]
75
+ valid = any(key.startswith(f.prefix) and key.endswith(f.suffix) for f in matchers)
76
+ if not valid:
77
+ return None
78
+
79
+ for replacement in self.mapping:
80
+ if not isinstance(replacement, ContentReplacement):
81
+ continue
82
+ if replacement.content in key:
83
+ key = key.replace(replacement.content, replacement.replacement)
84
+ return key
85
+
86
+ def apply_to_key_value(self, key: str, value: torch.Tensor) -> list[KeyValueOperationResult]:
87
+ """Apply the value operation to the given name and associated value."""
88
+ for operation in self.mapping:
89
+ if not isinstance(operation, SDKeyValueOperation):
90
+ continue
91
+ if key.startswith(operation.key_matcher.prefix) and key.endswith(operation.key_matcher.suffix):
92
+ return operation.kv_operation(key, value)
93
+ return [KeyValueOperationResult(key, value)]
94
+
95
+
96
+ # Predefined SDOps instances
97
+ LTXV_LORA_COMFY_RENAMING_MAP = (
98
+ SDOps("LTXV_LORA_COMFY_PREFIX_MAP").with_matching().with_replacement("diffusion_model.", "")
99
+ )
100
+
101
+ LTXV_LORA_COMFY_TARGET_MAP = (
102
+ SDOps("LTXV_LORA_COMFY_TARGET_MAP")
103
+ .with_matching()
104
+ .with_replacement("diffusion_model.", "")
105
+ .with_replacement(".lora_A.weight", ".weight")
106
+ .with_replacement(".lora_B.weight", ".weight")
107
+ )
packages/ltx-core/src/ltx_core/loader/sft_loader.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Lightricks. All rights reserved.
2
+ # Created by Alexey Kravtsov
3
+ import json
4
+
5
+ import safetensors
6
+ import torch
7
+
8
+ from ltx_core.loader.primitives import StateDict, StateDictLoader
9
+ from ltx_core.loader.sd_ops import SDOps
10
+
11
+
12
+ class SafetensorsStateDictLoader(StateDictLoader):
13
+ def metadata(self, path: str) -> dict:
14
+ raise NotImplementedError("Not implemented")
15
+
16
+ def load(self, path: str | list[str], sd_ops: SDOps, device: torch.device | None = None) -> StateDict:
17
+ """
18
+ Load state dict from path or paths (for sharded model storage) and apply sd_ops
19
+ """
20
+ sd = {}
21
+ size = 0
22
+ dtype = set()
23
+ device = device or torch.device("cpu")
24
+ model_paths = path if isinstance(path, list) else [path]
25
+ for shard_path in model_paths:
26
+ with safetensors.safe_open(shard_path, framework="pt", device=str(device)) as f:
27
+ safetensor_keys = f.keys()
28
+ for name in safetensor_keys:
29
+ expected_name = name if sd_ops is None else sd_ops.apply_to_key(name)
30
+ if expected_name is None:
31
+ continue
32
+ value = f.get_tensor(name).to(device=device, non_blocking=True, copy=False)
33
+ key_value_pairs = ((expected_name, value),)
34
+ if sd_ops is not None:
35
+ key_value_pairs = sd_ops.apply_to_key_value(expected_name, value)
36
+ for key, value in key_value_pairs:
37
+ size += value.nbytes
38
+ dtype.add(value.dtype)
39
+ sd[key] = value
40
+
41
+ return StateDict(sd=sd, device=device, size=size, dtype=dtype)
42
+
43
+
44
+ class SafetensorsModelStateDictLoader(StateDictLoader):
45
+ def __init__(self, weight_loader: SafetensorsStateDictLoader | None = None):
46
+ self.weight_loader = weight_loader if weight_loader is not None else SafetensorsStateDictLoader()
47
+
48
+ def metadata(self, path: str) -> dict:
49
+ with safetensors.safe_open(path, framework="pt") as f:
50
+ return json.loads(f.metadata()["config"])
51
+
52
+ def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict:
53
+ return self.weight_loader.load(path, sd_ops, device)
packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Lightricks. All rights reserved.
2
+ # Created by Alexey Kravtsov
3
+ import logging
4
+ from dataclasses import dataclass, field, replace
5
+ from typing import Generic
6
+
7
+ import torch
8
+
9
+ from ltx_core.loader.fuse_loras import apply_loras
10
+ from ltx_core.loader.module_ops import ModuleOps
11
+ from ltx_core.loader.primitives import (
12
+ LoRAAdaptableProtocol,
13
+ LoraPathStrengthAndSDOps,
14
+ LoraStateDictWithStrength,
15
+ ModelBuilderProtocol,
16
+ StateDict,
17
+ StateDictLoader,
18
+ )
19
+ from ltx_core.loader.registry import DummyRegistry, Registry
20
+ from ltx_core.loader.sd_ops import SDOps
21
+ from ltx_core.loader.sft_loader import SafetensorsModelStateDictLoader
22
+ from ltx_core.model.model_protocol import ModelConfigurator, ModelType
23
+
24
+ logger: logging.Logger = logging.getLogger(__name__)
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class SingleGPUModelBuilder(Generic[ModelType], ModelBuilderProtocol[ModelType], LoRAAdaptableProtocol):
29
+ model_class_configurator: type[ModelConfigurator[ModelType]]
30
+ model_path: str | tuple[str, ...]
31
+ model_sd_ops: SDOps | None = None
32
+ module_ops: tuple[ModuleOps, ...] = field(default_factory=tuple)
33
+ loras: tuple[LoraPathStrengthAndSDOps, ...] = field(default_factory=tuple)
34
+ model_loader: StateDictLoader = field(default_factory=SafetensorsModelStateDictLoader)
35
+ registry: Registry = field(default_factory=DummyRegistry)
36
+
37
+ def lora(self, lora_path: str, strength: float = 1.0, sd_ops: SDOps | None = None) -> "SingleGPUModelBuilder":
38
+ return replace(self, loras=(*self.loras, LoraPathStrengthAndSDOps(lora_path, strength, sd_ops)))
39
+
40
+ def model_config(self) -> dict:
41
+ first_shard_path = self.model_path[0] if isinstance(self.model_path, tuple) else self.model_path
42
+ return self.model_loader.metadata(first_shard_path)
43
+
44
+ def meta_model(self, config: dict, module_ops: tuple[ModuleOps, ...]) -> ModelType:
45
+ with torch.device("meta"):
46
+ model = self.model_class_configurator.from_config(config)
47
+ for module_op in module_ops:
48
+ if module_op.matcher(model):
49
+ model = module_op.mutator(model)
50
+ return model
51
+
52
+ def load_sd(
53
+ self, paths: list[str], registry: Registry, device: torch.device | None, sd_ops: SDOps | None = None
54
+ ) -> StateDict:
55
+ state_dict = registry.get(paths, sd_ops)
56
+ if state_dict is None:
57
+ state_dict = self.model_loader.load(paths, sd_ops=sd_ops, device=device)
58
+ registry.add(paths, sd_ops=sd_ops, state_dict=state_dict)
59
+ return state_dict
60
+
61
+ def _return_model(self, meta_model: ModelType, device: torch.device) -> ModelType:
62
+ uninitialized_params = [name for name, param in meta_model.named_parameters() if str(param.device) == "meta"]
63
+ uninitialized_buffers = [name for name, buffer in meta_model.named_buffers() if str(buffer.device) == "meta"]
64
+ if uninitialized_params or uninitialized_buffers:
65
+ logger.warning(f"Uninitialized parameters or buffers: {uninitialized_params + uninitialized_buffers}")
66
+ return meta_model
67
+ retval = meta_model.to(device)
68
+ return retval
69
+
70
+ def build(self, device: torch.device | None = None, dtype: torch.dtype | None = None) -> ModelType:
71
+ device = torch.device("cuda") if device is None else device
72
+ config = self.model_config()
73
+ meta_model = self.meta_model(config, self.module_ops)
74
+ model_paths = self.model_path if isinstance(self.model_path, tuple) else [self.model_path]
75
+ model_state_dict = self.load_sd(model_paths, sd_ops=self.model_sd_ops, registry=self.registry, device=device)
76
+
77
+ lora_strengths = [lora.strength for lora in self.loras]
78
+ if not lora_strengths or (min(lora_strengths) == 0 and max(lora_strengths) == 0):
79
+ sd = model_state_dict.sd
80
+ if dtype is not None:
81
+ sd = {key: value.to(dtype=dtype) for key, value in model_state_dict.sd.items()}
82
+ meta_model.load_state_dict(sd, strict=False, assign=True)
83
+ return self._return_model(meta_model, device)
84
+
85
+ lora_state_dicts = [
86
+ self.load_sd([lora.path], sd_ops=lora.sd_ops, registry=self.registry, device=device) for lora in self.loras
87
+ ]
88
+ lora_sd_and_strengths = [
89
+ LoraStateDictWithStrength(sd, strength)
90
+ for sd, strength in zip(lora_state_dicts, lora_strengths, strict=True)
91
+ ]
92
+ final_sd = apply_loras(
93
+ model_sd=model_state_dict,
94
+ lora_sd_and_strengths=lora_sd_and_strengths,
95
+ dtype=dtype,
96
+ destination_sd=model_state_dict if isinstance(self.registry, DummyRegistry) else None,
97
+ )
98
+ meta_model.load_state_dict(final_sd.sd, strict=False, assign=True)
99
+ return self._return_model(meta_model, device)
packages/ltx-core/src/ltx_core/model/.ipynb_checkpoints/model_ledger-checkpoint.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import replace
2
+ # from typing import Self
3
+ from typing_extensions import Self
4
+
5
+ import torch
6
+
7
+ from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
8
+ from ltx_core.loader.registry import DummyRegistry, Registry
9
+ from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder
10
+ from ltx_core.model.audio_vae.audio_vae import Decoder as AudioDecoder
11
+ from ltx_core.model.audio_vae.model_configurator import (
12
+ AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
13
+ VOCODER_COMFY_KEYS_FILTER,
14
+ VocoderConfigurator,
15
+ )
16
+ from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator as AudioDecoderConfigurator
17
+ from ltx_core.model.audio_vae.vocoder import Vocoder
18
+ from ltx_core.model.clip.gemma.encoders.av_encoder import (
19
+ AV_GEMMA_TEXT_ENCODER_KEY_OPS,
20
+ AVGemmaTextEncoderModel,
21
+ AVGemmaTextEncoderModelConfigurator,
22
+ )
23
+ from ltx_core.model.clip.gemma.encoders.base_encoder import module_ops_from_gemma_root
24
+ from ltx_core.model.transformer.model import X0Model
25
+ from ltx_core.model.transformer.model_configurator import (
26
+ LTXV_MODEL_COMFY_RENAMING_MAP,
27
+ LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP,
28
+ UPCAST_DURING_INFERENCE,
29
+ LTXModelConfigurator,
30
+ )
31
+ from ltx_core.model.upsampler.model import LatentUpsampler
32
+ from ltx_core.model.upsampler.model_configurator import LatentUpsamplerConfigurator
33
+ from ltx_core.model.video_vae.model_configurator import (
34
+ VAE_DECODER_COMFY_KEYS_FILTER,
35
+ VAE_ENCODER_COMFY_KEYS_FILTER,
36
+ VAEDecoderConfigurator,
37
+ VAEEncoderConfigurator,
38
+ )
39
+ from ltx_core.model.video_vae.video_vae import Decoder as VideoDecoder
40
+ from ltx_core.model.video_vae.video_vae import Encoder as VideoEncoder
41
+
42
+
43
+ class ModelLedger:
44
+ """
45
+ Central coordinator for loading, caching, and freeing models used in an LTX pipeline.
46
+ The ledger wires together multiple model builders (transformer, video VAE encoder/decoder,
47
+ audio VAE decoder, vocoder, text encoder, and optional latent upsampler) and exposes
48
+ the resulting models as lazily constructed, cached attributes.
49
+
50
+ ### Caching behavior
51
+
52
+ Each model attribute (e.g. :attr:`transformer`, :attr:`video_decoder`, :attr:`text_encoder`)
53
+ is implemented as a :func:`functools.cached_property`. The first time one of these
54
+ attributes is accessed, the corresponding builder loads weights from the
55
+ :class:`~ltx_core.loader.registry.StateDictRegistry`, instantiates the model on CPU with
56
+ the configured ``dtype``, moves it to ``self.device``, and stores the result in
57
+ the instance ``__dict__``. Subsequent accesses reuse the same model instance until it is
58
+ explicitly cleared via :meth:`clear_vram`.
59
+
60
+ ### Constructor parameters
61
+
62
+ dtype:
63
+ Torch dtype used when constructing all models (e.g. ``torch.float16``).
64
+ device:
65
+ Target device to which models are moved after construction (e.g. ``torch.device("cuda")``).
66
+ checkpoint_path:
67
+ Path to a checkpoint directory or file containing the core model weights
68
+ (transformer, video VAE, audio VAE, text encoder, vocoder). If ``None``, the
69
+ corresponding builders are not created and accessing those properties will raise
70
+ a :class:`ValueError`.
71
+ gemma_root_path:
72
+ Base path to Gemma-compatible CLIP/text encoder weights. Required to
73
+ initialize the text encoder builder; if omitted, :attr:`text_encoder` cannot be used.
74
+ spatial_upsampler_path:
75
+ Optional path to a latent upsampler checkpoint. If provided, the
76
+ :attr:`upsampler` property becomes available; otherwise accessing it raises
77
+ a :class:`ValueError`.
78
+ loras:
79
+ Optional collection of LoRA configurations (paths, strengths, and key operations)
80
+ that are applied on top of the base transformer weights when building the model.
81
+
82
+ ### Memory management
83
+
84
+ ``clear_ram()``
85
+ Clears the underlying :class:`Registry` cache of state dicts and triggers a
86
+ Python garbage collection pass. Use this when you no longer need to construct new
87
+ models from the currently loaded checkpoints and want to free host (CPU) memory.
88
+ ``clear_vram()``
89
+ Drops the cached model instances stored by the ``@cached_property`` attributes from
90
+ this ledger (by removing them from ``self.__dict__``) and calls
91
+ :func:`torch.cuda.empty_cache`. Use this when you want to release GPU memory;
92
+ subsequent access to a model property will rebuild the model from the registry
93
+ while keeping the existing builder configuration.
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ dtype: torch.dtype,
99
+ device: torch.device,
100
+ checkpoint_path: str | None = None,
101
+ gemma_root_path: str | None = None,
102
+ spatial_upsampler_path: str | None = None,
103
+ loras: LoraPathStrengthAndSDOps | None = None,
104
+ registry: Registry | None = None,
105
+ fp8transformer: bool = False,
106
+ local_files_only: bool = True
107
+ ):
108
+ self.dtype = dtype
109
+ self.device = device
110
+ self.checkpoint_path = checkpoint_path
111
+ self.gemma_root_path = gemma_root_path
112
+ self.spatial_upsampler_path = spatial_upsampler_path
113
+ self.loras = loras or ()
114
+ self.registry = registry or DummyRegistry()
115
+ self.fp8transformer = fp8transformer
116
+ self.local_files_only = local_files_only
117
+ self.build_model_builders()
118
+
119
+ def build_model_builders(self) -> None:
120
+ if self.checkpoint_path is not None:
121
+ self.transformer_builder = Builder(
122
+ model_path=self.checkpoint_path,
123
+ model_class_configurator=LTXModelConfigurator,
124
+ model_sd_ops=LTXV_MODEL_COMFY_RENAMING_MAP,
125
+ loras=tuple(self.loras),
126
+ registry=self.registry,
127
+ )
128
+
129
+ self.vae_decoder_builder = Builder(
130
+ model_path=self.checkpoint_path,
131
+ model_class_configurator=VAEDecoderConfigurator,
132
+ model_sd_ops=VAE_DECODER_COMFY_KEYS_FILTER,
133
+ registry=self.registry,
134
+ )
135
+
136
+ self.vae_encoder_builder = Builder(
137
+ model_path=self.checkpoint_path,
138
+ model_class_configurator=VAEEncoderConfigurator,
139
+ model_sd_ops=VAE_ENCODER_COMFY_KEYS_FILTER,
140
+ registry=self.registry,
141
+ )
142
+
143
+ self.audio_decoder_builder = Builder(
144
+ model_path=self.checkpoint_path,
145
+ model_class_configurator=AudioDecoderConfigurator,
146
+ model_sd_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
147
+ registry=self.registry,
148
+ )
149
+
150
+ self.vocoder_builder = Builder(
151
+ model_path=self.checkpoint_path,
152
+ model_class_configurator=VocoderConfigurator,
153
+ model_sd_ops=VOCODER_COMFY_KEYS_FILTER,
154
+ registry=self.registry,
155
+ )
156
+
157
+ if self.gemma_root_path is not None:
158
+ self.text_encoder_builder = Builder(
159
+ model_path=self.checkpoint_path,
160
+ model_class_configurator=AVGemmaTextEncoderModelConfigurator,
161
+ model_sd_ops=AV_GEMMA_TEXT_ENCODER_KEY_OPS,
162
+ registry=self.registry,
163
+ module_ops=module_ops_from_gemma_root(self.gemma_root_path, self.local_files_only),
164
+ )
165
+
166
+ if self.spatial_upsampler_path is not None:
167
+ self.upsampler_builder = Builder(
168
+ model_path=self.spatial_upsampler_path,
169
+ model_class_configurator=LatentUpsamplerConfigurator,
170
+ registry=self.registry,
171
+ )
172
+
173
+ def _target_device(self) -> torch.device:
174
+ if isinstance(self.registry, DummyRegistry) or self.registry is None:
175
+ return self.device
176
+ else:
177
+ return torch.device("cpu")
178
+
179
+ def with_loras(self, loras: LoraPathStrengthAndSDOps) -> Self:
180
+ return ModelLedger(
181
+ dtype=self.dtype,
182
+ device=self.device,
183
+ checkpoint_path=self.checkpoint_path,
184
+ gemma_root_path=self.gemma_root_path,
185
+ spatial_upsampler_path=self.spatial_upsampler_path,
186
+ loras=(*self.loras, *loras),
187
+ registry=self.registry,
188
+ fp8transformer=self.fp8transformer,
189
+ )
190
+
191
+ def transformer(self) -> X0Model:
192
+ if not hasattr(self, "transformer_builder"):
193
+ raise ValueError(
194
+ "Transformer not initialized. Please provide a checkpoint path to the ModelLedger constructor."
195
+ )
196
+ if self.fp8transformer:
197
+ fp8_builder = replace(
198
+ self.transformer_builder,
199
+ module_ops=(UPCAST_DURING_INFERENCE,),
200
+ model_sd_ops=LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP,
201
+ )
202
+ return X0Model(fp8_builder.build(device=self._target_device())).to(self.device)
203
+ else:
204
+ return X0Model(self.transformer_builder.build(device=self._target_device(), dtype=self.dtype)).to(
205
+ self.device
206
+ )
207
+
208
+ def video_decoder(self) -> VideoDecoder:
209
+ if not hasattr(self, "vae_decoder_builder"):
210
+ raise ValueError(
211
+ "Video decoder not initialized. Please provide a checkpoint path to the ModelLedger constructor."
212
+ )
213
+
214
+ return self.vae_decoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device)
215
+
216
+ def video_encoder(self) -> VideoEncoder:
217
+ if not hasattr(self, "vae_encoder_builder"):
218
+ raise ValueError(
219
+ "Video encoder not initialized. Please provide a checkpoint path to the ModelLedger constructor."
220
+ )
221
+
222
+ return self.vae_encoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device)
223
+
224
+ def text_encoder(self) -> AVGemmaTextEncoderModel:
225
+ if not hasattr(self, "text_encoder_builder"):
226
+ raise ValueError(
227
+ "Text encoder not initialized. Please provide a checkpoint path and gemma root path to the "
228
+ "ModelLedger constructor."
229
+ )
230
+
231
+ return self.text_encoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device)
232
+
233
+ def audio_decoder(self) -> AudioDecoder:
234
+ if not hasattr(self, "audio_decoder_builder"):
235
+ raise ValueError(
236
+ "Audio decoder not initialized. Please provide a checkpoint path to the ModelLedger constructor."
237
+ )
238
+
239
+ return self.audio_decoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device)
240
+
241
+ def vocoder(self) -> Vocoder:
242
+ if not hasattr(self, "vocoder_builder"):
243
+ raise ValueError(
244
+ "Vocoder not initialized. Please provide a checkpoint path to the ModelLedger constructor."
245
+ )
246
+
247
+ return self.vocoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device)
248
+
249
+ def spatial_upsampler(self) -> LatentUpsampler:
250
+ if not hasattr(self, "upsampler_builder"):
251
+ raise ValueError("Upsampler not initialized. Please provide upsampler path to the ModelLedger constructor.")
252
+
253
+ return self.upsampler_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device)
packages/ltx-core/src/ltx_core/model/__init__.py ADDED
File without changes
packages/ltx-core/src/ltx_core/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (182 Bytes). View file
 
packages/ltx-core/src/ltx_core/model/__pycache__/model_ledger.cpython-310.pyc ADDED
Binary file (9.16 kB). View file
 
packages/ltx-core/src/ltx_core/model/__pycache__/model_protocol.cpython-310.pyc ADDED
Binary file (744 Bytes). View file
 
packages/ltx-core/src/ltx_core/model/audio_vae/__init__.py ADDED
File without changes
packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (192 Bytes). View file
 
packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/attention.cpython-310.pyc ADDED
Binary file (2.37 kB). View file
 
packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/audio_vae.cpython-310.pyc ADDED
Binary file (13.6 kB). View file
 
packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/causal_conv_2d.cpython-310.pyc ADDED
Binary file (3.28 kB). View file
 
packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/causality_axis.cpython-310.pyc ADDED
Binary file (574 Bytes). View file
 
packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/downsample.cpython-310.pyc ADDED
Binary file (3.18 kB). View file
 
packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/model_configurator.cpython-310.pyc ADDED
Binary file (4.41 kB). View file
 
packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/ops.cpython-310.pyc ADDED
Binary file (3 kB). View file
 
packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/resnet.cpython-310.pyc ADDED
Binary file (4.06 kB). View file
 
packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/upsample.cpython-310.pyc ADDED
Binary file (2.88 kB). View file
 
packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/vocoder.cpython-310.pyc ADDED
Binary file (4.89 kB). View file
 
packages/ltx-core/src/ltx_core/model/audio_vae/attention.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ import torch
4
+
5
+ from ltx_core.model.common.normalization import NormType, build_normalization_layer
6
+
7
+
8
+ class AttentionType(Enum):
9
+ """Enum for specifying the attention mechanism type."""
10
+
11
+ VANILLA = "vanilla"
12
+ LINEAR = "linear"
13
+ NONE = "none"
14
+
15
+
16
+ class AttnBlock(torch.nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_channels: int,
20
+ norm_type: NormType = NormType.GROUP,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.in_channels = in_channels
24
+
25
+ self.norm = build_normalization_layer(in_channels, normtype=norm_type)
26
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
27
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
28
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
29
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
30
+
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
32
+ h_ = x
33
+ h_ = self.norm(h_)
34
+ q = self.q(h_)
35
+ k = self.k(h_)
36
+ v = self.v(h_)
37
+
38
+ # compute attention
39
+ b, c, h, w = q.shape
40
+ q = q.reshape(b, c, h * w).contiguous()
41
+ q = q.permute(0, 2, 1).contiguous() # b,hw,c
42
+ k = k.reshape(b, c, h * w).contiguous() # b,c,hw
43
+ w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
44
+ w_ = w_ * (int(c) ** (-0.5))
45
+ w_ = torch.nn.functional.softmax(w_, dim=2)
46
+
47
+ # attend to values
48
+ v = v.reshape(b, c, h * w).contiguous()
49
+ w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
50
+ h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
51
+ h_ = h_.reshape(b, c, h, w).contiguous()
52
+
53
+ h_ = self.proj_out(h_)
54
+
55
+ return x + h_
56
+
57
+
58
+ def make_attn(
59
+ in_channels: int,
60
+ attn_type: AttentionType = AttentionType.VANILLA,
61
+ norm_type: NormType = NormType.GROUP,
62
+ ) -> torch.nn.Module:
63
+ match attn_type:
64
+ case AttentionType.VANILLA:
65
+ return AttnBlock(in_channels, norm_type=norm_type)
66
+ case AttentionType.NONE:
67
+ return torch.nn.Identity()
68
+ case AttentionType.LINEAR:
69
+ raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
70
+ case _:
71
+ raise ValueError(f"Unknown attention type: {attn_type}")
packages/ltx-core/src/ltx_core/model/audio_vae/audio_vae.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Lightricks. All rights reserved.
2
+ # Created by Ivan Zorin
3
+
4
+
5
+ from typing import Set, Tuple
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from ltx_core.model.audio_vae.attention import AttentionType, make_attn
11
+ from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
12
+ from ltx_core.model.audio_vae.causality_axis import CausalityAxis
13
+ from ltx_core.model.audio_vae.downsample import build_downsampling_path
14
+ from ltx_core.model.audio_vae.ops import PerChannelStatistics
15
+ from ltx_core.model.audio_vae.resnet import ResnetBlock
16
+ from ltx_core.model.audio_vae.upsample import build_upsampling_path
17
+ from ltx_core.model.common.normalization import NormType, build_normalization_layer
18
+ from ltx_core.pipeline.components.patchifiers import AudioPatchifier
19
+ from ltx_core.pipeline.components.protocols import AudioLatentShape
20
+
21
+ LATENT_DOWNSAMPLE_FACTOR = 4
22
+
23
+
24
+ def build_mid_block(
25
+ channels: int,
26
+ temb_channels: int,
27
+ dropout: float,
28
+ norm_type: NormType,
29
+ causality_axis: CausalityAxis,
30
+ attn_type: AttentionType,
31
+ add_attention: bool,
32
+ ) -> torch.nn.Module:
33
+ """Build the middle block with two ResNet blocks and optional attention."""
34
+ mid = torch.nn.Module()
35
+ mid.block_1 = ResnetBlock(
36
+ in_channels=channels,
37
+ out_channels=channels,
38
+ temb_channels=temb_channels,
39
+ dropout=dropout,
40
+ norm_type=norm_type,
41
+ causality_axis=causality_axis,
42
+ )
43
+ mid.attn_1 = make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else torch.nn.Identity()
44
+ mid.block_2 = ResnetBlock(
45
+ in_channels=channels,
46
+ out_channels=channels,
47
+ temb_channels=temb_channels,
48
+ dropout=dropout,
49
+ norm_type=norm_type,
50
+ causality_axis=causality_axis,
51
+ )
52
+ return mid
53
+
54
+
55
+ def run_mid_block(mid: torch.nn.Module, features: torch.Tensor) -> torch.Tensor:
56
+ """Run features through the middle block."""
57
+ features = mid.block_1(features, temb=None)
58
+ features = mid.attn_1(features)
59
+ return mid.block_2(features, temb=None)
60
+
61
+
62
+ class Encoder(torch.nn.Module):
63
+ """
64
+ Encoder that compresses audio spectrograms into latent representations.
65
+
66
+ The encoder uses a series of downsampling blocks with residual connections,
67
+ attention mechanisms, and configurable causal convolutions.
68
+ """
69
+
70
+ def __init__( # noqa: PLR0913
71
+ self,
72
+ *,
73
+ ch: int,
74
+ ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
75
+ num_res_blocks: int,
76
+ attn_resolutions: Set[int],
77
+ dropout: float = 0.0,
78
+ resamp_with_conv: bool = True,
79
+ in_channels: int,
80
+ resolution: int,
81
+ z_channels: int,
82
+ double_z: bool = True,
83
+ attn_type: AttentionType = AttentionType.VANILLA,
84
+ mid_block_add_attention: bool = True,
85
+ norm_type: NormType = NormType.GROUP,
86
+ causality_axis: CausalityAxis = CausalityAxis.WIDTH,
87
+ sample_rate: int = 16000,
88
+ mel_hop_length: int = 160,
89
+ n_fft: int = 1024,
90
+ is_causal: bool = True,
91
+ mel_bins: int = 64,
92
+ **_ignore_kwargs,
93
+ ) -> None:
94
+ """
95
+ Initialize the Encoder.
96
+
97
+ Args:
98
+ Arguments are configuration parameters, loaded from the audio VAE checkpoint config
99
+ (audio_vae.model.params.ddconfig):
100
+
101
+ ch: Base number of feature channels used in the first convolution layer.
102
+ ch_mult: Multiplicative factors for the number of channels at each resolution level.
103
+ num_res_blocks: Number of residual blocks to use at each resolution level.
104
+ attn_resolutions: Spatial resolutions (e.g., in time/frequency) at which to apply attention.
105
+ resolution: Input spatial resolution of the spectrogram (height, width).
106
+ z_channels: Number of channels in the latent representation.
107
+ norm_type: Normalization layer type to use within the network (e.g., group, batch).
108
+ causality_axis: Axis along which convolutions should be causal (e.g., time axis).
109
+ sample_rate: Audio sample rate in Hz for the input signals.
110
+ mel_hop_length: Hop length used when computing the mel spectrogram.
111
+ n_fft: FFT size used to compute the spectrogram.
112
+ mel_bins: Number of mel-frequency bins in the input spectrogram.
113
+ in_channels: Number of channels in the input spectrogram tensor.
114
+ double_z: If True, predict both mean and log-variance (doubling latent channels).
115
+ is_causal: If True, use causal convolutions suitable for streaming setups.
116
+ dropout: Dropout probability used in residual and mid blocks.
117
+ attn_type: Type of attention mechanism to use in attention blocks.
118
+ resamp_with_conv: If True, perform resolution changes using strided convolutions.
119
+ mid_block_add_attention: If True, add an attention block in the mid-level of the encoder.
120
+ """
121
+ super().__init__()
122
+
123
+ self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
124
+ self.sample_rate = sample_rate
125
+ self.mel_hop_length = mel_hop_length
126
+ self.n_fft = n_fft
127
+ self.is_causal = is_causal
128
+ self.mel_bins = mel_bins
129
+
130
+ self.patchifier = AudioPatchifier(
131
+ patch_size=1,
132
+ audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
133
+ sample_rate=sample_rate,
134
+ hop_length=mel_hop_length,
135
+ is_causal=is_causal,
136
+ )
137
+
138
+ self.ch = ch
139
+ self.temb_ch = 0
140
+ self.num_resolutions = len(ch_mult)
141
+ self.num_res_blocks = num_res_blocks
142
+ self.resolution = resolution
143
+ self.in_channels = in_channels
144
+ self.z_channels = z_channels
145
+ self.double_z = double_z
146
+ self.norm_type = norm_type
147
+ self.causality_axis = causality_axis
148
+ self.attn_type = attn_type
149
+
150
+ # downsampling
151
+ self.conv_in = make_conv2d(
152
+ in_channels,
153
+ self.ch,
154
+ kernel_size=3,
155
+ stride=1,
156
+ causality_axis=self.causality_axis,
157
+ )
158
+
159
+ self.non_linearity = torch.nn.SiLU()
160
+
161
+ self.down, block_in = build_downsampling_path(
162
+ ch=ch,
163
+ ch_mult=ch_mult,
164
+ num_resolutions=self.num_resolutions,
165
+ num_res_blocks=num_res_blocks,
166
+ resolution=resolution,
167
+ temb_channels=self.temb_ch,
168
+ dropout=dropout,
169
+ norm_type=self.norm_type,
170
+ causality_axis=self.causality_axis,
171
+ attn_type=self.attn_type,
172
+ attn_resolutions=attn_resolutions,
173
+ resamp_with_conv=resamp_with_conv,
174
+ )
175
+
176
+ self.mid = build_mid_block(
177
+ channels=block_in,
178
+ temb_channels=self.temb_ch,
179
+ dropout=dropout,
180
+ norm_type=self.norm_type,
181
+ causality_axis=self.causality_axis,
182
+ attn_type=self.attn_type,
183
+ add_attention=mid_block_add_attention,
184
+ )
185
+
186
+ self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
187
+ self.conv_out = make_conv2d(
188
+ block_in,
189
+ 2 * z_channels if double_z else z_channels,
190
+ kernel_size=3,
191
+ stride=1,
192
+ causality_axis=self.causality_axis,
193
+ )
194
+
195
+ def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:
196
+ """
197
+ Encode audio spectrogram into latent representations.
198
+
199
+ Args:
200
+ spectrogram: Input spectrogram of shape (batch, channels, time, frequency)
201
+
202
+ Returns:
203
+ Encoded latent representation of shape (batch, channels, frames, mel_bins)
204
+ """
205
+ h = self.conv_in(spectrogram)
206
+ h = self._run_downsampling_path(h)
207
+ h = run_mid_block(self.mid, h)
208
+ h = self._finalize_output(h)
209
+
210
+ return self._normalize_latents(h)
211
+
212
+ def _run_downsampling_path(self, h: torch.Tensor) -> torch.Tensor:
213
+ for level in range(self.num_resolutions):
214
+ stage = self.down[level]
215
+ for block_idx in range(self.num_res_blocks):
216
+ h = stage.block[block_idx](h, temb=None)
217
+ if stage.attn:
218
+ h = stage.attn[block_idx](h)
219
+
220
+ if level != self.num_resolutions - 1:
221
+ h = stage.downsample(h)
222
+
223
+ return h
224
+
225
+ def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:
226
+ h = self.norm_out(h)
227
+ h = self.non_linearity(h)
228
+ return self.conv_out(h)
229
+
230
+ def _normalize_latents(self, latent_output: torch.Tensor) -> torch.Tensor:
231
+ """
232
+ Normalize encoder latents using per-channel statistics.
233
+
234
+ When the encoder is configured with ``double_z=True``, the final
235
+ convolution produces twice the number of latent channels, typically
236
+ interpreted as two concatenated tensors along the channel dimension
237
+ (e.g., mean and variance or other auxiliary parameters).
238
+
239
+ This method intentionally uses only the first half of the channels
240
+ (the "mean" component) as input to the patchifier and normalization
241
+ logic. The remaining channels are left unchanged by this method and
242
+ are expected to be consumed elsewhere in the VAE pipeline.
243
+
244
+ If ``double_z=False``, the encoder output already contains only the
245
+ mean latents and the chunking operation simply returns that tensor.
246
+ """
247
+ means = torch.chunk(latent_output, 2, dim=1)[0]
248
+ latent_shape = AudioLatentShape(
249
+ batch=means.shape[0],
250
+ channels=means.shape[1],
251
+ frames=means.shape[2],
252
+ mel_bins=means.shape[3],
253
+ )
254
+ latent_patched = self.patchifier.patchify(means)
255
+ latent_normalized = self.per_channel_statistics.normalize(latent_patched)
256
+ return self.patchifier.unpatchify(latent_normalized, latent_shape)
257
+
258
+
259
+ class Decoder(torch.nn.Module):
260
+ """
261
+ Symmetric decoder that reconstructs audio spectrograms from latent features.
262
+
263
+ The decoder mirrors the encoder structure with configurable channel multipliers,
264
+ attention resolutions, and causal convolutions.
265
+ """
266
+
267
+ def __init__( # noqa: PLR0913
268
+ self,
269
+ *,
270
+ ch: int,
271
+ out_ch: int,
272
+ ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
273
+ num_res_blocks: int,
274
+ attn_resolutions: Set[int],
275
+ resolution: int,
276
+ z_channels: int,
277
+ norm_type: NormType = NormType.GROUP,
278
+ causality_axis: CausalityAxis = CausalityAxis.WIDTH,
279
+ dropout: float = 0.0,
280
+ mid_block_add_attention: bool = True,
281
+ sample_rate: int = 16000,
282
+ mel_hop_length: int = 160,
283
+ is_causal: bool = True,
284
+ mel_bins: int | None = None,
285
+ ) -> None:
286
+ """
287
+ Initialize the Decoder.
288
+
289
+ Args:
290
+ Arguments are configuration parameters, loaded from the audio VAE checkpoint config
291
+ (audio_vae.model.params.ddconfig):
292
+ - ch, out_ch, ch_mult, num_res_blocks, attn_resolutions
293
+ - resolution, z_channels
294
+ - norm_type, causality_axis
295
+ """
296
+ super().__init__()
297
+
298
+ # Internal behavioural defaults that are not driven by the checkpoint.
299
+ resamp_with_conv = True
300
+ attn_type = AttentionType.VANILLA
301
+
302
+ # Per-channel statistics for denormalizing latents
303
+ self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
304
+ self.sample_rate = sample_rate
305
+ self.mel_hop_length = mel_hop_length
306
+ self.is_causal = is_causal
307
+ self.mel_bins = mel_bins
308
+ self.patchifier = AudioPatchifier(
309
+ patch_size=1,
310
+ audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
311
+ sample_rate=sample_rate,
312
+ hop_length=mel_hop_length,
313
+ is_causal=is_causal,
314
+ )
315
+
316
+ self.ch = ch
317
+ self.temb_ch = 0
318
+ self.num_resolutions = len(ch_mult)
319
+ self.num_res_blocks = num_res_blocks
320
+ self.resolution = resolution
321
+ self.out_ch = out_ch
322
+ self.give_pre_end = False
323
+ self.tanh_out = False
324
+ self.norm_type = norm_type
325
+ self.z_channels = z_channels
326
+ self.channel_multipliers = ch_mult
327
+ self.attn_resolutions = attn_resolutions
328
+ self.causality_axis = causality_axis
329
+ self.attn_type = attn_type
330
+
331
+ base_block_channels = ch * self.channel_multipliers[-1]
332
+ base_resolution = resolution // (2 ** (self.num_resolutions - 1))
333
+ self.z_shape = (1, z_channels, base_resolution, base_resolution)
334
+
335
+ self.conv_in = make_conv2d(
336
+ z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
337
+ )
338
+ self.non_linearity = torch.nn.SiLU()
339
+ self.mid = build_mid_block(
340
+ channels=base_block_channels,
341
+ temb_channels=self.temb_ch,
342
+ dropout=dropout,
343
+ norm_type=self.norm_type,
344
+ causality_axis=self.causality_axis,
345
+ attn_type=self.attn_type,
346
+ add_attention=mid_block_add_attention,
347
+ )
348
+ self.up, final_block_channels = build_upsampling_path(
349
+ ch=ch,
350
+ ch_mult=ch_mult,
351
+ num_resolutions=self.num_resolutions,
352
+ num_res_blocks=num_res_blocks,
353
+ resolution=resolution,
354
+ temb_channels=self.temb_ch,
355
+ dropout=dropout,
356
+ norm_type=self.norm_type,
357
+ causality_axis=self.causality_axis,
358
+ attn_type=self.attn_type,
359
+ attn_resolutions=attn_resolutions,
360
+ resamp_with_conv=resamp_with_conv,
361
+ initial_block_channels=base_block_channels,
362
+ )
363
+
364
+ self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
365
+ self.conv_out = make_conv2d(
366
+ final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
367
+ )
368
+
369
+ def forward(self, sample: torch.Tensor) -> torch.Tensor:
370
+ """
371
+ Decode latent features back to audio spectrograms.
372
+
373
+ Args:
374
+ sample: Encoded latent representation of shape (batch, channels, frames, mel_bins)
375
+
376
+ Returns:
377
+ Reconstructed audio spectrogram of shape (batch, channels, time, frequency)
378
+ """
379
+ sample, target_shape = self._denormalize_latents(sample)
380
+
381
+ h = self.conv_in(sample)
382
+ h = run_mid_block(self.mid, h)
383
+ h = self._run_upsampling_path(h)
384
+ h = self._finalize_output(h)
385
+
386
+ return self._adjust_output_shape(h, target_shape)
387
+
388
+ def _denormalize_latents(self, sample: torch.Tensor) -> tuple[torch.Tensor, AudioLatentShape]:
389
+ latent_shape = AudioLatentShape(
390
+ batch=sample.shape[0],
391
+ channels=sample.shape[1],
392
+ frames=sample.shape[2],
393
+ mel_bins=sample.shape[3],
394
+ )
395
+
396
+ sample_patched = self.patchifier.patchify(sample)
397
+ sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched)
398
+ sample = self.patchifier.unpatchify(sample_denormalized, latent_shape)
399
+
400
+ target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR
401
+ if self.causality_axis != CausalityAxis.NONE:
402
+ target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
403
+
404
+ target_shape = AudioLatentShape(
405
+ batch=latent_shape.batch,
406
+ channels=self.out_ch,
407
+ frames=target_frames,
408
+ mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
409
+ )
410
+
411
+ return sample, target_shape
412
+
413
+ def _adjust_output_shape(
414
+ self,
415
+ decoded_output: torch.Tensor,
416
+ target_shape: AudioLatentShape,
417
+ ) -> torch.Tensor:
418
+ """
419
+ Adjust output shape to match target dimensions for variable-length audio.
420
+
421
+ This function handles the common case where decoded audio spectrograms need to be
422
+ resized to match a specific target shape.
423
+
424
+ Args:
425
+ decoded_output: Tensor of shape (batch, channels, time, frequency)
426
+ target_shape: AudioLatentShape describing (batch, channels, time, mel bins)
427
+
428
+ Returns:
429
+ Tensor adjusted to match target_shape exactly
430
+ """
431
+ # Current output shape: (batch, channels, time, frequency)
432
+ _, _, current_time, current_freq = decoded_output.shape
433
+ target_channels = target_shape.channels
434
+ target_time = target_shape.frames
435
+ target_freq = target_shape.mel_bins
436
+
437
+ # Step 1: Crop first to avoid exceeding target dimensions
438
+ decoded_output = decoded_output[
439
+ :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
440
+ ]
441
+
442
+ # Step 2: Calculate padding needed for time and frequency dimensions
443
+ time_padding_needed = target_time - decoded_output.shape[2]
444
+ freq_padding_needed = target_freq - decoded_output.shape[3]
445
+
446
+ # Step 3: Apply padding if needed
447
+ if time_padding_needed > 0 or freq_padding_needed > 0:
448
+ # PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)
449
+ # For audio: pad_left/right = frequency, pad_top/bottom = time
450
+ padding = (
451
+ 0,
452
+ max(freq_padding_needed, 0), # frequency padding (left, right)
453
+ 0,
454
+ max(time_padding_needed, 0), # time padding (top, bottom)
455
+ )
456
+ decoded_output = F.pad(decoded_output, padding)
457
+
458
+ # Step 4: Final safety crop to ensure exact target shape
459
+ decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
460
+
461
+ return decoded_output
462
+
463
+ def _run_upsampling_path(self, h: torch.Tensor) -> torch.Tensor:
464
+ for level in reversed(range(self.num_resolutions)):
465
+ stage = self.up[level]
466
+ for block_idx, block in enumerate(stage.block):
467
+ h = block(h, temb=None)
468
+ if stage.attn:
469
+ h = stage.attn[block_idx](h)
470
+
471
+ if level != 0 and hasattr(stage, "upsample"):
472
+ h = stage.upsample(h)
473
+
474
+ return h
475
+
476
+ def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:
477
+ if self.give_pre_end:
478
+ return h
479
+
480
+ h = self.norm_out(h)
481
+ h = self.non_linearity(h)
482
+ h = self.conv_out(h)
483
+ return torch.tanh(h) if self.tanh_out else h