Spaces:
Running
on
Zero
Running
on
Zero
linoy
commited on
Commit
·
ebfc6b3
0
Parent(s):
inital commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +301 -0
- packages/ltx-core/README.md +1 -0
- packages/ltx-core/pyproject.toml +38 -0
- packages/ltx-core/src/ltx_core/__init__.py +0 -0
- packages/ltx-core/src/ltx_core/__pycache__/__init__.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/__pycache__/tiling.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/__pycache__/utils.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/guidance/__init__.py +0 -0
- packages/ltx-core/src/ltx_core/guidance/__pycache__/__init__.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/guidance/__pycache__/perturbations.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/guidance/perturbations.py +74 -0
- packages/ltx-core/src/ltx_core/legacy_tiling.py +258 -0
- packages/ltx-core/src/ltx_core/loader/.ipynb_checkpoints/sd_ops-checkpoint.py +107 -0
- packages/ltx-core/src/ltx_core/loader/__init__.py +0 -0
- packages/ltx-core/src/ltx_core/loader/__pycache__/__init__.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/loader/__pycache__/fuse_loras.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/loader/__pycache__/kernels.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/loader/__pycache__/module_ops.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/loader/__pycache__/primitives.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/loader/__pycache__/registry.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/loader/__pycache__/sd_ops.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/loader/__pycache__/sft_loader.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/loader/__pycache__/single_gpu_model_builder.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/loader/fuse_loras.py +102 -0
- packages/ltx-core/src/ltx_core/loader/kernels.py +74 -0
- packages/ltx-core/src/ltx_core/loader/module_ops.py +11 -0
- packages/ltx-core/src/ltx_core/loader/primitives.py +63 -0
- packages/ltx-core/src/ltx_core/loader/registry.py +68 -0
- packages/ltx-core/src/ltx_core/loader/sd_ops.py +107 -0
- packages/ltx-core/src/ltx_core/loader/sft_loader.py +53 -0
- packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py +99 -0
- packages/ltx-core/src/ltx_core/model/.ipynb_checkpoints/model_ledger-checkpoint.py +253 -0
- packages/ltx-core/src/ltx_core/model/__init__.py +0 -0
- packages/ltx-core/src/ltx_core/model/__pycache__/__init__.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/__pycache__/model_ledger.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/__pycache__/model_protocol.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/__init__.py +0 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/__init__.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/attention.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/audio_vae.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/causal_conv_2d.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/causality_axis.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/downsample.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/model_configurator.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/ops.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/resnet.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/upsample.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/vocoder.cpython-310.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/attention.py +71 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/audio_vae.py +483 -0
app.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simple Gradio app for LTX-2 inference based on ltx2_two_stage.py example
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
# Add packages to Python path
|
| 9 |
+
current_dir = Path(__file__).parent
|
| 10 |
+
sys.path.insert(0, str(current_dir / "packages" / "ltx-pipelines" / "src"))
|
| 11 |
+
sys.path.insert(0, str(current_dir / "packages" / "ltx-core" / "src"))
|
| 12 |
+
|
| 13 |
+
import gradio as gr
|
| 14 |
+
from typing import Optional
|
| 15 |
+
from huggingface_hub import hf_hub_download
|
| 16 |
+
from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline
|
| 17 |
+
from ltx_core.tiling import TilingConfig
|
| 18 |
+
from ltx_pipelines.constants import (
|
| 19 |
+
DEFAULT_SEED,
|
| 20 |
+
DEFAULT_HEIGHT,
|
| 21 |
+
DEFAULT_WIDTH,
|
| 22 |
+
DEFAULT_NUM_FRAMES,
|
| 23 |
+
DEFAULT_FRAME_RATE,
|
| 24 |
+
DEFAULT_NUM_INFERENCE_STEPS,
|
| 25 |
+
DEFAULT_CFG_GUIDANCE_SCALE,
|
| 26 |
+
DEFAULT_LORA_STRENGTH,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# Custom negative prompt
|
| 30 |
+
DEFAULT_NEGATIVE_PROMPT = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static"
|
| 31 |
+
|
| 32 |
+
# Default prompt from docstring example
|
| 33 |
+
DEFAULT_PROMPT = "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a breath-taking, movie-like shot."
|
| 34 |
+
|
| 35 |
+
# HuggingFace Hub defaults
|
| 36 |
+
DEFAULT_REPO_ID = "LTX-Colab/LTX-Video-Preview"
|
| 37 |
+
DEFAULT_GEMMA_REPO_ID = "google/gemma-3-12b-it-qat-q4_0-unquantized"
|
| 38 |
+
DEFAULT_CHECKPOINT_FILENAME = "ltx-2-19b-dev-rc1.safetensors"
|
| 39 |
+
DEFAULT_DISTILLED_LORA_FILENAME = "ltx-2-19b-distilled-lora-384-rc1.safetensors"
|
| 40 |
+
DEFAULT_SPATIAL_UPSAMPLER_FILENAME = "ltx-2-spatial-upscaler-x2-1.0-rc1.safetensors"
|
| 41 |
+
|
| 42 |
+
def get_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None):
|
| 43 |
+
"""Download from HuggingFace Hub or use local checkpoint."""
|
| 44 |
+
if repo_id is None and filename is None:
|
| 45 |
+
raise ValueError("Please supply at least one of `repo_id` or `filename`")
|
| 46 |
+
|
| 47 |
+
if repo_id is not None:
|
| 48 |
+
if filename is None:
|
| 49 |
+
raise ValueError("If repo_id is specified, filename must also be specified.")
|
| 50 |
+
print(f"Downloading {filename} from {repo_id}...")
|
| 51 |
+
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 52 |
+
print(f"Downloaded to {ckpt_path}")
|
| 53 |
+
else:
|
| 54 |
+
ckpt_path = filename
|
| 55 |
+
|
| 56 |
+
return ckpt_path
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Initialize pipeline at startup
|
| 60 |
+
print("=" * 80)
|
| 61 |
+
print("Loading LTX-2 2-stage pipeline...")
|
| 62 |
+
print("=" * 80)
|
| 63 |
+
|
| 64 |
+
checkpoint_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_CHECKPOINT_FILENAME)
|
| 65 |
+
distilled_lora_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_DISTILLED_LORA_FILENAME)
|
| 66 |
+
spatial_upsampler_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_SPATIAL_UPSAMPLER_FILENAME)
|
| 67 |
+
|
| 68 |
+
print(f"Initializing pipeline with:")
|
| 69 |
+
print(f" checkpoint_path={checkpoint_path}")
|
| 70 |
+
print(f" distilled_lora_path={distilled_lora_path}")
|
| 71 |
+
print(f" spatial_upsampler_path={spatial_upsampler_path}")
|
| 72 |
+
print(f" gemma_root={DEFAULT_GEMMA_REPO_ID}")
|
| 73 |
+
|
| 74 |
+
pipeline = TI2VidTwoStagesPipeline(
|
| 75 |
+
checkpoint_path=checkpoint_path,
|
| 76 |
+
distilled_lora_path=distilled_lora_path,
|
| 77 |
+
distilled_lora_strength=DEFAULT_LORA_STRENGTH,
|
| 78 |
+
spatial_upsampler_path=spatial_upsampler_path,
|
| 79 |
+
gemma_root=DEFAULT_GEMMA_REPO_ID,
|
| 80 |
+
loras=[],
|
| 81 |
+
fp8transformer=False,
|
| 82 |
+
local_files_only=False
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
print("=" * 80)
|
| 86 |
+
print("Warming up pipeline (loading Gemma text encoder)...")
|
| 87 |
+
print("=" * 80)
|
| 88 |
+
|
| 89 |
+
# Do a dummy warmup to load all models including Gemma
|
| 90 |
+
import tempfile
|
| 91 |
+
import os
|
| 92 |
+
warmup_output = tempfile.mktemp(suffix=".mp4")
|
| 93 |
+
try:
|
| 94 |
+
pipeline(
|
| 95 |
+
prompt="warmup",
|
| 96 |
+
negative_prompt="",
|
| 97 |
+
output_path=warmup_output,
|
| 98 |
+
seed=42,
|
| 99 |
+
height=256,
|
| 100 |
+
width=256,
|
| 101 |
+
num_frames=9,
|
| 102 |
+
frame_rate=8,
|
| 103 |
+
num_inference_steps=1,
|
| 104 |
+
cfg_guidance_scale=1.0,
|
| 105 |
+
images=[],
|
| 106 |
+
tiling_config=TilingConfig.default(),
|
| 107 |
+
)
|
| 108 |
+
# Clean up warmup output
|
| 109 |
+
if os.path.exists(warmup_output):
|
| 110 |
+
os.remove(warmup_output)
|
| 111 |
+
except Exception as e:
|
| 112 |
+
print(f"Warmup completed with note: {e}")
|
| 113 |
+
|
| 114 |
+
print("=" * 80)
|
| 115 |
+
print("Pipeline fully loaded and ready!")
|
| 116 |
+
print("=" * 80)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def generate_video(
|
| 120 |
+
input_image,
|
| 121 |
+
prompt: str,
|
| 122 |
+
duration: float,
|
| 123 |
+
negative_prompt: str,
|
| 124 |
+
seed: int,
|
| 125 |
+
randomize_seed: bool,
|
| 126 |
+
num_inference_steps: int,
|
| 127 |
+
cfg_guidance_scale: float,
|
| 128 |
+
height: int,
|
| 129 |
+
width: int,
|
| 130 |
+
progress=gr.Progress()
|
| 131 |
+
):
|
| 132 |
+
"""Generate a video based on the given parameters."""
|
| 133 |
+
try:
|
| 134 |
+
# Randomize seed if checkbox is enabled
|
| 135 |
+
if randomize_seed:
|
| 136 |
+
import random
|
| 137 |
+
seed = random.randint(0, 1000000)
|
| 138 |
+
|
| 139 |
+
# Calculate num_frames from duration (using fixed 24 fps)
|
| 140 |
+
frame_rate = 24.0
|
| 141 |
+
num_frames = int(duration * frame_rate) + 1 # +1 to ensure we meet the duration
|
| 142 |
+
|
| 143 |
+
# Create output directory if it doesn't exist
|
| 144 |
+
output_dir = Path("outputs")
|
| 145 |
+
output_dir.mkdir(exist_ok=True)
|
| 146 |
+
output_path = output_dir / f"video_{seed}.mp4"
|
| 147 |
+
|
| 148 |
+
# Handle image input
|
| 149 |
+
images = []
|
| 150 |
+
if input_image is not None:
|
| 151 |
+
# Save uploaded image temporarily
|
| 152 |
+
temp_image_path = output_dir / f"temp_input_{seed}.jpg"
|
| 153 |
+
if hasattr(input_image, 'save'):
|
| 154 |
+
input_image.save(temp_image_path)
|
| 155 |
+
else:
|
| 156 |
+
# If it's a file path already
|
| 157 |
+
temp_image_path = input_image
|
| 158 |
+
# Format: (image_path, frame_idx, strength)
|
| 159 |
+
images = [(str(temp_image_path), 0, 1.0)]
|
| 160 |
+
|
| 161 |
+
# Run inference
|
| 162 |
+
progress(0, desc="Generating video (2-stage)...")
|
| 163 |
+
pipeline(
|
| 164 |
+
prompt=prompt,
|
| 165 |
+
negative_prompt=negative_prompt,
|
| 166 |
+
output_path=str(output_path),
|
| 167 |
+
seed=seed,
|
| 168 |
+
height=height,
|
| 169 |
+
width=width,
|
| 170 |
+
num_frames=num_frames,
|
| 171 |
+
frame_rate=frame_rate,
|
| 172 |
+
num_inference_steps=num_inference_steps,
|
| 173 |
+
cfg_guidance_scale=cfg_guidance_scale,
|
| 174 |
+
images=images,
|
| 175 |
+
tiling_config=TilingConfig.default(),
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
progress(1.0, desc="Done!")
|
| 179 |
+
return str(output_path)
|
| 180 |
+
|
| 181 |
+
except Exception as e:
|
| 182 |
+
import traceback
|
| 183 |
+
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
|
| 184 |
+
print(error_msg)
|
| 185 |
+
return None
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# Create Gradio interface
|
| 189 |
+
with gr.Blocks(title="LTX-2 Image-to-Video") as demo:
|
| 190 |
+
gr.Markdown("# LTX-2 Image-to-Video Generation")
|
| 191 |
+
gr.Markdown("Transform images into videos using the LTX-2 2-stage pipeline")
|
| 192 |
+
|
| 193 |
+
with gr.Row():
|
| 194 |
+
with gr.Column():
|
| 195 |
+
input_image = gr.Image(
|
| 196 |
+
label="Input Image",
|
| 197 |
+
type="pil",
|
| 198 |
+
sources=["upload"]
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
prompt = gr.Textbox(
|
| 202 |
+
label="Prompt",
|
| 203 |
+
value="Make this image come alive with cinematic motion, smooth animation",
|
| 204 |
+
lines=3,
|
| 205 |
+
placeholder="Describe the motion and animation you want..."
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
duration = gr.Slider(
|
| 209 |
+
label="Duration (seconds)",
|
| 210 |
+
minimum=1.0,
|
| 211 |
+
maximum=10.0,
|
| 212 |
+
value=5.0,
|
| 213 |
+
step=0.1
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
|
| 217 |
+
|
| 218 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 219 |
+
negative_prompt = gr.Textbox(
|
| 220 |
+
label="Negative Prompt",
|
| 221 |
+
value=DEFAULT_NEGATIVE_PROMPT,
|
| 222 |
+
lines=2
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
seed = gr.Slider(
|
| 226 |
+
label="Seed",
|
| 227 |
+
minimum=0,
|
| 228 |
+
maximum=1000000,
|
| 229 |
+
value=DEFAULT_SEED,
|
| 230 |
+
step=1
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
randomize_seed = gr.Checkbox(
|
| 234 |
+
label="Randomize Seed",
|
| 235 |
+
value=True
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
num_inference_steps = gr.Slider(
|
| 239 |
+
label="Inference Steps",
|
| 240 |
+
minimum=1,
|
| 241 |
+
maximum=100,
|
| 242 |
+
value=DEFAULT_NUM_INFERENCE_STEPS,
|
| 243 |
+
step=1
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
cfg_guidance_scale = gr.Slider(
|
| 247 |
+
label="CFG Guidance Scale",
|
| 248 |
+
minimum=1.0,
|
| 249 |
+
maximum=10.0,
|
| 250 |
+
value=DEFAULT_CFG_GUIDANCE_SCALE,
|
| 251 |
+
step=0.1
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
with gr.Row():
|
| 255 |
+
width = gr.Number(
|
| 256 |
+
label="Width",
|
| 257 |
+
value=DEFAULT_WIDTH,
|
| 258 |
+
precision=0
|
| 259 |
+
)
|
| 260 |
+
height = gr.Number(
|
| 261 |
+
label="Height",
|
| 262 |
+
value=DEFAULT_HEIGHT,
|
| 263 |
+
precision=0
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
with gr.Column():
|
| 267 |
+
output_video = gr.Video(label="Generated Video", autoplay=True)
|
| 268 |
+
|
| 269 |
+
generate_btn.click(
|
| 270 |
+
fn=generate_video,
|
| 271 |
+
inputs=[
|
| 272 |
+
input_image,
|
| 273 |
+
prompt,
|
| 274 |
+
duration,
|
| 275 |
+
negative_prompt,
|
| 276 |
+
seed,
|
| 277 |
+
randomize_seed,
|
| 278 |
+
num_inference_steps,
|
| 279 |
+
cfg_guidance_scale,
|
| 280 |
+
height,
|
| 281 |
+
width,
|
| 282 |
+
],
|
| 283 |
+
outputs=output_video
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Add example
|
| 287 |
+
gr.Examples(
|
| 288 |
+
examples=[
|
| 289 |
+
[
|
| 290 |
+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg",
|
| 291 |
+
"An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a breath-taking, movie-like shot.",
|
| 292 |
+
5.0,
|
| 293 |
+
]
|
| 294 |
+
],
|
| 295 |
+
inputs=[input_image, prompt, duration],
|
| 296 |
+
label="Example"
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
if __name__ == "__main__":
|
| 301 |
+
demo.launch(share=True)
|
packages/ltx-core/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# LTX-2 Core
|
packages/ltx-core/pyproject.toml
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "ltx-core"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Core implementation of Lightricks' LTX-2 model"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.12"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"torch~=2.7",
|
| 9 |
+
"torchaudio",
|
| 10 |
+
"einops",
|
| 11 |
+
"numpy",
|
| 12 |
+
"transformers",
|
| 13 |
+
"safetensors",
|
| 14 |
+
"accelerate",
|
| 15 |
+
"scipy>=1.14",
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
[project.optional-dependencies]
|
| 19 |
+
flashpack = ["flashpack==0.1.2"]
|
| 20 |
+
xformers = ["xformers"]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
[tool.uv.sources]
|
| 24 |
+
xformers = { index = "pytorch" }
|
| 25 |
+
|
| 26 |
+
[[tool.uv.index]]
|
| 27 |
+
name = "pytorch"
|
| 28 |
+
url = "https://download.pytorch.org/whl/cu129"
|
| 29 |
+
explicit = true
|
| 30 |
+
|
| 31 |
+
[build-system]
|
| 32 |
+
requires = ["uv_build>=0.9.8,<0.10.0"]
|
| 33 |
+
build-backend = "uv_build"
|
| 34 |
+
|
| 35 |
+
[dependency-groups]
|
| 36 |
+
dev = [
|
| 37 |
+
"scikit-image>=0.25.2",
|
| 38 |
+
]
|
packages/ltx-core/src/ltx_core/__init__.py
ADDED
|
File without changes
|
packages/ltx-core/src/ltx_core/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (176 Bytes). View file
|
|
|
packages/ltx-core/src/ltx_core/__pycache__/tiling.cpython-310.pyc
ADDED
|
Binary file (10.7 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (1.85 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/guidance/__init__.py
ADDED
|
File without changes
|
packages/ltx-core/src/ltx_core/guidance/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (185 Bytes). View file
|
|
|
packages/ltx-core/src/ltx_core/guidance/__pycache__/perturbations.cpython-310.pyc
ADDED
|
Binary file (3.82 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/guidance/perturbations.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
+
# Created by Andrew Kvochko
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from enum import Enum
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch._prims_common import DeviceLikeType
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PerturbationType(Enum):
|
| 12 |
+
SKIP_A2V_CROSS_ATTN = "skip_a2v_cross_attn"
|
| 13 |
+
SKIP_V2A_CROSS_ATTN = "skip_v2a_cross_attn"
|
| 14 |
+
SKIP_VIDEO_SELF_ATTN = "skip_video_self_attn"
|
| 15 |
+
SKIP_AUDIO_SELF_ATTN = "skip_audio_self_attn"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass(frozen=True)
|
| 19 |
+
class Perturbation:
|
| 20 |
+
type: PerturbationType
|
| 21 |
+
blocks: list[int] | None # None means all blocks
|
| 22 |
+
|
| 23 |
+
def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:
|
| 24 |
+
if self.type != perturbation_type:
|
| 25 |
+
return False
|
| 26 |
+
|
| 27 |
+
if self.blocks is None:
|
| 28 |
+
return True
|
| 29 |
+
|
| 30 |
+
return block in self.blocks
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass(frozen=True)
|
| 34 |
+
class PerturbationConfig:
|
| 35 |
+
perturbations: list[Perturbation] | None
|
| 36 |
+
|
| 37 |
+
def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:
|
| 38 |
+
if self.perturbations is None:
|
| 39 |
+
return False
|
| 40 |
+
|
| 41 |
+
return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
def empty() -> "PerturbationConfig":
|
| 45 |
+
return PerturbationConfig([])
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass(frozen=True)
|
| 49 |
+
class BatchedPerturbationConfig:
|
| 50 |
+
perturbations: list[PerturbationConfig]
|
| 51 |
+
|
| 52 |
+
def mask(
|
| 53 |
+
self, perturbation_type: PerturbationType, block: int, device: DeviceLikeType, dtype: torch.dtype
|
| 54 |
+
) -> torch.Tensor:
|
| 55 |
+
mask = torch.ones((len(self.perturbations),), device=device, dtype=dtype)
|
| 56 |
+
for batch_idx, perturbation in enumerate(self.perturbations):
|
| 57 |
+
if perturbation.is_perturbed(perturbation_type, block):
|
| 58 |
+
mask[batch_idx] = 0
|
| 59 |
+
|
| 60 |
+
return mask
|
| 61 |
+
|
| 62 |
+
def mask_like(self, perturbation_type: PerturbationType, block: int, values: torch.Tensor) -> torch.Tensor:
|
| 63 |
+
mask = self.mask(perturbation_type, block, values.device, values.dtype)
|
| 64 |
+
return mask.view(mask.numel(), *([1] * len(values.shape[1:])))
|
| 65 |
+
|
| 66 |
+
def any_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool:
|
| 67 |
+
return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
|
| 68 |
+
|
| 69 |
+
def all_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool:
|
| 70 |
+
return all(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
|
| 71 |
+
|
| 72 |
+
@staticmethod
|
| 73 |
+
def empty(batch_size: int) -> "BatchedPerturbationConfig":
|
| 74 |
+
return BatchedPerturbationConfig([PerturbationConfig.empty() for _ in range(batch_size)])
|
packages/ltx-core/src/ltx_core/legacy_tiling.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from collections.abc import Generator
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from ltx_core.model.video_vae.video_vae import Decoder
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def compute_chunk_boundaries(
|
| 10 |
+
chunk_start: int,
|
| 11 |
+
temporal_tile_length: int,
|
| 12 |
+
temporal_overlap: int,
|
| 13 |
+
total_latent_frames: int,
|
| 14 |
+
) -> tuple[int, int]:
|
| 15 |
+
"""Compute chunk boundaries for temporal tiling.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
chunk_start: Starting frame index for the current chunk
|
| 19 |
+
temporal_tile_length: Length of each temporal tile
|
| 20 |
+
temporal_overlap: Number of frames to overlap between chunks
|
| 21 |
+
total_latent_frames: Total number of latent frames
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Tuple of (overlap_start, chunk_end)
|
| 25 |
+
"""
|
| 26 |
+
if chunk_start == 0:
|
| 27 |
+
# First chunk: no overlap needed
|
| 28 |
+
chunk_end = min(chunk_start + temporal_tile_length, total_latent_frames)
|
| 29 |
+
overlap_start = chunk_start
|
| 30 |
+
else:
|
| 31 |
+
# Subsequent chunks: include overlap from previous chunk
|
| 32 |
+
# -1 because we need one extra frame to overlap, which is decoded to a single frame
|
| 33 |
+
# never overlap with the first latent frame
|
| 34 |
+
overlap_start = max(1, chunk_start - temporal_overlap - 1)
|
| 35 |
+
extra_frames = chunk_start - overlap_start
|
| 36 |
+
chunk_end = min(
|
| 37 |
+
chunk_start + temporal_tile_length - extra_frames,
|
| 38 |
+
total_latent_frames,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
return overlap_start, chunk_end
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def spatial_decode( # noqa
|
| 45 |
+
decoder: Decoder,
|
| 46 |
+
samples: torch.Tensor,
|
| 47 |
+
horizontal_tiles: int,
|
| 48 |
+
vertical_tiles: int,
|
| 49 |
+
overlap: int,
|
| 50 |
+
last_frame_fix: bool,
|
| 51 |
+
scale_factors: tuple[float, float, float],
|
| 52 |
+
timestep: float,
|
| 53 |
+
generator: torch.Generator,
|
| 54 |
+
) -> torch.Tensor:
|
| 55 |
+
if last_frame_fix:
|
| 56 |
+
# Repeat the last frame along dimension 2 (frames)
|
| 57 |
+
# samples shape - [batch, channels, frames, height, width]
|
| 58 |
+
last_frame = samples[:, :, -1:, :, :]
|
| 59 |
+
samples = torch.cat([samples, last_frame], dim=2)
|
| 60 |
+
|
| 61 |
+
batch, _, frames, height, width = samples.shape
|
| 62 |
+
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
|
| 63 |
+
image_frames = 1 + (frames - 1) * time_scale_factor
|
| 64 |
+
|
| 65 |
+
# Calculate output image dimensions
|
| 66 |
+
output_height = height * height_scale_factor
|
| 67 |
+
output_width = width * width_scale_factor
|
| 68 |
+
|
| 69 |
+
# Calculate tile sizes with overlap
|
| 70 |
+
base_tile_height = (height + (vertical_tiles - 1) * overlap) // vertical_tiles
|
| 71 |
+
base_tile_width = (width + (horizontal_tiles - 1) * overlap) // horizontal_tiles
|
| 72 |
+
|
| 73 |
+
# Initialize output tensor and weight tensor
|
| 74 |
+
# VAE decode returns images in format [batch, height, width, channels]
|
| 75 |
+
output = None
|
| 76 |
+
weights = None
|
| 77 |
+
|
| 78 |
+
target_device = samples.device
|
| 79 |
+
target_dtype = samples.dtype
|
| 80 |
+
|
| 81 |
+
output = torch.zeros(
|
| 82 |
+
(
|
| 83 |
+
batch,
|
| 84 |
+
3,
|
| 85 |
+
image_frames,
|
| 86 |
+
output_height,
|
| 87 |
+
output_width,
|
| 88 |
+
),
|
| 89 |
+
device=target_device,
|
| 90 |
+
dtype=target_dtype,
|
| 91 |
+
)
|
| 92 |
+
weights = torch.zeros(
|
| 93 |
+
(batch, 1, image_frames, output_height, output_width),
|
| 94 |
+
device=target_device,
|
| 95 |
+
dtype=target_dtype,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Process each tile
|
| 99 |
+
for v in range(vertical_tiles):
|
| 100 |
+
for h in range(horizontal_tiles):
|
| 101 |
+
# Calculate tile boundaries
|
| 102 |
+
h_start = h * (base_tile_width - overlap)
|
| 103 |
+
v_start = v * (base_tile_height - overlap)
|
| 104 |
+
|
| 105 |
+
# Adjust end positions for edge tiles
|
| 106 |
+
h_end = min(h_start + base_tile_width, width) if h < horizontal_tiles - 1 else width
|
| 107 |
+
v_end = min(v_start + base_tile_height, height) if v < vertical_tiles - 1 else height
|
| 108 |
+
|
| 109 |
+
# Calculate actual tile dimensions
|
| 110 |
+
tile_height = v_end - v_start
|
| 111 |
+
tile_width = h_end - h_start
|
| 112 |
+
|
| 113 |
+
logging.info(f"Processing VAE decode tile at row {v}, col {h}:")
|
| 114 |
+
logging.info(f" Position: ({v_start}:{v_end}, {h_start}:{h_end})")
|
| 115 |
+
logging.info(f" Size: {tile_height}x{tile_width}")
|
| 116 |
+
|
| 117 |
+
# Extract tile
|
| 118 |
+
tile = samples[:, :, :, v_start:v_end, h_start:h_end]
|
| 119 |
+
|
| 120 |
+
# Decode the tile
|
| 121 |
+
decoded_tile = decoder.decode(tile, timestep, generator)
|
| 122 |
+
|
| 123 |
+
# Calculate output tile boundaries
|
| 124 |
+
out_h_start = v_start * height_scale_factor
|
| 125 |
+
out_h_end = v_end * height_scale_factor
|
| 126 |
+
out_w_start = h_start * width_scale_factor
|
| 127 |
+
out_w_end = h_end * width_scale_factor
|
| 128 |
+
|
| 129 |
+
# Create weight mask for this tile
|
| 130 |
+
tile_out_height = out_h_end - out_h_start
|
| 131 |
+
tile_out_width = out_w_end - out_w_start
|
| 132 |
+
tile_weights = torch.ones(
|
| 133 |
+
(batch, 1, image_frames, tile_out_height, tile_out_width),
|
| 134 |
+
device=decoded_tile.device,
|
| 135 |
+
dtype=decoded_tile.dtype,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Calculate overlap regions in output space
|
| 139 |
+
overlap_out_h = overlap * height_scale_factor
|
| 140 |
+
overlap_out_w = overlap * width_scale_factor
|
| 141 |
+
|
| 142 |
+
# Apply horizontal blending weights
|
| 143 |
+
if h > 0: # Left overlap
|
| 144 |
+
h_blend = torch.linspace(0, 1, overlap_out_w, device=decoded_tile.device)
|
| 145 |
+
tile_weights[:, :, :, :, :overlap_out_w] *= h_blend
|
| 146 |
+
if h < horizontal_tiles - 1: # Right overlap
|
| 147 |
+
h_blend = torch.linspace(1, 0, overlap_out_w, device=decoded_tile.device)
|
| 148 |
+
tile_weights[:, :, :, :, -overlap_out_w:] *= h_blend
|
| 149 |
+
|
| 150 |
+
# Apply vertical blending weights
|
| 151 |
+
if v > 0: # Top overlap
|
| 152 |
+
v_blend = torch.linspace(0, 1, overlap_out_h, device=decoded_tile.device)
|
| 153 |
+
tile_weights[:, :, :, :overlap_out_h, :] *= v_blend.view(1, 1, 1, -1, 1)
|
| 154 |
+
if v < vertical_tiles - 1: # Bottom overlap
|
| 155 |
+
v_blend = torch.linspace(1, 0, overlap_out_h, device=decoded_tile.device)
|
| 156 |
+
tile_weights[:, :, :, -overlap_out_h:, :] *= v_blend.view(1, 1, 1, -1, 1)
|
| 157 |
+
|
| 158 |
+
# Add weighted tile to output
|
| 159 |
+
output[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += (decoded_tile * tile_weights).to(
|
| 160 |
+
target_device, target_dtype
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# Add weights to weight tensor
|
| 164 |
+
weights[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += tile_weights.to(
|
| 165 |
+
target_device, target_dtype
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# Normalize by weights
|
| 169 |
+
output /= weights + 1e-8
|
| 170 |
+
# LT_INTERNAL: changed from output[:-time_scale_factor, :, :]!
|
| 171 |
+
if last_frame_fix:
|
| 172 |
+
output = output[:, :, :-time_scale_factor, :, :]
|
| 173 |
+
|
| 174 |
+
return output
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def decode_spatial_temporal(
|
| 178 |
+
decoder: Decoder,
|
| 179 |
+
samples: torch.ensor,
|
| 180 |
+
timestep: float,
|
| 181 |
+
generator: torch.Generator,
|
| 182 |
+
scale_factors: tuple[float, float, float],
|
| 183 |
+
spatial_tiles: int = 4,
|
| 184 |
+
spatial_overlap: int = 1,
|
| 185 |
+
temporal_tile_length: int = 16,
|
| 186 |
+
temporal_overlap: int = 1,
|
| 187 |
+
last_frame_fix: bool = False,
|
| 188 |
+
) -> Generator[torch.Tensor, None, None]:
|
| 189 |
+
if temporal_tile_length < temporal_overlap + 1:
|
| 190 |
+
raise ValueError("Temporal tile length must be greater than temporal overlap + 1")
|
| 191 |
+
|
| 192 |
+
_, _, frames, _, _ = samples.shape
|
| 193 |
+
time_scale_factor, _, _ = scale_factors
|
| 194 |
+
|
| 195 |
+
# Process temporal chunks similar to reference function
|
| 196 |
+
total_latent_frames = frames
|
| 197 |
+
chunk_start = 0
|
| 198 |
+
|
| 199 |
+
previous_tile = None
|
| 200 |
+
while chunk_start < total_latent_frames:
|
| 201 |
+
# Calculate chunk boundaries
|
| 202 |
+
overlap_start, chunk_end = compute_chunk_boundaries(
|
| 203 |
+
chunk_start, temporal_tile_length, temporal_overlap, total_latent_frames
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# units are latent frames
|
| 207 |
+
chunk_frames = chunk_end - overlap_start
|
| 208 |
+
logging.info(f"Processing temporal chunk: {overlap_start}:{chunk_end} ({chunk_frames} latent frames)")
|
| 209 |
+
|
| 210 |
+
# Extract tile
|
| 211 |
+
tile = samples[:, :, overlap_start:chunk_end]
|
| 212 |
+
|
| 213 |
+
# Decode the tile
|
| 214 |
+
decoded_tile = spatial_decode(
|
| 215 |
+
decoder,
|
| 216 |
+
tile,
|
| 217 |
+
spatial_tiles,
|
| 218 |
+
spatial_tiles,
|
| 219 |
+
spatial_overlap,
|
| 220 |
+
last_frame_fix,
|
| 221 |
+
scale_factors,
|
| 222 |
+
timestep,
|
| 223 |
+
generator,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
if previous_tile is None:
|
| 227 |
+
previous_tile = decoded_tile
|
| 228 |
+
else:
|
| 229 |
+
# Drop first frame if needed (overlap)
|
| 230 |
+
if decoded_tile.shape[2] == 1:
|
| 231 |
+
raise ValueError("Dropping first frame but tile has only 1 frame")
|
| 232 |
+
decoded_tile = decoded_tile[:, :, 1:] # Drop first frame
|
| 233 |
+
|
| 234 |
+
# Create weight mask for this tile
|
| 235 |
+
# -1 is for dropped frame above
|
| 236 |
+
overlap_frames = temporal_overlap * time_scale_factor
|
| 237 |
+
frame_weights = torch.linspace(
|
| 238 |
+
0,
|
| 239 |
+
1,
|
| 240 |
+
overlap_frames + 2,
|
| 241 |
+
device=decoded_tile.device,
|
| 242 |
+
dtype=decoded_tile.dtype,
|
| 243 |
+
)[1:-1]
|
| 244 |
+
tile_weights = frame_weights.view(1, 1, -1, 1, 1)
|
| 245 |
+
|
| 246 |
+
previous_tile[:, :, -overlap_frames:] = (
|
| 247 |
+
previous_tile[:, :, -overlap_frames:] * (1 - tile_weights)
|
| 248 |
+
+ decoded_tile[:, :, :overlap_frames] * tile_weights
|
| 249 |
+
)
|
| 250 |
+
resulting_tile = previous_tile[:, :, :-overlap_frames]
|
| 251 |
+
decoded_tile[:, :, :overlap_frames] = previous_tile[:, :, -overlap_frames:]
|
| 252 |
+
yield resulting_tile
|
| 253 |
+
previous_tile = decoded_tile
|
| 254 |
+
|
| 255 |
+
# Move to next chunk
|
| 256 |
+
chunk_start = chunk_end
|
| 257 |
+
|
| 258 |
+
yield decoded_tile
|
packages/ltx-core/src/ltx_core/loader/.ipynb_checkpoints/sd_ops-checkpoint.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
+
# Created by Alexey Kravtsov
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass, replace
|
| 5 |
+
#from typing import NamedTuple, Protocol, Self
|
| 6 |
+
from typing import NamedTuple, Protocol
|
| 7 |
+
from typing_extensions import Self
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass(frozen=True, slots=True)
|
| 13 |
+
class ContentReplacement:
|
| 14 |
+
content: str
|
| 15 |
+
replacement: str
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass(frozen=True, slots=True)
|
| 19 |
+
class ContentMatching:
|
| 20 |
+
prefix: str = ""
|
| 21 |
+
suffix: str = ""
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class KeyValueOperationResult(NamedTuple):
|
| 25 |
+
new_key: str
|
| 26 |
+
new_value: torch.Tensor
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class KeyValueOperation(Protocol):
|
| 30 |
+
def __call__(self, tensor_key: str, tensor_value: torch.Tensor) -> list[KeyValueOperationResult]: ...
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass(frozen=True, slots=True)
|
| 34 |
+
class SDKeyValueOperation:
|
| 35 |
+
key_matcher: ContentMatching
|
| 36 |
+
kv_operation: KeyValueOperation
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass(frozen=True, slots=True)
|
| 40 |
+
class SDOps:
|
| 41 |
+
"""Immutable class representing state dict key operations."""
|
| 42 |
+
|
| 43 |
+
name: str
|
| 44 |
+
mapping: tuple[
|
| 45 |
+
ContentReplacement | ContentMatching | SDKeyValueOperation, ...
|
| 46 |
+
] = () # Immutable tuple of (key, value) pairs
|
| 47 |
+
|
| 48 |
+
def with_replacement(self, content: str, replacement: str) -> Self:
|
| 49 |
+
"""Create a new SDOps instance with the specified replacement added to the mapping."""
|
| 50 |
+
|
| 51 |
+
new_mapping = (*self.mapping, ContentReplacement(content, replacement))
|
| 52 |
+
return replace(self, mapping=new_mapping)
|
| 53 |
+
|
| 54 |
+
def with_matching(self, prefix: str = "", suffix: str = "") -> Self:
|
| 55 |
+
"""Create a new SDOps instance with the specified prefix and suffix matching added to the mapping."""
|
| 56 |
+
|
| 57 |
+
new_mapping = (*self.mapping, ContentMatching(prefix, suffix))
|
| 58 |
+
return replace(self, mapping=new_mapping)
|
| 59 |
+
|
| 60 |
+
def with_kv_operation(
|
| 61 |
+
self,
|
| 62 |
+
operation: KeyValueOperation,
|
| 63 |
+
key_prefix: str = "",
|
| 64 |
+
key_suffix: str = "",
|
| 65 |
+
) -> Self:
|
| 66 |
+
"""Create a new SDOps instance with the specified value operation added to the mapping."""
|
| 67 |
+
key_matcher = ContentMatching(key_prefix, key_suffix)
|
| 68 |
+
sd_kv_operation = SDKeyValueOperation(key_matcher, operation)
|
| 69 |
+
new_mapping = (*self.mapping, sd_kv_operation)
|
| 70 |
+
return replace(self, mapping=new_mapping)
|
| 71 |
+
|
| 72 |
+
def apply_to_key(self, key: str) -> str | None:
|
| 73 |
+
"""Apply the mapping to the given name."""
|
| 74 |
+
matchers = [content for content in self.mapping if isinstance(content, ContentMatching)]
|
| 75 |
+
valid = any(key.startswith(f.prefix) and key.endswith(f.suffix) for f in matchers)
|
| 76 |
+
if not valid:
|
| 77 |
+
return None
|
| 78 |
+
|
| 79 |
+
for replacement in self.mapping:
|
| 80 |
+
if not isinstance(replacement, ContentReplacement):
|
| 81 |
+
continue
|
| 82 |
+
if replacement.content in key:
|
| 83 |
+
key = key.replace(replacement.content, replacement.replacement)
|
| 84 |
+
return key
|
| 85 |
+
|
| 86 |
+
def apply_to_key_value(self, key: str, value: torch.Tensor) -> list[KeyValueOperationResult]:
|
| 87 |
+
"""Apply the value operation to the given name and associated value."""
|
| 88 |
+
for operation in self.mapping:
|
| 89 |
+
if not isinstance(operation, SDKeyValueOperation):
|
| 90 |
+
continue
|
| 91 |
+
if key.startswith(operation.key_matcher.prefix) and key.endswith(operation.key_matcher.suffix):
|
| 92 |
+
return operation.kv_operation(key, value)
|
| 93 |
+
return [KeyValueOperationResult(key, value)]
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# Predefined SDOps instances
|
| 97 |
+
LTXV_LORA_COMFY_RENAMING_MAP = (
|
| 98 |
+
SDOps("LTXV_LORA_COMFY_PREFIX_MAP").with_matching().with_replacement("diffusion_model.", "")
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
LTXV_LORA_COMFY_TARGET_MAP = (
|
| 102 |
+
SDOps("LTXV_LORA_COMFY_TARGET_MAP")
|
| 103 |
+
.with_matching()
|
| 104 |
+
.with_replacement("diffusion_model.", "")
|
| 105 |
+
.with_replacement(".lora_A.weight", ".weight")
|
| 106 |
+
.with_replacement(".lora_B.weight", ".weight")
|
| 107 |
+
)
|
packages/ltx-core/src/ltx_core/loader/__init__.py
ADDED
|
File without changes
|
packages/ltx-core/src/ltx_core/loader/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (183 Bytes). View file
|
|
|
packages/ltx-core/src/ltx_core/loader/__pycache__/fuse_loras.cpython-310.pyc
ADDED
|
Binary file (2.75 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/loader/__pycache__/kernels.cpython-310.pyc
ADDED
|
Binary file (1.58 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/loader/__pycache__/module_ops.cpython-310.pyc
ADDED
|
Binary file (558 Bytes). View file
|
|
|
packages/ltx-core/src/ltx_core/loader/__pycache__/primitives.cpython-310.pyc
ADDED
|
Binary file (3.09 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/loader/__pycache__/registry.cpython-310.pyc
ADDED
|
Binary file (3.53 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/loader/__pycache__/sd_ops.cpython-310.pyc
ADDED
|
Binary file (4.36 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/loader/__pycache__/sft_loader.cpython-310.pyc
ADDED
|
Binary file (2.65 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/loader/__pycache__/single_gpu_model_builder.cpython-310.pyc
ADDED
|
Binary file (5.04 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/loader/fuse_loras.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
+
# Created by Alexey Kravtsov
|
| 3 |
+
import torch
|
| 4 |
+
import triton
|
| 5 |
+
|
| 6 |
+
from ltx_core.loader.kernels import fused_add_round_kernel
|
| 7 |
+
from ltx_core.loader.primitives import LoraStateDictWithStrength, StateDict
|
| 8 |
+
|
| 9 |
+
BLOCK_SIZE = 1024
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def fused_add_round_launch(target_weight: torch.Tensor, original_weight: torch.Tensor, seed: int) -> torch.Tensor:
|
| 13 |
+
if original_weight.dtype == torch.float8_e4m3fn:
|
| 14 |
+
exponent_bits, mantissa_bits, exponent_bias = 4, 3, 7
|
| 15 |
+
elif original_weight.dtype == torch.float8_e5m2:
|
| 16 |
+
exponent_bits, mantissa_bits, exponent_bias = 5, 2, 15 # noqa: F841
|
| 17 |
+
else:
|
| 18 |
+
raise ValueError("Unsupported dtype")
|
| 19 |
+
|
| 20 |
+
if target_weight.dtype != torch.bfloat16:
|
| 21 |
+
raise ValueError("target_weight dtype must be bfloat16")
|
| 22 |
+
|
| 23 |
+
# Calculate grid and block sizes
|
| 24 |
+
n_elements = original_weight.numel()
|
| 25 |
+
grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
|
| 26 |
+
|
| 27 |
+
# Launch kernel
|
| 28 |
+
fused_add_round_kernel[grid](
|
| 29 |
+
original_weight,
|
| 30 |
+
target_weight,
|
| 31 |
+
seed,
|
| 32 |
+
n_elements,
|
| 33 |
+
exponent_bias,
|
| 34 |
+
mantissa_bits,
|
| 35 |
+
BLOCK_SIZE,
|
| 36 |
+
)
|
| 37 |
+
return target_weight
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def calculate_weight_float8_(target_weights: torch.Tensor, original_weights: torch.Tensor) -> torch.Tensor:
|
| 41 |
+
result = fused_add_round_launch(target_weights, original_weights, seed=0).to(target_weights.dtype)
|
| 42 |
+
target_weights.copy_(result, non_blocking=True)
|
| 43 |
+
return target_weights
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _prepare_deltas(
|
| 47 |
+
lora_sd_and_strengths: list[LoraStateDictWithStrength], key: str, dtype: torch.dtype, device: torch.device
|
| 48 |
+
) -> torch.Tensor | None:
|
| 49 |
+
deltas = []
|
| 50 |
+
prefix = key[: -len(".weight")]
|
| 51 |
+
key_a = f"{prefix}.lora_A.weight"
|
| 52 |
+
key_b = f"{prefix}.lora_B.weight"
|
| 53 |
+
for lsd, coef in lora_sd_and_strengths:
|
| 54 |
+
if key_a not in lsd.sd or key_b not in lsd.sd:
|
| 55 |
+
continue
|
| 56 |
+
product = torch.matmul(lsd.sd[key_b] * coef, lsd.sd[key_a])
|
| 57 |
+
deltas.append(product.to(dtype=dtype, device=device))
|
| 58 |
+
if len(deltas) == 0:
|
| 59 |
+
return None
|
| 60 |
+
elif len(deltas) == 1:
|
| 61 |
+
return deltas[0]
|
| 62 |
+
return torch.sum(torch.stack(deltas, dim=0), dim=0)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def apply_loras(
|
| 66 |
+
model_sd: StateDict,
|
| 67 |
+
lora_sd_and_strengths: list[LoraStateDictWithStrength],
|
| 68 |
+
dtype: torch.dtype,
|
| 69 |
+
destination_sd: StateDict | None = None,
|
| 70 |
+
) -> StateDict:
|
| 71 |
+
sd = {}
|
| 72 |
+
if destination_sd is not None:
|
| 73 |
+
sd = destination_sd.sd
|
| 74 |
+
size = 0
|
| 75 |
+
device = torch.device("meta")
|
| 76 |
+
inner_dtypes = set()
|
| 77 |
+
for key, weight in model_sd.sd.items():
|
| 78 |
+
if weight is None:
|
| 79 |
+
continue
|
| 80 |
+
device = weight.device
|
| 81 |
+
target_dtype = dtype if dtype is not None else weight.dtype
|
| 82 |
+
deltas_dtype = target_dtype if target_dtype not in [torch.float8_e4m3fn, torch.float8_e5m2] else torch.bfloat16
|
| 83 |
+
deltas = _prepare_deltas(lora_sd_and_strengths, key, deltas_dtype, device)
|
| 84 |
+
if deltas is None:
|
| 85 |
+
if key in sd:
|
| 86 |
+
continue
|
| 87 |
+
deltas = weight.clone().to(dtype=target_dtype, device=device)
|
| 88 |
+
elif weight.dtype == torch.float8_e4m3fn:
|
| 89 |
+
if str(device).startswith("cuda"):
|
| 90 |
+
deltas = calculate_weight_float8_(deltas, weight)
|
| 91 |
+
else:
|
| 92 |
+
deltas.add_(weight.to(dtype=deltas.dtype, device=device))
|
| 93 |
+
elif weight.dtype == torch.bfloat16:
|
| 94 |
+
deltas.add_(weight)
|
| 95 |
+
else:
|
| 96 |
+
raise ValueError(f"Unsupported dtype: {weight.dtype}")
|
| 97 |
+
sd[key] = deltas.to(dtype=target_dtype)
|
| 98 |
+
inner_dtypes.add(target_dtype)
|
| 99 |
+
size += deltas.nbytes
|
| 100 |
+
if destination_sd is not None:
|
| 101 |
+
return destination_sd
|
| 102 |
+
return StateDict(sd, device, size, inner_dtypes)
|
packages/ltx-core/src/ltx_core/loader/kernels.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ruff: noqa: ANN001, ANN201, ERA001, N803, N806
|
| 2 |
+
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 3 |
+
# Created by Alexey Kravtsov
|
| 4 |
+
import triton
|
| 5 |
+
import triton.language as tl
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@triton.jit
|
| 9 |
+
def fused_add_round_kernel(
|
| 10 |
+
x_ptr,
|
| 11 |
+
output_ptr, # contents will be added to the output
|
| 12 |
+
seed,
|
| 13 |
+
n_elements,
|
| 14 |
+
EXPONENT_BIAS,
|
| 15 |
+
MANTISSA_BITS,
|
| 16 |
+
BLOCK_SIZE: tl.constexpr,
|
| 17 |
+
):
|
| 18 |
+
"""
|
| 19 |
+
A kernel to upcast 8bit quantized weights to bfloat16 with stochastic rounding
|
| 20 |
+
and add them to bfloat16 output weights. Might be used to upcast original model weights
|
| 21 |
+
and to further add them to precalculated deltas coming from LoRAs.
|
| 22 |
+
"""
|
| 23 |
+
# Get program ID and compute offsets
|
| 24 |
+
pid = tl.program_id(axis=0)
|
| 25 |
+
block_start = pid * BLOCK_SIZE
|
| 26 |
+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
| 27 |
+
mask = offsets < n_elements
|
| 28 |
+
|
| 29 |
+
# Load data
|
| 30 |
+
x = tl.load(x_ptr + offsets, mask=mask)
|
| 31 |
+
rand_vals = tl.rand(seed, offsets) - 0.5
|
| 32 |
+
|
| 33 |
+
x = tl.cast(x, tl.float16)
|
| 34 |
+
delta = tl.load(output_ptr + offsets, mask=mask)
|
| 35 |
+
delta = tl.cast(delta, tl.float16)
|
| 36 |
+
x = x + delta
|
| 37 |
+
|
| 38 |
+
x_bits = tl.cast(x, tl.int16, bitcast=True)
|
| 39 |
+
|
| 40 |
+
# Calculate the exponent. Unbiased fp16 exponent is ((x_bits & 0x7C00) >> 10) - 15 for
|
| 41 |
+
# normal numbers and -14 for subnormals.
|
| 42 |
+
fp16_exponent_bits = (x_bits & 0x7C00) >> 10
|
| 43 |
+
fp16_normals = fp16_exponent_bits > 0
|
| 44 |
+
fp16_exponent = tl.where(fp16_normals, fp16_exponent_bits - 15, -14)
|
| 45 |
+
|
| 46 |
+
# Add the target dtype's exponent bias and clamp to the target dtype's exponent range.
|
| 47 |
+
exponent = fp16_exponent + EXPONENT_BIAS
|
| 48 |
+
MAX_EXPONENT = 2 * EXPONENT_BIAS + 1
|
| 49 |
+
exponent = tl.where(exponent > MAX_EXPONENT, MAX_EXPONENT, exponent)
|
| 50 |
+
exponent = tl.where(exponent < 0, 0, exponent)
|
| 51 |
+
|
| 52 |
+
# Normal ULP exponent, expressed as an fp16 exponent field:
|
| 53 |
+
# (exponent - EXPONENT_BIAS - MANTISSA_BITS) + 15
|
| 54 |
+
# Simplifies to: fp16_exponent - MANTISSA_BITS + 15
|
| 55 |
+
# See https://en.wikipedia.org/wiki/Unit_in_the_last_place
|
| 56 |
+
eps_exp = tl.maximum(0, tl.minimum(31, exponent - EXPONENT_BIAS - MANTISSA_BITS + 15))
|
| 57 |
+
|
| 58 |
+
# Calculate epsilon in the target dtype
|
| 59 |
+
eps_normal = tl.cast(tl.cast(eps_exp << 10, tl.int16), tl.float16, bitcast=True)
|
| 60 |
+
|
| 61 |
+
# Subnormal ULP: 2^(1 - EXPONENT_BIAS - MANTISSA_BITS) ->
|
| 62 |
+
# fp16 exponent bits: (1 - EXPONENT_BIAS - MANTISSA_BITS) + 15 =
|
| 63 |
+
# 16 - EXPONENT_BIAS - MANTISSA_BITS
|
| 64 |
+
eps_subnormal = tl.cast(tl.cast((16 - EXPONENT_BIAS - MANTISSA_BITS) << 10, tl.int16), tl.float16, bitcast=True)
|
| 65 |
+
eps = tl.where(exponent > 0, eps_normal, eps_subnormal)
|
| 66 |
+
|
| 67 |
+
# Apply zero mask to epsilon
|
| 68 |
+
eps = tl.where(x == 0, 0.0, eps)
|
| 69 |
+
|
| 70 |
+
# Apply stochastic rounding
|
| 71 |
+
output = tl.cast(x + rand_vals * eps, tl.bfloat16)
|
| 72 |
+
|
| 73 |
+
# Store the result
|
| 74 |
+
tl.store(output_ptr + offsets, output, mask=mask)
|
packages/ltx-core/src/ltx_core/loader/module_ops.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
+
# Created by Alexey Kravtsov
|
| 3 |
+
from typing import Callable, NamedTuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ModuleOps(NamedTuple):
|
| 9 |
+
name: str
|
| 10 |
+
matcher: Callable[[torch.nn.Module], bool]
|
| 11 |
+
mutator: Callable[[torch.nn.Module], torch.nn.Module]
|
packages/ltx-core/src/ltx_core/loader/primitives.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
+
# Created by Alexey Kravtsov
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import NamedTuple, Protocol
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from ltx_core.loader.module_ops import ModuleOps
|
| 9 |
+
from ltx_core.loader.sd_ops import SDOps
|
| 10 |
+
from ltx_core.model.model_protocol import ModelType
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass(frozen=True)
|
| 14 |
+
class StateDict:
|
| 15 |
+
sd: dict
|
| 16 |
+
device: torch.device
|
| 17 |
+
size: int
|
| 18 |
+
dtype: set[torch.dtype]
|
| 19 |
+
|
| 20 |
+
def footprint(self) -> tuple[int, torch.device]:
|
| 21 |
+
return self.size, self.device
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class StateDictLoader(Protocol):
|
| 25 |
+
def metadata(self, path: str) -> dict:
|
| 26 |
+
"""
|
| 27 |
+
Load metadata from path
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict:
|
| 31 |
+
"""
|
| 32 |
+
Load state dict from path or paths (for sharded model storage) and apply sd_ops
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ModelBuilderProtocol(Protocol[ModelType]):
|
| 37 |
+
def meta_model(self, config: dict, module_ops: list[ModuleOps] | None = None) -> ModelType: ...
|
| 38 |
+
|
| 39 |
+
def build(self, dtype: torch.dtype | None = None) -> ModelType:
|
| 40 |
+
"""
|
| 41 |
+
Build the model
|
| 42 |
+
Args:
|
| 43 |
+
dtype: Target dtype for the model, if None, uses the dtype of the model_path model
|
| 44 |
+
Returns:
|
| 45 |
+
Model instance
|
| 46 |
+
"""
|
| 47 |
+
...
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class LoRAAdaptableProtocol(Protocol):
|
| 51 |
+
def lora(self, lora_path: str, strength: float) -> "LoRAAdaptableProtocol":
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class LoraPathStrengthAndSDOps(NamedTuple):
|
| 56 |
+
path: str
|
| 57 |
+
strength: float
|
| 58 |
+
sd_ops: SDOps
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class LoraStateDictWithStrength(NamedTuple):
|
| 62 |
+
state_dict: StateDict
|
| 63 |
+
strength: float
|
packages/ltx-core/src/ltx_core/loader/registry.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
+
# Created by Alexey Kravtsov
|
| 3 |
+
import hashlib
|
| 4 |
+
import threading
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Protocol
|
| 8 |
+
|
| 9 |
+
from ltx_core.loader.primitives import StateDict
|
| 10 |
+
from ltx_core.loader.sd_ops import SDOps
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Registry(Protocol):
|
| 14 |
+
def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None: ...
|
| 15 |
+
|
| 16 |
+
def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ...
|
| 17 |
+
|
| 18 |
+
def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ...
|
| 19 |
+
|
| 20 |
+
def clear(self) -> None: ...
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DummyRegistry(Registry):
|
| 24 |
+
def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None:
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
def clear(self) -> None:
|
| 34 |
+
pass
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class StateDictRegistry(Registry):
|
| 39 |
+
_state_dicts: dict[str, StateDict] = field(default_factory=dict)
|
| 40 |
+
_lock: threading.Lock = field(default_factory=threading.Lock)
|
| 41 |
+
|
| 42 |
+
def _generate_id(self, paths: list[str], sd_ops: SDOps) -> str:
|
| 43 |
+
m = hashlib.sha256()
|
| 44 |
+
parts = [str(Path(p).resolve()) for p in paths]
|
| 45 |
+
if sd_ops is not None:
|
| 46 |
+
parts.append(sd_ops.name)
|
| 47 |
+
m.update("\0".join(parts).encode("utf-8"))
|
| 48 |
+
return m.hexdigest()
|
| 49 |
+
|
| 50 |
+
def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> str:
|
| 51 |
+
sd_id = self._generate_id(paths, sd_ops)
|
| 52 |
+
with self._lock:
|
| 53 |
+
if sd_id in self._state_dicts:
|
| 54 |
+
raise ValueError(f"State dict retrieved from {paths} with {sd_ops} already added, check with get first")
|
| 55 |
+
self._state_dicts[sd_id] = state_dict
|
| 56 |
+
return sd_id
|
| 57 |
+
|
| 58 |
+
def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
|
| 59 |
+
with self._lock:
|
| 60 |
+
return self._state_dicts.pop(self._generate_id(paths, sd_ops), None)
|
| 61 |
+
|
| 62 |
+
def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
|
| 63 |
+
with self._lock:
|
| 64 |
+
return self._state_dicts.get(self._generate_id(paths, sd_ops), None)
|
| 65 |
+
|
| 66 |
+
def clear(self) -> None:
|
| 67 |
+
with self._lock:
|
| 68 |
+
self._state_dicts.clear()
|
packages/ltx-core/src/ltx_core/loader/sd_ops.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
+
# Created by Alexey Kravtsov
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass, replace
|
| 5 |
+
#from typing import NamedTuple, Protocol, Self
|
| 6 |
+
from typing import NamedTuple, Protocol
|
| 7 |
+
from typing_extensions import Self
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass(frozen=True, slots=True)
|
| 13 |
+
class ContentReplacement:
|
| 14 |
+
content: str
|
| 15 |
+
replacement: str
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass(frozen=True, slots=True)
|
| 19 |
+
class ContentMatching:
|
| 20 |
+
prefix: str = ""
|
| 21 |
+
suffix: str = ""
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class KeyValueOperationResult(NamedTuple):
|
| 25 |
+
new_key: str
|
| 26 |
+
new_value: torch.Tensor
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class KeyValueOperation(Protocol):
|
| 30 |
+
def __call__(self, tensor_key: str, tensor_value: torch.Tensor) -> list[KeyValueOperationResult]: ...
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass(frozen=True, slots=True)
|
| 34 |
+
class SDKeyValueOperation:
|
| 35 |
+
key_matcher: ContentMatching
|
| 36 |
+
kv_operation: KeyValueOperation
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass(frozen=True, slots=True)
|
| 40 |
+
class SDOps:
|
| 41 |
+
"""Immutable class representing state dict key operations."""
|
| 42 |
+
|
| 43 |
+
name: str
|
| 44 |
+
mapping: tuple[
|
| 45 |
+
ContentReplacement | ContentMatching | SDKeyValueOperation, ...
|
| 46 |
+
] = () # Immutable tuple of (key, value) pairs
|
| 47 |
+
|
| 48 |
+
def with_replacement(self, content: str, replacement: str) -> Self:
|
| 49 |
+
"""Create a new SDOps instance with the specified replacement added to the mapping."""
|
| 50 |
+
|
| 51 |
+
new_mapping = (*self.mapping, ContentReplacement(content, replacement))
|
| 52 |
+
return replace(self, mapping=new_mapping)
|
| 53 |
+
|
| 54 |
+
def with_matching(self, prefix: str = "", suffix: str = "") -> Self:
|
| 55 |
+
"""Create a new SDOps instance with the specified prefix and suffix matching added to the mapping."""
|
| 56 |
+
|
| 57 |
+
new_mapping = (*self.mapping, ContentMatching(prefix, suffix))
|
| 58 |
+
return replace(self, mapping=new_mapping)
|
| 59 |
+
|
| 60 |
+
def with_kv_operation(
|
| 61 |
+
self,
|
| 62 |
+
operation: KeyValueOperation,
|
| 63 |
+
key_prefix: str = "",
|
| 64 |
+
key_suffix: str = "",
|
| 65 |
+
) -> Self:
|
| 66 |
+
"""Create a new SDOps instance with the specified value operation added to the mapping."""
|
| 67 |
+
key_matcher = ContentMatching(key_prefix, key_suffix)
|
| 68 |
+
sd_kv_operation = SDKeyValueOperation(key_matcher, operation)
|
| 69 |
+
new_mapping = (*self.mapping, sd_kv_operation)
|
| 70 |
+
return replace(self, mapping=new_mapping)
|
| 71 |
+
|
| 72 |
+
def apply_to_key(self, key: str) -> str | None:
|
| 73 |
+
"""Apply the mapping to the given name."""
|
| 74 |
+
matchers = [content for content in self.mapping if isinstance(content, ContentMatching)]
|
| 75 |
+
valid = any(key.startswith(f.prefix) and key.endswith(f.suffix) for f in matchers)
|
| 76 |
+
if not valid:
|
| 77 |
+
return None
|
| 78 |
+
|
| 79 |
+
for replacement in self.mapping:
|
| 80 |
+
if not isinstance(replacement, ContentReplacement):
|
| 81 |
+
continue
|
| 82 |
+
if replacement.content in key:
|
| 83 |
+
key = key.replace(replacement.content, replacement.replacement)
|
| 84 |
+
return key
|
| 85 |
+
|
| 86 |
+
def apply_to_key_value(self, key: str, value: torch.Tensor) -> list[KeyValueOperationResult]:
|
| 87 |
+
"""Apply the value operation to the given name and associated value."""
|
| 88 |
+
for operation in self.mapping:
|
| 89 |
+
if not isinstance(operation, SDKeyValueOperation):
|
| 90 |
+
continue
|
| 91 |
+
if key.startswith(operation.key_matcher.prefix) and key.endswith(operation.key_matcher.suffix):
|
| 92 |
+
return operation.kv_operation(key, value)
|
| 93 |
+
return [KeyValueOperationResult(key, value)]
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# Predefined SDOps instances
|
| 97 |
+
LTXV_LORA_COMFY_RENAMING_MAP = (
|
| 98 |
+
SDOps("LTXV_LORA_COMFY_PREFIX_MAP").with_matching().with_replacement("diffusion_model.", "")
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
LTXV_LORA_COMFY_TARGET_MAP = (
|
| 102 |
+
SDOps("LTXV_LORA_COMFY_TARGET_MAP")
|
| 103 |
+
.with_matching()
|
| 104 |
+
.with_replacement("diffusion_model.", "")
|
| 105 |
+
.with_replacement(".lora_A.weight", ".weight")
|
| 106 |
+
.with_replacement(".lora_B.weight", ".weight")
|
| 107 |
+
)
|
packages/ltx-core/src/ltx_core/loader/sft_loader.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
+
# Created by Alexey Kravtsov
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
import safetensors
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from ltx_core.loader.primitives import StateDict, StateDictLoader
|
| 9 |
+
from ltx_core.loader.sd_ops import SDOps
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SafetensorsStateDictLoader(StateDictLoader):
|
| 13 |
+
def metadata(self, path: str) -> dict:
|
| 14 |
+
raise NotImplementedError("Not implemented")
|
| 15 |
+
|
| 16 |
+
def load(self, path: str | list[str], sd_ops: SDOps, device: torch.device | None = None) -> StateDict:
|
| 17 |
+
"""
|
| 18 |
+
Load state dict from path or paths (for sharded model storage) and apply sd_ops
|
| 19 |
+
"""
|
| 20 |
+
sd = {}
|
| 21 |
+
size = 0
|
| 22 |
+
dtype = set()
|
| 23 |
+
device = device or torch.device("cpu")
|
| 24 |
+
model_paths = path if isinstance(path, list) else [path]
|
| 25 |
+
for shard_path in model_paths:
|
| 26 |
+
with safetensors.safe_open(shard_path, framework="pt", device=str(device)) as f:
|
| 27 |
+
safetensor_keys = f.keys()
|
| 28 |
+
for name in safetensor_keys:
|
| 29 |
+
expected_name = name if sd_ops is None else sd_ops.apply_to_key(name)
|
| 30 |
+
if expected_name is None:
|
| 31 |
+
continue
|
| 32 |
+
value = f.get_tensor(name).to(device=device, non_blocking=True, copy=False)
|
| 33 |
+
key_value_pairs = ((expected_name, value),)
|
| 34 |
+
if sd_ops is not None:
|
| 35 |
+
key_value_pairs = sd_ops.apply_to_key_value(expected_name, value)
|
| 36 |
+
for key, value in key_value_pairs:
|
| 37 |
+
size += value.nbytes
|
| 38 |
+
dtype.add(value.dtype)
|
| 39 |
+
sd[key] = value
|
| 40 |
+
|
| 41 |
+
return StateDict(sd=sd, device=device, size=size, dtype=dtype)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class SafetensorsModelStateDictLoader(StateDictLoader):
|
| 45 |
+
def __init__(self, weight_loader: SafetensorsStateDictLoader | None = None):
|
| 46 |
+
self.weight_loader = weight_loader if weight_loader is not None else SafetensorsStateDictLoader()
|
| 47 |
+
|
| 48 |
+
def metadata(self, path: str) -> dict:
|
| 49 |
+
with safetensors.safe_open(path, framework="pt") as f:
|
| 50 |
+
return json.loads(f.metadata()["config"])
|
| 51 |
+
|
| 52 |
+
def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict:
|
| 53 |
+
return self.weight_loader.load(path, sd_ops, device)
|
packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
+
# Created by Alexey Kravtsov
|
| 3 |
+
import logging
|
| 4 |
+
from dataclasses import dataclass, field, replace
|
| 5 |
+
from typing import Generic
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from ltx_core.loader.fuse_loras import apply_loras
|
| 10 |
+
from ltx_core.loader.module_ops import ModuleOps
|
| 11 |
+
from ltx_core.loader.primitives import (
|
| 12 |
+
LoRAAdaptableProtocol,
|
| 13 |
+
LoraPathStrengthAndSDOps,
|
| 14 |
+
LoraStateDictWithStrength,
|
| 15 |
+
ModelBuilderProtocol,
|
| 16 |
+
StateDict,
|
| 17 |
+
StateDictLoader,
|
| 18 |
+
)
|
| 19 |
+
from ltx_core.loader.registry import DummyRegistry, Registry
|
| 20 |
+
from ltx_core.loader.sd_ops import SDOps
|
| 21 |
+
from ltx_core.loader.sft_loader import SafetensorsModelStateDictLoader
|
| 22 |
+
from ltx_core.model.model_protocol import ModelConfigurator, ModelType
|
| 23 |
+
|
| 24 |
+
logger: logging.Logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass(frozen=True)
|
| 28 |
+
class SingleGPUModelBuilder(Generic[ModelType], ModelBuilderProtocol[ModelType], LoRAAdaptableProtocol):
|
| 29 |
+
model_class_configurator: type[ModelConfigurator[ModelType]]
|
| 30 |
+
model_path: str | tuple[str, ...]
|
| 31 |
+
model_sd_ops: SDOps | None = None
|
| 32 |
+
module_ops: tuple[ModuleOps, ...] = field(default_factory=tuple)
|
| 33 |
+
loras: tuple[LoraPathStrengthAndSDOps, ...] = field(default_factory=tuple)
|
| 34 |
+
model_loader: StateDictLoader = field(default_factory=SafetensorsModelStateDictLoader)
|
| 35 |
+
registry: Registry = field(default_factory=DummyRegistry)
|
| 36 |
+
|
| 37 |
+
def lora(self, lora_path: str, strength: float = 1.0, sd_ops: SDOps | None = None) -> "SingleGPUModelBuilder":
|
| 38 |
+
return replace(self, loras=(*self.loras, LoraPathStrengthAndSDOps(lora_path, strength, sd_ops)))
|
| 39 |
+
|
| 40 |
+
def model_config(self) -> dict:
|
| 41 |
+
first_shard_path = self.model_path[0] if isinstance(self.model_path, tuple) else self.model_path
|
| 42 |
+
return self.model_loader.metadata(first_shard_path)
|
| 43 |
+
|
| 44 |
+
def meta_model(self, config: dict, module_ops: tuple[ModuleOps, ...]) -> ModelType:
|
| 45 |
+
with torch.device("meta"):
|
| 46 |
+
model = self.model_class_configurator.from_config(config)
|
| 47 |
+
for module_op in module_ops:
|
| 48 |
+
if module_op.matcher(model):
|
| 49 |
+
model = module_op.mutator(model)
|
| 50 |
+
return model
|
| 51 |
+
|
| 52 |
+
def load_sd(
|
| 53 |
+
self, paths: list[str], registry: Registry, device: torch.device | None, sd_ops: SDOps | None = None
|
| 54 |
+
) -> StateDict:
|
| 55 |
+
state_dict = registry.get(paths, sd_ops)
|
| 56 |
+
if state_dict is None:
|
| 57 |
+
state_dict = self.model_loader.load(paths, sd_ops=sd_ops, device=device)
|
| 58 |
+
registry.add(paths, sd_ops=sd_ops, state_dict=state_dict)
|
| 59 |
+
return state_dict
|
| 60 |
+
|
| 61 |
+
def _return_model(self, meta_model: ModelType, device: torch.device) -> ModelType:
|
| 62 |
+
uninitialized_params = [name for name, param in meta_model.named_parameters() if str(param.device) == "meta"]
|
| 63 |
+
uninitialized_buffers = [name for name, buffer in meta_model.named_buffers() if str(buffer.device) == "meta"]
|
| 64 |
+
if uninitialized_params or uninitialized_buffers:
|
| 65 |
+
logger.warning(f"Uninitialized parameters or buffers: {uninitialized_params + uninitialized_buffers}")
|
| 66 |
+
return meta_model
|
| 67 |
+
retval = meta_model.to(device)
|
| 68 |
+
return retval
|
| 69 |
+
|
| 70 |
+
def build(self, device: torch.device | None = None, dtype: torch.dtype | None = None) -> ModelType:
|
| 71 |
+
device = torch.device("cuda") if device is None else device
|
| 72 |
+
config = self.model_config()
|
| 73 |
+
meta_model = self.meta_model(config, self.module_ops)
|
| 74 |
+
model_paths = self.model_path if isinstance(self.model_path, tuple) else [self.model_path]
|
| 75 |
+
model_state_dict = self.load_sd(model_paths, sd_ops=self.model_sd_ops, registry=self.registry, device=device)
|
| 76 |
+
|
| 77 |
+
lora_strengths = [lora.strength for lora in self.loras]
|
| 78 |
+
if not lora_strengths or (min(lora_strengths) == 0 and max(lora_strengths) == 0):
|
| 79 |
+
sd = model_state_dict.sd
|
| 80 |
+
if dtype is not None:
|
| 81 |
+
sd = {key: value.to(dtype=dtype) for key, value in model_state_dict.sd.items()}
|
| 82 |
+
meta_model.load_state_dict(sd, strict=False, assign=True)
|
| 83 |
+
return self._return_model(meta_model, device)
|
| 84 |
+
|
| 85 |
+
lora_state_dicts = [
|
| 86 |
+
self.load_sd([lora.path], sd_ops=lora.sd_ops, registry=self.registry, device=device) for lora in self.loras
|
| 87 |
+
]
|
| 88 |
+
lora_sd_and_strengths = [
|
| 89 |
+
LoraStateDictWithStrength(sd, strength)
|
| 90 |
+
for sd, strength in zip(lora_state_dicts, lora_strengths, strict=True)
|
| 91 |
+
]
|
| 92 |
+
final_sd = apply_loras(
|
| 93 |
+
model_sd=model_state_dict,
|
| 94 |
+
lora_sd_and_strengths=lora_sd_and_strengths,
|
| 95 |
+
dtype=dtype,
|
| 96 |
+
destination_sd=model_state_dict if isinstance(self.registry, DummyRegistry) else None,
|
| 97 |
+
)
|
| 98 |
+
meta_model.load_state_dict(final_sd.sd, strict=False, assign=True)
|
| 99 |
+
return self._return_model(meta_model, device)
|
packages/ltx-core/src/ltx_core/model/.ipynb_checkpoints/model_ledger-checkpoint.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import replace
|
| 2 |
+
# from typing import Self
|
| 3 |
+
from typing_extensions import Self
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
|
| 8 |
+
from ltx_core.loader.registry import DummyRegistry, Registry
|
| 9 |
+
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder
|
| 10 |
+
from ltx_core.model.audio_vae.audio_vae import Decoder as AudioDecoder
|
| 11 |
+
from ltx_core.model.audio_vae.model_configurator import (
|
| 12 |
+
AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
|
| 13 |
+
VOCODER_COMFY_KEYS_FILTER,
|
| 14 |
+
VocoderConfigurator,
|
| 15 |
+
)
|
| 16 |
+
from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator as AudioDecoderConfigurator
|
| 17 |
+
from ltx_core.model.audio_vae.vocoder import Vocoder
|
| 18 |
+
from ltx_core.model.clip.gemma.encoders.av_encoder import (
|
| 19 |
+
AV_GEMMA_TEXT_ENCODER_KEY_OPS,
|
| 20 |
+
AVGemmaTextEncoderModel,
|
| 21 |
+
AVGemmaTextEncoderModelConfigurator,
|
| 22 |
+
)
|
| 23 |
+
from ltx_core.model.clip.gemma.encoders.base_encoder import module_ops_from_gemma_root
|
| 24 |
+
from ltx_core.model.transformer.model import X0Model
|
| 25 |
+
from ltx_core.model.transformer.model_configurator import (
|
| 26 |
+
LTXV_MODEL_COMFY_RENAMING_MAP,
|
| 27 |
+
LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP,
|
| 28 |
+
UPCAST_DURING_INFERENCE,
|
| 29 |
+
LTXModelConfigurator,
|
| 30 |
+
)
|
| 31 |
+
from ltx_core.model.upsampler.model import LatentUpsampler
|
| 32 |
+
from ltx_core.model.upsampler.model_configurator import LatentUpsamplerConfigurator
|
| 33 |
+
from ltx_core.model.video_vae.model_configurator import (
|
| 34 |
+
VAE_DECODER_COMFY_KEYS_FILTER,
|
| 35 |
+
VAE_ENCODER_COMFY_KEYS_FILTER,
|
| 36 |
+
VAEDecoderConfigurator,
|
| 37 |
+
VAEEncoderConfigurator,
|
| 38 |
+
)
|
| 39 |
+
from ltx_core.model.video_vae.video_vae import Decoder as VideoDecoder
|
| 40 |
+
from ltx_core.model.video_vae.video_vae import Encoder as VideoEncoder
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ModelLedger:
|
| 44 |
+
"""
|
| 45 |
+
Central coordinator for loading, caching, and freeing models used in an LTX pipeline.
|
| 46 |
+
The ledger wires together multiple model builders (transformer, video VAE encoder/decoder,
|
| 47 |
+
audio VAE decoder, vocoder, text encoder, and optional latent upsampler) and exposes
|
| 48 |
+
the resulting models as lazily constructed, cached attributes.
|
| 49 |
+
|
| 50 |
+
### Caching behavior
|
| 51 |
+
|
| 52 |
+
Each model attribute (e.g. :attr:`transformer`, :attr:`video_decoder`, :attr:`text_encoder`)
|
| 53 |
+
is implemented as a :func:`functools.cached_property`. The first time one of these
|
| 54 |
+
attributes is accessed, the corresponding builder loads weights from the
|
| 55 |
+
:class:`~ltx_core.loader.registry.StateDictRegistry`, instantiates the model on CPU with
|
| 56 |
+
the configured ``dtype``, moves it to ``self.device``, and stores the result in
|
| 57 |
+
the instance ``__dict__``. Subsequent accesses reuse the same model instance until it is
|
| 58 |
+
explicitly cleared via :meth:`clear_vram`.
|
| 59 |
+
|
| 60 |
+
### Constructor parameters
|
| 61 |
+
|
| 62 |
+
dtype:
|
| 63 |
+
Torch dtype used when constructing all models (e.g. ``torch.float16``).
|
| 64 |
+
device:
|
| 65 |
+
Target device to which models are moved after construction (e.g. ``torch.device("cuda")``).
|
| 66 |
+
checkpoint_path:
|
| 67 |
+
Path to a checkpoint directory or file containing the core model weights
|
| 68 |
+
(transformer, video VAE, audio VAE, text encoder, vocoder). If ``None``, the
|
| 69 |
+
corresponding builders are not created and accessing those properties will raise
|
| 70 |
+
a :class:`ValueError`.
|
| 71 |
+
gemma_root_path:
|
| 72 |
+
Base path to Gemma-compatible CLIP/text encoder weights. Required to
|
| 73 |
+
initialize the text encoder builder; if omitted, :attr:`text_encoder` cannot be used.
|
| 74 |
+
spatial_upsampler_path:
|
| 75 |
+
Optional path to a latent upsampler checkpoint. If provided, the
|
| 76 |
+
:attr:`upsampler` property becomes available; otherwise accessing it raises
|
| 77 |
+
a :class:`ValueError`.
|
| 78 |
+
loras:
|
| 79 |
+
Optional collection of LoRA configurations (paths, strengths, and key operations)
|
| 80 |
+
that are applied on top of the base transformer weights when building the model.
|
| 81 |
+
|
| 82 |
+
### Memory management
|
| 83 |
+
|
| 84 |
+
``clear_ram()``
|
| 85 |
+
Clears the underlying :class:`Registry` cache of state dicts and triggers a
|
| 86 |
+
Python garbage collection pass. Use this when you no longer need to construct new
|
| 87 |
+
models from the currently loaded checkpoints and want to free host (CPU) memory.
|
| 88 |
+
``clear_vram()``
|
| 89 |
+
Drops the cached model instances stored by the ``@cached_property`` attributes from
|
| 90 |
+
this ledger (by removing them from ``self.__dict__``) and calls
|
| 91 |
+
:func:`torch.cuda.empty_cache`. Use this when you want to release GPU memory;
|
| 92 |
+
subsequent access to a model property will rebuild the model from the registry
|
| 93 |
+
while keeping the existing builder configuration.
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
dtype: torch.dtype,
|
| 99 |
+
device: torch.device,
|
| 100 |
+
checkpoint_path: str | None = None,
|
| 101 |
+
gemma_root_path: str | None = None,
|
| 102 |
+
spatial_upsampler_path: str | None = None,
|
| 103 |
+
loras: LoraPathStrengthAndSDOps | None = None,
|
| 104 |
+
registry: Registry | None = None,
|
| 105 |
+
fp8transformer: bool = False,
|
| 106 |
+
local_files_only: bool = True
|
| 107 |
+
):
|
| 108 |
+
self.dtype = dtype
|
| 109 |
+
self.device = device
|
| 110 |
+
self.checkpoint_path = checkpoint_path
|
| 111 |
+
self.gemma_root_path = gemma_root_path
|
| 112 |
+
self.spatial_upsampler_path = spatial_upsampler_path
|
| 113 |
+
self.loras = loras or ()
|
| 114 |
+
self.registry = registry or DummyRegistry()
|
| 115 |
+
self.fp8transformer = fp8transformer
|
| 116 |
+
self.local_files_only = local_files_only
|
| 117 |
+
self.build_model_builders()
|
| 118 |
+
|
| 119 |
+
def build_model_builders(self) -> None:
|
| 120 |
+
if self.checkpoint_path is not None:
|
| 121 |
+
self.transformer_builder = Builder(
|
| 122 |
+
model_path=self.checkpoint_path,
|
| 123 |
+
model_class_configurator=LTXModelConfigurator,
|
| 124 |
+
model_sd_ops=LTXV_MODEL_COMFY_RENAMING_MAP,
|
| 125 |
+
loras=tuple(self.loras),
|
| 126 |
+
registry=self.registry,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
self.vae_decoder_builder = Builder(
|
| 130 |
+
model_path=self.checkpoint_path,
|
| 131 |
+
model_class_configurator=VAEDecoderConfigurator,
|
| 132 |
+
model_sd_ops=VAE_DECODER_COMFY_KEYS_FILTER,
|
| 133 |
+
registry=self.registry,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
self.vae_encoder_builder = Builder(
|
| 137 |
+
model_path=self.checkpoint_path,
|
| 138 |
+
model_class_configurator=VAEEncoderConfigurator,
|
| 139 |
+
model_sd_ops=VAE_ENCODER_COMFY_KEYS_FILTER,
|
| 140 |
+
registry=self.registry,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
self.audio_decoder_builder = Builder(
|
| 144 |
+
model_path=self.checkpoint_path,
|
| 145 |
+
model_class_configurator=AudioDecoderConfigurator,
|
| 146 |
+
model_sd_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
|
| 147 |
+
registry=self.registry,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
self.vocoder_builder = Builder(
|
| 151 |
+
model_path=self.checkpoint_path,
|
| 152 |
+
model_class_configurator=VocoderConfigurator,
|
| 153 |
+
model_sd_ops=VOCODER_COMFY_KEYS_FILTER,
|
| 154 |
+
registry=self.registry,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
if self.gemma_root_path is not None:
|
| 158 |
+
self.text_encoder_builder = Builder(
|
| 159 |
+
model_path=self.checkpoint_path,
|
| 160 |
+
model_class_configurator=AVGemmaTextEncoderModelConfigurator,
|
| 161 |
+
model_sd_ops=AV_GEMMA_TEXT_ENCODER_KEY_OPS,
|
| 162 |
+
registry=self.registry,
|
| 163 |
+
module_ops=module_ops_from_gemma_root(self.gemma_root_path, self.local_files_only),
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
if self.spatial_upsampler_path is not None:
|
| 167 |
+
self.upsampler_builder = Builder(
|
| 168 |
+
model_path=self.spatial_upsampler_path,
|
| 169 |
+
model_class_configurator=LatentUpsamplerConfigurator,
|
| 170 |
+
registry=self.registry,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
def _target_device(self) -> torch.device:
|
| 174 |
+
if isinstance(self.registry, DummyRegistry) or self.registry is None:
|
| 175 |
+
return self.device
|
| 176 |
+
else:
|
| 177 |
+
return torch.device("cpu")
|
| 178 |
+
|
| 179 |
+
def with_loras(self, loras: LoraPathStrengthAndSDOps) -> Self:
|
| 180 |
+
return ModelLedger(
|
| 181 |
+
dtype=self.dtype,
|
| 182 |
+
device=self.device,
|
| 183 |
+
checkpoint_path=self.checkpoint_path,
|
| 184 |
+
gemma_root_path=self.gemma_root_path,
|
| 185 |
+
spatial_upsampler_path=self.spatial_upsampler_path,
|
| 186 |
+
loras=(*self.loras, *loras),
|
| 187 |
+
registry=self.registry,
|
| 188 |
+
fp8transformer=self.fp8transformer,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
def transformer(self) -> X0Model:
|
| 192 |
+
if not hasattr(self, "transformer_builder"):
|
| 193 |
+
raise ValueError(
|
| 194 |
+
"Transformer not initialized. Please provide a checkpoint path to the ModelLedger constructor."
|
| 195 |
+
)
|
| 196 |
+
if self.fp8transformer:
|
| 197 |
+
fp8_builder = replace(
|
| 198 |
+
self.transformer_builder,
|
| 199 |
+
module_ops=(UPCAST_DURING_INFERENCE,),
|
| 200 |
+
model_sd_ops=LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP,
|
| 201 |
+
)
|
| 202 |
+
return X0Model(fp8_builder.build(device=self._target_device())).to(self.device)
|
| 203 |
+
else:
|
| 204 |
+
return X0Model(self.transformer_builder.build(device=self._target_device(), dtype=self.dtype)).to(
|
| 205 |
+
self.device
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
def video_decoder(self) -> VideoDecoder:
|
| 209 |
+
if not hasattr(self, "vae_decoder_builder"):
|
| 210 |
+
raise ValueError(
|
| 211 |
+
"Video decoder not initialized. Please provide a checkpoint path to the ModelLedger constructor."
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
return self.vae_decoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device)
|
| 215 |
+
|
| 216 |
+
def video_encoder(self) -> VideoEncoder:
|
| 217 |
+
if not hasattr(self, "vae_encoder_builder"):
|
| 218 |
+
raise ValueError(
|
| 219 |
+
"Video encoder not initialized. Please provide a checkpoint path to the ModelLedger constructor."
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
return self.vae_encoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device)
|
| 223 |
+
|
| 224 |
+
def text_encoder(self) -> AVGemmaTextEncoderModel:
|
| 225 |
+
if not hasattr(self, "text_encoder_builder"):
|
| 226 |
+
raise ValueError(
|
| 227 |
+
"Text encoder not initialized. Please provide a checkpoint path and gemma root path to the "
|
| 228 |
+
"ModelLedger constructor."
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
return self.text_encoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device)
|
| 232 |
+
|
| 233 |
+
def audio_decoder(self) -> AudioDecoder:
|
| 234 |
+
if not hasattr(self, "audio_decoder_builder"):
|
| 235 |
+
raise ValueError(
|
| 236 |
+
"Audio decoder not initialized. Please provide a checkpoint path to the ModelLedger constructor."
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
return self.audio_decoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device)
|
| 240 |
+
|
| 241 |
+
def vocoder(self) -> Vocoder:
|
| 242 |
+
if not hasattr(self, "vocoder_builder"):
|
| 243 |
+
raise ValueError(
|
| 244 |
+
"Vocoder not initialized. Please provide a checkpoint path to the ModelLedger constructor."
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
return self.vocoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device)
|
| 248 |
+
|
| 249 |
+
def spatial_upsampler(self) -> LatentUpsampler:
|
| 250 |
+
if not hasattr(self, "upsampler_builder"):
|
| 251 |
+
raise ValueError("Upsampler not initialized. Please provide upsampler path to the ModelLedger constructor.")
|
| 252 |
+
|
| 253 |
+
return self.upsampler_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device)
|
packages/ltx-core/src/ltx_core/model/__init__.py
ADDED
|
File without changes
|
packages/ltx-core/src/ltx_core/model/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (182 Bytes). View file
|
|
|
packages/ltx-core/src/ltx_core/model/__pycache__/model_ledger.cpython-310.pyc
ADDED
|
Binary file (9.16 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/__pycache__/model_protocol.cpython-310.pyc
ADDED
|
Binary file (744 Bytes). View file
|
|
|
packages/ltx-core/src/ltx_core/model/audio_vae/__init__.py
ADDED
|
File without changes
|
packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (192 Bytes). View file
|
|
|
packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/attention.cpython-310.pyc
ADDED
|
Binary file (2.37 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/audio_vae.cpython-310.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/causal_conv_2d.cpython-310.pyc
ADDED
|
Binary file (3.28 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/causality_axis.cpython-310.pyc
ADDED
|
Binary file (574 Bytes). View file
|
|
|
packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/downsample.cpython-310.pyc
ADDED
|
Binary file (3.18 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/model_configurator.cpython-310.pyc
ADDED
|
Binary file (4.41 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/ops.cpython-310.pyc
ADDED
|
Binary file (3 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/resnet.cpython-310.pyc
ADDED
|
Binary file (4.06 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/upsample.cpython-310.pyc
ADDED
|
Binary file (2.88 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/vocoder.cpython-310.pyc
ADDED
|
Binary file (4.89 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/audio_vae/attention.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.model.common.normalization import NormType, build_normalization_layer
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AttentionType(Enum):
|
| 9 |
+
"""Enum for specifying the attention mechanism type."""
|
| 10 |
+
|
| 11 |
+
VANILLA = "vanilla"
|
| 12 |
+
LINEAR = "linear"
|
| 13 |
+
NONE = "none"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class AttnBlock(torch.nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
in_channels: int,
|
| 20 |
+
norm_type: NormType = NormType.GROUP,
|
| 21 |
+
) -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.in_channels = in_channels
|
| 24 |
+
|
| 25 |
+
self.norm = build_normalization_layer(in_channels, normtype=norm_type)
|
| 26 |
+
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 27 |
+
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 28 |
+
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 29 |
+
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 30 |
+
|
| 31 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 32 |
+
h_ = x
|
| 33 |
+
h_ = self.norm(h_)
|
| 34 |
+
q = self.q(h_)
|
| 35 |
+
k = self.k(h_)
|
| 36 |
+
v = self.v(h_)
|
| 37 |
+
|
| 38 |
+
# compute attention
|
| 39 |
+
b, c, h, w = q.shape
|
| 40 |
+
q = q.reshape(b, c, h * w).contiguous()
|
| 41 |
+
q = q.permute(0, 2, 1).contiguous() # b,hw,c
|
| 42 |
+
k = k.reshape(b, c, h * w).contiguous() # b,c,hw
|
| 43 |
+
w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
| 44 |
+
w_ = w_ * (int(c) ** (-0.5))
|
| 45 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
| 46 |
+
|
| 47 |
+
# attend to values
|
| 48 |
+
v = v.reshape(b, c, h * w).contiguous()
|
| 49 |
+
w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
|
| 50 |
+
h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
| 51 |
+
h_ = h_.reshape(b, c, h, w).contiguous()
|
| 52 |
+
|
| 53 |
+
h_ = self.proj_out(h_)
|
| 54 |
+
|
| 55 |
+
return x + h_
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def make_attn(
|
| 59 |
+
in_channels: int,
|
| 60 |
+
attn_type: AttentionType = AttentionType.VANILLA,
|
| 61 |
+
norm_type: NormType = NormType.GROUP,
|
| 62 |
+
) -> torch.nn.Module:
|
| 63 |
+
match attn_type:
|
| 64 |
+
case AttentionType.VANILLA:
|
| 65 |
+
return AttnBlock(in_channels, norm_type=norm_type)
|
| 66 |
+
case AttentionType.NONE:
|
| 67 |
+
return torch.nn.Identity()
|
| 68 |
+
case AttentionType.LINEAR:
|
| 69 |
+
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
|
| 70 |
+
case _:
|
| 71 |
+
raise ValueError(f"Unknown attention type: {attn_type}")
|
packages/ltx-core/src/ltx_core/model/audio_vae/audio_vae.py
ADDED
|
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
+
# Created by Ivan Zorin
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
from typing import Set, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from ltx_core.model.audio_vae.attention import AttentionType, make_attn
|
| 11 |
+
from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
|
| 12 |
+
from ltx_core.model.audio_vae.causality_axis import CausalityAxis
|
| 13 |
+
from ltx_core.model.audio_vae.downsample import build_downsampling_path
|
| 14 |
+
from ltx_core.model.audio_vae.ops import PerChannelStatistics
|
| 15 |
+
from ltx_core.model.audio_vae.resnet import ResnetBlock
|
| 16 |
+
from ltx_core.model.audio_vae.upsample import build_upsampling_path
|
| 17 |
+
from ltx_core.model.common.normalization import NormType, build_normalization_layer
|
| 18 |
+
from ltx_core.pipeline.components.patchifiers import AudioPatchifier
|
| 19 |
+
from ltx_core.pipeline.components.protocols import AudioLatentShape
|
| 20 |
+
|
| 21 |
+
LATENT_DOWNSAMPLE_FACTOR = 4
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def build_mid_block(
|
| 25 |
+
channels: int,
|
| 26 |
+
temb_channels: int,
|
| 27 |
+
dropout: float,
|
| 28 |
+
norm_type: NormType,
|
| 29 |
+
causality_axis: CausalityAxis,
|
| 30 |
+
attn_type: AttentionType,
|
| 31 |
+
add_attention: bool,
|
| 32 |
+
) -> torch.nn.Module:
|
| 33 |
+
"""Build the middle block with two ResNet blocks and optional attention."""
|
| 34 |
+
mid = torch.nn.Module()
|
| 35 |
+
mid.block_1 = ResnetBlock(
|
| 36 |
+
in_channels=channels,
|
| 37 |
+
out_channels=channels,
|
| 38 |
+
temb_channels=temb_channels,
|
| 39 |
+
dropout=dropout,
|
| 40 |
+
norm_type=norm_type,
|
| 41 |
+
causality_axis=causality_axis,
|
| 42 |
+
)
|
| 43 |
+
mid.attn_1 = make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else torch.nn.Identity()
|
| 44 |
+
mid.block_2 = ResnetBlock(
|
| 45 |
+
in_channels=channels,
|
| 46 |
+
out_channels=channels,
|
| 47 |
+
temb_channels=temb_channels,
|
| 48 |
+
dropout=dropout,
|
| 49 |
+
norm_type=norm_type,
|
| 50 |
+
causality_axis=causality_axis,
|
| 51 |
+
)
|
| 52 |
+
return mid
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def run_mid_block(mid: torch.nn.Module, features: torch.Tensor) -> torch.Tensor:
|
| 56 |
+
"""Run features through the middle block."""
|
| 57 |
+
features = mid.block_1(features, temb=None)
|
| 58 |
+
features = mid.attn_1(features)
|
| 59 |
+
return mid.block_2(features, temb=None)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class Encoder(torch.nn.Module):
|
| 63 |
+
"""
|
| 64 |
+
Encoder that compresses audio spectrograms into latent representations.
|
| 65 |
+
|
| 66 |
+
The encoder uses a series of downsampling blocks with residual connections,
|
| 67 |
+
attention mechanisms, and configurable causal convolutions.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__( # noqa: PLR0913
|
| 71 |
+
self,
|
| 72 |
+
*,
|
| 73 |
+
ch: int,
|
| 74 |
+
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
| 75 |
+
num_res_blocks: int,
|
| 76 |
+
attn_resolutions: Set[int],
|
| 77 |
+
dropout: float = 0.0,
|
| 78 |
+
resamp_with_conv: bool = True,
|
| 79 |
+
in_channels: int,
|
| 80 |
+
resolution: int,
|
| 81 |
+
z_channels: int,
|
| 82 |
+
double_z: bool = True,
|
| 83 |
+
attn_type: AttentionType = AttentionType.VANILLA,
|
| 84 |
+
mid_block_add_attention: bool = True,
|
| 85 |
+
norm_type: NormType = NormType.GROUP,
|
| 86 |
+
causality_axis: CausalityAxis = CausalityAxis.WIDTH,
|
| 87 |
+
sample_rate: int = 16000,
|
| 88 |
+
mel_hop_length: int = 160,
|
| 89 |
+
n_fft: int = 1024,
|
| 90 |
+
is_causal: bool = True,
|
| 91 |
+
mel_bins: int = 64,
|
| 92 |
+
**_ignore_kwargs,
|
| 93 |
+
) -> None:
|
| 94 |
+
"""
|
| 95 |
+
Initialize the Encoder.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
Arguments are configuration parameters, loaded from the audio VAE checkpoint config
|
| 99 |
+
(audio_vae.model.params.ddconfig):
|
| 100 |
+
|
| 101 |
+
ch: Base number of feature channels used in the first convolution layer.
|
| 102 |
+
ch_mult: Multiplicative factors for the number of channels at each resolution level.
|
| 103 |
+
num_res_blocks: Number of residual blocks to use at each resolution level.
|
| 104 |
+
attn_resolutions: Spatial resolutions (e.g., in time/frequency) at which to apply attention.
|
| 105 |
+
resolution: Input spatial resolution of the spectrogram (height, width).
|
| 106 |
+
z_channels: Number of channels in the latent representation.
|
| 107 |
+
norm_type: Normalization layer type to use within the network (e.g., group, batch).
|
| 108 |
+
causality_axis: Axis along which convolutions should be causal (e.g., time axis).
|
| 109 |
+
sample_rate: Audio sample rate in Hz for the input signals.
|
| 110 |
+
mel_hop_length: Hop length used when computing the mel spectrogram.
|
| 111 |
+
n_fft: FFT size used to compute the spectrogram.
|
| 112 |
+
mel_bins: Number of mel-frequency bins in the input spectrogram.
|
| 113 |
+
in_channels: Number of channels in the input spectrogram tensor.
|
| 114 |
+
double_z: If True, predict both mean and log-variance (doubling latent channels).
|
| 115 |
+
is_causal: If True, use causal convolutions suitable for streaming setups.
|
| 116 |
+
dropout: Dropout probability used in residual and mid blocks.
|
| 117 |
+
attn_type: Type of attention mechanism to use in attention blocks.
|
| 118 |
+
resamp_with_conv: If True, perform resolution changes using strided convolutions.
|
| 119 |
+
mid_block_add_attention: If True, add an attention block in the mid-level of the encoder.
|
| 120 |
+
"""
|
| 121 |
+
super().__init__()
|
| 122 |
+
|
| 123 |
+
self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
|
| 124 |
+
self.sample_rate = sample_rate
|
| 125 |
+
self.mel_hop_length = mel_hop_length
|
| 126 |
+
self.n_fft = n_fft
|
| 127 |
+
self.is_causal = is_causal
|
| 128 |
+
self.mel_bins = mel_bins
|
| 129 |
+
|
| 130 |
+
self.patchifier = AudioPatchifier(
|
| 131 |
+
patch_size=1,
|
| 132 |
+
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
| 133 |
+
sample_rate=sample_rate,
|
| 134 |
+
hop_length=mel_hop_length,
|
| 135 |
+
is_causal=is_causal,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
self.ch = ch
|
| 139 |
+
self.temb_ch = 0
|
| 140 |
+
self.num_resolutions = len(ch_mult)
|
| 141 |
+
self.num_res_blocks = num_res_blocks
|
| 142 |
+
self.resolution = resolution
|
| 143 |
+
self.in_channels = in_channels
|
| 144 |
+
self.z_channels = z_channels
|
| 145 |
+
self.double_z = double_z
|
| 146 |
+
self.norm_type = norm_type
|
| 147 |
+
self.causality_axis = causality_axis
|
| 148 |
+
self.attn_type = attn_type
|
| 149 |
+
|
| 150 |
+
# downsampling
|
| 151 |
+
self.conv_in = make_conv2d(
|
| 152 |
+
in_channels,
|
| 153 |
+
self.ch,
|
| 154 |
+
kernel_size=3,
|
| 155 |
+
stride=1,
|
| 156 |
+
causality_axis=self.causality_axis,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
self.non_linearity = torch.nn.SiLU()
|
| 160 |
+
|
| 161 |
+
self.down, block_in = build_downsampling_path(
|
| 162 |
+
ch=ch,
|
| 163 |
+
ch_mult=ch_mult,
|
| 164 |
+
num_resolutions=self.num_resolutions,
|
| 165 |
+
num_res_blocks=num_res_blocks,
|
| 166 |
+
resolution=resolution,
|
| 167 |
+
temb_channels=self.temb_ch,
|
| 168 |
+
dropout=dropout,
|
| 169 |
+
norm_type=self.norm_type,
|
| 170 |
+
causality_axis=self.causality_axis,
|
| 171 |
+
attn_type=self.attn_type,
|
| 172 |
+
attn_resolutions=attn_resolutions,
|
| 173 |
+
resamp_with_conv=resamp_with_conv,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
self.mid = build_mid_block(
|
| 177 |
+
channels=block_in,
|
| 178 |
+
temb_channels=self.temb_ch,
|
| 179 |
+
dropout=dropout,
|
| 180 |
+
norm_type=self.norm_type,
|
| 181 |
+
causality_axis=self.causality_axis,
|
| 182 |
+
attn_type=self.attn_type,
|
| 183 |
+
add_attention=mid_block_add_attention,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
|
| 187 |
+
self.conv_out = make_conv2d(
|
| 188 |
+
block_in,
|
| 189 |
+
2 * z_channels if double_z else z_channels,
|
| 190 |
+
kernel_size=3,
|
| 191 |
+
stride=1,
|
| 192 |
+
causality_axis=self.causality_axis,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:
|
| 196 |
+
"""
|
| 197 |
+
Encode audio spectrogram into latent representations.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
spectrogram: Input spectrogram of shape (batch, channels, time, frequency)
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
Encoded latent representation of shape (batch, channels, frames, mel_bins)
|
| 204 |
+
"""
|
| 205 |
+
h = self.conv_in(spectrogram)
|
| 206 |
+
h = self._run_downsampling_path(h)
|
| 207 |
+
h = run_mid_block(self.mid, h)
|
| 208 |
+
h = self._finalize_output(h)
|
| 209 |
+
|
| 210 |
+
return self._normalize_latents(h)
|
| 211 |
+
|
| 212 |
+
def _run_downsampling_path(self, h: torch.Tensor) -> torch.Tensor:
|
| 213 |
+
for level in range(self.num_resolutions):
|
| 214 |
+
stage = self.down[level]
|
| 215 |
+
for block_idx in range(self.num_res_blocks):
|
| 216 |
+
h = stage.block[block_idx](h, temb=None)
|
| 217 |
+
if stage.attn:
|
| 218 |
+
h = stage.attn[block_idx](h)
|
| 219 |
+
|
| 220 |
+
if level != self.num_resolutions - 1:
|
| 221 |
+
h = stage.downsample(h)
|
| 222 |
+
|
| 223 |
+
return h
|
| 224 |
+
|
| 225 |
+
def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:
|
| 226 |
+
h = self.norm_out(h)
|
| 227 |
+
h = self.non_linearity(h)
|
| 228 |
+
return self.conv_out(h)
|
| 229 |
+
|
| 230 |
+
def _normalize_latents(self, latent_output: torch.Tensor) -> torch.Tensor:
|
| 231 |
+
"""
|
| 232 |
+
Normalize encoder latents using per-channel statistics.
|
| 233 |
+
|
| 234 |
+
When the encoder is configured with ``double_z=True``, the final
|
| 235 |
+
convolution produces twice the number of latent channels, typically
|
| 236 |
+
interpreted as two concatenated tensors along the channel dimension
|
| 237 |
+
(e.g., mean and variance or other auxiliary parameters).
|
| 238 |
+
|
| 239 |
+
This method intentionally uses only the first half of the channels
|
| 240 |
+
(the "mean" component) as input to the patchifier and normalization
|
| 241 |
+
logic. The remaining channels are left unchanged by this method and
|
| 242 |
+
are expected to be consumed elsewhere in the VAE pipeline.
|
| 243 |
+
|
| 244 |
+
If ``double_z=False``, the encoder output already contains only the
|
| 245 |
+
mean latents and the chunking operation simply returns that tensor.
|
| 246 |
+
"""
|
| 247 |
+
means = torch.chunk(latent_output, 2, dim=1)[0]
|
| 248 |
+
latent_shape = AudioLatentShape(
|
| 249 |
+
batch=means.shape[0],
|
| 250 |
+
channels=means.shape[1],
|
| 251 |
+
frames=means.shape[2],
|
| 252 |
+
mel_bins=means.shape[3],
|
| 253 |
+
)
|
| 254 |
+
latent_patched = self.patchifier.patchify(means)
|
| 255 |
+
latent_normalized = self.per_channel_statistics.normalize(latent_patched)
|
| 256 |
+
return self.patchifier.unpatchify(latent_normalized, latent_shape)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class Decoder(torch.nn.Module):
|
| 260 |
+
"""
|
| 261 |
+
Symmetric decoder that reconstructs audio spectrograms from latent features.
|
| 262 |
+
|
| 263 |
+
The decoder mirrors the encoder structure with configurable channel multipliers,
|
| 264 |
+
attention resolutions, and causal convolutions.
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
def __init__( # noqa: PLR0913
|
| 268 |
+
self,
|
| 269 |
+
*,
|
| 270 |
+
ch: int,
|
| 271 |
+
out_ch: int,
|
| 272 |
+
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
| 273 |
+
num_res_blocks: int,
|
| 274 |
+
attn_resolutions: Set[int],
|
| 275 |
+
resolution: int,
|
| 276 |
+
z_channels: int,
|
| 277 |
+
norm_type: NormType = NormType.GROUP,
|
| 278 |
+
causality_axis: CausalityAxis = CausalityAxis.WIDTH,
|
| 279 |
+
dropout: float = 0.0,
|
| 280 |
+
mid_block_add_attention: bool = True,
|
| 281 |
+
sample_rate: int = 16000,
|
| 282 |
+
mel_hop_length: int = 160,
|
| 283 |
+
is_causal: bool = True,
|
| 284 |
+
mel_bins: int | None = None,
|
| 285 |
+
) -> None:
|
| 286 |
+
"""
|
| 287 |
+
Initialize the Decoder.
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
Arguments are configuration parameters, loaded from the audio VAE checkpoint config
|
| 291 |
+
(audio_vae.model.params.ddconfig):
|
| 292 |
+
- ch, out_ch, ch_mult, num_res_blocks, attn_resolutions
|
| 293 |
+
- resolution, z_channels
|
| 294 |
+
- norm_type, causality_axis
|
| 295 |
+
"""
|
| 296 |
+
super().__init__()
|
| 297 |
+
|
| 298 |
+
# Internal behavioural defaults that are not driven by the checkpoint.
|
| 299 |
+
resamp_with_conv = True
|
| 300 |
+
attn_type = AttentionType.VANILLA
|
| 301 |
+
|
| 302 |
+
# Per-channel statistics for denormalizing latents
|
| 303 |
+
self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
|
| 304 |
+
self.sample_rate = sample_rate
|
| 305 |
+
self.mel_hop_length = mel_hop_length
|
| 306 |
+
self.is_causal = is_causal
|
| 307 |
+
self.mel_bins = mel_bins
|
| 308 |
+
self.patchifier = AudioPatchifier(
|
| 309 |
+
patch_size=1,
|
| 310 |
+
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
| 311 |
+
sample_rate=sample_rate,
|
| 312 |
+
hop_length=mel_hop_length,
|
| 313 |
+
is_causal=is_causal,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
self.ch = ch
|
| 317 |
+
self.temb_ch = 0
|
| 318 |
+
self.num_resolutions = len(ch_mult)
|
| 319 |
+
self.num_res_blocks = num_res_blocks
|
| 320 |
+
self.resolution = resolution
|
| 321 |
+
self.out_ch = out_ch
|
| 322 |
+
self.give_pre_end = False
|
| 323 |
+
self.tanh_out = False
|
| 324 |
+
self.norm_type = norm_type
|
| 325 |
+
self.z_channels = z_channels
|
| 326 |
+
self.channel_multipliers = ch_mult
|
| 327 |
+
self.attn_resolutions = attn_resolutions
|
| 328 |
+
self.causality_axis = causality_axis
|
| 329 |
+
self.attn_type = attn_type
|
| 330 |
+
|
| 331 |
+
base_block_channels = ch * self.channel_multipliers[-1]
|
| 332 |
+
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
|
| 333 |
+
self.z_shape = (1, z_channels, base_resolution, base_resolution)
|
| 334 |
+
|
| 335 |
+
self.conv_in = make_conv2d(
|
| 336 |
+
z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
| 337 |
+
)
|
| 338 |
+
self.non_linearity = torch.nn.SiLU()
|
| 339 |
+
self.mid = build_mid_block(
|
| 340 |
+
channels=base_block_channels,
|
| 341 |
+
temb_channels=self.temb_ch,
|
| 342 |
+
dropout=dropout,
|
| 343 |
+
norm_type=self.norm_type,
|
| 344 |
+
causality_axis=self.causality_axis,
|
| 345 |
+
attn_type=self.attn_type,
|
| 346 |
+
add_attention=mid_block_add_attention,
|
| 347 |
+
)
|
| 348 |
+
self.up, final_block_channels = build_upsampling_path(
|
| 349 |
+
ch=ch,
|
| 350 |
+
ch_mult=ch_mult,
|
| 351 |
+
num_resolutions=self.num_resolutions,
|
| 352 |
+
num_res_blocks=num_res_blocks,
|
| 353 |
+
resolution=resolution,
|
| 354 |
+
temb_channels=self.temb_ch,
|
| 355 |
+
dropout=dropout,
|
| 356 |
+
norm_type=self.norm_type,
|
| 357 |
+
causality_axis=self.causality_axis,
|
| 358 |
+
attn_type=self.attn_type,
|
| 359 |
+
attn_resolutions=attn_resolutions,
|
| 360 |
+
resamp_with_conv=resamp_with_conv,
|
| 361 |
+
initial_block_channels=base_block_channels,
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
|
| 365 |
+
self.conv_out = make_conv2d(
|
| 366 |
+
final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
| 370 |
+
"""
|
| 371 |
+
Decode latent features back to audio spectrograms.
|
| 372 |
+
|
| 373 |
+
Args:
|
| 374 |
+
sample: Encoded latent representation of shape (batch, channels, frames, mel_bins)
|
| 375 |
+
|
| 376 |
+
Returns:
|
| 377 |
+
Reconstructed audio spectrogram of shape (batch, channels, time, frequency)
|
| 378 |
+
"""
|
| 379 |
+
sample, target_shape = self._denormalize_latents(sample)
|
| 380 |
+
|
| 381 |
+
h = self.conv_in(sample)
|
| 382 |
+
h = run_mid_block(self.mid, h)
|
| 383 |
+
h = self._run_upsampling_path(h)
|
| 384 |
+
h = self._finalize_output(h)
|
| 385 |
+
|
| 386 |
+
return self._adjust_output_shape(h, target_shape)
|
| 387 |
+
|
| 388 |
+
def _denormalize_latents(self, sample: torch.Tensor) -> tuple[torch.Tensor, AudioLatentShape]:
|
| 389 |
+
latent_shape = AudioLatentShape(
|
| 390 |
+
batch=sample.shape[0],
|
| 391 |
+
channels=sample.shape[1],
|
| 392 |
+
frames=sample.shape[2],
|
| 393 |
+
mel_bins=sample.shape[3],
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
sample_patched = self.patchifier.patchify(sample)
|
| 397 |
+
sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched)
|
| 398 |
+
sample = self.patchifier.unpatchify(sample_denormalized, latent_shape)
|
| 399 |
+
|
| 400 |
+
target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR
|
| 401 |
+
if self.causality_axis != CausalityAxis.NONE:
|
| 402 |
+
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
|
| 403 |
+
|
| 404 |
+
target_shape = AudioLatentShape(
|
| 405 |
+
batch=latent_shape.batch,
|
| 406 |
+
channels=self.out_ch,
|
| 407 |
+
frames=target_frames,
|
| 408 |
+
mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
return sample, target_shape
|
| 412 |
+
|
| 413 |
+
def _adjust_output_shape(
|
| 414 |
+
self,
|
| 415 |
+
decoded_output: torch.Tensor,
|
| 416 |
+
target_shape: AudioLatentShape,
|
| 417 |
+
) -> torch.Tensor:
|
| 418 |
+
"""
|
| 419 |
+
Adjust output shape to match target dimensions for variable-length audio.
|
| 420 |
+
|
| 421 |
+
This function handles the common case where decoded audio spectrograms need to be
|
| 422 |
+
resized to match a specific target shape.
|
| 423 |
+
|
| 424 |
+
Args:
|
| 425 |
+
decoded_output: Tensor of shape (batch, channels, time, frequency)
|
| 426 |
+
target_shape: AudioLatentShape describing (batch, channels, time, mel bins)
|
| 427 |
+
|
| 428 |
+
Returns:
|
| 429 |
+
Tensor adjusted to match target_shape exactly
|
| 430 |
+
"""
|
| 431 |
+
# Current output shape: (batch, channels, time, frequency)
|
| 432 |
+
_, _, current_time, current_freq = decoded_output.shape
|
| 433 |
+
target_channels = target_shape.channels
|
| 434 |
+
target_time = target_shape.frames
|
| 435 |
+
target_freq = target_shape.mel_bins
|
| 436 |
+
|
| 437 |
+
# Step 1: Crop first to avoid exceeding target dimensions
|
| 438 |
+
decoded_output = decoded_output[
|
| 439 |
+
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
|
| 440 |
+
]
|
| 441 |
+
|
| 442 |
+
# Step 2: Calculate padding needed for time and frequency dimensions
|
| 443 |
+
time_padding_needed = target_time - decoded_output.shape[2]
|
| 444 |
+
freq_padding_needed = target_freq - decoded_output.shape[3]
|
| 445 |
+
|
| 446 |
+
# Step 3: Apply padding if needed
|
| 447 |
+
if time_padding_needed > 0 or freq_padding_needed > 0:
|
| 448 |
+
# PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)
|
| 449 |
+
# For audio: pad_left/right = frequency, pad_top/bottom = time
|
| 450 |
+
padding = (
|
| 451 |
+
0,
|
| 452 |
+
max(freq_padding_needed, 0), # frequency padding (left, right)
|
| 453 |
+
0,
|
| 454 |
+
max(time_padding_needed, 0), # time padding (top, bottom)
|
| 455 |
+
)
|
| 456 |
+
decoded_output = F.pad(decoded_output, padding)
|
| 457 |
+
|
| 458 |
+
# Step 4: Final safety crop to ensure exact target shape
|
| 459 |
+
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
|
| 460 |
+
|
| 461 |
+
return decoded_output
|
| 462 |
+
|
| 463 |
+
def _run_upsampling_path(self, h: torch.Tensor) -> torch.Tensor:
|
| 464 |
+
for level in reversed(range(self.num_resolutions)):
|
| 465 |
+
stage = self.up[level]
|
| 466 |
+
for block_idx, block in enumerate(stage.block):
|
| 467 |
+
h = block(h, temb=None)
|
| 468 |
+
if stage.attn:
|
| 469 |
+
h = stage.attn[block_idx](h)
|
| 470 |
+
|
| 471 |
+
if level != 0 and hasattr(stage, "upsample"):
|
| 472 |
+
h = stage.upsample(h)
|
| 473 |
+
|
| 474 |
+
return h
|
| 475 |
+
|
| 476 |
+
def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:
|
| 477 |
+
if self.give_pre_end:
|
| 478 |
+
return h
|
| 479 |
+
|
| 480 |
+
h = self.norm_out(h)
|
| 481 |
+
h = self.non_linearity(h)
|
| 482 |
+
h = self.conv_out(h)
|
| 483 |
+
return torch.tanh(h) if self.tanh_out else h
|