diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..86464063b49745f267f0d70d4b82ca29a6b9258b --- /dev/null +++ b/app.py @@ -0,0 +1,301 @@ +""" +Simple Gradio app for LTX-2 inference based on ltx2_two_stage.py example +""" + +import sys +from pathlib import Path + +# Add packages to Python path +current_dir = Path(__file__).parent +sys.path.insert(0, str(current_dir / "packages" / "ltx-pipelines" / "src")) +sys.path.insert(0, str(current_dir / "packages" / "ltx-core" / "src")) + +import gradio as gr +from typing import Optional +from huggingface_hub import hf_hub_download +from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline +from ltx_core.tiling import TilingConfig +from ltx_pipelines.constants import ( + DEFAULT_SEED, + DEFAULT_HEIGHT, + DEFAULT_WIDTH, + DEFAULT_NUM_FRAMES, + DEFAULT_FRAME_RATE, + DEFAULT_NUM_INFERENCE_STEPS, + DEFAULT_CFG_GUIDANCE_SCALE, + DEFAULT_LORA_STRENGTH, +) + +# Custom negative prompt +DEFAULT_NEGATIVE_PROMPT = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static" + +# Default prompt from docstring example +DEFAULT_PROMPT = "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a breath-taking, movie-like shot." + +# HuggingFace Hub defaults +DEFAULT_REPO_ID = "LTX-Colab/LTX-Video-Preview" +DEFAULT_GEMMA_REPO_ID = "google/gemma-3-12b-it-qat-q4_0-unquantized" +DEFAULT_CHECKPOINT_FILENAME = "ltx-2-19b-dev-rc1.safetensors" +DEFAULT_DISTILLED_LORA_FILENAME = "ltx-2-19b-distilled-lora-384-rc1.safetensors" +DEFAULT_SPATIAL_UPSAMPLER_FILENAME = "ltx-2-spatial-upscaler-x2-1.0-rc1.safetensors" + +def get_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None): + """Download from HuggingFace Hub or use local checkpoint.""" + if repo_id is None and filename is None: + raise ValueError("Please supply at least one of `repo_id` or `filename`") + + if repo_id is not None: + if filename is None: + raise ValueError("If repo_id is specified, filename must also be specified.") + print(f"Downloading {filename} from {repo_id}...") + ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) + print(f"Downloaded to {ckpt_path}") + else: + ckpt_path = filename + + return ckpt_path + + +# Initialize pipeline at startup +print("=" * 80) +print("Loading LTX-2 2-stage pipeline...") +print("=" * 80) + +checkpoint_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_CHECKPOINT_FILENAME) +distilled_lora_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_DISTILLED_LORA_FILENAME) +spatial_upsampler_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_SPATIAL_UPSAMPLER_FILENAME) + +print(f"Initializing pipeline with:") +print(f" checkpoint_path={checkpoint_path}") +print(f" distilled_lora_path={distilled_lora_path}") +print(f" spatial_upsampler_path={spatial_upsampler_path}") +print(f" gemma_root={DEFAULT_GEMMA_REPO_ID}") + +pipeline = TI2VidTwoStagesPipeline( + checkpoint_path=checkpoint_path, + distilled_lora_path=distilled_lora_path, + distilled_lora_strength=DEFAULT_LORA_STRENGTH, + spatial_upsampler_path=spatial_upsampler_path, + gemma_root=DEFAULT_GEMMA_REPO_ID, + loras=[], + fp8transformer=False, + local_files_only=False +) + +print("=" * 80) +print("Warming up pipeline (loading Gemma text encoder)...") +print("=" * 80) + +# Do a dummy warmup to load all models including Gemma +import tempfile +import os +warmup_output = tempfile.mktemp(suffix=".mp4") +try: + pipeline( + prompt="warmup", + negative_prompt="", + output_path=warmup_output, + seed=42, + height=256, + width=256, + num_frames=9, + frame_rate=8, + num_inference_steps=1, + cfg_guidance_scale=1.0, + images=[], + tiling_config=TilingConfig.default(), + ) + # Clean up warmup output + if os.path.exists(warmup_output): + os.remove(warmup_output) +except Exception as e: + print(f"Warmup completed with note: {e}") + +print("=" * 80) +print("Pipeline fully loaded and ready!") +print("=" * 80) + + +def generate_video( + input_image, + prompt: str, + duration: float, + negative_prompt: str, + seed: int, + randomize_seed: bool, + num_inference_steps: int, + cfg_guidance_scale: float, + height: int, + width: int, + progress=gr.Progress() +): + """Generate a video based on the given parameters.""" + try: + # Randomize seed if checkbox is enabled + if randomize_seed: + import random + seed = random.randint(0, 1000000) + + # Calculate num_frames from duration (using fixed 24 fps) + frame_rate = 24.0 + num_frames = int(duration * frame_rate) + 1 # +1 to ensure we meet the duration + + # Create output directory if it doesn't exist + output_dir = Path("outputs") + output_dir.mkdir(exist_ok=True) + output_path = output_dir / f"video_{seed}.mp4" + + # Handle image input + images = [] + if input_image is not None: + # Save uploaded image temporarily + temp_image_path = output_dir / f"temp_input_{seed}.jpg" + if hasattr(input_image, 'save'): + input_image.save(temp_image_path) + else: + # If it's a file path already + temp_image_path = input_image + # Format: (image_path, frame_idx, strength) + images = [(str(temp_image_path), 0, 1.0)] + + # Run inference + progress(0, desc="Generating video (2-stage)...") + pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + output_path=str(output_path), + seed=seed, + height=height, + width=width, + num_frames=num_frames, + frame_rate=frame_rate, + num_inference_steps=num_inference_steps, + cfg_guidance_scale=cfg_guidance_scale, + images=images, + tiling_config=TilingConfig.default(), + ) + + progress(1.0, desc="Done!") + return str(output_path) + + except Exception as e: + import traceback + error_msg = f"Error: {str(e)}\n{traceback.format_exc()}" + print(error_msg) + return None + + +# Create Gradio interface +with gr.Blocks(title="LTX-2 Image-to-Video") as demo: + gr.Markdown("# LTX-2 Image-to-Video Generation") + gr.Markdown("Transform images into videos using the LTX-2 2-stage pipeline") + + with gr.Row(): + with gr.Column(): + input_image = gr.Image( + label="Input Image", + type="pil", + sources=["upload"] + ) + + prompt = gr.Textbox( + label="Prompt", + value="Make this image come alive with cinematic motion, smooth animation", + lines=3, + placeholder="Describe the motion and animation you want..." + ) + + duration = gr.Slider( + label="Duration (seconds)", + minimum=1.0, + maximum=10.0, + value=5.0, + step=0.1 + ) + + generate_btn = gr.Button("Generate Video", variant="primary", size="lg") + + with gr.Accordion("Advanced Settings", open=False): + negative_prompt = gr.Textbox( + label="Negative Prompt", + value=DEFAULT_NEGATIVE_PROMPT, + lines=2 + ) + + seed = gr.Slider( + label="Seed", + minimum=0, + maximum=1000000, + value=DEFAULT_SEED, + step=1 + ) + + randomize_seed = gr.Checkbox( + label="Randomize Seed", + value=True + ) + + num_inference_steps = gr.Slider( + label="Inference Steps", + minimum=1, + maximum=100, + value=DEFAULT_NUM_INFERENCE_STEPS, + step=1 + ) + + cfg_guidance_scale = gr.Slider( + label="CFG Guidance Scale", + minimum=1.0, + maximum=10.0, + value=DEFAULT_CFG_GUIDANCE_SCALE, + step=0.1 + ) + + with gr.Row(): + width = gr.Number( + label="Width", + value=DEFAULT_WIDTH, + precision=0 + ) + height = gr.Number( + label="Height", + value=DEFAULT_HEIGHT, + precision=0 + ) + + with gr.Column(): + output_video = gr.Video(label="Generated Video", autoplay=True) + + generate_btn.click( + fn=generate_video, + inputs=[ + input_image, + prompt, + duration, + negative_prompt, + seed, + randomize_seed, + num_inference_steps, + cfg_guidance_scale, + height, + width, + ], + outputs=output_video + ) + + # Add example + gr.Examples( + examples=[ + [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg", + "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a breath-taking, movie-like shot.", + 5.0, + ] + ], + inputs=[input_image, prompt, duration], + label="Example" + ) + + +if __name__ == "__main__": + demo.launch(share=True) diff --git a/packages/ltx-core/README.md b/packages/ltx-core/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a53fbe0af49abce26d921ee9112dcaeb9eb5af63 --- /dev/null +++ b/packages/ltx-core/README.md @@ -0,0 +1 @@ +# LTX-2 Core diff --git a/packages/ltx-core/pyproject.toml b/packages/ltx-core/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..6aeef433213f89e3538726486ed16f329e17162d --- /dev/null +++ b/packages/ltx-core/pyproject.toml @@ -0,0 +1,38 @@ +[project] +name = "ltx-core" +version = "0.1.0" +description = "Core implementation of Lightricks' LTX-2 model" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "torch~=2.7", + "torchaudio", + "einops", + "numpy", + "transformers", + "safetensors", + "accelerate", + "scipy>=1.14", +] + +[project.optional-dependencies] +flashpack = ["flashpack==0.1.2"] +xformers = ["xformers"] + + +[tool.uv.sources] +xformers = { index = "pytorch" } + +[[tool.uv.index]] +name = "pytorch" +url = "https://download.pytorch.org/whl/cu129" +explicit = true + +[build-system] +requires = ["uv_build>=0.9.8,<0.10.0"] +build-backend = "uv_build" + +[dependency-groups] +dev = [ + "scikit-image>=0.25.2", +] diff --git a/packages/ltx-core/src/ltx_core/__init__.py b/packages/ltx-core/src/ltx_core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/packages/ltx-core/src/ltx_core/__pycache__/__init__.cpython-310.pyc b/packages/ltx-core/src/ltx_core/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c454086a82a0be1a880bab44a524d0b4a0781201 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/__pycache__/__init__.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/__pycache__/tiling.cpython-310.pyc b/packages/ltx-core/src/ltx_core/__pycache__/tiling.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b26f2567cc4e00c6d93837d518d5610e6132529 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/__pycache__/tiling.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/__pycache__/utils.cpython-310.pyc b/packages/ltx-core/src/ltx_core/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae6569df9ca1a6ff8dde4f94925c1075489fad80 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/__pycache__/utils.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/guidance/__init__.py b/packages/ltx-core/src/ltx_core/guidance/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/packages/ltx-core/src/ltx_core/guidance/__pycache__/__init__.cpython-310.pyc b/packages/ltx-core/src/ltx_core/guidance/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8af91c74e4528480af7d35a2f3b895398f405331 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/guidance/__pycache__/__init__.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/guidance/__pycache__/perturbations.cpython-310.pyc b/packages/ltx-core/src/ltx_core/guidance/__pycache__/perturbations.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48660639912b96de11460de01840080cc57062e1 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/guidance/__pycache__/perturbations.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/guidance/perturbations.py b/packages/ltx-core/src/ltx_core/guidance/perturbations.py new file mode 100644 index 0000000000000000000000000000000000000000..8c46f78ba9bc524cf1a35f6979b0f392fe68d6a1 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/guidance/perturbations.py @@ -0,0 +1,74 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + +from dataclasses import dataclass +from enum import Enum + +import torch +from torch._prims_common import DeviceLikeType + + +class PerturbationType(Enum): + SKIP_A2V_CROSS_ATTN = "skip_a2v_cross_attn" + SKIP_V2A_CROSS_ATTN = "skip_v2a_cross_attn" + SKIP_VIDEO_SELF_ATTN = "skip_video_self_attn" + SKIP_AUDIO_SELF_ATTN = "skip_audio_self_attn" + + +@dataclass(frozen=True) +class Perturbation: + type: PerturbationType + blocks: list[int] | None # None means all blocks + + def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool: + if self.type != perturbation_type: + return False + + if self.blocks is None: + return True + + return block in self.blocks + + +@dataclass(frozen=True) +class PerturbationConfig: + perturbations: list[Perturbation] | None + + def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool: + if self.perturbations is None: + return False + + return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations) + + @staticmethod + def empty() -> "PerturbationConfig": + return PerturbationConfig([]) + + +@dataclass(frozen=True) +class BatchedPerturbationConfig: + perturbations: list[PerturbationConfig] + + def mask( + self, perturbation_type: PerturbationType, block: int, device: DeviceLikeType, dtype: torch.dtype + ) -> torch.Tensor: + mask = torch.ones((len(self.perturbations),), device=device, dtype=dtype) + for batch_idx, perturbation in enumerate(self.perturbations): + if perturbation.is_perturbed(perturbation_type, block): + mask[batch_idx] = 0 + + return mask + + def mask_like(self, perturbation_type: PerturbationType, block: int, values: torch.Tensor) -> torch.Tensor: + mask = self.mask(perturbation_type, block, values.device, values.dtype) + return mask.view(mask.numel(), *([1] * len(values.shape[1:]))) + + def any_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool: + return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations) + + def all_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool: + return all(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations) + + @staticmethod + def empty(batch_size: int) -> "BatchedPerturbationConfig": + return BatchedPerturbationConfig([PerturbationConfig.empty() for _ in range(batch_size)]) diff --git a/packages/ltx-core/src/ltx_core/legacy_tiling.py b/packages/ltx-core/src/ltx_core/legacy_tiling.py new file mode 100644 index 0000000000000000000000000000000000000000..cec9c37a6f644b772519c2a4b948b1aa9cc3ac5c --- /dev/null +++ b/packages/ltx-core/src/ltx_core/legacy_tiling.py @@ -0,0 +1,258 @@ +import logging +from collections.abc import Generator + +import torch + +from ltx_core.model.video_vae.video_vae import Decoder + + +def compute_chunk_boundaries( + chunk_start: int, + temporal_tile_length: int, + temporal_overlap: int, + total_latent_frames: int, +) -> tuple[int, int]: + """Compute chunk boundaries for temporal tiling. + + Args: + chunk_start: Starting frame index for the current chunk + temporal_tile_length: Length of each temporal tile + temporal_overlap: Number of frames to overlap between chunks + total_latent_frames: Total number of latent frames + + Returns: + Tuple of (overlap_start, chunk_end) + """ + if chunk_start == 0: + # First chunk: no overlap needed + chunk_end = min(chunk_start + temporal_tile_length, total_latent_frames) + overlap_start = chunk_start + else: + # Subsequent chunks: include overlap from previous chunk + # -1 because we need one extra frame to overlap, which is decoded to a single frame + # never overlap with the first latent frame + overlap_start = max(1, chunk_start - temporal_overlap - 1) + extra_frames = chunk_start - overlap_start + chunk_end = min( + chunk_start + temporal_tile_length - extra_frames, + total_latent_frames, + ) + + return overlap_start, chunk_end + + +def spatial_decode( # noqa + decoder: Decoder, + samples: torch.Tensor, + horizontal_tiles: int, + vertical_tiles: int, + overlap: int, + last_frame_fix: bool, + scale_factors: tuple[float, float, float], + timestep: float, + generator: torch.Generator, +) -> torch.Tensor: + if last_frame_fix: + # Repeat the last frame along dimension 2 (frames) + # samples shape - [batch, channels, frames, height, width] + last_frame = samples[:, :, -1:, :, :] + samples = torch.cat([samples, last_frame], dim=2) + + batch, _, frames, height, width = samples.shape + time_scale_factor, width_scale_factor, height_scale_factor = scale_factors + image_frames = 1 + (frames - 1) * time_scale_factor + + # Calculate output image dimensions + output_height = height * height_scale_factor + output_width = width * width_scale_factor + + # Calculate tile sizes with overlap + base_tile_height = (height + (vertical_tiles - 1) * overlap) // vertical_tiles + base_tile_width = (width + (horizontal_tiles - 1) * overlap) // horizontal_tiles + + # Initialize output tensor and weight tensor + # VAE decode returns images in format [batch, height, width, channels] + output = None + weights = None + + target_device = samples.device + target_dtype = samples.dtype + + output = torch.zeros( + ( + batch, + 3, + image_frames, + output_height, + output_width, + ), + device=target_device, + dtype=target_dtype, + ) + weights = torch.zeros( + (batch, 1, image_frames, output_height, output_width), + device=target_device, + dtype=target_dtype, + ) + + # Process each tile + for v in range(vertical_tiles): + for h in range(horizontal_tiles): + # Calculate tile boundaries + h_start = h * (base_tile_width - overlap) + v_start = v * (base_tile_height - overlap) + + # Adjust end positions for edge tiles + h_end = min(h_start + base_tile_width, width) if h < horizontal_tiles - 1 else width + v_end = min(v_start + base_tile_height, height) if v < vertical_tiles - 1 else height + + # Calculate actual tile dimensions + tile_height = v_end - v_start + tile_width = h_end - h_start + + logging.info(f"Processing VAE decode tile at row {v}, col {h}:") + logging.info(f" Position: ({v_start}:{v_end}, {h_start}:{h_end})") + logging.info(f" Size: {tile_height}x{tile_width}") + + # Extract tile + tile = samples[:, :, :, v_start:v_end, h_start:h_end] + + # Decode the tile + decoded_tile = decoder.decode(tile, timestep, generator) + + # Calculate output tile boundaries + out_h_start = v_start * height_scale_factor + out_h_end = v_end * height_scale_factor + out_w_start = h_start * width_scale_factor + out_w_end = h_end * width_scale_factor + + # Create weight mask for this tile + tile_out_height = out_h_end - out_h_start + tile_out_width = out_w_end - out_w_start + tile_weights = torch.ones( + (batch, 1, image_frames, tile_out_height, tile_out_width), + device=decoded_tile.device, + dtype=decoded_tile.dtype, + ) + + # Calculate overlap regions in output space + overlap_out_h = overlap * height_scale_factor + overlap_out_w = overlap * width_scale_factor + + # Apply horizontal blending weights + if h > 0: # Left overlap + h_blend = torch.linspace(0, 1, overlap_out_w, device=decoded_tile.device) + tile_weights[:, :, :, :, :overlap_out_w] *= h_blend + if h < horizontal_tiles - 1: # Right overlap + h_blend = torch.linspace(1, 0, overlap_out_w, device=decoded_tile.device) + tile_weights[:, :, :, :, -overlap_out_w:] *= h_blend + + # Apply vertical blending weights + if v > 0: # Top overlap + v_blend = torch.linspace(0, 1, overlap_out_h, device=decoded_tile.device) + tile_weights[:, :, :, :overlap_out_h, :] *= v_blend.view(1, 1, 1, -1, 1) + if v < vertical_tiles - 1: # Bottom overlap + v_blend = torch.linspace(1, 0, overlap_out_h, device=decoded_tile.device) + tile_weights[:, :, :, -overlap_out_h:, :] *= v_blend.view(1, 1, 1, -1, 1) + + # Add weighted tile to output + output[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += (decoded_tile * tile_weights).to( + target_device, target_dtype + ) + + # Add weights to weight tensor + weights[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += tile_weights.to( + target_device, target_dtype + ) + + # Normalize by weights + output /= weights + 1e-8 + # LT_INTERNAL: changed from output[:-time_scale_factor, :, :]! + if last_frame_fix: + output = output[:, :, :-time_scale_factor, :, :] + + return output + + +def decode_spatial_temporal( + decoder: Decoder, + samples: torch.ensor, + timestep: float, + generator: torch.Generator, + scale_factors: tuple[float, float, float], + spatial_tiles: int = 4, + spatial_overlap: int = 1, + temporal_tile_length: int = 16, + temporal_overlap: int = 1, + last_frame_fix: bool = False, +) -> Generator[torch.Tensor, None, None]: + if temporal_tile_length < temporal_overlap + 1: + raise ValueError("Temporal tile length must be greater than temporal overlap + 1") + + _, _, frames, _, _ = samples.shape + time_scale_factor, _, _ = scale_factors + + # Process temporal chunks similar to reference function + total_latent_frames = frames + chunk_start = 0 + + previous_tile = None + while chunk_start < total_latent_frames: + # Calculate chunk boundaries + overlap_start, chunk_end = compute_chunk_boundaries( + chunk_start, temporal_tile_length, temporal_overlap, total_latent_frames + ) + + # units are latent frames + chunk_frames = chunk_end - overlap_start + logging.info(f"Processing temporal chunk: {overlap_start}:{chunk_end} ({chunk_frames} latent frames)") + + # Extract tile + tile = samples[:, :, overlap_start:chunk_end] + + # Decode the tile + decoded_tile = spatial_decode( + decoder, + tile, + spatial_tiles, + spatial_tiles, + spatial_overlap, + last_frame_fix, + scale_factors, + timestep, + generator, + ) + + if previous_tile is None: + previous_tile = decoded_tile + else: + # Drop first frame if needed (overlap) + if decoded_tile.shape[2] == 1: + raise ValueError("Dropping first frame but tile has only 1 frame") + decoded_tile = decoded_tile[:, :, 1:] # Drop first frame + + # Create weight mask for this tile + # -1 is for dropped frame above + overlap_frames = temporal_overlap * time_scale_factor + frame_weights = torch.linspace( + 0, + 1, + overlap_frames + 2, + device=decoded_tile.device, + dtype=decoded_tile.dtype, + )[1:-1] + tile_weights = frame_weights.view(1, 1, -1, 1, 1) + + previous_tile[:, :, -overlap_frames:] = ( + previous_tile[:, :, -overlap_frames:] * (1 - tile_weights) + + decoded_tile[:, :, :overlap_frames] * tile_weights + ) + resulting_tile = previous_tile[:, :, :-overlap_frames] + decoded_tile[:, :, :overlap_frames] = previous_tile[:, :, -overlap_frames:] + yield resulting_tile + previous_tile = decoded_tile + + # Move to next chunk + chunk_start = chunk_end + + yield decoded_tile diff --git a/packages/ltx-core/src/ltx_core/loader/.ipynb_checkpoints/sd_ops-checkpoint.py b/packages/ltx-core/src/ltx_core/loader/.ipynb_checkpoints/sd_ops-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..a48b3683cff367e7528e11c30703cd9ece1e55ef --- /dev/null +++ b/packages/ltx-core/src/ltx_core/loader/.ipynb_checkpoints/sd_ops-checkpoint.py @@ -0,0 +1,107 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Alexey Kravtsov + +from dataclasses import dataclass, replace +#from typing import NamedTuple, Protocol, Self +from typing import NamedTuple, Protocol +from typing_extensions import Self + +import torch + + +@dataclass(frozen=True, slots=True) +class ContentReplacement: + content: str + replacement: str + + +@dataclass(frozen=True, slots=True) +class ContentMatching: + prefix: str = "" + suffix: str = "" + + +class KeyValueOperationResult(NamedTuple): + new_key: str + new_value: torch.Tensor + + +class KeyValueOperation(Protocol): + def __call__(self, tensor_key: str, tensor_value: torch.Tensor) -> list[KeyValueOperationResult]: ... + + +@dataclass(frozen=True, slots=True) +class SDKeyValueOperation: + key_matcher: ContentMatching + kv_operation: KeyValueOperation + + +@dataclass(frozen=True, slots=True) +class SDOps: + """Immutable class representing state dict key operations.""" + + name: str + mapping: tuple[ + ContentReplacement | ContentMatching | SDKeyValueOperation, ... + ] = () # Immutable tuple of (key, value) pairs + + def with_replacement(self, content: str, replacement: str) -> Self: + """Create a new SDOps instance with the specified replacement added to the mapping.""" + + new_mapping = (*self.mapping, ContentReplacement(content, replacement)) + return replace(self, mapping=new_mapping) + + def with_matching(self, prefix: str = "", suffix: str = "") -> Self: + """Create a new SDOps instance with the specified prefix and suffix matching added to the mapping.""" + + new_mapping = (*self.mapping, ContentMatching(prefix, suffix)) + return replace(self, mapping=new_mapping) + + def with_kv_operation( + self, + operation: KeyValueOperation, + key_prefix: str = "", + key_suffix: str = "", + ) -> Self: + """Create a new SDOps instance with the specified value operation added to the mapping.""" + key_matcher = ContentMatching(key_prefix, key_suffix) + sd_kv_operation = SDKeyValueOperation(key_matcher, operation) + new_mapping = (*self.mapping, sd_kv_operation) + return replace(self, mapping=new_mapping) + + def apply_to_key(self, key: str) -> str | None: + """Apply the mapping to the given name.""" + matchers = [content for content in self.mapping if isinstance(content, ContentMatching)] + valid = any(key.startswith(f.prefix) and key.endswith(f.suffix) for f in matchers) + if not valid: + return None + + for replacement in self.mapping: + if not isinstance(replacement, ContentReplacement): + continue + if replacement.content in key: + key = key.replace(replacement.content, replacement.replacement) + return key + + def apply_to_key_value(self, key: str, value: torch.Tensor) -> list[KeyValueOperationResult]: + """Apply the value operation to the given name and associated value.""" + for operation in self.mapping: + if not isinstance(operation, SDKeyValueOperation): + continue + if key.startswith(operation.key_matcher.prefix) and key.endswith(operation.key_matcher.suffix): + return operation.kv_operation(key, value) + return [KeyValueOperationResult(key, value)] + + +# Predefined SDOps instances +LTXV_LORA_COMFY_RENAMING_MAP = ( + SDOps("LTXV_LORA_COMFY_PREFIX_MAP").with_matching().with_replacement("diffusion_model.", "") +) + +LTXV_LORA_COMFY_TARGET_MAP = ( + SDOps("LTXV_LORA_COMFY_TARGET_MAP") + .with_matching() + .with_replacement("diffusion_model.", "") + .with_replacement(".lora_A.weight", ".weight") + .with_replacement(".lora_B.weight", ".weight") +) diff --git a/packages/ltx-core/src/ltx_core/loader/__init__.py b/packages/ltx-core/src/ltx_core/loader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/packages/ltx-core/src/ltx_core/loader/__pycache__/__init__.cpython-310.pyc b/packages/ltx-core/src/ltx_core/loader/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1da4dfeed43f52b69497bd9b3d0f7c418030ef7d Binary files /dev/null and b/packages/ltx-core/src/ltx_core/loader/__pycache__/__init__.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/loader/__pycache__/fuse_loras.cpython-310.pyc b/packages/ltx-core/src/ltx_core/loader/__pycache__/fuse_loras.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f39b92b5b9ab2ed79ef696efda540abe16170bf Binary files /dev/null and b/packages/ltx-core/src/ltx_core/loader/__pycache__/fuse_loras.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/loader/__pycache__/kernels.cpython-310.pyc b/packages/ltx-core/src/ltx_core/loader/__pycache__/kernels.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..341e4a0f3573d866cfb10ac081d0bea8c66efd0c Binary files /dev/null and b/packages/ltx-core/src/ltx_core/loader/__pycache__/kernels.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/loader/__pycache__/module_ops.cpython-310.pyc b/packages/ltx-core/src/ltx_core/loader/__pycache__/module_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d92d432b4a4dd90a0726ec880a196790860fe867 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/loader/__pycache__/module_ops.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/loader/__pycache__/primitives.cpython-310.pyc b/packages/ltx-core/src/ltx_core/loader/__pycache__/primitives.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfe093fc61d1187f05820b67770e7e173c65ccd4 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/loader/__pycache__/primitives.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/loader/__pycache__/registry.cpython-310.pyc b/packages/ltx-core/src/ltx_core/loader/__pycache__/registry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b9b4d4e22f3fa4ca7ea3f096095a09dd37ba74c Binary files /dev/null and b/packages/ltx-core/src/ltx_core/loader/__pycache__/registry.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/loader/__pycache__/sd_ops.cpython-310.pyc b/packages/ltx-core/src/ltx_core/loader/__pycache__/sd_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..880d26d47319a5c63cb62d989e78ca3b449210c0 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/loader/__pycache__/sd_ops.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/loader/__pycache__/sft_loader.cpython-310.pyc b/packages/ltx-core/src/ltx_core/loader/__pycache__/sft_loader.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df6a72e748cb0c4f7de40b46626d499ff874440b Binary files /dev/null and b/packages/ltx-core/src/ltx_core/loader/__pycache__/sft_loader.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/loader/__pycache__/single_gpu_model_builder.cpython-310.pyc b/packages/ltx-core/src/ltx_core/loader/__pycache__/single_gpu_model_builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a413c54b6e6b70da2dfcd9a988879eda1fba7fab Binary files /dev/null and b/packages/ltx-core/src/ltx_core/loader/__pycache__/single_gpu_model_builder.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/loader/fuse_loras.py b/packages/ltx-core/src/ltx_core/loader/fuse_loras.py new file mode 100644 index 0000000000000000000000000000000000000000..2588e78fa845c40fcb1147b9c52fc063fc8c66fa --- /dev/null +++ b/packages/ltx-core/src/ltx_core/loader/fuse_loras.py @@ -0,0 +1,102 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Alexey Kravtsov +import torch +import triton + +from ltx_core.loader.kernels import fused_add_round_kernel +from ltx_core.loader.primitives import LoraStateDictWithStrength, StateDict + +BLOCK_SIZE = 1024 + + +def fused_add_round_launch(target_weight: torch.Tensor, original_weight: torch.Tensor, seed: int) -> torch.Tensor: + if original_weight.dtype == torch.float8_e4m3fn: + exponent_bits, mantissa_bits, exponent_bias = 4, 3, 7 + elif original_weight.dtype == torch.float8_e5m2: + exponent_bits, mantissa_bits, exponent_bias = 5, 2, 15 # noqa: F841 + else: + raise ValueError("Unsupported dtype") + + if target_weight.dtype != torch.bfloat16: + raise ValueError("target_weight dtype must be bfloat16") + + # Calculate grid and block sizes + n_elements = original_weight.numel() + grid = (triton.cdiv(n_elements, BLOCK_SIZE),) + + # Launch kernel + fused_add_round_kernel[grid]( + original_weight, + target_weight, + seed, + n_elements, + exponent_bias, + mantissa_bits, + BLOCK_SIZE, + ) + return target_weight + + +def calculate_weight_float8_(target_weights: torch.Tensor, original_weights: torch.Tensor) -> torch.Tensor: + result = fused_add_round_launch(target_weights, original_weights, seed=0).to(target_weights.dtype) + target_weights.copy_(result, non_blocking=True) + return target_weights + + +def _prepare_deltas( + lora_sd_and_strengths: list[LoraStateDictWithStrength], key: str, dtype: torch.dtype, device: torch.device +) -> torch.Tensor | None: + deltas = [] + prefix = key[: -len(".weight")] + key_a = f"{prefix}.lora_A.weight" + key_b = f"{prefix}.lora_B.weight" + for lsd, coef in lora_sd_and_strengths: + if key_a not in lsd.sd or key_b not in lsd.sd: + continue + product = torch.matmul(lsd.sd[key_b] * coef, lsd.sd[key_a]) + deltas.append(product.to(dtype=dtype, device=device)) + if len(deltas) == 0: + return None + elif len(deltas) == 1: + return deltas[0] + return torch.sum(torch.stack(deltas, dim=0), dim=0) + + +def apply_loras( + model_sd: StateDict, + lora_sd_and_strengths: list[LoraStateDictWithStrength], + dtype: torch.dtype, + destination_sd: StateDict | None = None, +) -> StateDict: + sd = {} + if destination_sd is not None: + sd = destination_sd.sd + size = 0 + device = torch.device("meta") + inner_dtypes = set() + for key, weight in model_sd.sd.items(): + if weight is None: + continue + device = weight.device + target_dtype = dtype if dtype is not None else weight.dtype + deltas_dtype = target_dtype if target_dtype not in [torch.float8_e4m3fn, torch.float8_e5m2] else torch.bfloat16 + deltas = _prepare_deltas(lora_sd_and_strengths, key, deltas_dtype, device) + if deltas is None: + if key in sd: + continue + deltas = weight.clone().to(dtype=target_dtype, device=device) + elif weight.dtype == torch.float8_e4m3fn: + if str(device).startswith("cuda"): + deltas = calculate_weight_float8_(deltas, weight) + else: + deltas.add_(weight.to(dtype=deltas.dtype, device=device)) + elif weight.dtype == torch.bfloat16: + deltas.add_(weight) + else: + raise ValueError(f"Unsupported dtype: {weight.dtype}") + sd[key] = deltas.to(dtype=target_dtype) + inner_dtypes.add(target_dtype) + size += deltas.nbytes + if destination_sd is not None: + return destination_sd + return StateDict(sd, device, size, inner_dtypes) diff --git a/packages/ltx-core/src/ltx_core/loader/kernels.py b/packages/ltx-core/src/ltx_core/loader/kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..adaeee6e4e04563b1fd605787f2fc63aec483fb1 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/loader/kernels.py @@ -0,0 +1,74 @@ +# ruff: noqa: ANN001, ANN201, ERA001, N803, N806 +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Alexey Kravtsov +import triton +import triton.language as tl + + +@triton.jit +def fused_add_round_kernel( + x_ptr, + output_ptr, # contents will be added to the output + seed, + n_elements, + EXPONENT_BIAS, + MANTISSA_BITS, + BLOCK_SIZE: tl.constexpr, +): + """ + A kernel to upcast 8bit quantized weights to bfloat16 with stochastic rounding + and add them to bfloat16 output weights. Might be used to upcast original model weights + and to further add them to precalculated deltas coming from LoRAs. + """ + # Get program ID and compute offsets + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + # Load data + x = tl.load(x_ptr + offsets, mask=mask) + rand_vals = tl.rand(seed, offsets) - 0.5 + + x = tl.cast(x, tl.float16) + delta = tl.load(output_ptr + offsets, mask=mask) + delta = tl.cast(delta, tl.float16) + x = x + delta + + x_bits = tl.cast(x, tl.int16, bitcast=True) + + # Calculate the exponent. Unbiased fp16 exponent is ((x_bits & 0x7C00) >> 10) - 15 for + # normal numbers and -14 for subnormals. + fp16_exponent_bits = (x_bits & 0x7C00) >> 10 + fp16_normals = fp16_exponent_bits > 0 + fp16_exponent = tl.where(fp16_normals, fp16_exponent_bits - 15, -14) + + # Add the target dtype's exponent bias and clamp to the target dtype's exponent range. + exponent = fp16_exponent + EXPONENT_BIAS + MAX_EXPONENT = 2 * EXPONENT_BIAS + 1 + exponent = tl.where(exponent > MAX_EXPONENT, MAX_EXPONENT, exponent) + exponent = tl.where(exponent < 0, 0, exponent) + + # Normal ULP exponent, expressed as an fp16 exponent field: + # (exponent - EXPONENT_BIAS - MANTISSA_BITS) + 15 + # Simplifies to: fp16_exponent - MANTISSA_BITS + 15 + # See https://en.wikipedia.org/wiki/Unit_in_the_last_place + eps_exp = tl.maximum(0, tl.minimum(31, exponent - EXPONENT_BIAS - MANTISSA_BITS + 15)) + + # Calculate epsilon in the target dtype + eps_normal = tl.cast(tl.cast(eps_exp << 10, tl.int16), tl.float16, bitcast=True) + + # Subnormal ULP: 2^(1 - EXPONENT_BIAS - MANTISSA_BITS) -> + # fp16 exponent bits: (1 - EXPONENT_BIAS - MANTISSA_BITS) + 15 = + # 16 - EXPONENT_BIAS - MANTISSA_BITS + eps_subnormal = tl.cast(tl.cast((16 - EXPONENT_BIAS - MANTISSA_BITS) << 10, tl.int16), tl.float16, bitcast=True) + eps = tl.where(exponent > 0, eps_normal, eps_subnormal) + + # Apply zero mask to epsilon + eps = tl.where(x == 0, 0.0, eps) + + # Apply stochastic rounding + output = tl.cast(x + rand_vals * eps, tl.bfloat16) + + # Store the result + tl.store(output_ptr + offsets, output, mask=mask) diff --git a/packages/ltx-core/src/ltx_core/loader/module_ops.py b/packages/ltx-core/src/ltx_core/loader/module_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c2c2722ca265a8d8de71fa1ff8472cb1d4d266dd --- /dev/null +++ b/packages/ltx-core/src/ltx_core/loader/module_ops.py @@ -0,0 +1,11 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Alexey Kravtsov +from typing import Callable, NamedTuple + +import torch + + +class ModuleOps(NamedTuple): + name: str + matcher: Callable[[torch.nn.Module], bool] + mutator: Callable[[torch.nn.Module], torch.nn.Module] diff --git a/packages/ltx-core/src/ltx_core/loader/primitives.py b/packages/ltx-core/src/ltx_core/loader/primitives.py new file mode 100644 index 0000000000000000000000000000000000000000..2591b1af9dc686cc9e606ff5cc541762767ec177 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/loader/primitives.py @@ -0,0 +1,63 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Alexey Kravtsov +from dataclasses import dataclass +from typing import NamedTuple, Protocol + +import torch + +from ltx_core.loader.module_ops import ModuleOps +from ltx_core.loader.sd_ops import SDOps +from ltx_core.model.model_protocol import ModelType + + +@dataclass(frozen=True) +class StateDict: + sd: dict + device: torch.device + size: int + dtype: set[torch.dtype] + + def footprint(self) -> tuple[int, torch.device]: + return self.size, self.device + + +class StateDictLoader(Protocol): + def metadata(self, path: str) -> dict: + """ + Load metadata from path + """ + + def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict: + """ + Load state dict from path or paths (for sharded model storage) and apply sd_ops + """ + + +class ModelBuilderProtocol(Protocol[ModelType]): + def meta_model(self, config: dict, module_ops: list[ModuleOps] | None = None) -> ModelType: ... + + def build(self, dtype: torch.dtype | None = None) -> ModelType: + """ + Build the model + Args: + dtype: Target dtype for the model, if None, uses the dtype of the model_path model + Returns: + Model instance + """ + ... + + +class LoRAAdaptableProtocol(Protocol): + def lora(self, lora_path: str, strength: float) -> "LoRAAdaptableProtocol": + pass + + +class LoraPathStrengthAndSDOps(NamedTuple): + path: str + strength: float + sd_ops: SDOps + + +class LoraStateDictWithStrength(NamedTuple): + state_dict: StateDict + strength: float diff --git a/packages/ltx-core/src/ltx_core/loader/registry.py b/packages/ltx-core/src/ltx_core/loader/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..e966bcf9c0d876e5a6048d9da79780eddfa421f4 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/loader/registry.py @@ -0,0 +1,68 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Alexey Kravtsov +import hashlib +import threading +from dataclasses import dataclass, field +from pathlib import Path +from typing import Protocol + +from ltx_core.loader.primitives import StateDict +from ltx_core.loader.sd_ops import SDOps + + +class Registry(Protocol): + def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None: ... + + def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ... + + def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ... + + def clear(self) -> None: ... + + +class DummyRegistry(Registry): + def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None: + pass + + def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: + pass + + def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: + pass + + def clear(self) -> None: + pass + + +@dataclass +class StateDictRegistry(Registry): + _state_dicts: dict[str, StateDict] = field(default_factory=dict) + _lock: threading.Lock = field(default_factory=threading.Lock) + + def _generate_id(self, paths: list[str], sd_ops: SDOps) -> str: + m = hashlib.sha256() + parts = [str(Path(p).resolve()) for p in paths] + if sd_ops is not None: + parts.append(sd_ops.name) + m.update("\0".join(parts).encode("utf-8")) + return m.hexdigest() + + def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> str: + sd_id = self._generate_id(paths, sd_ops) + with self._lock: + if sd_id in self._state_dicts: + raise ValueError(f"State dict retrieved from {paths} with {sd_ops} already added, check with get first") + self._state_dicts[sd_id] = state_dict + return sd_id + + def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: + with self._lock: + return self._state_dicts.pop(self._generate_id(paths, sd_ops), None) + + def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: + with self._lock: + return self._state_dicts.get(self._generate_id(paths, sd_ops), None) + + def clear(self) -> None: + with self._lock: + self._state_dicts.clear() diff --git a/packages/ltx-core/src/ltx_core/loader/sd_ops.py b/packages/ltx-core/src/ltx_core/loader/sd_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a48b3683cff367e7528e11c30703cd9ece1e55ef --- /dev/null +++ b/packages/ltx-core/src/ltx_core/loader/sd_ops.py @@ -0,0 +1,107 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Alexey Kravtsov + +from dataclasses import dataclass, replace +#from typing import NamedTuple, Protocol, Self +from typing import NamedTuple, Protocol +from typing_extensions import Self + +import torch + + +@dataclass(frozen=True, slots=True) +class ContentReplacement: + content: str + replacement: str + + +@dataclass(frozen=True, slots=True) +class ContentMatching: + prefix: str = "" + suffix: str = "" + + +class KeyValueOperationResult(NamedTuple): + new_key: str + new_value: torch.Tensor + + +class KeyValueOperation(Protocol): + def __call__(self, tensor_key: str, tensor_value: torch.Tensor) -> list[KeyValueOperationResult]: ... + + +@dataclass(frozen=True, slots=True) +class SDKeyValueOperation: + key_matcher: ContentMatching + kv_operation: KeyValueOperation + + +@dataclass(frozen=True, slots=True) +class SDOps: + """Immutable class representing state dict key operations.""" + + name: str + mapping: tuple[ + ContentReplacement | ContentMatching | SDKeyValueOperation, ... + ] = () # Immutable tuple of (key, value) pairs + + def with_replacement(self, content: str, replacement: str) -> Self: + """Create a new SDOps instance with the specified replacement added to the mapping.""" + + new_mapping = (*self.mapping, ContentReplacement(content, replacement)) + return replace(self, mapping=new_mapping) + + def with_matching(self, prefix: str = "", suffix: str = "") -> Self: + """Create a new SDOps instance with the specified prefix and suffix matching added to the mapping.""" + + new_mapping = (*self.mapping, ContentMatching(prefix, suffix)) + return replace(self, mapping=new_mapping) + + def with_kv_operation( + self, + operation: KeyValueOperation, + key_prefix: str = "", + key_suffix: str = "", + ) -> Self: + """Create a new SDOps instance with the specified value operation added to the mapping.""" + key_matcher = ContentMatching(key_prefix, key_suffix) + sd_kv_operation = SDKeyValueOperation(key_matcher, operation) + new_mapping = (*self.mapping, sd_kv_operation) + return replace(self, mapping=new_mapping) + + def apply_to_key(self, key: str) -> str | None: + """Apply the mapping to the given name.""" + matchers = [content for content in self.mapping if isinstance(content, ContentMatching)] + valid = any(key.startswith(f.prefix) and key.endswith(f.suffix) for f in matchers) + if not valid: + return None + + for replacement in self.mapping: + if not isinstance(replacement, ContentReplacement): + continue + if replacement.content in key: + key = key.replace(replacement.content, replacement.replacement) + return key + + def apply_to_key_value(self, key: str, value: torch.Tensor) -> list[KeyValueOperationResult]: + """Apply the value operation to the given name and associated value.""" + for operation in self.mapping: + if not isinstance(operation, SDKeyValueOperation): + continue + if key.startswith(operation.key_matcher.prefix) and key.endswith(operation.key_matcher.suffix): + return operation.kv_operation(key, value) + return [KeyValueOperationResult(key, value)] + + +# Predefined SDOps instances +LTXV_LORA_COMFY_RENAMING_MAP = ( + SDOps("LTXV_LORA_COMFY_PREFIX_MAP").with_matching().with_replacement("diffusion_model.", "") +) + +LTXV_LORA_COMFY_TARGET_MAP = ( + SDOps("LTXV_LORA_COMFY_TARGET_MAP") + .with_matching() + .with_replacement("diffusion_model.", "") + .with_replacement(".lora_A.weight", ".weight") + .with_replacement(".lora_B.weight", ".weight") +) diff --git a/packages/ltx-core/src/ltx_core/loader/sft_loader.py b/packages/ltx-core/src/ltx_core/loader/sft_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..ea6b057c0f26145cf13cbdd28c7b7d5d1b85425c --- /dev/null +++ b/packages/ltx-core/src/ltx_core/loader/sft_loader.py @@ -0,0 +1,53 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Alexey Kravtsov +import json + +import safetensors +import torch + +from ltx_core.loader.primitives import StateDict, StateDictLoader +from ltx_core.loader.sd_ops import SDOps + + +class SafetensorsStateDictLoader(StateDictLoader): + def metadata(self, path: str) -> dict: + raise NotImplementedError("Not implemented") + + def load(self, path: str | list[str], sd_ops: SDOps, device: torch.device | None = None) -> StateDict: + """ + Load state dict from path or paths (for sharded model storage) and apply sd_ops + """ + sd = {} + size = 0 + dtype = set() + device = device or torch.device("cpu") + model_paths = path if isinstance(path, list) else [path] + for shard_path in model_paths: + with safetensors.safe_open(shard_path, framework="pt", device=str(device)) as f: + safetensor_keys = f.keys() + for name in safetensor_keys: + expected_name = name if sd_ops is None else sd_ops.apply_to_key(name) + if expected_name is None: + continue + value = f.get_tensor(name).to(device=device, non_blocking=True, copy=False) + key_value_pairs = ((expected_name, value),) + if sd_ops is not None: + key_value_pairs = sd_ops.apply_to_key_value(expected_name, value) + for key, value in key_value_pairs: + size += value.nbytes + dtype.add(value.dtype) + sd[key] = value + + return StateDict(sd=sd, device=device, size=size, dtype=dtype) + + +class SafetensorsModelStateDictLoader(StateDictLoader): + def __init__(self, weight_loader: SafetensorsStateDictLoader | None = None): + self.weight_loader = weight_loader if weight_loader is not None else SafetensorsStateDictLoader() + + def metadata(self, path: str) -> dict: + with safetensors.safe_open(path, framework="pt") as f: + return json.loads(f.metadata()["config"]) + + def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict: + return self.weight_loader.load(path, sd_ops, device) diff --git a/packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py b/packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..c3f11a498a9c33f0f8b11a919112363321818352 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py @@ -0,0 +1,99 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Alexey Kravtsov +import logging +from dataclasses import dataclass, field, replace +from typing import Generic + +import torch + +from ltx_core.loader.fuse_loras import apply_loras +from ltx_core.loader.module_ops import ModuleOps +from ltx_core.loader.primitives import ( + LoRAAdaptableProtocol, + LoraPathStrengthAndSDOps, + LoraStateDictWithStrength, + ModelBuilderProtocol, + StateDict, + StateDictLoader, +) +from ltx_core.loader.registry import DummyRegistry, Registry +from ltx_core.loader.sd_ops import SDOps +from ltx_core.loader.sft_loader import SafetensorsModelStateDictLoader +from ltx_core.model.model_protocol import ModelConfigurator, ModelType + +logger: logging.Logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class SingleGPUModelBuilder(Generic[ModelType], ModelBuilderProtocol[ModelType], LoRAAdaptableProtocol): + model_class_configurator: type[ModelConfigurator[ModelType]] + model_path: str | tuple[str, ...] + model_sd_ops: SDOps | None = None + module_ops: tuple[ModuleOps, ...] = field(default_factory=tuple) + loras: tuple[LoraPathStrengthAndSDOps, ...] = field(default_factory=tuple) + model_loader: StateDictLoader = field(default_factory=SafetensorsModelStateDictLoader) + registry: Registry = field(default_factory=DummyRegistry) + + def lora(self, lora_path: str, strength: float = 1.0, sd_ops: SDOps | None = None) -> "SingleGPUModelBuilder": + return replace(self, loras=(*self.loras, LoraPathStrengthAndSDOps(lora_path, strength, sd_ops))) + + def model_config(self) -> dict: + first_shard_path = self.model_path[0] if isinstance(self.model_path, tuple) else self.model_path + return self.model_loader.metadata(first_shard_path) + + def meta_model(self, config: dict, module_ops: tuple[ModuleOps, ...]) -> ModelType: + with torch.device("meta"): + model = self.model_class_configurator.from_config(config) + for module_op in module_ops: + if module_op.matcher(model): + model = module_op.mutator(model) + return model + + def load_sd( + self, paths: list[str], registry: Registry, device: torch.device | None, sd_ops: SDOps | None = None + ) -> StateDict: + state_dict = registry.get(paths, sd_ops) + if state_dict is None: + state_dict = self.model_loader.load(paths, sd_ops=sd_ops, device=device) + registry.add(paths, sd_ops=sd_ops, state_dict=state_dict) + return state_dict + + def _return_model(self, meta_model: ModelType, device: torch.device) -> ModelType: + uninitialized_params = [name for name, param in meta_model.named_parameters() if str(param.device) == "meta"] + uninitialized_buffers = [name for name, buffer in meta_model.named_buffers() if str(buffer.device) == "meta"] + if uninitialized_params or uninitialized_buffers: + logger.warning(f"Uninitialized parameters or buffers: {uninitialized_params + uninitialized_buffers}") + return meta_model + retval = meta_model.to(device) + return retval + + def build(self, device: torch.device | None = None, dtype: torch.dtype | None = None) -> ModelType: + device = torch.device("cuda") if device is None else device + config = self.model_config() + meta_model = self.meta_model(config, self.module_ops) + model_paths = self.model_path if isinstance(self.model_path, tuple) else [self.model_path] + model_state_dict = self.load_sd(model_paths, sd_ops=self.model_sd_ops, registry=self.registry, device=device) + + lora_strengths = [lora.strength for lora in self.loras] + if not lora_strengths or (min(lora_strengths) == 0 and max(lora_strengths) == 0): + sd = model_state_dict.sd + if dtype is not None: + sd = {key: value.to(dtype=dtype) for key, value in model_state_dict.sd.items()} + meta_model.load_state_dict(sd, strict=False, assign=True) + return self._return_model(meta_model, device) + + lora_state_dicts = [ + self.load_sd([lora.path], sd_ops=lora.sd_ops, registry=self.registry, device=device) for lora in self.loras + ] + lora_sd_and_strengths = [ + LoraStateDictWithStrength(sd, strength) + for sd, strength in zip(lora_state_dicts, lora_strengths, strict=True) + ] + final_sd = apply_loras( + model_sd=model_state_dict, + lora_sd_and_strengths=lora_sd_and_strengths, + dtype=dtype, + destination_sd=model_state_dict if isinstance(self.registry, DummyRegistry) else None, + ) + meta_model.load_state_dict(final_sd.sd, strict=False, assign=True) + return self._return_model(meta_model, device) diff --git a/packages/ltx-core/src/ltx_core/model/.ipynb_checkpoints/model_ledger-checkpoint.py b/packages/ltx-core/src/ltx_core/model/.ipynb_checkpoints/model_ledger-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..37c8a5ea78f2eb134a8b01a4a0e09fe0a8dceab2 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/.ipynb_checkpoints/model_ledger-checkpoint.py @@ -0,0 +1,253 @@ +from dataclasses import replace +# from typing import Self +from typing_extensions import Self + +import torch + +from ltx_core.loader.primitives import LoraPathStrengthAndSDOps +from ltx_core.loader.registry import DummyRegistry, Registry +from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder +from ltx_core.model.audio_vae.audio_vae import Decoder as AudioDecoder +from ltx_core.model.audio_vae.model_configurator import ( + AUDIO_VAE_DECODER_COMFY_KEYS_FILTER, + VOCODER_COMFY_KEYS_FILTER, + VocoderConfigurator, +) +from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator as AudioDecoderConfigurator +from ltx_core.model.audio_vae.vocoder import Vocoder +from ltx_core.model.clip.gemma.encoders.av_encoder import ( + AV_GEMMA_TEXT_ENCODER_KEY_OPS, + AVGemmaTextEncoderModel, + AVGemmaTextEncoderModelConfigurator, +) +from ltx_core.model.clip.gemma.encoders.base_encoder import module_ops_from_gemma_root +from ltx_core.model.transformer.model import X0Model +from ltx_core.model.transformer.model_configurator import ( + LTXV_MODEL_COMFY_RENAMING_MAP, + LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP, + UPCAST_DURING_INFERENCE, + LTXModelConfigurator, +) +from ltx_core.model.upsampler.model import LatentUpsampler +from ltx_core.model.upsampler.model_configurator import LatentUpsamplerConfigurator +from ltx_core.model.video_vae.model_configurator import ( + VAE_DECODER_COMFY_KEYS_FILTER, + VAE_ENCODER_COMFY_KEYS_FILTER, + VAEDecoderConfigurator, + VAEEncoderConfigurator, +) +from ltx_core.model.video_vae.video_vae import Decoder as VideoDecoder +from ltx_core.model.video_vae.video_vae import Encoder as VideoEncoder + + +class ModelLedger: + """ + Central coordinator for loading, caching, and freeing models used in an LTX pipeline. + The ledger wires together multiple model builders (transformer, video VAE encoder/decoder, + audio VAE decoder, vocoder, text encoder, and optional latent upsampler) and exposes + the resulting models as lazily constructed, cached attributes. + + ### Caching behavior + + Each model attribute (e.g. :attr:`transformer`, :attr:`video_decoder`, :attr:`text_encoder`) + is implemented as a :func:`functools.cached_property`. The first time one of these + attributes is accessed, the corresponding builder loads weights from the + :class:`~ltx_core.loader.registry.StateDictRegistry`, instantiates the model on CPU with + the configured ``dtype``, moves it to ``self.device``, and stores the result in + the instance ``__dict__``. Subsequent accesses reuse the same model instance until it is + explicitly cleared via :meth:`clear_vram`. + + ### Constructor parameters + + dtype: + Torch dtype used when constructing all models (e.g. ``torch.float16``). + device: + Target device to which models are moved after construction (e.g. ``torch.device("cuda")``). + checkpoint_path: + Path to a checkpoint directory or file containing the core model weights + (transformer, video VAE, audio VAE, text encoder, vocoder). If ``None``, the + corresponding builders are not created and accessing those properties will raise + a :class:`ValueError`. + gemma_root_path: + Base path to Gemma-compatible CLIP/text encoder weights. Required to + initialize the text encoder builder; if omitted, :attr:`text_encoder` cannot be used. + spatial_upsampler_path: + Optional path to a latent upsampler checkpoint. If provided, the + :attr:`upsampler` property becomes available; otherwise accessing it raises + a :class:`ValueError`. + loras: + Optional collection of LoRA configurations (paths, strengths, and key operations) + that are applied on top of the base transformer weights when building the model. + + ### Memory management + + ``clear_ram()`` + Clears the underlying :class:`Registry` cache of state dicts and triggers a + Python garbage collection pass. Use this when you no longer need to construct new + models from the currently loaded checkpoints and want to free host (CPU) memory. + ``clear_vram()`` + Drops the cached model instances stored by the ``@cached_property`` attributes from + this ledger (by removing them from ``self.__dict__``) and calls + :func:`torch.cuda.empty_cache`. Use this when you want to release GPU memory; + subsequent access to a model property will rebuild the model from the registry + while keeping the existing builder configuration. + """ + + def __init__( + self, + dtype: torch.dtype, + device: torch.device, + checkpoint_path: str | None = None, + gemma_root_path: str | None = None, + spatial_upsampler_path: str | None = None, + loras: LoraPathStrengthAndSDOps | None = None, + registry: Registry | None = None, + fp8transformer: bool = False, + local_files_only: bool = True + ): + self.dtype = dtype + self.device = device + self.checkpoint_path = checkpoint_path + self.gemma_root_path = gemma_root_path + self.spatial_upsampler_path = spatial_upsampler_path + self.loras = loras or () + self.registry = registry or DummyRegistry() + self.fp8transformer = fp8transformer + self.local_files_only = local_files_only + self.build_model_builders() + + def build_model_builders(self) -> None: + if self.checkpoint_path is not None: + self.transformer_builder = Builder( + model_path=self.checkpoint_path, + model_class_configurator=LTXModelConfigurator, + model_sd_ops=LTXV_MODEL_COMFY_RENAMING_MAP, + loras=tuple(self.loras), + registry=self.registry, + ) + + self.vae_decoder_builder = Builder( + model_path=self.checkpoint_path, + model_class_configurator=VAEDecoderConfigurator, + model_sd_ops=VAE_DECODER_COMFY_KEYS_FILTER, + registry=self.registry, + ) + + self.vae_encoder_builder = Builder( + model_path=self.checkpoint_path, + model_class_configurator=VAEEncoderConfigurator, + model_sd_ops=VAE_ENCODER_COMFY_KEYS_FILTER, + registry=self.registry, + ) + + self.audio_decoder_builder = Builder( + model_path=self.checkpoint_path, + model_class_configurator=AudioDecoderConfigurator, + model_sd_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER, + registry=self.registry, + ) + + self.vocoder_builder = Builder( + model_path=self.checkpoint_path, + model_class_configurator=VocoderConfigurator, + model_sd_ops=VOCODER_COMFY_KEYS_FILTER, + registry=self.registry, + ) + + if self.gemma_root_path is not None: + self.text_encoder_builder = Builder( + model_path=self.checkpoint_path, + model_class_configurator=AVGemmaTextEncoderModelConfigurator, + model_sd_ops=AV_GEMMA_TEXT_ENCODER_KEY_OPS, + registry=self.registry, + module_ops=module_ops_from_gemma_root(self.gemma_root_path, self.local_files_only), + ) + + if self.spatial_upsampler_path is not None: + self.upsampler_builder = Builder( + model_path=self.spatial_upsampler_path, + model_class_configurator=LatentUpsamplerConfigurator, + registry=self.registry, + ) + + def _target_device(self) -> torch.device: + if isinstance(self.registry, DummyRegistry) or self.registry is None: + return self.device + else: + return torch.device("cpu") + + def with_loras(self, loras: LoraPathStrengthAndSDOps) -> Self: + return ModelLedger( + dtype=self.dtype, + device=self.device, + checkpoint_path=self.checkpoint_path, + gemma_root_path=self.gemma_root_path, + spatial_upsampler_path=self.spatial_upsampler_path, + loras=(*self.loras, *loras), + registry=self.registry, + fp8transformer=self.fp8transformer, + ) + + def transformer(self) -> X0Model: + if not hasattr(self, "transformer_builder"): + raise ValueError( + "Transformer not initialized. Please provide a checkpoint path to the ModelLedger constructor." + ) + if self.fp8transformer: + fp8_builder = replace( + self.transformer_builder, + module_ops=(UPCAST_DURING_INFERENCE,), + model_sd_ops=LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP, + ) + return X0Model(fp8_builder.build(device=self._target_device())).to(self.device) + else: + return X0Model(self.transformer_builder.build(device=self._target_device(), dtype=self.dtype)).to( + self.device + ) + + def video_decoder(self) -> VideoDecoder: + if not hasattr(self, "vae_decoder_builder"): + raise ValueError( + "Video decoder not initialized. Please provide a checkpoint path to the ModelLedger constructor." + ) + + return self.vae_decoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device) + + def video_encoder(self) -> VideoEncoder: + if not hasattr(self, "vae_encoder_builder"): + raise ValueError( + "Video encoder not initialized. Please provide a checkpoint path to the ModelLedger constructor." + ) + + return self.vae_encoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device) + + def text_encoder(self) -> AVGemmaTextEncoderModel: + if not hasattr(self, "text_encoder_builder"): + raise ValueError( + "Text encoder not initialized. Please provide a checkpoint path and gemma root path to the " + "ModelLedger constructor." + ) + + return self.text_encoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device) + + def audio_decoder(self) -> AudioDecoder: + if not hasattr(self, "audio_decoder_builder"): + raise ValueError( + "Audio decoder not initialized. Please provide a checkpoint path to the ModelLedger constructor." + ) + + return self.audio_decoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device) + + def vocoder(self) -> Vocoder: + if not hasattr(self, "vocoder_builder"): + raise ValueError( + "Vocoder not initialized. Please provide a checkpoint path to the ModelLedger constructor." + ) + + return self.vocoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device) + + def spatial_upsampler(self) -> LatentUpsampler: + if not hasattr(self, "upsampler_builder"): + raise ValueError("Upsampler not initialized. Please provide upsampler path to the ModelLedger constructor.") + + return self.upsampler_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device) diff --git a/packages/ltx-core/src/ltx_core/model/__init__.py b/packages/ltx-core/src/ltx_core/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/packages/ltx-core/src/ltx_core/model/__pycache__/__init__.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6969f1f494be8028dfbc9e6d4ac546d087b4e44 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/__pycache__/__init__.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/__pycache__/model_ledger.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/__pycache__/model_ledger.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a241299122cd5cf2b69d4b7cec2f6bc1eaf2cf69 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/__pycache__/model_ledger.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/__pycache__/model_protocol.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/__pycache__/model_protocol.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d675ffb236b1b44ccca0cfd8910658da0962ca57 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/__pycache__/model_protocol.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/audio_vae/__init__.py b/packages/ltx-core/src/ltx_core/model/audio_vae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/__init__.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b80365d77125b2cd74721947e3a484cacbb447f7 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/__init__.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/attention.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfab69d9db495ab7b915b8e7b6cf6c2d786f2afd Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/attention.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/audio_vae.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/audio_vae.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85915330954fe10cea58afeb39daf09bd6d0f358 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/audio_vae.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/causal_conv_2d.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/causal_conv_2d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d7331f2cbd481cbcf43ec6b5011647af17b93e7 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/causal_conv_2d.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/causality_axis.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/causality_axis.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1a8675fb303e042e8c62acd1e323b98fdd315d6 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/causality_axis.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/downsample.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/downsample.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf97673689c96cde0fa391ca502bc52c65ff6e85 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/downsample.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/model_configurator.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/model_configurator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4f35229be2a52acbac25d24b164dc52ba2fcd43 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/model_configurator.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/ops.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1fbe34b058cc8684c3b432f7613b41367a0f8aa Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/ops.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/resnet.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/resnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2eb28a9f5e4ab0b079b49984c4417bb0ed22bf97 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/resnet.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/upsample.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/upsample.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbd53d87639fa688cc01d265f368811407fbb566 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/upsample.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/vocoder.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/vocoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54e940ea72440812014b38d4d474ea9196600d6c Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/audio_vae/__pycache__/vocoder.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/audio_vae/attention.py b/packages/ltx-core/src/ltx_core/model/audio_vae/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..46d5ebb29d340d75cd1907ce4e4bd7e14e80f394 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/audio_vae/attention.py @@ -0,0 +1,71 @@ +from enum import Enum + +import torch + +from ltx_core.model.common.normalization import NormType, build_normalization_layer + + +class AttentionType(Enum): + """Enum for specifying the attention mechanism type.""" + + VANILLA = "vanilla" + LINEAR = "linear" + NONE = "none" + + +class AttnBlock(torch.nn.Module): + def __init__( + self, + in_channels: int, + norm_type: NormType = NormType.GROUP, + ) -> None: + super().__init__() + self.in_channels = in_channels + + self.norm = build_normalization_layer(in_channels, normtype=norm_type) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w).contiguous() + q = q.permute(0, 2, 1).contiguous() # b,hw,c + k = k.reshape(b, c, h * w).contiguous() # b,c,hw + w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w).contiguous() + w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w).contiguous() + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn( + in_channels: int, + attn_type: AttentionType = AttentionType.VANILLA, + norm_type: NormType = NormType.GROUP, +) -> torch.nn.Module: + match attn_type: + case AttentionType.VANILLA: + return AttnBlock(in_channels, norm_type=norm_type) + case AttentionType.NONE: + return torch.nn.Identity() + case AttentionType.LINEAR: + raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.") + case _: + raise ValueError(f"Unknown attention type: {attn_type}") diff --git a/packages/ltx-core/src/ltx_core/model/audio_vae/audio_vae.py b/packages/ltx-core/src/ltx_core/model/audio_vae/audio_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..19a971e8f3249287761e91a9865667c812c46c4d --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/audio_vae/audio_vae.py @@ -0,0 +1,483 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Ivan Zorin + + +from typing import Set, Tuple + +import torch +import torch.nn.functional as F + +from ltx_core.model.audio_vae.attention import AttentionType, make_attn +from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d +from ltx_core.model.audio_vae.causality_axis import CausalityAxis +from ltx_core.model.audio_vae.downsample import build_downsampling_path +from ltx_core.model.audio_vae.ops import PerChannelStatistics +from ltx_core.model.audio_vae.resnet import ResnetBlock +from ltx_core.model.audio_vae.upsample import build_upsampling_path +from ltx_core.model.common.normalization import NormType, build_normalization_layer +from ltx_core.pipeline.components.patchifiers import AudioPatchifier +from ltx_core.pipeline.components.protocols import AudioLatentShape + +LATENT_DOWNSAMPLE_FACTOR = 4 + + +def build_mid_block( + channels: int, + temb_channels: int, + dropout: float, + norm_type: NormType, + causality_axis: CausalityAxis, + attn_type: AttentionType, + add_attention: bool, +) -> torch.nn.Module: + """Build the middle block with two ResNet blocks and optional attention.""" + mid = torch.nn.Module() + mid.block_1 = ResnetBlock( + in_channels=channels, + out_channels=channels, + temb_channels=temb_channels, + dropout=dropout, + norm_type=norm_type, + causality_axis=causality_axis, + ) + mid.attn_1 = make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else torch.nn.Identity() + mid.block_2 = ResnetBlock( + in_channels=channels, + out_channels=channels, + temb_channels=temb_channels, + dropout=dropout, + norm_type=norm_type, + causality_axis=causality_axis, + ) + return mid + + +def run_mid_block(mid: torch.nn.Module, features: torch.Tensor) -> torch.Tensor: + """Run features through the middle block.""" + features = mid.block_1(features, temb=None) + features = mid.attn_1(features) + return mid.block_2(features, temb=None) + + +class Encoder(torch.nn.Module): + """ + Encoder that compresses audio spectrograms into latent representations. + + The encoder uses a series of downsampling blocks with residual connections, + attention mechanisms, and configurable causal convolutions. + """ + + def __init__( # noqa: PLR0913 + self, + *, + ch: int, + ch_mult: Tuple[int, ...] = (1, 2, 4, 8), + num_res_blocks: int, + attn_resolutions: Set[int], + dropout: float = 0.0, + resamp_with_conv: bool = True, + in_channels: int, + resolution: int, + z_channels: int, + double_z: bool = True, + attn_type: AttentionType = AttentionType.VANILLA, + mid_block_add_attention: bool = True, + norm_type: NormType = NormType.GROUP, + causality_axis: CausalityAxis = CausalityAxis.WIDTH, + sample_rate: int = 16000, + mel_hop_length: int = 160, + n_fft: int = 1024, + is_causal: bool = True, + mel_bins: int = 64, + **_ignore_kwargs, + ) -> None: + """ + Initialize the Encoder. + + Args: + Arguments are configuration parameters, loaded from the audio VAE checkpoint config + (audio_vae.model.params.ddconfig): + + ch: Base number of feature channels used in the first convolution layer. + ch_mult: Multiplicative factors for the number of channels at each resolution level. + num_res_blocks: Number of residual blocks to use at each resolution level. + attn_resolutions: Spatial resolutions (e.g., in time/frequency) at which to apply attention. + resolution: Input spatial resolution of the spectrogram (height, width). + z_channels: Number of channels in the latent representation. + norm_type: Normalization layer type to use within the network (e.g., group, batch). + causality_axis: Axis along which convolutions should be causal (e.g., time axis). + sample_rate: Audio sample rate in Hz for the input signals. + mel_hop_length: Hop length used when computing the mel spectrogram. + n_fft: FFT size used to compute the spectrogram. + mel_bins: Number of mel-frequency bins in the input spectrogram. + in_channels: Number of channels in the input spectrogram tensor. + double_z: If True, predict both mean and log-variance (doubling latent channels). + is_causal: If True, use causal convolutions suitable for streaming setups. + dropout: Dropout probability used in residual and mid blocks. + attn_type: Type of attention mechanism to use in attention blocks. + resamp_with_conv: If True, perform resolution changes using strided convolutions. + mid_block_add_attention: If True, add an attention block in the mid-level of the encoder. + """ + super().__init__() + + self.per_channel_statistics = PerChannelStatistics(latent_channels=ch) + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.n_fft = n_fft + self.is_causal = is_causal + self.mel_bins = mel_bins + + self.patchifier = AudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=sample_rate, + hop_length=mel_hop_length, + is_causal=is_causal, + ) + + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.z_channels = z_channels + self.double_z = double_z + self.norm_type = norm_type + self.causality_axis = causality_axis + self.attn_type = attn_type + + # downsampling + self.conv_in = make_conv2d( + in_channels, + self.ch, + kernel_size=3, + stride=1, + causality_axis=self.causality_axis, + ) + + self.non_linearity = torch.nn.SiLU() + + self.down, block_in = build_downsampling_path( + ch=ch, + ch_mult=ch_mult, + num_resolutions=self.num_resolutions, + num_res_blocks=num_res_blocks, + resolution=resolution, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + attn_type=self.attn_type, + attn_resolutions=attn_resolutions, + resamp_with_conv=resamp_with_conv, + ) + + self.mid = build_mid_block( + channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + attn_type=self.attn_type, + add_attention=mid_block_add_attention, + ) + + self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type) + self.conv_out = make_conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + causality_axis=self.causality_axis, + ) + + def forward(self, spectrogram: torch.Tensor) -> torch.Tensor: + """ + Encode audio spectrogram into latent representations. + + Args: + spectrogram: Input spectrogram of shape (batch, channels, time, frequency) + + Returns: + Encoded latent representation of shape (batch, channels, frames, mel_bins) + """ + h = self.conv_in(spectrogram) + h = self._run_downsampling_path(h) + h = run_mid_block(self.mid, h) + h = self._finalize_output(h) + + return self._normalize_latents(h) + + def _run_downsampling_path(self, h: torch.Tensor) -> torch.Tensor: + for level in range(self.num_resolutions): + stage = self.down[level] + for block_idx in range(self.num_res_blocks): + h = stage.block[block_idx](h, temb=None) + if stage.attn: + h = stage.attn[block_idx](h) + + if level != self.num_resolutions - 1: + h = stage.downsample(h) + + return h + + def _finalize_output(self, h: torch.Tensor) -> torch.Tensor: + h = self.norm_out(h) + h = self.non_linearity(h) + return self.conv_out(h) + + def _normalize_latents(self, latent_output: torch.Tensor) -> torch.Tensor: + """ + Normalize encoder latents using per-channel statistics. + + When the encoder is configured with ``double_z=True``, the final + convolution produces twice the number of latent channels, typically + interpreted as two concatenated tensors along the channel dimension + (e.g., mean and variance or other auxiliary parameters). + + This method intentionally uses only the first half of the channels + (the "mean" component) as input to the patchifier and normalization + logic. The remaining channels are left unchanged by this method and + are expected to be consumed elsewhere in the VAE pipeline. + + If ``double_z=False``, the encoder output already contains only the + mean latents and the chunking operation simply returns that tensor. + """ + means = torch.chunk(latent_output, 2, dim=1)[0] + latent_shape = AudioLatentShape( + batch=means.shape[0], + channels=means.shape[1], + frames=means.shape[2], + mel_bins=means.shape[3], + ) + latent_patched = self.patchifier.patchify(means) + latent_normalized = self.per_channel_statistics.normalize(latent_patched) + return self.patchifier.unpatchify(latent_normalized, latent_shape) + + +class Decoder(torch.nn.Module): + """ + Symmetric decoder that reconstructs audio spectrograms from latent features. + + The decoder mirrors the encoder structure with configurable channel multipliers, + attention resolutions, and causal convolutions. + """ + + def __init__( # noqa: PLR0913 + self, + *, + ch: int, + out_ch: int, + ch_mult: Tuple[int, ...] = (1, 2, 4, 8), + num_res_blocks: int, + attn_resolutions: Set[int], + resolution: int, + z_channels: int, + norm_type: NormType = NormType.GROUP, + causality_axis: CausalityAxis = CausalityAxis.WIDTH, + dropout: float = 0.0, + mid_block_add_attention: bool = True, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: int | None = None, + ) -> None: + """ + Initialize the Decoder. + + Args: + Arguments are configuration parameters, loaded from the audio VAE checkpoint config + (audio_vae.model.params.ddconfig): + - ch, out_ch, ch_mult, num_res_blocks, attn_resolutions + - resolution, z_channels + - norm_type, causality_axis + """ + super().__init__() + + # Internal behavioural defaults that are not driven by the checkpoint. + resamp_with_conv = True + attn_type = AttentionType.VANILLA + + # Per-channel statistics for denormalizing latents + self.per_channel_statistics = PerChannelStatistics(latent_channels=ch) + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.is_causal = is_causal + self.mel_bins = mel_bins + self.patchifier = AudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=sample_rate, + hop_length=mel_hop_length, + is_causal=is_causal, + ) + + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.out_ch = out_ch + self.give_pre_end = False + self.tanh_out = False + self.norm_type = norm_type + self.z_channels = z_channels + self.channel_multipliers = ch_mult + self.attn_resolutions = attn_resolutions + self.causality_axis = causality_axis + self.attn_type = attn_type + + base_block_channels = ch * self.channel_multipliers[-1] + base_resolution = resolution // (2 ** (self.num_resolutions - 1)) + self.z_shape = (1, z_channels, base_resolution, base_resolution) + + self.conv_in = make_conv2d( + z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + self.non_linearity = torch.nn.SiLU() + self.mid = build_mid_block( + channels=base_block_channels, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + attn_type=self.attn_type, + add_attention=mid_block_add_attention, + ) + self.up, final_block_channels = build_upsampling_path( + ch=ch, + ch_mult=ch_mult, + num_resolutions=self.num_resolutions, + num_res_blocks=num_res_blocks, + resolution=resolution, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + attn_type=self.attn_type, + attn_resolutions=attn_resolutions, + resamp_with_conv=resamp_with_conv, + initial_block_channels=base_block_channels, + ) + + self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type) + self.conv_out = make_conv2d( + final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + """ + Decode latent features back to audio spectrograms. + + Args: + sample: Encoded latent representation of shape (batch, channels, frames, mel_bins) + + Returns: + Reconstructed audio spectrogram of shape (batch, channels, time, frequency) + """ + sample, target_shape = self._denormalize_latents(sample) + + h = self.conv_in(sample) + h = run_mid_block(self.mid, h) + h = self._run_upsampling_path(h) + h = self._finalize_output(h) + + return self._adjust_output_shape(h, target_shape) + + def _denormalize_latents(self, sample: torch.Tensor) -> tuple[torch.Tensor, AudioLatentShape]: + latent_shape = AudioLatentShape( + batch=sample.shape[0], + channels=sample.shape[1], + frames=sample.shape[2], + mel_bins=sample.shape[3], + ) + + sample_patched = self.patchifier.patchify(sample) + sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched) + sample = self.patchifier.unpatchify(sample_denormalized, latent_shape) + + target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR + if self.causality_axis != CausalityAxis.NONE: + target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1) + + target_shape = AudioLatentShape( + batch=latent_shape.batch, + channels=self.out_ch, + frames=target_frames, + mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins, + ) + + return sample, target_shape + + def _adjust_output_shape( + self, + decoded_output: torch.Tensor, + target_shape: AudioLatentShape, + ) -> torch.Tensor: + """ + Adjust output shape to match target dimensions for variable-length audio. + + This function handles the common case where decoded audio spectrograms need to be + resized to match a specific target shape. + + Args: + decoded_output: Tensor of shape (batch, channels, time, frequency) + target_shape: AudioLatentShape describing (batch, channels, time, mel bins) + + Returns: + Tensor adjusted to match target_shape exactly + """ + # Current output shape: (batch, channels, time, frequency) + _, _, current_time, current_freq = decoded_output.shape + target_channels = target_shape.channels + target_time = target_shape.frames + target_freq = target_shape.mel_bins + + # Step 1: Crop first to avoid exceeding target dimensions + decoded_output = decoded_output[ + :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq) + ] + + # Step 2: Calculate padding needed for time and frequency dimensions + time_padding_needed = target_time - decoded_output.shape[2] + freq_padding_needed = target_freq - decoded_output.shape[3] + + # Step 3: Apply padding if needed + if time_padding_needed > 0 or freq_padding_needed > 0: + # PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom) + # For audio: pad_left/right = frequency, pad_top/bottom = time + padding = ( + 0, + max(freq_padding_needed, 0), # frequency padding (left, right) + 0, + max(time_padding_needed, 0), # time padding (top, bottom) + ) + decoded_output = F.pad(decoded_output, padding) + + # Step 4: Final safety crop to ensure exact target shape + decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq] + + return decoded_output + + def _run_upsampling_path(self, h: torch.Tensor) -> torch.Tensor: + for level in reversed(range(self.num_resolutions)): + stage = self.up[level] + for block_idx, block in enumerate(stage.block): + h = block(h, temb=None) + if stage.attn: + h = stage.attn[block_idx](h) + + if level != 0 and hasattr(stage, "upsample"): + h = stage.upsample(h) + + return h + + def _finalize_output(self, h: torch.Tensor) -> torch.Tensor: + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = self.non_linearity(h) + h = self.conv_out(h) + return torch.tanh(h) if self.tanh_out else h diff --git a/packages/ltx-core/src/ltx_core/model/audio_vae/causal_conv_2d.py b/packages/ltx-core/src/ltx_core/model/audio_vae/causal_conv_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..69b798fc74bb7ea464522a16d78967b7deb534f1 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/audio_vae/causal_conv_2d.py @@ -0,0 +1,113 @@ +import torch +import torch.nn.functional as F + +from ltx_core.model.audio_vae.causality_axis import CausalityAxis + + +class CausalConv2d(torch.nn.Module): + """ + A causal 2D convolution. + + This layer ensures that the output at time `t` only depends on inputs + at time `t` and earlier. It achieves this by applying asymmetric padding + to the time dimension (width) before the convolution. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int], + stride: int = 1, + dilation: int | tuple[int, int] = 1, + groups: int = 1, + bias: bool = True, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ) -> None: + super().__init__() + + self.causality_axis = causality_axis + + # Ensure kernel_size and dilation are tuples + kernel_size = torch.nn.modules.utils._pair(kernel_size) + dilation = torch.nn.modules.utils._pair(dilation) + + # Calculate padding dimensions + pad_h = (kernel_size[0] - 1) * dilation[0] + pad_w = (kernel_size[1] - 1) * dilation[1] + + # The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom) + match self.causality_axis: + case CausalityAxis.NONE: + self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY: + self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2) + case CausalityAxis.HEIGHT: + self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0) + case _: + raise ValueError(f"Invalid causality_axis: {causality_axis}") + + # The internal convolution layer uses no padding, as we handle it manually + self.conv = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Apply causal padding before convolution + x = F.pad(x, self.padding) + return self.conv(x) + + +def make_conv2d( + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int], + stride: int = 1, + padding: tuple[int, int, int, int] | None = None, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causality_axis: CausalityAxis | None = None, +) -> torch.nn.Module: + """ + Create a 2D convolution layer that can be either causal or non-causal. + + Args: + in_channels: Number of input channels + out_channels: Number of output channels + kernel_size: Size of the convolution kernel + stride: Convolution stride + padding: Padding (if None, will be calculated based on causal flag) + dilation: Dilation rate + groups: Number of groups for grouped convolution + bias: Whether to use bias + causality_axis: Dimension along which to apply causality. + + Returns: + Either a regular Conv2d or CausalConv2d layer + """ + if causality_axis is not None: + # For causal convolution, padding is handled internally by CausalConv2d + return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis) + else: + # For non-causal convolution, use symmetric padding if not specified + if padding is None: + padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size) + + return torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) diff --git a/packages/ltx-core/src/ltx_core/model/audio_vae/causality_axis.py b/packages/ltx-core/src/ltx_core/model/audio_vae/causality_axis.py new file mode 100644 index 0000000000000000000000000000000000000000..b99f83550f3e73658b05b4c467d78ecb330b1822 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/audio_vae/causality_axis.py @@ -0,0 +1,10 @@ +from enum import Enum + + +class CausalityAxis(Enum): + """Enum for specifying the causality axis in causal convolutions.""" + + NONE = None + WIDTH = "width" + HEIGHT = "height" + WIDTH_COMPATIBILITY = "width-compatibility" diff --git a/packages/ltx-core/src/ltx_core/model/audio_vae/downsample.py b/packages/ltx-core/src/ltx_core/model/audio_vae/downsample.py new file mode 100644 index 0000000000000000000000000000000000000000..336735bcb4392bbce38015c2ed51a935b03c6260 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/audio_vae/downsample.py @@ -0,0 +1,110 @@ +from typing import Set, Tuple + +import torch + +from ltx_core.model.audio_vae.attention import AttentionType, make_attn +from ltx_core.model.audio_vae.causality_axis import CausalityAxis +from ltx_core.model.audio_vae.resnet import ResnetBlock +from ltx_core.model.common.normalization import NormType + + +class Downsample(torch.nn.Module): + """ + A downsampling layer that can use either a strided convolution + or average pooling. Supports standard and causal padding for the + convolutional mode. + """ + + def __init__( + self, + in_channels: int, + with_conv: bool, + causality_axis: CausalityAxis = CausalityAxis.WIDTH, + ) -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + + if self.causality_axis != CausalityAxis.NONE and not self.with_conv: + raise ValueError("causality is only supported when `with_conv=True`.") + + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.with_conv: + # Padding tuple is in the order: (left, right, top, bottom). + match self.causality_axis: + case CausalityAxis.NONE: + pad = (0, 1, 0, 1) + case CausalityAxis.WIDTH: + pad = (2, 0, 0, 1) + case CausalityAxis.HEIGHT: + pad = (0, 1, 2, 0) + case CausalityAxis.WIDTH_COMPATIBILITY: + pad = (1, 0, 0, 1) + case _: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + # This branch is only taken if with_conv=False, which implies causality_axis is NONE. + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + + return x + + +def build_downsampling_path( # noqa: PLR0913 + *, + ch: int, + ch_mult: Tuple[int, ...], + num_resolutions: int, + num_res_blocks: int, + resolution: int, + temb_channels: int, + dropout: float, + norm_type: NormType, + causality_axis: CausalityAxis, + attn_type: AttentionType, + attn_resolutions: Set[int], + resamp_with_conv: bool, +) -> tuple[torch.nn.ModuleList, int]: + """Build the downsampling path with residual blocks, attention, and downsampling layers.""" + down_modules = torch.nn.ModuleList() + curr_res = resolution + in_ch_mult = (1, *tuple(ch_mult)) + block_in = ch + + for i_level in range(num_resolutions): + block = torch.nn.ModuleList() + attn = torch.nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + + for _ in range(num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=temb_channels, + dropout=dropout, + norm_type=norm_type, + causality_axis=causality_axis, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type)) + + down = torch.nn.Module() + down.block = block + down.attn = attn + if i_level != num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis) + curr_res = curr_res // 2 + down_modules.append(down) + + return down_modules, block_in diff --git a/packages/ltx-core/src/ltx_core/model/audio_vae/model_configurator.py b/packages/ltx-core/src/ltx_core/model/audio_vae/model_configurator.py new file mode 100644 index 0000000000000000000000000000000000000000..ab89cc7cbada36c21d441a761dfd5b23525f0d34 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/audio_vae/model_configurator.py @@ -0,0 +1,123 @@ +from ltx_core.loader.sd_ops import SDOps +from ltx_core.model.audio_vae.attention import AttentionType +from ltx_core.model.audio_vae.audio_vae import Decoder, Encoder +from ltx_core.model.audio_vae.causality_axis import CausalityAxis +from ltx_core.model.audio_vae.vocoder import Vocoder +from ltx_core.model.common.normalization import NormType +from ltx_core.model.model_protocol import ModelConfigurator + + +class VocoderConfigurator(ModelConfigurator[Vocoder]): + @classmethod + def from_config(cls: type[Vocoder], config: dict) -> Vocoder: + config = config.get("vocoder", {}) + return Vocoder( + resblock_kernel_sizes=config.get("resblock_kernel_sizes", [3, 7, 11]), + upsample_rates=config.get("upsample_rates", [6, 5, 2, 2, 2]), + upsample_kernel_sizes=config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4]), + resblock_dilation_sizes=config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]]), + upsample_initial_channel=config.get("upsample_initial_channel", 1024), + stereo=config.get("stereo", True), + resblock=config.get("resblock", "1"), + output_sample_rate=config.get("output_sample_rate", 24000), + ) + + +VOCODER_COMFY_KEYS_FILTER = ( + SDOps("VOCODER_COMFY_KEYS_FILTER").with_matching(prefix="vocoder.").with_replacement("vocoder.", "") +) + + +class VAEDecoderConfigurator(ModelConfigurator[Decoder]): + @classmethod + def from_config(cls: type[Decoder], config: dict) -> Decoder: + audio_vae_cfg = config.get("audio_vae", {}) + model_cfg = audio_vae_cfg.get("model", {}) + model_params = model_cfg.get("params", {}) + ddconfig = model_params.get("ddconfig", {}) + preprocessing_cfg = audio_vae_cfg.get("preprocessing", {}) + stft_cfg = preprocessing_cfg.get("stft", {}) + mel_cfg = preprocessing_cfg.get("mel", {}) + variables_cfg = audio_vae_cfg.get("variables", {}) + + sample_rate = model_params.get("sampling_rate", 16000) + mel_hop_length = stft_cfg.get("hop_length", 160) + is_causal = stft_cfg.get("causal", True) + mel_bins = ddconfig.get("mel_bins") or mel_cfg.get("n_mel_channels") or variables_cfg.get("mel_bins") + + return Decoder( + ch=ddconfig.get("ch", 128), + out_ch=ddconfig.get("out_ch", 2), + ch_mult=tuple(ddconfig.get("ch_mult", (1, 2, 4))), + num_res_blocks=ddconfig.get("num_res_blocks", 2), + attn_resolutions=ddconfig.get("attn_resolutions", {8, 16, 32}), + resolution=ddconfig.get("resolution", 256), + z_channels=ddconfig.get("z_channels", 8), + norm_type=NormType(ddconfig.get("norm_type", "pixel")), + causality_axis=CausalityAxis(ddconfig.get("causality_axis", "height")), + dropout=ddconfig.get("dropout", 0.0), + mid_block_add_attention=ddconfig.get("mid_block_add_attention", True), + sample_rate=sample_rate, + mel_hop_length=mel_hop_length, + is_causal=is_causal, + mel_bins=mel_bins, + ) + + +class VAEEncoderConfigurator(ModelConfigurator[Encoder]): + @classmethod + def from_config(cls: type[Encoder], config: dict) -> Encoder: + audio_vae_cfg = config.get("audio_vae", {}) + model_cfg = audio_vae_cfg.get("model", {}) + model_params = model_cfg.get("params", {}) + ddconfig = model_params.get("ddconfig", {}) + preprocessing_cfg = audio_vae_cfg.get("preprocessing", {}) + stft_cfg = preprocessing_cfg.get("stft", {}) + mel_cfg = preprocessing_cfg.get("mel", {}) + variables_cfg = audio_vae_cfg.get("variables", {}) + + sample_rate = model_params.get("sampling_rate", 16000) + mel_hop_length = stft_cfg.get("hop_length", 160) + n_fft = stft_cfg.get("filter_length", 1024) + is_causal = stft_cfg.get("causal", True) + mel_bins = ddconfig.get("mel_bins") or mel_cfg.get("n_mel_channels") or variables_cfg.get("mel_bins") + + return Encoder( + ch=ddconfig.get("ch", 128), + ch_mult=tuple(ddconfig.get("ch_mult", (1, 2, 4))), + num_res_blocks=ddconfig.get("num_res_blocks", 2), + attn_resolutions=ddconfig.get("attn_resolutions", {8, 16, 32}), + resolution=ddconfig.get("resolution", 256), + z_channels=ddconfig.get("z_channels", 8), + double_z=ddconfig.get("double_z", True), + dropout=ddconfig.get("dropout", 0.0), + resamp_with_conv=ddconfig.get("resamp_with_conv", True), + in_channels=ddconfig.get("in_channels", 2), + attn_type=AttentionType(ddconfig.get("attn_type", "vanilla")), + mid_block_add_attention=ddconfig.get("mid_block_add_attention", True), + norm_type=NormType(ddconfig.get("norm_type", "pixel")), + causality_axis=CausalityAxis(ddconfig.get("causality_axis", "height")), + sample_rate=sample_rate, + mel_hop_length=mel_hop_length, + n_fft=n_fft, + is_causal=is_causal, + mel_bins=mel_bins, + ) + + +AUDIO_VAE_DECODER_COMFY_KEYS_FILTER = ( + SDOps("AUDIO_VAE_DECODER_COMFY_KEYS_FILTER") + .with_matching(prefix="audio_vae.decoder.") + .with_matching(prefix="audio_vae.per_channel_statistics.") + .with_replacement("audio_vae.decoder.", "") + .with_replacement("audio_vae.per_channel_statistics.", "per_channel_statistics.") +) + + +AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER = ( + SDOps("AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER") + .with_matching(prefix="audio_vae.encoder.") + .with_matching(prefix="audio_vae.per_channel_statistics.") + .with_replacement("audio_vae.encoder.", "") + .with_replacement("audio_vae.per_channel_statistics.", "per_channel_statistics.") +) diff --git a/packages/ltx-core/src/ltx_core/model/audio_vae/ops.py b/packages/ltx-core/src/ltx_core/model/audio_vae/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..87a1a8f11fffb69be45d527d3ffee38c07e254fb --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/audio_vae/ops.py @@ -0,0 +1,77 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Ivan Zorin + +import torch +import torchaudio +from torch import nn + + +class AudioProcessor(nn.Module): + def __init__( + self, + sample_rate: int, + mel_bins: int, + mel_hop_length: int, + n_fft: int, + ) -> None: + super().__init__() + self.sample_rate = sample_rate + self.mel_transform = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, + n_fft=n_fft, + win_length=n_fft, + hop_length=mel_hop_length, + f_min=0.0, + f_max=sample_rate / 2.0, + n_mels=mel_bins, + window_fn=torch.hann_window, + center=True, + pad_mode="reflect", + power=1.0, + mel_scale="slaney", + norm="slaney", + ) + + def resample_waveform( + self, + waveform: torch.Tensor, + source_rate: int, + target_rate: int, + ) -> torch.Tensor: + """Resample waveform to target sample rate if needed.""" + if source_rate == target_rate: + return waveform + resampled = torchaudio.functional.resample(waveform, source_rate, target_rate) + return resampled.to(device=waveform.device, dtype=waveform.dtype) + + def waveform_to_mel( + self, + waveform: torch.Tensor, + waveform_sample_rate: int, + ) -> torch.Tensor: + """Convert waveform to log-mel spectrogram [batch, channels, time, n_mels].""" + waveform = self.resample_waveform(waveform, waveform_sample_rate, self.sample_rate) + + mel = self.mel_transform(waveform) + mel = torch.log(torch.clamp(mel, min=1e-5)) + + mel = mel.to(device=waveform.device, dtype=waveform.dtype) + return mel.permute(0, 1, 3, 2).contiguous() + + +class PerChannelStatistics(nn.Module): + """ + Per-channel statistics for normalizing and denormalizing the latent representation. + This statics is computed over the entire dataset and stored in model's checkpoint under AudioVAE state_dict. + """ + + def __init__(self, latent_channels: int = 128) -> None: + super().__init__() + self.register_buffer("std-of-means", torch.empty(latent_channels)) + self.register_buffer("mean-of-means", torch.empty(latent_channels)) + + def un_normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x) + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x) diff --git a/packages/ltx-core/src/ltx_core/model/audio_vae/resnet.py b/packages/ltx-core/src/ltx_core/model/audio_vae/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a529d6da8853daa1d0e2bb85e8fa2432dff5723e --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/audio_vae/resnet.py @@ -0,0 +1,176 @@ +from typing import Tuple + +import torch + +from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d +from ltx_core.model.audio_vae.causality_axis import CausalityAxis +from ltx_core.model.common.normalization import NormType, build_normalization_layer + +LRELU_SLOPE = 0.1 + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int, int] = (1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = torch.nn.ModuleList( + [ + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding="same", + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding="same", + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding="same", + ), + ] + ) + + self.convs2 = torch.nn.ModuleList( + [ + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding="same", + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding="same", + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding="same", + ), + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for conv1, conv2 in zip(self.convs1, self.convs2, strict=True): + xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) + xt = conv1(xt) + xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE) + xt = conv2(xt) + x = xt + x + return x + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int] = (1, 3)): + super(ResBlock2, self).__init__() + self.convs = torch.nn.ModuleList( + [ + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding="same", + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding="same", + ), + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for conv in self.convs: + xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) + xt = conv(xt) + x = xt + x + return x + + +class ResnetBlock(torch.nn.Module): + def __init__( + self, + *, + in_channels: int, + out_channels: int | None = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + norm_type: NormType = NormType.GROUP, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ) -> None: + super().__init__() + self.causality_axis = causality_axis + + if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP: + raise ValueError("Causal ResnetBlock with GroupNorm is not supported.") + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = build_normalization_layer(in_channels, normtype=norm_type) + self.non_linearity = torch.nn.SiLU() + self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = build_normalization_layer(out_channels, normtype=norm_type) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.nin_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis + ) + + def forward( + self, + x: torch.Tensor, + temb: torch.Tensor | None = None, + ) -> torch.Tensor: + h = x + h = self.norm1(h) + h = self.non_linearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = self.non_linearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x) + + return x + h diff --git a/packages/ltx-core/src/ltx_core/model/audio_vae/upsample.py b/packages/ltx-core/src/ltx_core/model/audio_vae/upsample.py new file mode 100644 index 0000000000000000000000000000000000000000..3046e210ec62fef1118f88126b1e64ed7149a15f --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/audio_vae/upsample.py @@ -0,0 +1,106 @@ +from typing import Set, Tuple + +import torch + +from ltx_core.model.audio_vae.attention import AttentionType, make_attn +from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d +from ltx_core.model.audio_vae.causality_axis import CausalityAxis +from ltx_core.model.audio_vae.resnet import ResnetBlock +from ltx_core.model.common.normalization import NormType + + +class Upsample(torch.nn.Module): + def __init__( + self, + in_channels: int, + with_conv: bool, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ) -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + if self.with_conv: + self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + # Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n. + # For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2]. + # The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2], + # So the output elements rely on the following windows: + # 0: [-,-,0] + # 1: [-,0,0] + # 2: [0,0,1] + # 3: [0,1,1] + # 4: [1,1,2] + # 5: [1,2,2] + # Notice that the first and second elements in the output rely only on the first element in the input, + # while all other elements rely on two elements in the input. + # So we can drop the first element to undo the padding (rather than the last element). + # This is a no-op for non-causal convolutions. + match self.causality_axis: + case CausalityAxis.NONE: + pass # x remains unchanged + case CausalityAxis.HEIGHT: + x = x[:, :, 1:, :] + case CausalityAxis.WIDTH: + x = x[:, :, :, 1:] + case CausalityAxis.WIDTH_COMPATIBILITY: + pass # x remains unchanged + case _: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + return x + + +def build_upsampling_path( # noqa: PLR0913 + *, + ch: int, + ch_mult: Tuple[int, ...], + num_resolutions: int, + num_res_blocks: int, + resolution: int, + temb_channels: int, + dropout: float, + norm_type: NormType, + causality_axis: CausalityAxis, + attn_type: AttentionType, + attn_resolutions: Set[int], + resamp_with_conv: bool, + initial_block_channels: int, +) -> tuple[torch.nn.ModuleList, int]: + """Build the upsampling path with residual blocks, attention, and upsampling layers.""" + up_modules = torch.nn.ModuleList() + block_in = initial_block_channels + curr_res = resolution // (2 ** (num_resolutions - 1)) + + for level in reversed(range(num_resolutions)): + stage = torch.nn.Module() + stage.block = torch.nn.ModuleList() + stage.attn = torch.nn.ModuleList() + block_out = ch * ch_mult[level] + + for _ in range(num_res_blocks + 1): + stage.block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=temb_channels, + dropout=dropout, + norm_type=norm_type, + causality_axis=causality_axis, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + stage.attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type)) + + if level != 0: + stage.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis) + curr_res *= 2 + + up_modules.insert(0, stage) + + return up_modules, block_in diff --git a/packages/ltx-core/src/ltx_core/model/audio_vae/vocoder.py b/packages/ltx-core/src/ltx_core/model/audio_vae/vocoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d8a1b418d69ddaf6800e2380238ac6094416b17b --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/audio_vae/vocoder.py @@ -0,0 +1,129 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Ivan Zorin + +import math +from typing import List + +import einops +import torch +import torch.nn.functional as F +from torch import nn + +from ltx_core.model.audio_vae.resnet import LRELU_SLOPE, ResBlock1, ResBlock2 + + +class Vocoder(torch.nn.Module): + """ + Vocoder model for synthesizing audio from Mel spectrograms. + + Args: + resblock_kernel_sizes: List of kernel sizes for the residual blocks. + This value is read from the checkpoint at `config.vocoder.resblock_kernel_sizes`. + upsample_rates: List of upsampling rates. + This value is read from the checkpoint at `config.vocoder.upsample_rates`. + upsample_kernel_sizes: List of kernel sizes for the upsampling layers. + This value is read from the checkpoint at `config.vocoder.upsample_kernel_sizes`. + resblock_dilation_sizes: List of dilation sizes for the residual blocks. + This value is read from the checkpoint at `config.vocoder.resblock_dilation_sizes`. + upsample_initial_channel: Initial number of channels for the upsampling layers. + This value is read from the checkpoint at `config.vocoder.upsample_initial_channel`. + stereo: Whether to use stereo output. + This value is read from the checkpoint at `config.vocoder.stereo`. + resblock: Type of residual block to use. + This value is read from the checkpoint at `config.vocoder.resblock`. + output_sample_rate: Waveform sample rate. + This value is read from the checkpoint at `config.vocoder.output_sample_rate`. + """ + + def __init__( + self, + resblock_kernel_sizes: List[int] | None = None, + upsample_rates: List[int] | None = None, + upsample_kernel_sizes: List[int] | None = None, + resblock_dilation_sizes: List[List[int]] | None = None, + upsample_initial_channel: int = 1024, + stereo: bool = True, + resblock: str = "1", + output_sample_rate: int = 24000, + ): + super().__init__() + + # Initialize default values if not provided. Note that mutable default values are not supported. + if resblock_kernel_sizes is None: + resblock_kernel_sizes = [3, 7, 11] + if upsample_rates is None: + upsample_rates = [6, 5, 2, 2, 2] + if upsample_kernel_sizes is None: + upsample_kernel_sizes = [16, 15, 8, 4, 4] + if resblock_dilation_sizes is None: + resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + + self.output_sample_rate = output_sample_rate + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + in_channels = 128 if stereo else 64 + self.conv_pre = nn.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3) + resblock_class = ResBlock1 if resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True)): + self.ups.append( + nn.ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + kernel_size, + stride, + padding=(kernel_size - stride) // 2, + ) + ) + + self.resblocks = nn.ModuleList() + for i, _ in enumerate(self.ups): + ch = upsample_initial_channel // (2 ** (i + 1)) + for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes, strict=True): + self.resblocks.append(resblock_class(ch, kernel_size, dilations)) + + out_channels = 2 if stereo else 1 + final_channels = upsample_initial_channel // (2**self.num_upsamples) + self.conv_post = nn.Conv1d(final_channels, out_channels, 7, 1, padding=3) + + self.upsample_factor = math.prod(layer.stride[0] for layer in self.ups) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the vocoder. + + Args: + x: Input Mel spectrogram tensor. Can be either: + - 3D: (batch_size, time, mel_bins) for mono + - 4D: (batch_size, 2, time, mel_bins) for stereo + + Returns: + Audio waveform tensor of shape (batch_size, out_channels, audio_length) + """ + x = x.transpose(2, 3) # (batch, channels, time, mel_bins) -> (batch, channels, mel_bins, time) + + if x.dim() == 4: # stereo + assert x.shape[1] == 2, "Input must have 2 channels for stereo" + x = einops.rearrange(x, "b s c t -> b (s c) t") + + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + start = i * self.num_kernels + end = start + self.num_kernels + + # Evaluate all resblocks with the same input tensor so they can run + # independently (and thus in parallel on accelerator hardware) before + # aggregating their outputs via mean. + block_outputs = torch.stack( + [self.resblocks[idx](x) for idx in range(start, end)], + dim=0, + ) + + x = block_outputs.mean(dim=0) + + x = self.conv_post(F.leaky_relu(x)) + return torch.tanh(x) diff --git a/packages/ltx-core/src/ltx_core/model/clip/__init__.py b/packages/ltx-core/src/ltx_core/model/clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/packages/ltx-core/src/ltx_core/model/clip/__pycache__/__init__.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/clip/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c38904ef014483a99b3a6e5d02aea23ed70ad60 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/clip/__pycache__/__init__.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/clip/gemma/.ipynb_checkpoints/feature_extractor-checkpoint.py b/packages/ltx-core/src/ltx_core/model/clip/gemma/.ipynb_checkpoints/feature_extractor-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..e681236409df88f15667bb2063467985a86d0485 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/clip/gemma/.ipynb_checkpoints/feature_extractor-checkpoint.py @@ -0,0 +1,46 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Amit Pintz. + +# from typing import Self +from typing_extensions import Self +import torch + +from ltx_core.model.model_protocol import ModelConfigurator + + +class GemmaFeaturesExtractorProjLinear(torch.nn.Module, ModelConfigurator[Self]): + """ + Feature extractor module for Gemma models. + + This module applies a single linear projection to the input tensor. + It expects a flattened feature tensor of shape (batch_size, 3840*49). + The linear layer maps this to a (batch_size, 3840) embedding. + + Attributes: + aggregate_embed (torch.nn.Linear): Linear projection layer. + """ + + def __init__(self) -> None: + """ + Initialize the GemmaFeaturesExtractorProjLinear module. + + The input dimension is expected to be 3840 * 49, and the output is 3840. + """ + super().__init__() + self.aggregate_embed = torch.nn.Linear(3840 * 49, 3840, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the feature extractor. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, 3840 * 49). + + Returns: + torch.Tensor: Output tensor of shape (batch_size, 3840). + """ + return self.aggregate_embed(x) + + @classmethod + def from_config(cls: type[Self], _config: dict) -> Self: + return cls() diff --git a/packages/ltx-core/src/ltx_core/model/clip/gemma/__init__.py b/packages/ltx-core/src/ltx_core/model/clip/gemma/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/packages/ltx-core/src/ltx_core/model/clip/gemma/__pycache__/__init__.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/clip/gemma/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e6c8836aac030b856ff002bc461fa5311a25a8f Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/clip/gemma/__pycache__/__init__.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/clip/gemma/__pycache__/embeddings_connector.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/clip/gemma/__pycache__/embeddings_connector.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a5314cb8584f19ad04f2f078cce772965a7b331 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/clip/gemma/__pycache__/embeddings_connector.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/clip/gemma/__pycache__/feature_extractor.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/clip/gemma/__pycache__/feature_extractor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..403d5092849b70a5941fa45f8dccf0c100dfa9ab Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/clip/gemma/__pycache__/feature_extractor.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/clip/gemma/__pycache__/tokenizer.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/clip/gemma/__pycache__/tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de1c95fd0363347a8fbfe2ebdd3325cd96981a23 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/clip/gemma/__pycache__/tokenizer.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/clip/gemma/embeddings_connector.py b/packages/ltx-core/src/ltx_core/model/clip/gemma/embeddings_connector.py new file mode 100644 index 0000000000000000000000000000000000000000..741747234eb01bed3f681383e814e370771cf6a5 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/clip/gemma/embeddings_connector.py @@ -0,0 +1,216 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Amit Pintz. + +import torch + +from ltx_core.model.model_protocol import ModelConfigurator +from ltx_core.model.transformer.attention import Attention +from ltx_core.model.transformer.feed_forward import FeedForward +from ltx_core.model.transformer.rope import ( + LTXRopeType, + generate_freq_grid_np, + generate_freq_grid_pytorch, + precompute_freqs_cis, +) +from ltx_core.utils import rms_norm + + +class _BasicTransformerBlock1D(torch.nn.Module): + def __init__( + self, + dim: int, + heads: int, + dim_head: int, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + ): + super().__init__() + + self.attn1 = Attention( + query_dim=dim, + heads=heads, + dim_head=dim_head, + rope_type=rope_type, + ) + + self.ff = FeedForward( + dim, + dim_out=dim, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + pe: torch.Tensor | None = None, + ) -> torch.Tensor: + # Notice that normalization is always applied before the real computation in the following blocks. + + # 1. Normalization Before Self-Attention + norm_hidden_states = rms_norm(hidden_states) + + norm_hidden_states = norm_hidden_states.squeeze(1) + + # 2. Self-Attention + attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe) + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 3. Normalization before Feed-Forward + norm_hidden_states = rms_norm(hidden_states) + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class Embeddings1DConnector(torch.nn.Module): + """ + Embeddings1DConnector applies a 1D transformer-based processing to sequential embeddings (e.g., for video, audio, or + other modalities). It supports rotary positional encoding (rope), optional causal temporal positioning, and can + substitute padded positions with learnable registers. The module is highly configurable for head size, number of + layers, and register usage. + + Args: + attention_head_dim (int): Dimension of each attention head (default=128). + num_attention_heads (int): Number of attention heads (default=30). + num_layers (int): Number of transformer layers (default=2). + positional_embedding_theta (float): Scaling factor for position embedding (default=10000.0). + positional_embedding_max_pos (list[int] | None): Max positions for positional embeddings (default=[1]). + causal_temporal_positioning (bool): If True, uses causal attention (default=False). + num_learnable_registers (int | None): Number of learnable registers to replace padded tokens. If None, disables + register replacement. (default=128) + rope_type (LTXRopeType): The RoPE variant to use (default=DEFAULT_ROPE_TYPE). + double_precision_rope (bool): Use double precision rope calculation (default=False). + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + attention_head_dim: int = 128, + num_attention_heads: int = 30, + num_layers: int = 2, + positional_embedding_theta: float = 10000.0, + positional_embedding_max_pos: list[int] | None = None, + causal_temporal_positioning: bool = False, + num_learnable_registers: int | None = 128, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + double_precision_rope: bool = False, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self.causal_temporal_positioning = causal_temporal_positioning + self.positional_embedding_theta = positional_embedding_theta + self.positional_embedding_max_pos = ( + positional_embedding_max_pos if positional_embedding_max_pos is not None else [1] + ) + self.rope_type = rope_type + self.double_precision_rope = double_precision_rope + self.transformer_1d_blocks = torch.nn.ModuleList( + [ + _BasicTransformerBlock1D( + dim=self.inner_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + rope_type=rope_type, + ) + for _ in range(num_layers) + ] + ) + + self.num_learnable_registers = num_learnable_registers + if self.num_learnable_registers: + self.learnable_registers = torch.nn.Parameter( + torch.rand(self.num_learnable_registers, self.inner_dim, dtype=torch.bfloat16) * 2.0 - 1.0 + ) + + def _replace_padded_with_learnable_registers( + self, hidden_states: torch.Tensor, attention_mask: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + assert hidden_states.shape[1] % self.num_learnable_registers == 0, ( + f"Hidden states sequence length {hidden_states.shape[1]} must be divisible by num_learnable_registers " + f"{self.num_learnable_registers}." + ) + + num_registers_duplications = hidden_states.shape[1] // self.num_learnable_registers + learnable_registers = torch.tile(self.learnable_registers, (num_registers_duplications, 1)) + attention_mask_binary = (attention_mask.squeeze(1).squeeze(1).unsqueeze(-1) >= -9000.0).int() + + non_zero_hidden_states = hidden_states[:, attention_mask_binary.squeeze().bool(), :] + non_zero_nums = non_zero_hidden_states.shape[1] + pad_length = hidden_states.shape[1] - non_zero_nums + adjusted_hidden_states = torch.nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0) + flipped_mask = torch.flip(attention_mask_binary, dims=[1]) + hidden_states = flipped_mask * adjusted_hidden_states + (1 - flipped_mask) * learnable_registers + + attention_mask = torch.full_like( + attention_mask, + 0.0, + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + return hidden_states, attention_mask + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass of Embeddings1DConnector. + + Args: + hidden_states (torch.Tensor): Input tensor of embeddings (shape [batch, seq_len, feature_dim]). + attention_mask (torch.Tensor|None): Optional mask for valid tokens (shape compatible with hidden_states). + + Returns: + tuple[torch.Tensor, torch.Tensor]: Processed features and the corresponding (possibly modified) mask. + """ + if self.num_learnable_registers: + hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask) + + indices_grid = torch.arange(hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device) + indices_grid = indices_grid[None, None, :] + freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch + freqs_cis = precompute_freqs_cis( + indices_grid=indices_grid, + dim=self.inner_dim, + out_dtype=hidden_states.dtype, + theta=self.positional_embedding_theta, + max_pos=self.positional_embedding_max_pos, + num_attention_heads=self.num_attention_heads, + rope_type=self.rope_type, + freq_grid_generator=freq_grid_generator, + ) + + for block in self.transformer_1d_blocks: + hidden_states = block(hidden_states, attention_mask=attention_mask, pe=freqs_cis) + + hidden_states = rms_norm(hidden_states) + + return hidden_states, attention_mask + + +class Embeddings1DConnectorConfigurator(ModelConfigurator[Embeddings1DConnector]): + @classmethod + def from_config(cls: type[Embeddings1DConnector], config: dict) -> Embeddings1DConnector: + config = config.get("transformer", {}) + rope_type = LTXRopeType(config.get("rope_type", "interleaved")) + double_precision_rope = config.get("frequencies_precision", False) == "float64" + pe_max_pos = config.get("connector_positional_embedding_max_pos", [1]) + + connector = Embeddings1DConnector( + positional_embedding_max_pos=pe_max_pos, + rope_type=rope_type, + double_precision_rope=double_precision_rope, + ) + return connector diff --git a/packages/ltx-core/src/ltx_core/model/clip/gemma/encoders/.ipynb_checkpoints/av_encoder-checkpoint.py b/packages/ltx-core/src/ltx_core/model/clip/gemma/encoders/.ipynb_checkpoints/av_encoder-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..8a934a1e8f194cdb398ba1da5be94bec04fa39a6 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/clip/gemma/encoders/.ipynb_checkpoints/av_encoder-checkpoint.py @@ -0,0 +1,101 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Amit Pintz. + +# from typing import NamedTuple, Self +from typing import NamedTuple +from typing_extensions import Self +import torch +from transformers.models.gemma3 import Gemma3ForConditionalGeneration + +from ltx_core.loader.sd_ops import SDOps +from ltx_core.model.clip.gemma.embeddings_connector import ( + Embeddings1DConnector, + Embeddings1DConnectorConfigurator, +) +from ltx_core.model.clip.gemma.encoders.base_encoder import ( + GemmaTextEncoderModelBase, +) +from ltx_core.model.clip.gemma.feature_extractor import GemmaFeaturesExtractorProjLinear +from ltx_core.model.clip.gemma.tokenizer import LTXVGemmaTokenizer +from ltx_core.model.model_protocol import ModelConfigurator + + +class AVGemmaEncoderOutput(NamedTuple): + video_encoding: torch.Tensor + audio_encoding: torch.Tensor + attention_mask: torch.Tensor + + +class AVGemmaTextEncoderModel(GemmaTextEncoderModelBase): + """ + AVGemma Text Encoder Model. + + This class combines the tokenizer, Gemma model, feature extractor from base class and a + video and audio embeddings connectors to provide a preprocessing for audio-visual pipeline. + """ + + def __init__( + self, + feature_extractor_linear: GemmaFeaturesExtractorProjLinear, + embeddings_connector: Embeddings1DConnector, + audio_embeddings_connector: Embeddings1DConnector, + tokenizer: LTXVGemmaTokenizer | None = None, + model: Gemma3ForConditionalGeneration | None = None, + dtype: torch.dtype = torch.bfloat16, + ) -> None: + super().__init__( + feature_extractor_linear=feature_extractor_linear, + tokenizer=tokenizer, + model=model, + dtype=dtype, + ) + self.embeddings_connector = embeddings_connector.to(dtype=dtype) + self.audio_embeddings_connector = audio_embeddings_connector.to(dtype=dtype) + + def _run_connectors( + self, encoded_input: torch.Tensor, attention_mask: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + connector_attention_mask = self._convert_to_additive_mask(attention_mask, encoded_input.dtype) + + encoded, encoded_connector_attention_mask = self.embeddings_connector( + encoded_input, + connector_attention_mask, + ) + + # restore the mask values to int64 + attention_mask = (encoded_connector_attention_mask < 0.000001).to(torch.int64) + attention_mask = attention_mask.reshape([encoded.shape[0], encoded.shape[1], 1]) + encoded = encoded * attention_mask + + encoded_for_audio, _ = self.audio_embeddings_connector(encoded_input, connector_attention_mask) + + return encoded, encoded_for_audio, attention_mask.squeeze(-1) + + def forward(self, text: str, padding_side: str = "left") -> AVGemmaEncoderOutput: + encoded_inputs, attention_mask = self._preprocess_text(text, padding_side) + video_encoding, audio_encoding, attention_mask = self._run_connectors(encoded_inputs, attention_mask) + return AVGemmaEncoderOutput(video_encoding, audio_encoding, attention_mask) + + +class AVGemmaTextEncoderModelConfigurator(ModelConfigurator[AVGemmaTextEncoderModel]): + @classmethod + def from_config(cls: type[Self], config: dict) -> Self: + feature_extractor_linear = GemmaFeaturesExtractorProjLinear.from_config(config) + embeddings_connector = Embeddings1DConnectorConfigurator.from_config(config) + audio_embeddings_connector = Embeddings1DConnectorConfigurator.from_config(config) + return AVGemmaTextEncoderModel( + feature_extractor_linear=feature_extractor_linear, + embeddings_connector=embeddings_connector, + audio_embeddings_connector=audio_embeddings_connector, + ) + + +AV_GEMMA_TEXT_ENCODER_KEY_OPS = ( + SDOps("AV_GEMMA_TEXT_ENCODER_KEY_OPS") + .with_matching(prefix="text_embedding_projection.") + .with_matching(prefix="model.diffusion_model.audio_embeddings_connector.") + .with_matching(prefix="model.diffusion_model.video_embeddings_connector.") + .with_replacement("text_embedding_projection.", "feature_extractor_linear.") + .with_replacement("model.diffusion_model.video_embeddings_connector.", "embeddings_connector.") + .with_replacement("model.diffusion_model.audio_embeddings_connector.", "audio_embeddings_connector.") +) diff --git a/packages/ltx-core/src/ltx_core/model/clip/gemma/encoders/__pycache__/av_encoder.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/clip/gemma/encoders/__pycache__/av_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ec48ab7a91e82367c4533ffbab3ff5348359a99 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/clip/gemma/encoders/__pycache__/av_encoder.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/clip/gemma/encoders/__pycache__/base_encoder.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/clip/gemma/encoders/__pycache__/base_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fef63a26c85fb27b3f03e689634596bd1760f6bb Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/clip/gemma/encoders/__pycache__/base_encoder.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/clip/gemma/encoders/av_encoder.py b/packages/ltx-core/src/ltx_core/model/clip/gemma/encoders/av_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..8a934a1e8f194cdb398ba1da5be94bec04fa39a6 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/clip/gemma/encoders/av_encoder.py @@ -0,0 +1,101 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Amit Pintz. + +# from typing import NamedTuple, Self +from typing import NamedTuple +from typing_extensions import Self +import torch +from transformers.models.gemma3 import Gemma3ForConditionalGeneration + +from ltx_core.loader.sd_ops import SDOps +from ltx_core.model.clip.gemma.embeddings_connector import ( + Embeddings1DConnector, + Embeddings1DConnectorConfigurator, +) +from ltx_core.model.clip.gemma.encoders.base_encoder import ( + GemmaTextEncoderModelBase, +) +from ltx_core.model.clip.gemma.feature_extractor import GemmaFeaturesExtractorProjLinear +from ltx_core.model.clip.gemma.tokenizer import LTXVGemmaTokenizer +from ltx_core.model.model_protocol import ModelConfigurator + + +class AVGemmaEncoderOutput(NamedTuple): + video_encoding: torch.Tensor + audio_encoding: torch.Tensor + attention_mask: torch.Tensor + + +class AVGemmaTextEncoderModel(GemmaTextEncoderModelBase): + """ + AVGemma Text Encoder Model. + + This class combines the tokenizer, Gemma model, feature extractor from base class and a + video and audio embeddings connectors to provide a preprocessing for audio-visual pipeline. + """ + + def __init__( + self, + feature_extractor_linear: GemmaFeaturesExtractorProjLinear, + embeddings_connector: Embeddings1DConnector, + audio_embeddings_connector: Embeddings1DConnector, + tokenizer: LTXVGemmaTokenizer | None = None, + model: Gemma3ForConditionalGeneration | None = None, + dtype: torch.dtype = torch.bfloat16, + ) -> None: + super().__init__( + feature_extractor_linear=feature_extractor_linear, + tokenizer=tokenizer, + model=model, + dtype=dtype, + ) + self.embeddings_connector = embeddings_connector.to(dtype=dtype) + self.audio_embeddings_connector = audio_embeddings_connector.to(dtype=dtype) + + def _run_connectors( + self, encoded_input: torch.Tensor, attention_mask: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + connector_attention_mask = self._convert_to_additive_mask(attention_mask, encoded_input.dtype) + + encoded, encoded_connector_attention_mask = self.embeddings_connector( + encoded_input, + connector_attention_mask, + ) + + # restore the mask values to int64 + attention_mask = (encoded_connector_attention_mask < 0.000001).to(torch.int64) + attention_mask = attention_mask.reshape([encoded.shape[0], encoded.shape[1], 1]) + encoded = encoded * attention_mask + + encoded_for_audio, _ = self.audio_embeddings_connector(encoded_input, connector_attention_mask) + + return encoded, encoded_for_audio, attention_mask.squeeze(-1) + + def forward(self, text: str, padding_side: str = "left") -> AVGemmaEncoderOutput: + encoded_inputs, attention_mask = self._preprocess_text(text, padding_side) + video_encoding, audio_encoding, attention_mask = self._run_connectors(encoded_inputs, attention_mask) + return AVGemmaEncoderOutput(video_encoding, audio_encoding, attention_mask) + + +class AVGemmaTextEncoderModelConfigurator(ModelConfigurator[AVGemmaTextEncoderModel]): + @classmethod + def from_config(cls: type[Self], config: dict) -> Self: + feature_extractor_linear = GemmaFeaturesExtractorProjLinear.from_config(config) + embeddings_connector = Embeddings1DConnectorConfigurator.from_config(config) + audio_embeddings_connector = Embeddings1DConnectorConfigurator.from_config(config) + return AVGemmaTextEncoderModel( + feature_extractor_linear=feature_extractor_linear, + embeddings_connector=embeddings_connector, + audio_embeddings_connector=audio_embeddings_connector, + ) + + +AV_GEMMA_TEXT_ENCODER_KEY_OPS = ( + SDOps("AV_GEMMA_TEXT_ENCODER_KEY_OPS") + .with_matching(prefix="text_embedding_projection.") + .with_matching(prefix="model.diffusion_model.audio_embeddings_connector.") + .with_matching(prefix="model.diffusion_model.video_embeddings_connector.") + .with_replacement("text_embedding_projection.", "feature_extractor_linear.") + .with_replacement("model.diffusion_model.video_embeddings_connector.", "embeddings_connector.") + .with_replacement("model.diffusion_model.audio_embeddings_connector.", "audio_embeddings_connector.") +) diff --git a/packages/ltx-core/src/ltx_core/model/clip/gemma/encoders/base_encoder.py b/packages/ltx-core/src/ltx_core/model/clip/gemma/encoders/base_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..9830ed2b2eb05e49f87136dec1c6be1727ae794a --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/clip/gemma/encoders/base_encoder.py @@ -0,0 +1,294 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Amit Pintz. + +import functools +from pathlib import Path + +import torch +from einops import rearrange +from PIL import Image as PILImage +from transformers import AutoImageProcessor, Gemma3ForConditionalGeneration, Gemma3Processor + +from ltx_core.loader.module_ops import ModuleOps +from ltx_core.model.clip.gemma.feature_extractor import ( + GemmaFeaturesExtractorProjLinear, +) +from ltx_core.model.clip.gemma.tokenizer import LTXVGemmaTokenizer + + +class GemmaTextEncoderModelBase(torch.nn.Module): + """ + Gemma Text Encoder Model. + + This base class combines the tokenizer, Gemma model and feature extractor to provide a preprocessing + for implementation classes for multimodal pipelines. It processes input text through tokenization, + obtains hidden states from the base language model, applies a linear feature extractor. + + Args: + tokenizer (LTXVGemmaTokenizer): The tokenizer used for text preprocessing. + model (Gemma3ForConditionalGeneration): The base Gemma LLM. + feature_extractor_linear (GemmaFeaturesExtractorProjLinear): Linear projection for hidden state aggregation. + dtype (torch.dtype, optional): The data type for model parameters (default: torch.bfloat16). + """ + + def __init__( + self, + feature_extractor_linear: GemmaFeaturesExtractorProjLinear, + tokenizer: LTXVGemmaTokenizer | None = None, + model: Gemma3ForConditionalGeneration | None = None, + img_processor: Gemma3Processor | None = None, + dtype: torch.dtype = torch.bfloat16, + ) -> None: + super().__init__() + self._gemma_root = None + self.tokenizer = tokenizer + self.model = model + self.processor = img_processor + self.feature_extractor_linear = feature_extractor_linear.to(dtype=dtype) + + def _run_feature_extractor( + self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, padding_side: str = "right" + ) -> torch.Tensor: + encoded_text_features = torch.stack(hidden_states, dim=-1) + encoded_text_features_dtype = encoded_text_features.dtype + + sequence_lengths = attention_mask.sum(dim=-1) + normed_concated_encoded_text_features = _norm_and_concat_padded_batch( + encoded_text_features, sequence_lengths, padding_side=padding_side + ) + + return self.feature_extractor_linear(normed_concated_encoded_text_features.to(encoded_text_features_dtype)) + + def _convert_to_additive_mask(self, attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + return (attention_mask - 1).to(dtype).reshape( + (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + ) * torch.finfo(dtype).max + + def _preprocess_text(self, text: str, padding_side: str = "left") -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + """ + Encode a given string into feature tensors suitable for downstream tasks. + + Args: + text (str): Input string to encode. + + Returns: + tuple[torch.Tensor, dict[str, torch.Tensor]]: Encoded features and a dictionary with attention mask. + """ + token_pairs = self.tokenizer.tokenize_with_weights(text)["gemma"] + input_ids = torch.tensor([[t[0] for t in token_pairs]], device=self.model.device) + attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=self.model.device) + outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) + projected = self._run_feature_extractor( + hidden_states=outputs.hidden_states, attention_mask=attention_mask, padding_side=padding_side + ) + return projected, attention_mask + + def _init_image_processor(self) -> None: + img_processor = AutoImageProcessor.from_pretrained(self._gemma_root, local_files_only=True) + if not self.tokenizer: + raise ValueError("Tokenizer is not loaded, cannot load image processor") + self.processor = Gemma3Processor(image_processor=img_processor, tokenizer=self.tokenizer.tokenizer) + + @torch.inference_mode() + def enhance_t2v( + self, + prompt: str, + max_new_tokens: int = 256, + system_prompt: str | None = None, + ) -> str: + """Enhance a text prompt for T2V generation.""" + if self.processor is None: + self._init_image_processor() + system_prompt = system_prompt or self.default_gemma_t2v_system_prompt + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"user prompt: {prompt}"}, + ] + + text = self.processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + model_inputs = self.processor.tokenizer(text, return_tensors="pt").to(self.model.device) + + outputs = self.model.generate( + **model_inputs, + max_new_tokens=max_new_tokens, + do_sample=False, + ) + generated_ids = outputs[0][len(model_inputs.input_ids[0]) :] + enhanced_prompt = self.processor.tokenizer.decode(generated_ids, skip_special_tokens=True) + + return enhanced_prompt + + @torch.inference_mode() + def enhance_i2v( + self, + prompt: str, + image: PILImage.Image, + max_new_tokens: int = 256, + system_prompt: str | None = None, + ) -> str: + """Enhance a text prompt for I2V generation using a reference image.""" + if self.processor is None: + self._init_image_processor() + system_prompt = system_prompt or self.default_gemma_i2v_system_prompt + messages = [ + {"role": "system", "content": system_prompt}, + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": f"user prompt: {prompt}"}, + ], + }, + ] + + text = self.processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + model_inputs = self.processor( + text=text, + images=image, + return_tensors="pt", + ).to(self.model.device) + + outputs = self.model.generate( + **model_inputs, + max_new_tokens=max_new_tokens, + do_sample=False, + ) + generated_ids = outputs[0][len(model_inputs.input_ids[0]) :] + enhanced_prompt = self.processor.tokenizer.decode(generated_ids, skip_special_tokens=True) + + return enhanced_prompt + + @functools.cached_property + def default_gemma_i2v_system_prompt(self) -> str: + return _load_system_prompt("gemma_i2v_system_prompt.txt") + + @functools.cached_property + def default_gemma_t2v_system_prompt(self) -> str: + return _load_system_prompt("gemma_t2v_system_prompt.txt") + + def forward(self, text: str, padding_side: str = "left") -> tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError("This method is not implemented for the base class") + + +def _norm_and_concat_padded_batch( + encoded_text: torch.Tensor, + sequence_lengths: torch.Tensor, + padding_side: str = "right", +) -> torch.Tensor: + """Normalize and flatten multi-layer hidden states, respecting padding. + + Performs per-batch, per-layer normalization using masked mean and range, + then concatenates across the layer dimension. + + Args: + encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers]. + sequence_lengths: Number of valid (non-padded) tokens per batch item. + padding_side: Whether padding is on "left" or "right". + + Returns: + Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers], + with padded positions zeroed out. + """ + b, t, d, l = encoded_text.shape # noqa: E741 + device = encoded_text.device + + # Build mask: [B, T, 1, 1] + token_indices = torch.arange(t, device=device)[None, :] # [1, T] + + if padding_side == "right": + # For right padding, valid tokens are from 0 to sequence_length-1 + mask = token_indices < sequence_lengths[:, None] # [B, T] + elif padding_side == "left": + # For left padding, valid tokens are from (T - sequence_length) to T-1 + start_indices = t - sequence_lengths[:, None] # [B, 1] + mask = token_indices >= start_indices # [B, T] + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + + mask = rearrange(mask, "b t -> b t 1 1") + + eps = 1e-6 + + # Compute masked mean: [B, 1, 1, L] + masked = encoded_text.masked_fill(~mask, 0.0) + denom = (sequence_lengths * d).view(b, 1, 1, 1) + mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps) + + # Compute masked min/max: [B, 1, 1, L] + x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + range_ = x_max - x_min + + # Normalize only the valid tokens + normed = 8 * (encoded_text - mean) / (range_ + eps) + + # concat to be [Batch, T, D * L] - this preserves the original structure + normed = normed.reshape(b, t, -1) # [B, T, D * L] + + # Apply mask to preserve original padding (set padded positions to 0) + mask_flattened = rearrange(mask, "b t 1 1 -> b t 1").expand(-1, -1, d * l) + normed = normed.masked_fill(~mask_flattened, 0.0) + + return normed + + +@functools.lru_cache(maxsize=2) +def _load_system_prompt(prompt_name: str) -> str: + with open(Path(__file__).parent / "prompts" / f"{prompt_name}", "r") as f: + return f.read() + + +def _find_matching_dir(root_path: str, pattern: str) -> str: + """ + Recursively search for files matching a glob pattern and return the parent directory of the first match. + + LT_INTERNAL_BEGIN + Handles both LT internal storage and HuggingFace directory structures for Gemma model files. + See: https://huggingface.co/google/gemma-3-12b-it-qat-q4_0-unquantized + LT_INTERNAL_END + """ + + matches = list(Path(root_path).rglob(pattern)) + if not matches: + raise FileNotFoundError(f"No files matching pattern '{pattern}' found under {root_path}") + return str(matches[0].parent) + + +def module_ops_from_gemma_root(gemma_root: str, local_files_only: bool = True) -> tuple[ModuleOps, ...]: + if len(gemma_root.split("/")) != 2: + gemma_path = _find_matching_dir(gemma_root, "model*.safetensors") + tokenizer_path = _find_matching_dir(gemma_root, "tokenizer.model") + else: + # Hub ID: google/gemma-3-12b-it-qat-q4_0-unquantized + gemma_path = tokenizer_path = gemma_root + + # LT_INTERNAL_BEGIN + # Note: We pass torch_dtype to from_pretrained here to maintain backward compatibility with older versions of + # Transformers. This is necessary to compare results with ComfyUI, which uses an older version that raises an error + # when dtype is passed. Current solution only logs a warning. + # LT_INTERNAL_END + def load_gemma(module: GemmaTextEncoderModelBase) -> GemmaTextEncoderModelBase: + module.model = Gemma3ForConditionalGeneration.from_pretrained( + gemma_path, local_files_only=local_files_only, torch_dtype=torch.bfloat16 + ) + module._gemma_root = module._gemma_root or gemma_root + return module + + def load_tokenizer(module: GemmaTextEncoderModelBase) -> GemmaTextEncoderModelBase: + module.tokenizer = LTXVGemmaTokenizer(tokenizer_path, 1024, local_files_only) + module._gemma_root = module._gemma_root or gemma_root + return module + + gemma_load_ops = ModuleOps( + "GemmaLoad", + matcher=lambda module: isinstance(module, GemmaTextEncoderModelBase) and module.model is None, + mutator=load_gemma, + ) + tokenizer_load_ops = ModuleOps( + "TokenizerLoad", + matcher=lambda module: isinstance(module, GemmaTextEncoderModelBase) and module.tokenizer is None, + mutator=load_tokenizer, + ) + return (gemma_load_ops, tokenizer_load_ops) diff --git a/packages/ltx-core/src/ltx_core/model/clip/gemma/encoders/prompts/gemma_i2v_system_prompt.txt b/packages/ltx-core/src/ltx_core/model/clip/gemma/encoders/prompts/gemma_i2v_system_prompt.txt new file mode 100644 index 0000000000000000000000000000000000000000..f86f9c37098147f60bb1cb002138d24f5584e470 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/clip/gemma/encoders/prompts/gemma_i2v_system_prompt.txt @@ -0,0 +1,32 @@ +You are a prompt enhancer for an image-to-video model. Your task is to take a conditioning image and a user's motion caption, and generate a video prompt that animates the exact scene depicted in the image. + +The Golden Rule: The conditioning image is the absolute, non-negotiable source of truth. You do not invent. You do not imagine. You animate what is visually present. + +The Rule of Priority: Handling Conflicts +The conditioning image is ALWAYS the highest authority. The user's motion caption is a request for how to animate the image, not a command to add new elements or change the scene. + +If a user's caption requests an element that IS NOT in the image (e.g., a "nebula" where there is only a starfield), you must find a visually consistent alternative. Do not add the non-existent element. Instead, animate the existing elements to create a similar effect. For example, instead of adding a gaseous nebula, you would describe the existing stars swirling to imply a forming cloud. +If a user's caption requests an impossible action (e.g., a "stone statue crying"), find a plausible, metaphorical interpretation. For example, you would describe "raindrops streaking down its cheeks like tears," not the statue itself generating water. +Your Internal Thought Process (Mandatory Steps): +Note: These steps are for your internal reasoning only. Do NOT include headings, bullet points, numbered steps, labels, or any meta/explanations about your process in the final output. + +Step 1: Analyze the Image. First, identify the core components of the static image: Subject (e.g., toy astronaut figurines), Setting (e.g., a stylized, blurry starfield background), Elements (e.g., shiny plastic, golden star shapes, lens flares), and Mood (e.g., whimsical, miniature, lo-fi). +Step 2: Interpret the Motion Caption & Resolve Conflicts. Read the user's caption and apply the "Rule of Priority." For every requested action, check if it's visually possible. If not, devise a creative, visually-grounded alternative that captures the user's intent. +Step 3: Synthesize and Write. Combine your image analysis and your resolved motion plan into a single, cinematic paragraph. +Guidelines for Enhancement: + +Anchor to Visuals: Begin by describing the literal, static scene from the image, including its style (e.g., "Two small, toy-like astronaut figurines..."). +Contextual Motion: Add the user-requested motion, ensuring it is applied plausibly to the objects actually in the image, using your resolved interpretations from Step 2. +Cinematic Framing: Use camera dynamics to explore the existing composition. +Atmosphere and Style: Enhance the mood and artistic style already present in the image. +Consistent Audio Layer (always include): Add sounds that match the actual visuals (e.g., ethereal synth for a toy-in-space scene, not rocket noises). Describe how audio contributes to mood, rhythm, or narrative—not just what is heard. +Dialogue (when present): If the input mentions dialogue, explicitly script the exact quoted lines each character says and when they say them; describe each speaker distinctively (apperance, role, clothing, position) so it is unambiguous who speaks when. If a language other than English is required, explicitly state the language for the dialogue lines. +General audio: Weave audio descriptions naturally into the chronological flow of the visual description. DO NOT append all audio at the end as a separate section. Instead, mention sounds as they occur temporally alongside the visual actions. +CRITICAL CONSTRAINT: Your final description must only contain animated elements that are visually derived from the original image. The output must never introduce new objects or phenomena. + +Output Format (Strict): +- Produce a continuous paragraph in natural language. +- Do NOT include titles, headings, section labels, bullet points, numbered lists, or metadata of any kind. +- Do NOT include prefaces like "Okay, here's..." or labels such as "Video Prompt:", "Audio Layer:", or "Internal Thought Process:". +- Do NOT include code fences, Markdown, or JSON—plain prose only. +- Weave audio naturally within the paragraph; never create a separate audio section. diff --git a/packages/ltx-core/src/ltx_core/model/clip/gemma/encoders/prompts/gemma_t2v_system_prompt.txt b/packages/ltx-core/src/ltx_core/model/clip/gemma/encoders/prompts/gemma_t2v_system_prompt.txt new file mode 100644 index 0000000000000000000000000000000000000000..f1c7820e0d8027bb8febd70b9c5508402a46b5b0 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/clip/gemma/encoders/prompts/gemma_t2v_system_prompt.txt @@ -0,0 +1,45 @@ +You are a prompt enhancer for a text-to-video model. Your task is to take an input caption that describes a scene and expand it into a detailed, cinematic, visually expressive, temporally dynamic prompt. The output should maximize visual richness and storytelling while staying faithful to the input description. + +Guidelines for Enhancement: + +- Visual Detail: Add fine-grained visual information about lighting, color palettes, textures, reflections, and atmospheric elements. +- Character & Action: Enrich descriptions of subjects (clothing, expressions, gestures) and specify their interactions with the environment. +- Camera & Composition: Specify cinematic camera work — shot type (wide/close-up), camera movement (static, dolly, pan, tilt), framing, depth of field. +- Environment & Ambiance: Expand on the setting with mood, crowd dynamics, architectural features, props, and immersive background details. +- Temporal Flow: Suggest how the scene evolves over a few seconds — subtle changes in light, character movements, or environmental shifts. +- Style & Atmosphere: Optionally add adjectives for artistic tone (e.g., "moody cyberpunk atmosphere," "grainy analog film texture," "hyper-realistic neon lighting"). +- Audio Layer (always include): Add immersive audio cues beyond visible sources — unseen background music, atmospheric sound effects, or disembodied voices that enrich the scene's mood. The audio may come from implied or off-screen sources. Weave audio descriptions naturally into the chronological flow of the visual description. DO NOT append all audio at the end as a separate section. Instead, mention sounds as they occur temporally alongside the visual actions. +- Dialogue (when present): If the input mentions dialogue, explicitly script the exact quoted lines each character says and when they say them; describe each speaker distinctively (name, role, clothing, position) so it is unambiguous who speaks when. If a language other than English is required, explicitly state the language for the dialogue lines. + +Output Format: Produce a single continuous paragraph in natural language, optimized for text-to-video generation. + +Example 1 — +Input Caption: +A lone traveler walks through a desert at sunset. +Enhanced Output: +A solitary figure trudges across vast, golden dunes as the sun sinks low on the horizon, painting the sky in gradients of deep orange, crimson, and fading violet. Footsteps crunch softly in the sand while a low, resonant wind drones through the scene, occasionally whistling over the dune crests. Each step leaves a trail of rippling footprints as the breeze stirs small clouds of dust around their boots. The traveler, wrapped in a weathered cloak and turban, adjusts their scarf against the dry air; fabric rustles under the steady breath of the breeze, their shadow stretching long and wavering behind them. The camera holds a wide, sweeping shot from a low angle, capturing both the immense emptiness and the wanderer’s slow, determined pace. Heat haze shimmers in the distance, blurring earth and sky; now and then a faint metallic chime seems to drift from nowhere, mingling with the dry hiss of shifting grains. As dusk deepens, the wind’s lonely tone swells and falls, amplifying the hush of the desert and the measured cadence of the traveler’s march. + +Example 2 — +Input Caption: +A busy street market in a futuristic city. +Enhanced Output: +Neon signs buzz overhead, their glow reflecting off slick pavement as a dense crowd flows through narrow alleys lined with holographic stalls and vendors selling exotic, otherworldly goods. Voices tumble in overlapping bursts of haggling while a wok sizzles and vapor hisses from a steaming cart. The camera begins with a wide establishing shot that slowly dollies forward through the throng as bass-heavy music leaks from hidden speakers in pockets, rising and fading with our movement between stalls. Robotic merchants with chrome limbs gesture to passersby; servos whirr softly as they pivot, and customers bargain beneath flickering tubes that hum above glass counters of cybernetic implants. Steam vents with a soft exhale; the night air carries laughter, clipped PA snippets, and the bright snap of packaging. In the background, towering skyscrapers vanish into low clouds under animated billboards, while an occasional hover-car scythes overhead with a clean whoosh, briefly cutting across the music before it swells again amid the market’s layered chatter. + +Example 3 — +Input Caption: +A waterfall deep in a forest. +Enhanced Output: +A towering cascade of crystal-clear water plunges into a moss-lined pool, sending a misty spray into the air that glimmers in scattered shafts of sunlight piercing through the dense canopy. The deep, steady roar of falling water underpins the scene, punctuated by crisp plinks as stray droplets tap nearby leaves. The surrounding forest is lush with emerald foliage, tangled vines, and ancient trees whose roots twist into the damp earth; birds dart through the frame, their wings whispering as cicadas rise and fall in the canopy. The camera begins with a slow aerial tilt downward, revealing the waterfall from above, then drifts closer to frame the cascade in intimate detail as droplets spatter the lens with bright taps. Subtle movements animate the frame — ripples expand across the pool, a branch sways with a faint creak, and leaves tremble in the spray — while the roar swells and softens in waves, the forest’s breathy hush filling the spaces between, creating a meditative pulse of sound and motion. + +Example 4 — +Input Caption: +Two friends at a rainy bus stop have a brief conversation about missing the last bus. +Enhanced Output: +A medium two-shot frames two friends under a flickering streetlamp at a glass-walled bus stop on a rainy city night; droplets bead on the plexiglass and neon reflections ripple across wet pavement. Close rain patter ticks against the shelter as Maya, in a yellow raincoat, sighs and glances at the timetable, saying, "We missed it by five minutes." A distant bus whoosh fades down the block while Jordan, taller, in a charcoal hoodie, checks his watch and replies, "Yeah… next one’s in thirty." Wind nudges the route map with a soft plastic click; tires hiss past on the street as a car throws a fine mist into the edge of frame. Maya adds, "Let’s just walk—it's only ten blocks," and a nearby puddle ripples with a gentle plop from a leaky gutter. Jordan answers, "Fine, but we’ll be drenched," his voice carrying under the steady rainfall and low, muffled traffic wash. The camera holds at shoulder height with shallow depth of field, then eases into a subtle dolly back as they step off the curb together, their footsteps splashing in time with the rain while the urban ambience (distant engines, a faint PA announcement down the street) trails them into the night. + +Output Format (Strict): +- Produce a continuous paragraph in natural language. +- Do NOT include titles, headings, section labels, bullet points, numbered lists, or metadata of any kind. +- Do NOT include prefaces like "Okay, here's..." or labels such as "Video Prompt:", "Audio Layer:", or "Internal Thought Process:". +- Do NOT include code fences, Markdown, or JSON—plain prose only. +- Weave audio naturally within the paragraph; never create a separate audio section. diff --git a/packages/ltx-core/src/ltx_core/model/clip/gemma/encoders/video_only_encoder.py b/packages/ltx-core/src/ltx_core/model/clip/gemma/encoders/video_only_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b80648b5129135a46041ea68a5748f656f3f2cde --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/clip/gemma/encoders/video_only_encoder.py @@ -0,0 +1,91 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Amit Pintz. + +from typing import NamedTuple, Self + +import torch +from transformers import Gemma3ForConditionalGeneration + +from ltx_core.loader.sd_ops import SDOps +from ltx_core.model.clip.gemma.embeddings_connector import ( + Embeddings1DConnector, + Embeddings1DConnectorConfigurator, +) +from ltx_core.model.clip.gemma.encoders.base_encoder import ( + GemmaTextEncoderModelBase, +) +from ltx_core.model.clip.gemma.feature_extractor import GemmaFeaturesExtractorProjLinear +from ltx_core.model.clip.gemma.tokenizer import LTXVGemmaTokenizer +from ltx_core.model.model_protocol import ModelConfigurator + + +class VideoGemmaEncoderOutput(NamedTuple): + video_encoding: torch.Tensor + attention_mask: torch.Tensor + + +class VideoGemmaTextEncoderModel(GemmaTextEncoderModelBase): + """ + Video Gemma Text Encoder Model. + + This class combines the tokenizer, Gemma model, feature extractor from base class and a + video embeddings connector to provide a preprocessing for video only pipeline. + """ + + def __init__( + self, + feature_extractor_linear: GemmaFeaturesExtractorProjLinear, + embeddings_connector: Embeddings1DConnector, + tokenizer: LTXVGemmaTokenizer | None = None, + model: Gemma3ForConditionalGeneration | None = None, + dtype: torch.dtype = torch.bfloat16, + ) -> None: + super().__init__( + feature_extractor_linear=feature_extractor_linear, + tokenizer=tokenizer, + model=model, + dtype=dtype, + ) + self.embeddings_connector = embeddings_connector.to(dtype=dtype) + + def _run_connector( + self, encoded_input: torch.Tensor, attention_mask: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + connector_attention_mask = self._convert_to_additive_mask(attention_mask, encoded_input.dtype) + + encoded, encoded_connector_attention_mask = self.embeddings_connector( + encoded_input, + connector_attention_mask, + ) + + # restore the mask values to int64 + attention_mask = (encoded_connector_attention_mask < 0.000001).to(torch.int64) + attention_mask = attention_mask.reshape([encoded.shape[0], encoded.shape[1], 1]) + encoded = encoded * attention_mask + + return encoded, attention_mask.squeeze(-1) + + def forward(self, text: str, padding_side: str = "left") -> VideoGemmaEncoderOutput: + encoded_inputs, attention_mask = self._preprocess_text(text, padding_side) + video_encoding, attention_mask = self._run_connector(encoded_inputs, attention_mask) + return VideoGemmaEncoderOutput(video_encoding, attention_mask) + + +class VideoGemmaTextEncoderModelConfigurator(ModelConfigurator[VideoGemmaTextEncoderModel]): + @classmethod + def from_config(cls: type[Self], config: dict) -> Self: + feature_extractor_linear = GemmaFeaturesExtractorProjLinear.from_config(config) + embeddings_connector = Embeddings1DConnectorConfigurator.from_config(config) + return VideoGemmaTextEncoderModel( + feature_extractor_linear=feature_extractor_linear, + embeddings_connector=embeddings_connector, + ) + + +VIDEO_ONLY_GEMMA_TEXT_ENCODER_KEY_OPS = ( + SDOps("VIDEO_ONLY_GEMMA_TEXT_ENCODER_KEY_OPS") + .with_matching(prefix="text_embedding_projection.") + .with_matching(prefix="model.diffusion_model.embeddings_connector.") + .with_replacement("text_embedding_projection.", "feature_extractor_linear.") + .with_replacement("model.diffusion_model.embeddings_connector.", "embeddings_connector.") +) diff --git a/packages/ltx-core/src/ltx_core/model/clip/gemma/feature_extractor.py b/packages/ltx-core/src/ltx_core/model/clip/gemma/feature_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..e681236409df88f15667bb2063467985a86d0485 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/clip/gemma/feature_extractor.py @@ -0,0 +1,46 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Amit Pintz. + +# from typing import Self +from typing_extensions import Self +import torch + +from ltx_core.model.model_protocol import ModelConfigurator + + +class GemmaFeaturesExtractorProjLinear(torch.nn.Module, ModelConfigurator[Self]): + """ + Feature extractor module for Gemma models. + + This module applies a single linear projection to the input tensor. + It expects a flattened feature tensor of shape (batch_size, 3840*49). + The linear layer maps this to a (batch_size, 3840) embedding. + + Attributes: + aggregate_embed (torch.nn.Linear): Linear projection layer. + """ + + def __init__(self) -> None: + """ + Initialize the GemmaFeaturesExtractorProjLinear module. + + The input dimension is expected to be 3840 * 49, and the output is 3840. + """ + super().__init__() + self.aggregate_embed = torch.nn.Linear(3840 * 49, 3840, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the feature extractor. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, 3840 * 49). + + Returns: + torch.Tensor: Output tensor of shape (batch_size, 3840). + """ + return self.aggregate_embed(x) + + @classmethod + def from_config(cls: type[Self], _config: dict) -> Self: + return cls() diff --git a/packages/ltx-core/src/ltx_core/model/clip/gemma/tokenizer.py b/packages/ltx-core/src/ltx_core/model/clip/gemma/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..794acebb701692690323734e3ebb9e9500fd1baf --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/clip/gemma/tokenizer.py @@ -0,0 +1,72 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Amit Pintz. + +from transformers import AutoTokenizer + + +class LTXVGemmaTokenizer: + """ + Tokenizer wrapper for Gemma models compatible with LTXV processes. + + This class wraps HuggingFace's `AutoTokenizer` for use with Gemma text encoders, + ensuring correct settings and output formatting for downstream consumption. + """ + + def __init__(self, tokenizer_path: str, max_length: int = 256, local_files_only: bool = True): + """ + Initialize the tokenizer. + + Args: + tokenizer_path (str): Path to the pretrained tokenizer files or model directory. + max_length (int, optional): Max sequence length for encoding. Defaults to 256. + """ + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, local_files_only=local_files_only, model_max_length=max_length + ) + # Gemma expects left padding for chat-style prompts; for plain text it doesn't matter much. + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.max_length = max_length + + def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> dict[str, list[tuple[int, int]]]: + """ + Tokenize the given text and return token IDs and attention weights. + + Args: + text (str): The input string to tokenize. + return_word_ids (bool, optional): If True, includes the token's position (index) in the output tuples. + If False (default), omits the indices. + + Returns: + dict[str, list[tuple[int, int]]] OR dict[str, list[tuple[int, int, int]]]: + A dictionary with a "gemma" key mapping to: + - a list of (token_id, attention_mask) tuples if return_word_ids is False; + - a list of (token_id, attention_mask, index) tuples if return_word_ids is True. + + Example: + >>> tokenizer = LTXVGemmaTokenizer("path/to/tokenizer", max_length=8) + >>> tokenizer.tokenize_with_weights("hello world") + {'gemma': [(1234, 1), (5678, 1), (2, 0), ...]} + """ + text = text.strip() + encoded = self.tokenizer( + text, + padding="max_length", + max_length=self.max_length, + truncation=True, + return_tensors="pt", + ) + input_ids = encoded.input_ids + attention_mask = encoded.attention_mask + tuples = [ + (token_id, attn, i) for i, (token_id, attn) in enumerate(zip(input_ids[0], attention_mask[0], strict=True)) + ] + out = {"gemma": tuples} + + if not return_word_ids: + # Return only (token_id, attention_mask) pairs, omitting token position + out = {k: [(t, w) for t, w, _ in v] for k, v in out.items()} + + return out diff --git a/packages/ltx-core/src/ltx_core/model/common/__init__.py b/packages/ltx-core/src/ltx_core/model/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/packages/ltx-core/src/ltx_core/model/common/__pycache__/__init__.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/common/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fea1f6acc468ad5a7e62b8745661cca3085e8328 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/common/__pycache__/__init__.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/common/__pycache__/normalization.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/common/__pycache__/normalization.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1853e64e17c76177715e3fcd2b40df61b92edc80 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/common/__pycache__/normalization.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/common/normalization.py b/packages/ltx-core/src/ltx_core/model/common/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..b3631e93e401230e0015b41b6f1b7a454e9b3af4 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/common/normalization.py @@ -0,0 +1,63 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Ivan Zorin + +from enum import Enum + +import torch +from torch import nn + + +class NormType(Enum): + GROUP = "group" + PIXEL = "pixel" + + +class PixelNorm(nn.Module): + """ + Per-pixel (per-location) RMS normalization layer. + + For each element along the chosen dimension, this layer normalizes the tensor + by the root-mean-square of its values across that dimension: + + y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps) + + """ + + def __init__(self, dim: int = 1, eps: float = 1e-8) -> None: + """ + Args: + dim: Dimension along which to compute the RMS (typically channels). + eps: Small constant added for numerical stability. + """ + super().__init__() + self.dim = dim + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply RMS normalization along the configured dimension. + """ + # Compute mean of squared values along `dim`, keep dimensions for broadcasting. + mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True) + # Normalize by the root-mean-square (RMS). + rms = torch.sqrt(mean_sq + self.eps) + return x / rms + + +def build_normalization_layer( + in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP +) -> nn.Module: + """ + Create a normalization layer based on the normalization type. + Args: + in_channels: Number of input channels + num_groups: Number of groups for group normalization + normtype: Type of normalization: "group" or "pixel" + Returns: + A normalization layer + """ + if normtype == NormType.GROUP: + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if normtype == NormType.PIXEL: + return PixelNorm(dim=1, eps=1e-6) + raise ValueError(f"Invalid normalization type: {normtype}") diff --git a/packages/ltx-core/src/ltx_core/model/model_ledger.py b/packages/ltx-core/src/ltx_core/model/model_ledger.py new file mode 100644 index 0000000000000000000000000000000000000000..37c8a5ea78f2eb134a8b01a4a0e09fe0a8dceab2 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/model_ledger.py @@ -0,0 +1,253 @@ +from dataclasses import replace +# from typing import Self +from typing_extensions import Self + +import torch + +from ltx_core.loader.primitives import LoraPathStrengthAndSDOps +from ltx_core.loader.registry import DummyRegistry, Registry +from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder +from ltx_core.model.audio_vae.audio_vae import Decoder as AudioDecoder +from ltx_core.model.audio_vae.model_configurator import ( + AUDIO_VAE_DECODER_COMFY_KEYS_FILTER, + VOCODER_COMFY_KEYS_FILTER, + VocoderConfigurator, +) +from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator as AudioDecoderConfigurator +from ltx_core.model.audio_vae.vocoder import Vocoder +from ltx_core.model.clip.gemma.encoders.av_encoder import ( + AV_GEMMA_TEXT_ENCODER_KEY_OPS, + AVGemmaTextEncoderModel, + AVGemmaTextEncoderModelConfigurator, +) +from ltx_core.model.clip.gemma.encoders.base_encoder import module_ops_from_gemma_root +from ltx_core.model.transformer.model import X0Model +from ltx_core.model.transformer.model_configurator import ( + LTXV_MODEL_COMFY_RENAMING_MAP, + LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP, + UPCAST_DURING_INFERENCE, + LTXModelConfigurator, +) +from ltx_core.model.upsampler.model import LatentUpsampler +from ltx_core.model.upsampler.model_configurator import LatentUpsamplerConfigurator +from ltx_core.model.video_vae.model_configurator import ( + VAE_DECODER_COMFY_KEYS_FILTER, + VAE_ENCODER_COMFY_KEYS_FILTER, + VAEDecoderConfigurator, + VAEEncoderConfigurator, +) +from ltx_core.model.video_vae.video_vae import Decoder as VideoDecoder +from ltx_core.model.video_vae.video_vae import Encoder as VideoEncoder + + +class ModelLedger: + """ + Central coordinator for loading, caching, and freeing models used in an LTX pipeline. + The ledger wires together multiple model builders (transformer, video VAE encoder/decoder, + audio VAE decoder, vocoder, text encoder, and optional latent upsampler) and exposes + the resulting models as lazily constructed, cached attributes. + + ### Caching behavior + + Each model attribute (e.g. :attr:`transformer`, :attr:`video_decoder`, :attr:`text_encoder`) + is implemented as a :func:`functools.cached_property`. The first time one of these + attributes is accessed, the corresponding builder loads weights from the + :class:`~ltx_core.loader.registry.StateDictRegistry`, instantiates the model on CPU with + the configured ``dtype``, moves it to ``self.device``, and stores the result in + the instance ``__dict__``. Subsequent accesses reuse the same model instance until it is + explicitly cleared via :meth:`clear_vram`. + + ### Constructor parameters + + dtype: + Torch dtype used when constructing all models (e.g. ``torch.float16``). + device: + Target device to which models are moved after construction (e.g. ``torch.device("cuda")``). + checkpoint_path: + Path to a checkpoint directory or file containing the core model weights + (transformer, video VAE, audio VAE, text encoder, vocoder). If ``None``, the + corresponding builders are not created and accessing those properties will raise + a :class:`ValueError`. + gemma_root_path: + Base path to Gemma-compatible CLIP/text encoder weights. Required to + initialize the text encoder builder; if omitted, :attr:`text_encoder` cannot be used. + spatial_upsampler_path: + Optional path to a latent upsampler checkpoint. If provided, the + :attr:`upsampler` property becomes available; otherwise accessing it raises + a :class:`ValueError`. + loras: + Optional collection of LoRA configurations (paths, strengths, and key operations) + that are applied on top of the base transformer weights when building the model. + + ### Memory management + + ``clear_ram()`` + Clears the underlying :class:`Registry` cache of state dicts and triggers a + Python garbage collection pass. Use this when you no longer need to construct new + models from the currently loaded checkpoints and want to free host (CPU) memory. + ``clear_vram()`` + Drops the cached model instances stored by the ``@cached_property`` attributes from + this ledger (by removing them from ``self.__dict__``) and calls + :func:`torch.cuda.empty_cache`. Use this when you want to release GPU memory; + subsequent access to a model property will rebuild the model from the registry + while keeping the existing builder configuration. + """ + + def __init__( + self, + dtype: torch.dtype, + device: torch.device, + checkpoint_path: str | None = None, + gemma_root_path: str | None = None, + spatial_upsampler_path: str | None = None, + loras: LoraPathStrengthAndSDOps | None = None, + registry: Registry | None = None, + fp8transformer: bool = False, + local_files_only: bool = True + ): + self.dtype = dtype + self.device = device + self.checkpoint_path = checkpoint_path + self.gemma_root_path = gemma_root_path + self.spatial_upsampler_path = spatial_upsampler_path + self.loras = loras or () + self.registry = registry or DummyRegistry() + self.fp8transformer = fp8transformer + self.local_files_only = local_files_only + self.build_model_builders() + + def build_model_builders(self) -> None: + if self.checkpoint_path is not None: + self.transformer_builder = Builder( + model_path=self.checkpoint_path, + model_class_configurator=LTXModelConfigurator, + model_sd_ops=LTXV_MODEL_COMFY_RENAMING_MAP, + loras=tuple(self.loras), + registry=self.registry, + ) + + self.vae_decoder_builder = Builder( + model_path=self.checkpoint_path, + model_class_configurator=VAEDecoderConfigurator, + model_sd_ops=VAE_DECODER_COMFY_KEYS_FILTER, + registry=self.registry, + ) + + self.vae_encoder_builder = Builder( + model_path=self.checkpoint_path, + model_class_configurator=VAEEncoderConfigurator, + model_sd_ops=VAE_ENCODER_COMFY_KEYS_FILTER, + registry=self.registry, + ) + + self.audio_decoder_builder = Builder( + model_path=self.checkpoint_path, + model_class_configurator=AudioDecoderConfigurator, + model_sd_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER, + registry=self.registry, + ) + + self.vocoder_builder = Builder( + model_path=self.checkpoint_path, + model_class_configurator=VocoderConfigurator, + model_sd_ops=VOCODER_COMFY_KEYS_FILTER, + registry=self.registry, + ) + + if self.gemma_root_path is not None: + self.text_encoder_builder = Builder( + model_path=self.checkpoint_path, + model_class_configurator=AVGemmaTextEncoderModelConfigurator, + model_sd_ops=AV_GEMMA_TEXT_ENCODER_KEY_OPS, + registry=self.registry, + module_ops=module_ops_from_gemma_root(self.gemma_root_path, self.local_files_only), + ) + + if self.spatial_upsampler_path is not None: + self.upsampler_builder = Builder( + model_path=self.spatial_upsampler_path, + model_class_configurator=LatentUpsamplerConfigurator, + registry=self.registry, + ) + + def _target_device(self) -> torch.device: + if isinstance(self.registry, DummyRegistry) or self.registry is None: + return self.device + else: + return torch.device("cpu") + + def with_loras(self, loras: LoraPathStrengthAndSDOps) -> Self: + return ModelLedger( + dtype=self.dtype, + device=self.device, + checkpoint_path=self.checkpoint_path, + gemma_root_path=self.gemma_root_path, + spatial_upsampler_path=self.spatial_upsampler_path, + loras=(*self.loras, *loras), + registry=self.registry, + fp8transformer=self.fp8transformer, + ) + + def transformer(self) -> X0Model: + if not hasattr(self, "transformer_builder"): + raise ValueError( + "Transformer not initialized. Please provide a checkpoint path to the ModelLedger constructor." + ) + if self.fp8transformer: + fp8_builder = replace( + self.transformer_builder, + module_ops=(UPCAST_DURING_INFERENCE,), + model_sd_ops=LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP, + ) + return X0Model(fp8_builder.build(device=self._target_device())).to(self.device) + else: + return X0Model(self.transformer_builder.build(device=self._target_device(), dtype=self.dtype)).to( + self.device + ) + + def video_decoder(self) -> VideoDecoder: + if not hasattr(self, "vae_decoder_builder"): + raise ValueError( + "Video decoder not initialized. Please provide a checkpoint path to the ModelLedger constructor." + ) + + return self.vae_decoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device) + + def video_encoder(self) -> VideoEncoder: + if not hasattr(self, "vae_encoder_builder"): + raise ValueError( + "Video encoder not initialized. Please provide a checkpoint path to the ModelLedger constructor." + ) + + return self.vae_encoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device) + + def text_encoder(self) -> AVGemmaTextEncoderModel: + if not hasattr(self, "text_encoder_builder"): + raise ValueError( + "Text encoder not initialized. Please provide a checkpoint path and gemma root path to the " + "ModelLedger constructor." + ) + + return self.text_encoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device) + + def audio_decoder(self) -> AudioDecoder: + if not hasattr(self, "audio_decoder_builder"): + raise ValueError( + "Audio decoder not initialized. Please provide a checkpoint path to the ModelLedger constructor." + ) + + return self.audio_decoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device) + + def vocoder(self) -> Vocoder: + if not hasattr(self, "vocoder_builder"): + raise ValueError( + "Vocoder not initialized. Please provide a checkpoint path to the ModelLedger constructor." + ) + + return self.vocoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device) + + def spatial_upsampler(self) -> LatentUpsampler: + if not hasattr(self, "upsampler_builder"): + raise ValueError("Upsampler not initialized. Please provide upsampler path to the ModelLedger constructor.") + + return self.upsampler_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device) diff --git a/packages/ltx-core/src/ltx_core/model/model_protocol.py b/packages/ltx-core/src/ltx_core/model/model_protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..e2540bbe0d2e5845d068712847f0757ab2cf7333 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/model_protocol.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + +from typing import Protocol, TypeVar + +ModelType = TypeVar("ModelType") + + +class ModelConfigurator(Protocol[ModelType]): + """Protocol for model loader classes that instantiates models from a configuration dictionary.""" + + @classmethod + def from_config(cls, config: dict) -> ModelType: ... diff --git a/packages/ltx-core/src/ltx_core/model/transformer/__init__.py b/packages/ltx-core/src/ltx_core/model/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/__init__.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2fb48c4fdc53c92d22f24f6ecafc148c1e01d1e Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/__init__.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/adaln.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/adaln.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eaad84bfe7f020b1e0f8b81cf2b0d5ef473be3c7 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/adaln.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/attention.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9009d5b53dbbf371571cd08433282a38bcd24420 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/attention.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/feed_forward.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/feed_forward.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8fee4a480b6718bebd325a0fa76e0a7de66d20f9 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/feed_forward.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/gelu_approx.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/gelu_approx.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c9360ae46c57217e4850196398c0d7b0d8ae366 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/gelu_approx.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/modality.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/modality.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7439b9e960c8de357e079c25b1c8112604ba546f Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/modality.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/model.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66d5a89eef55f49e99c5ab19fadf3cc98698f198 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/model.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/model_configurator.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/model_configurator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43e75e42cbbfabb86bf1c32eb709ca5c8f0a7642 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/model_configurator.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/rope.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/rope.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e28a1a7f0c4ed2b9ca653980ec363794274465e Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/rope.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/text_projection.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/text_projection.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8349fd3976fd8d545e83d52d9ec4ddfc2ccfdd9a Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/text_projection.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/timestep_embedding.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/timestep_embedding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b97165f95a65998dc69e4c85ea486d14007fbb0 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/timestep_embedding.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/transformer.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2931baf3b50dd4d9958c6407123558fb479bbd0 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/transformer.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/transformer_args.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/transformer_args.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78673b45d46fd791f95b4c79ed947fa81d40f100 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/transformer/__pycache__/transformer_args.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/transformer/adaln.py b/packages/ltx-core/src/ltx_core/model/transformer/adaln.py new file mode 100644 index 0000000000000000000000000000000000000000..1b0a1fe74e93b85ee05c9e23cf2d3ef1ad9c26f4 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/transformer/adaln.py @@ -0,0 +1,39 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + +from typing import Optional, Tuple + +import torch + +from ltx_core.model.transformer.timestep_embedding import PixArtAlphaCombinedTimestepSizeEmbeddings + + +class AdaLayerNormSingle(torch.nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, embedding_coefficient: int = 6): + super().__init__() + + self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( + embedding_dim, + size_emb_dim=embedding_dim // 3, + ) + + self.silu = torch.nn.SiLU() + self.linear = torch.nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + embedded_timestep = self.emb(timestep, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep diff --git a/packages/ltx-core/src/ltx_core/model/transformer/attention.py b/packages/ltx-core/src/ltx_core/model/transformer/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..6814f2ad7fc10f341cf4ee5d88848a08d2114663 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/transformer/attention.py @@ -0,0 +1,199 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + +from enum import Enum +from typing import Protocol + +import torch + +from ltx_core.model.transformer.rope import LTXRopeType, apply_rotary_emb + +memory_efficient_attention = None +flash_attn_interface = None +try: + from xformers.ops import memory_efficient_attention +except ImportError: + memory_efficient_attention = None +try: + # FlashAttention3 and XFormersAttention cannot be used together + if memory_efficient_attention is None: + import flash_attn_interface +except ImportError: + flash_attn_interface = None + + +class AttentionCallable(Protocol): + def __call__( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None + ) -> torch.Tensor: ... + + +class PytorchAttention(AttentionCallable): + def __call__( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None + ) -> torch.Tensor: + b, _, dim_head = q.shape + dim_head //= heads + q, k, v = (t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v)) + + if mask is not None: + # add a batch dimension if there isn't already one + if mask.ndim == 2: + mask = mask.unsqueeze(0) + # add a heads dimension if there isn't already one + if mask.ndim == 3: + mask = mask.unsqueeze(1) + + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + out = out.transpose(1, 2).reshape(b, -1, heads * dim_head) + return out + + +class XFormersAttention(AttentionCallable): + def __call__( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + heads: int, + mask: torch.Tensor | None = None, + ) -> torch.Tensor: + if memory_efficient_attention is None: + raise RuntimeError("XFormersAttention was selected but `xformers` is not installed.") + + b, _, dim_head = q.shape + dim_head //= heads + + # xformers expects [B, M, H, K] + q, k, v = (t.view(b, -1, heads, dim_head) for t in (q, k, v)) + + # LT_INTERNAL: https://github.com/LightricksResearch/ComfyUI/blob/ee2a50cd8fb3544c66f8a3096390c741fff12ae3/comfy/ldm/modules/attention.py#L441-L459 + if mask is not None: + # add a singleton batch dimension + if mask.ndim == 2: + mask = mask.unsqueeze(0) + # add a singleton heads dimension + if mask.ndim == 3: + mask = mask.unsqueeze(1) + # pad to a multiple of 8 + pad = 8 - mask.shape[-1] % 8 + # the xformers docs says that it's allowed to have a mask of shape (1, Nq, Nk) + # but when using separated heads, the shape has to be (B, H, Nq, Nk) + # in flux, this matrix ends up being over 1GB + # here, we create a mask with the same batch/head size as the input mask (potentially singleton or full) + mask_out = torch.empty( + [mask.shape[0], mask.shape[1], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device + ) + + mask_out[..., : mask.shape[-1]] = mask + # doesn't this remove the padding again?? + mask = mask_out[..., : mask.shape[-1]] + mask = mask.expand(b, heads, -1, -1) + + out = memory_efficient_attention(q.to(v.dtype), k.to(v.dtype), v, attn_bias=mask, p=0.0) + out = out.reshape(b, -1, heads * dim_head) + return out + + +class FlashAttention3(AttentionCallable): + def __call__( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + heads: int, + mask: torch.Tensor | None = None, + ) -> torch.Tensor: + if flash_attn_interface is None: + raise RuntimeError("FlashAttention3 was selected but `FlashAttention3` is not installed.") + + b, _, dim_head = q.shape + dim_head //= heads + + q, k, v = (t.view(b, -1, heads, dim_head) for t in (q, k, v)) + + if mask is not None: + raise NotImplementedError("Mask is not supported for FlashAttention3") + + out = flash_attn_interface.flash_attn_func(q.to(v.dtype), k.to(v.dtype), v) + out = out.reshape(b, -1, heads * dim_head) + return out + + +class AttentionFunction(Enum): + PYTORCH = "pytorch" + XFORMERS = "xformers" + FLASH_ATTENTION_3 = "flash_attention_3" + DEFAULT = "default" + + def __call__( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None + ) -> torch.Tensor: + if self is AttentionFunction.PYTORCH: + return PytorchAttention()(q, k, v, heads, mask) + elif self is AttentionFunction.XFORMERS: + return XFormersAttention()(q, k, v, heads, mask) + elif self is AttentionFunction.FLASH_ATTENTION_3: + return FlashAttention3()(q, k, v, heads, mask) + else: + # Default behavior: XFormers if installed else - PyTorch + return ( + XFormersAttention()(q, k, v, heads, mask) + if memory_efficient_attention is not None + else PytorchAttention()(q, k, v, heads, mask) + ) + + +class Attention(torch.nn.Module): + def __init__( + self, + query_dim: int, + context_dim: int | None = None, + heads: int = 8, + dim_head: int = 64, + norm_eps: float = 1e-6, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + attention_function: AttentionCallable | AttentionFunction = AttentionFunction.DEFAULT, + ) -> None: + super().__init__() + self.rope_type = rope_type + self.attention_function = attention_function + + inner_dim = dim_head * heads + context_dim = query_dim if context_dim is None else context_dim + + self.heads = heads + self.dim_head = dim_head + + self.q_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps) + self.k_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps) + + self.to_q = torch.nn.Linear(query_dim, inner_dim, bias=True) + self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True) + self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True) + + self.to_out = torch.nn.Sequential(torch.nn.Linear(inner_dim, query_dim, bias=True), torch.nn.Identity()) + + def forward( + self, + x: torch.Tensor, + context: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + pe: torch.Tensor | None = None, + k_pe: torch.Tensor | None = None, + ) -> torch.Tensor: + q = self.to_q(x) + context = x if context is None else context + k = self.to_k(context) + v = self.to_v(context) + + q = self.q_norm(q) + k = self.k_norm(k) + + if pe is not None: + q = apply_rotary_emb(q, pe, self.rope_type) + k = apply_rotary_emb(k, pe if k_pe is None else k_pe, self.rope_type) + + # attention_function can be an enum *or* a custom callable + out = self.attention_function(q, k, v, self.heads, mask) + return self.to_out(out) diff --git a/packages/ltx-core/src/ltx_core/model/transformer/feed_forward.py b/packages/ltx-core/src/ltx_core/model/transformer/feed_forward.py new file mode 100644 index 0000000000000000000000000000000000000000..66340f6c11e209f83d5e299eda04d93c2badd8ff --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/transformer/feed_forward.py @@ -0,0 +1,18 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + +import torch + +from ltx_core.model.transformer.gelu_approx import GELUApprox + + +class FeedForward(torch.nn.Module): + def __init__(self, dim: int, dim_out: int, mult: int = 4) -> None: + super().__init__() + inner_dim = int(dim * mult) + project_in = GELUApprox(dim, inner_dim) + + self.net = torch.nn.Sequential(project_in, torch.nn.Identity(), torch.nn.Linear(inner_dim, dim_out)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) diff --git a/packages/ltx-core/src/ltx_core/model/transformer/gelu_approx.py b/packages/ltx-core/src/ltx_core/model/transformer/gelu_approx.py new file mode 100644 index 0000000000000000000000000000000000000000..e8c1265e6a6de5f127304fa541f8b9f52ef94e76 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/transformer/gelu_approx.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + +import torch + + +class GELUApprox(torch.nn.Module): + def __init__(self, dim_in: int, dim_out: int) -> None: + super().__init__() + self.proj = torch.nn.Linear(dim_in, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(self.proj(x), approximate="tanh") diff --git a/packages/ltx-core/src/ltx_core/model/transformer/modality.py b/packages/ltx-core/src/ltx_core/model/transformer/modality.py new file mode 100644 index 0000000000000000000000000000000000000000..5137a7eba62692c3135568b568943bcf07913654 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/transformer/modality.py @@ -0,0 +1,20 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + +from dataclasses import dataclass + +import torch + + +@dataclass(frozen=True) +class Modality: + latent: ( + torch.Tensor + ) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension + timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps + positions: ( + torch.Tensor + ) # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens + context: torch.Tensor + enabled: bool = True + context_mask: torch.Tensor | None = None diff --git a/packages/ltx-core/src/ltx_core/model/transformer/model.py b/packages/ltx-core/src/ltx_core/model/transformer/model.py new file mode 100644 index 0000000000000000000000000000000000000000..fe7ec481b9dccca990b05cfdae0818e7a0cf5209 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/transformer/model.py @@ -0,0 +1,482 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + + +from enum import Enum + +import torch + +from ltx_core.guidance.perturbations import BatchedPerturbationConfig +from ltx_core.model.transformer.adaln import AdaLayerNormSingle +from ltx_core.model.transformer.attention import AttentionCallable, AttentionFunction +from ltx_core.model.transformer.modality import Modality +from ltx_core.model.transformer.rope import LTXRopeType +from ltx_core.model.transformer.text_projection import PixArtAlphaTextProjection +from ltx_core.model.transformer.transformer import BasicAVTransformerBlock, TransformerConfig +from ltx_core.model.transformer.transformer_args import ( + MultiModalTransformerArgsPreprocessor, + TransformerArgs, + TransformerArgsPreprocessor, +) +from ltx_core.utils import to_denoised + + +class LTXModelType(Enum): + AudioVideo = "ltx av model" + VideoOnly = "ltx video only model" + AudioOnly = "ltx audio only model" + + def is_video_enabled(self) -> bool: + return self in (LTXModelType.AudioVideo, LTXModelType.VideoOnly) + + def is_audio_enabled(self) -> bool: + return self in (LTXModelType.AudioVideo, LTXModelType.AudioOnly) + + +class LTXModel(torch.nn.Module): + """ + LTX model transformer implementation. + + This class implements the transformer blocks for the LTX model. + """ + + def __init__( # noqa: PLR0913 + self, + *, + model_type: LTXModelType = LTXModelType.AudioVideo, + num_attention_heads: int = 32, + attention_head_dim: int = 128, + in_channels: int = 128, + out_channels: int = 128, + num_layers: int = 48, + cross_attention_dim: int = 4096, + norm_eps: float = 1e-06, + attention_type: AttentionFunction | AttentionCallable = AttentionFunction.DEFAULT, + caption_channels: int = 3840, + positional_embedding_theta: float = 10000.0, + positional_embedding_max_pos: list[int] | None = None, + timestep_scale_multiplier: int = 1000, + use_middle_indices_grid: bool = True, + audio_num_attention_heads: int = 32, + audio_attention_head_dim: int = 64, + audio_in_channels: int = 128, + audio_out_channels: int = 128, + audio_cross_attention_dim: int = 2048, + audio_positional_embedding_max_pos: list[int] | None = None, + av_ca_timestep_scale_multiplier: int = 1, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + double_precision_rope: bool = False, + ): + super().__init__() + self._enable_gradient_checkpointing = False + self.use_middle_indices_grid = use_middle_indices_grid + self.rope_type = rope_type + self.double_precision_rope = double_precision_rope + self.timestep_scale_multiplier = timestep_scale_multiplier + self.positional_embedding_theta = positional_embedding_theta + self.model_type = model_type + cross_pe_max_pos = None + if model_type.is_video_enabled(): + if positional_embedding_max_pos is None: + positional_embedding_max_pos = [20, 2048, 2048] + self.positional_embedding_max_pos = positional_embedding_max_pos + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self._init_video( + in_channels=in_channels, + out_channels=out_channels, + caption_channels=caption_channels, + norm_eps=norm_eps, + ) + + if model_type.is_audio_enabled(): + if audio_positional_embedding_max_pos is None: + audio_positional_embedding_max_pos = [20] + self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos + self.audio_num_attention_heads = audio_num_attention_heads + self.audio_inner_dim = self.audio_num_attention_heads * audio_attention_head_dim + self._init_audio( + in_channels=audio_in_channels, + out_channels=audio_out_channels, + caption_channels=caption_channels, + norm_eps=norm_eps, + ) + + if model_type.is_video_enabled() and model_type.is_audio_enabled(): + cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]) + self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier + self.audio_cross_attention_dim = audio_cross_attention_dim + self._init_audio_video(num_scale_shift_values=4) + + self._init_preprocessors(cross_pe_max_pos) + # Initialize transformer blocks + self._init_transformer_blocks( + num_layers=num_layers, + attention_head_dim=attention_head_dim if model_type.is_video_enabled() else 0, + cross_attention_dim=cross_attention_dim, + audio_attention_head_dim=audio_attention_head_dim if model_type.is_audio_enabled() else 0, + audio_cross_attention_dim=audio_cross_attention_dim, + norm_eps=norm_eps, + attention_type=attention_type, + ) + + def _init_video( + self, + in_channels: int, + out_channels: int, + caption_channels: int, + norm_eps: float, + ) -> None: + """Initialize video-specific components.""" + # Video input components + self.patchify_proj = torch.nn.Linear(in_channels, self.inner_dim, bias=True) + + self.adaln_single = AdaLayerNormSingle(self.inner_dim) + + # Video caption projection + self.caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, + hidden_size=self.inner_dim, + ) + + # Video output components + self.scale_shift_table = torch.nn.Parameter(torch.empty(2, self.inner_dim)) + self.norm_out = torch.nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=norm_eps) + self.proj_out = torch.nn.Linear(self.inner_dim, out_channels) + + def _init_audio( + self, + in_channels: int, + out_channels: int, + caption_channels: int, + norm_eps: float, + ) -> None: + """Initialize audio-specific components.""" + + # Audio input components + self.audio_patchify_proj = torch.nn.Linear(in_channels, self.audio_inner_dim, bias=True) + + self.audio_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + ) + + # Audio caption projection + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, + hidden_size=self.audio_inner_dim, + ) + + # Audio output components + self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(2, self.audio_inner_dim)) + self.audio_norm_out = torch.nn.LayerNorm(self.audio_inner_dim, elementwise_affine=False, eps=norm_eps) + self.audio_proj_out = torch.nn.Linear(self.audio_inner_dim, out_channels) + + def _init_audio_video( + self, + num_scale_shift_values: int, + ) -> None: + """Initialize audio-video cross-attention components.""" + self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=num_scale_shift_values, + ) + + self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + embedding_coefficient=num_scale_shift_values, + ) + + self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=1, + ) + + self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + embedding_coefficient=1, + ) + + def _init_preprocessors( + self, + cross_pe_max_pos: int | None = None, + ) -> None: + """Initialize preprocessors for LTX.""" + + if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled(): + self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor( + patchify_proj=self.patchify_proj, + adaln=self.adaln_single, + caption_projection=self.caption_projection, + cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single, + cross_gate_adaln=self.av_ca_a2v_gate_adaln_single, + inner_dim=self.inner_dim, + max_pos=self.positional_embedding_max_pos, + num_attention_heads=self.num_attention_heads, + cross_pe_max_pos=cross_pe_max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + audio_cross_attention_dim=self.audio_cross_attention_dim, + timestep_scale_multiplier=self.timestep_scale_multiplier, + double_precision_rope=self.double_precision_rope, + positional_embedding_theta=self.positional_embedding_theta, + rope_type=self.rope_type, + av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier, + ) + self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor( + patchify_proj=self.audio_patchify_proj, + adaln=self.audio_adaln_single, + caption_projection=self.audio_caption_projection, + cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single, + cross_gate_adaln=self.av_ca_v2a_gate_adaln_single, + inner_dim=self.audio_inner_dim, + max_pos=self.audio_positional_embedding_max_pos, + num_attention_heads=self.audio_num_attention_heads, + cross_pe_max_pos=cross_pe_max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + audio_cross_attention_dim=self.audio_cross_attention_dim, + timestep_scale_multiplier=self.timestep_scale_multiplier, + double_precision_rope=self.double_precision_rope, + positional_embedding_theta=self.positional_embedding_theta, + rope_type=self.rope_type, + av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier, + ) + elif self.model_type.is_video_enabled(): + self.video_args_preprocessor = TransformerArgsPreprocessor( + patchify_proj=self.patchify_proj, + adaln=self.adaln_single, + caption_projection=self.caption_projection, + inner_dim=self.inner_dim, + max_pos=self.positional_embedding_max_pos, + num_attention_heads=self.num_attention_heads, + use_middle_indices_grid=self.use_middle_indices_grid, + timestep_scale_multiplier=self.timestep_scale_multiplier, + double_precision_rope=self.double_precision_rope, + positional_embedding_theta=self.positional_embedding_theta, + rope_type=self.rope_type, + ) + elif self.model_type.is_audio_enabled(): + self.audio_args_preprocessor = TransformerArgsPreprocessor( + patchify_proj=self.audio_patchify_proj, + adaln=self.audio_adaln_single, + caption_projection=self.audio_caption_projection, + inner_dim=self.audio_inner_dim, + max_pos=self.audio_positional_embedding_max_pos, + num_attention_heads=self.audio_num_attention_heads, + use_middle_indices_grid=self.use_middle_indices_grid, + timestep_scale_multiplier=self.timestep_scale_multiplier, + double_precision_rope=self.double_precision_rope, + positional_embedding_theta=self.positional_embedding_theta, + rope_type=self.rope_type, + ) + + def _init_transformer_blocks( + self, + num_layers: int, + attention_head_dim: int, + cross_attention_dim: int, + audio_attention_head_dim: int, + audio_cross_attention_dim: int, + norm_eps: float, + attention_type: AttentionFunction | AttentionCallable, + ) -> None: + """Initialize transformer blocks for LTX.""" + video_config = ( + TransformerConfig( + dim=self.inner_dim, + heads=self.num_attention_heads, + d_head=attention_head_dim, + context_dim=cross_attention_dim, + ) + if self.model_type.is_video_enabled() + else None + ) + audio_config = ( + TransformerConfig( + dim=self.audio_inner_dim, + heads=self.audio_num_attention_heads, + d_head=audio_attention_head_dim, + context_dim=audio_cross_attention_dim, + ) + if self.model_type.is_audio_enabled() + else None + ) + self.transformer_blocks = torch.nn.ModuleList( + [ + BasicAVTransformerBlock( + idx=idx, + video=video_config, + audio=audio_config, + rope_type=self.rope_type, + norm_eps=norm_eps, + attention_function=attention_type, + ) + for idx in range(num_layers) + ] + ) + + def set_gradient_checkpointing(self, enable: bool) -> None: + """Enable or disable gradient checkpointing for transformer blocks. + + Gradient checkpointing trades compute for memory by recomputing activations + during the backward pass instead of storing them. This can significantly + reduce memory usage at the cost of ~20-30% slower training. + + Args: + enable: Whether to enable gradient checkpointing + """ + self._enable_gradient_checkpointing = enable + + def _process_transformer_blocks( + self, + video: TransformerArgs | None, + audio: TransformerArgs | None, + perturbations: BatchedPerturbationConfig, + ) -> tuple[TransformerArgs, TransformerArgs]: + """Process transformer blocks for LTXAV.""" + + # Process transformer blocks + for block in self.transformer_blocks: + if self._enable_gradient_checkpointing and self.training: + # Use gradient checkpointing to save memory during training. + # With use_reentrant=False, we can pass dataclasses directly - + # PyTorch will track all tensor leaves in the computation graph. + video, audio = torch.utils.checkpoint.checkpoint( + block, + video, + audio, + perturbations, + use_reentrant=False, + ) + else: + video, audio = block( + video=video, + audio=audio, + perturbations=perturbations, + ) + + return video, audio + + def _process_output( + self, + scale_shift_table: torch.Tensor, + norm_out: torch.nn.LayerNorm, + proj_out: torch.nn.Linear, + x: torch.Tensor, + embedded_timestep: torch.Tensor, + ) -> torch.Tensor: + """Process output for LTXV.""" + # Apply scale-shift modulation + scale_shift_values = ( + scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + + x = norm_out(x) + x = x * (1 + scale) + shift + x = proj_out(x) + return x + + def forward( + self, video: Modality | None, audio: Modality | None, perturbations: BatchedPerturbationConfig + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for LTX models. + + Returns: + Processed output tensors + """ + if not self.model_type.is_video_enabled() and video is not None: + raise ValueError("Video is not enabled for this model") + if not self.model_type.is_audio_enabled() and audio is not None: + raise ValueError("Audio is not enabled for this model") + + video_args = self.video_args_preprocessor.prepare(video) if video is not None else None + audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None + # Process transformer blocks + video_out, audio_out = self._process_transformer_blocks( + video=video_args, + audio=audio_args, + perturbations=perturbations, + ) + + # Process output + vx = ( + self._process_output( + self.scale_shift_table, self.norm_out, self.proj_out, video_out.x, video_out.embedded_timestep + ) + if video_out is not None + else None + ) + ax = ( + self._process_output( + self.audio_scale_shift_table, + self.audio_norm_out, + self.audio_proj_out, + audio_out.x, + audio_out.embedded_timestep, + ) + if audio_out is not None + else None + ) + return vx, ax + + +class LegacyX0Model(torch.nn.Module): + """ + Legacy X0 model implementation. + Returns fully denoised output based on the velocities produced by the base model. + LT_INTERNAL_BEGIN + Applies full sigma when denoising which is mathematically incorrect but in accordance with: + https://github.com/LightricksResearch/ComfyUI/blob/cc26711bd34135a3eac782b81f9526c5acfcf94d/comfy/model_sampling.py#L62-L68 + LT_INTERNAL_END + """ + + def __init__(self, velocity_model: LTXModel): + super().__init__() + self.velocity_model = velocity_model + + def forward( + self, + video: Modality | None, + audio: Modality | None, + perturbations: BatchedPerturbationConfig, + sigma: float, + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + """ + Denoise the video and audio according to the sigma. + + Returns: + Denoised video and audio + """ + vx, ax = self.velocity_model(video, audio, perturbations) + denoised_video = to_denoised(video.latent, vx, sigma) if vx is not None else None + denoised_audio = to_denoised(audio.latent, ax, sigma) if ax is not None else None + return denoised_video, denoised_audio + + +class X0Model(torch.nn.Module): + """ + X0 model implementation. + Returns fully denoised outputs based on the velocities produced by the base model. + Applies scaled denoising to the video and audio according to the timesteps = sigma * denoising_mask. + """ + + def __init__(self, velocity_model: LTXModel): + super().__init__() + self.velocity_model = velocity_model + + def forward( + self, + video: Modality | None, + audio: Modality | None, + perturbations: BatchedPerturbationConfig, + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + """ + Denoise the video and audio according to the sigma. + + Returns: + Denoised video and audio + """ + vx, ax = self.velocity_model(video, audio, perturbations) + denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None + denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None + return denoised_video, denoised_audio diff --git a/packages/ltx-core/src/ltx_core/model/transformer/model_configurator.py b/packages/ltx-core/src/ltx_core/model/transformer/model_configurator.py new file mode 100644 index 0000000000000000000000000000000000000000..6c018fa0a6149eef5fad31ce55c430456369aa10 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/transformer/model_configurator.py @@ -0,0 +1,222 @@ +import torch + +from ltx_core.loader.fuse_loras import fused_add_round_launch +from ltx_core.loader.module_ops import ModuleOps +from ltx_core.loader.sd_ops import KeyValueOperationResult, SDOps +from ltx_core.model.model_protocol import ModelConfigurator +from ltx_core.model.transformer.attention import AttentionFunction +from ltx_core.model.transformer.model import LTXModel, LTXModelType +from ltx_core.model.transformer.rope import LTXRopeType +from ltx_core.utils import check_config_value + + +class LTXModelConfigurator(ModelConfigurator[LTXModel]): + @classmethod + def from_config(cls: type[LTXModel], config: dict) -> LTXModel: + config = config.get("transformer", {}) + + check_config_value(config, "dropout", 0.0) + check_config_value(config, "attention_bias", True) + check_config_value(config, "num_vector_embeds", None) + check_config_value(config, "activation_fn", "gelu-approximate") + check_config_value(config, "num_embeds_ada_norm", 1000) + check_config_value(config, "use_linear_projection", False) + check_config_value(config, "only_cross_attention", False) + check_config_value(config, "cross_attention_norm", True) + check_config_value(config, "double_self_attention", False) + check_config_value(config, "upcast_attention", False) + check_config_value(config, "standardization_norm", "rms_norm") + check_config_value(config, "norm_elementwise_affine", False) + check_config_value(config, "qk_norm", "rms_norm") + check_config_value(config, "positional_embedding_type", "rope") + check_config_value(config, "use_audio_video_cross_attention", True) + check_config_value(config, "share_ff", False) + check_config_value(config, "av_cross_ada_norm", True) + check_config_value(config, "use_middle_indices_grid", True) + + return LTXModel( + model_type=LTXModelType.AudioVideo, + num_attention_heads=config.get("num_attention_heads", 32), + attention_head_dim=config.get("attention_head_dim", 128), + in_channels=config.get("in_channels", 128), + out_channels=config.get("out_channels", 128), + num_layers=config.get("num_layers", 48), + cross_attention_dim=config.get("cross_attention_dim", 4096), + norm_eps=config.get("norm_eps", 1e-06), + attention_type=AttentionFunction(config.get("attention_type", "default")), + caption_channels=config.get("caption_channels", 3840), + positional_embedding_theta=config.get("positional_embedding_theta", 10000.0), + positional_embedding_max_pos=config.get("positional_embedding_max_pos", [20, 2048, 2048]), + timestep_scale_multiplier=config.get("timestep_scale_multiplier", 1000), + use_middle_indices_grid=config.get("use_middle_indices_grid", True), + audio_num_attention_heads=config.get("audio_num_attention_heads", 32), + audio_attention_head_dim=config.get("audio_attention_head_dim", 64), + audio_in_channels=config.get("audio_in_channels", 128), + audio_out_channels=config.get("audio_out_channels", 128), + audio_cross_attention_dim=config.get("audio_cross_attention_dim", 2048), + audio_positional_embedding_max_pos=config.get("audio_positional_embedding_max_pos", [20]), + av_ca_timestep_scale_multiplier=config.get("av_ca_timestep_scale_multiplier", 1), + rope_type=LTXRopeType(config.get("rope_type", "interleaved")), + double_precision_rope=config.get("frequencies_precision", False) == "float64", + ) + + +class LTXVideoOnlyModelConfigurator(ModelConfigurator[LTXModel]): + @classmethod + def from_config(cls: type[LTXModel], config: dict) -> LTXModel: + config = config.get("transformer", {}) + + check_config_value(config, "dropout", 0.0) + check_config_value(config, "attention_bias", True) + check_config_value(config, "num_vector_embeds", None) + check_config_value(config, "activation_fn", "gelu-approximate") + check_config_value(config, "num_embeds_ada_norm", 1000) + check_config_value(config, "use_linear_projection", False) + check_config_value(config, "only_cross_attention", False) + check_config_value(config, "cross_attention_norm", True) + check_config_value(config, "double_self_attention", False) + check_config_value(config, "upcast_attention", False) + check_config_value(config, "standardization_norm", "rms_norm") + check_config_value(config, "norm_elementwise_affine", False) + check_config_value(config, "qk_norm", "rms_norm") + check_config_value(config, "positional_embedding_type", "rope") + check_config_value(config, "use_middle_indices_grid", True) + + return LTXModel( + model_type=LTXModelType.VideoOnly, + num_attention_heads=config.get("num_attention_heads", 32), + attention_head_dim=config.get("attention_head_dim", 128), + in_channels=config.get("in_channels", 128), + out_channels=config.get("out_channels", 128), + num_layers=config.get("num_layers", 48), + cross_attention_dim=config.get("cross_attention_dim", 4096), + norm_eps=config.get("norm_eps", 1e-06), + attention_type=AttentionFunction(config.get("attention_type", "default")), + caption_channels=config.get("caption_channels", 3840), + positional_embedding_theta=config.get("positional_embedding_theta", 10000.0), + positional_embedding_max_pos=config.get("positional_embedding_max_pos", [20, 2048, 2048]), + timestep_scale_multiplier=config.get("timestep_scale_multiplier", 1000), + use_middle_indices_grid=config.get("use_middle_indices_grid", True), + rope_type=LTXRopeType(config.get("rope_type", "interleaved")), + double_precision_rope=config.get("frequencies_precision", False) == "float64", + ) + + +def _naive_weight_or_bias_downcast(key: str, value: torch.Tensor) -> list[KeyValueOperationResult]: + """ + Downcast the weight or bias to the float8_e4m3fn dtype. + """ + return [KeyValueOperationResult(key, value.to(dtype=torch.float8_e4m3fn))] + + +def _upcast_and_round( + weight: torch.Tensor, dtype: torch.dtype, with_stochastic_rounding: bool = False, seed: int = 0 +) -> torch.Tensor: + """ + Upcast the weight to the given dtype and optionally apply stochastic rounding. + Input weight needs to have float8_e4m3fn or float8_e5m2 dtype. + """ + if not with_stochastic_rounding: + return weight.to(dtype) + return fused_add_round_launch(torch.zeros_like(weight, dtype=dtype), weight, seed) + + +def replace_fwd_with_upcast(layer: torch.nn.Linear, with_stochastic_rounding: bool = False, seed: int = 0) -> None: + """ + Replace linear.forward and rms_norm.forward with a version that: + - upcasts weight and bias to input's dtype + - returns F.linear or F.rms_norm calculated in that dtype + """ + + layer.original_forward = layer.forward + + def new_linear_forward(*args, **_kwargs) -> torch.Tensor: + # assume first arg is the input tensor + x = args[0] + w_up = _upcast_and_round(layer.weight, x.dtype, with_stochastic_rounding, seed) + b_up = None + + if layer.bias is not None: + b_up = _upcast_and_round(layer.bias, x.dtype, with_stochastic_rounding, seed) + + return torch.nn.functional.linear(x, w_up, b_up) + + layer.forward = new_linear_forward + + +def amend_forward_with_upcast( + model: torch.nn.Module, with_stochastic_rounding: bool = False, seed: int = 0 +) -> torch.nn.Module: + """ + Replace the forward method of the model's Linear and RMSNorm layers to forward + with upcast and optional stochastic rounding. + """ + for m in model.modules(): + if isinstance(m, (torch.nn.Linear)): + replace_fwd_with_upcast(m, with_stochastic_rounding, seed) + return model + + +LTXV_MODEL_COMFY_RENAMING_MAP = ( + SDOps("LTXV_MODEL_COMFY_PREFIX_MAP") + .with_matching(prefix="model.diffusion_model.") + .with_replacement("model.diffusion_model.", "") +) + +LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP = ( + SDOps("LTXV_MODEL_COMFY_PREFIX_MAP") + .with_matching(prefix="model.diffusion_model.") + .with_replacement("model.diffusion_model.", "") + .with_kv_operation( + key_prefix="transformer_blocks.", key_suffix=".to_q.weight", operation=_naive_weight_or_bias_downcast + ) + .with_kv_operation( + key_prefix="transformer_blocks.", key_suffix=".to_q.bias", operation=_naive_weight_or_bias_downcast + ) + .with_kv_operation( + key_prefix="transformer_blocks.", key_suffix=".to_k.weight", operation=_naive_weight_or_bias_downcast + ) + .with_kv_operation( + key_prefix="transformer_blocks.", key_suffix=".to_k.bias", operation=_naive_weight_or_bias_downcast + ) + .with_kv_operation( + key_prefix="transformer_blocks.", key_suffix=".to_v.weight", operation=_naive_weight_or_bias_downcast + ) + .with_kv_operation( + key_prefix="transformer_blocks.", key_suffix=".to_v.bias", operation=_naive_weight_or_bias_downcast + ) + .with_kv_operation( + key_prefix="transformer_blocks.", key_suffix=".to_out.0.weight", operation=_naive_weight_or_bias_downcast + ) + .with_kv_operation( + key_prefix="transformer_blocks.", key_suffix=".to_out.0.bias", operation=_naive_weight_or_bias_downcast + ) + .with_kv_operation( + key_prefix="transformer_blocks.", key_suffix=".ff.net.0.proj.weight", operation=_naive_weight_or_bias_downcast + ) + .with_kv_operation( + key_prefix="transformer_blocks.", key_suffix=".ff.net.0.proj.bias", operation=_naive_weight_or_bias_downcast + ) + .with_kv_operation( + key_prefix="transformer_blocks.", key_suffix=".ff.net.2.weight", operation=_naive_weight_or_bias_downcast + ) + .with_kv_operation( + key_prefix="transformer_blocks.", key_suffix=".ff.net.2.bias", operation=_naive_weight_or_bias_downcast + ) +) + +UPCAST_DURING_INFERENCE = ModuleOps( + name="upcast_fp8_during_linear_forward", + matcher=lambda model: isinstance(model, LTXModel), + mutator=lambda model: amend_forward_with_upcast(model, False), +) + + +class UpcastWithStochasticRounding(ModuleOps): + def __new__(cls, seed: int = 0): + return super().__new__( + cls, + name="upcast_fp8_during_linear_forward_with_stochastic_rounding", + matcher=lambda model: isinstance(model, LTXModel), + mutator=lambda model: amend_forward_with_upcast(model, True, seed), + ) diff --git a/packages/ltx-core/src/ltx_core/model/transformer/rope.py b/packages/ltx-core/src/ltx_core/model/transformer/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..99dc66d14964bb1ef07fb6ad797cf0a54d615fb4 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/transformer/rope.py @@ -0,0 +1,207 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + +import functools +import math +from enum import Enum +from typing import Callable, Tuple + +import numpy as np +import torch +from einops import rearrange + + +class LTXRopeType(Enum): + INTERLEAVED = "interleaved" + SPLIT = "split" + + +def apply_rotary_emb( + input_tensor: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor], + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, +) -> torch.Tensor: + if rope_type == LTXRopeType.INTERLEAVED: + return apply_interleaved_rotary_emb(input_tensor, *freqs_cis) + elif rope_type == LTXRopeType.SPLIT: + return apply_split_rotary_emb(input_tensor, *freqs_cis) + else: + raise ValueError(f"Invalid rope type: {rope_type}") + + +def apply_interleaved_rotary_emb( + input_tensor: torch.Tensor, cos_freqs: torch.Tensor, sin_freqs: torch.Tensor +) -> torch.Tensor: + t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) + t1, t2 = t_dup.unbind(dim=-1) + t_dup = torch.stack((-t2, t1), dim=-1) + input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") + + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out + + +def apply_split_rotary_emb( + input_tensor: torch.Tensor, cos_freqs: torch.Tensor, sin_freqs: torch.Tensor +) -> torch.Tensor: + needs_reshape = False + if input_tensor.ndim != 4 and cos_freqs.ndim == 4: + b, h, t, _ = cos_freqs.shape + input_tensor = input_tensor.reshape(b, t, h, -1).swapaxes(1, 2) + needs_reshape = True + + split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2) + first_half_input = split_input[..., :1, :] + second_half_input = split_input[..., 1:, :] + + output = split_input * cos_freqs.unsqueeze(-2) + first_half_output = output[..., :1, :] + second_half_output = output[..., 1:, :] + + first_half_output.addcmul_(-sin_freqs.unsqueeze(-2), second_half_input) + second_half_output.addcmul_(sin_freqs.unsqueeze(-2), first_half_input) + + output = rearrange(output, "... d r -> ... (d r)") + if needs_reshape: + output = output.swapaxes(1, 2).reshape(b, t, -1) + + return output + + +@functools.lru_cache(maxsize=5) +def generate_freq_grid_np( + positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int +) -> torch.Tensor: + theta = positional_embedding_theta + start = 1 + end = theta + + n_elem = 2 * positional_embedding_max_pos_count + pow_indices = np.power( + theta, + np.linspace( + np.log(start) / np.log(theta), + np.log(end) / np.log(theta), + inner_dim // n_elem, + dtype=np.float64, + ), + ) + return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32) + + +@functools.lru_cache(maxsize=5) +def generate_freq_grid_pytorch( + positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int +) -> torch.Tensor: + theta = positional_embedding_theta + start = 1 + end = theta + n_elem = 2 * positional_embedding_max_pos_count + + indices = theta ** ( + torch.linspace( + math.log(start, theta), + math.log(end, theta), + inner_dim // n_elem, + dtype=torch.float32, + ) + ) + indices = indices.to(dtype=torch.float32) + + indices = indices * math.pi / 2 + + return indices + + +def get_fractional_positions(indices_grid: torch.Tensor, max_pos: list[int]) -> torch.Tensor: + n_pos_dims = indices_grid.shape[1] + assert n_pos_dims == len(max_pos), ( + f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})" + ) + fractional_positions = torch.stack( + [indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)], + dim=-1, + ) + return fractional_positions + + +def generate_freqs( + indices: torch.Tensor, indices_grid: torch.Tensor, max_pos: list[int], use_middle_indices_grid: bool +) -> torch.Tensor: + if use_middle_indices_grid: + assert len(indices_grid.shape) == 4 + assert indices_grid.shape[-1] == 2 + indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1] + indices_grid = (indices_grid_start + indices_grid_end) / 2.0 + elif len(indices_grid.shape) == 4: + indices_grid = indices_grid[..., 0] + + fractional_positions = get_fractional_positions(indices_grid, max_pos) + indices = indices.to(device=fractional_positions.device) + + freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2) + return freqs + + +def split_freqs_cis(freqs: torch.Tensor, pad_size: int, num_attention_heads: int) -> tuple[torch.Tensor, torch.Tensor]: + cos_freq = freqs.cos() + sin_freq = freqs.sin() + + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) + + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) + + # Reshape freqs to be compatible with multi-head attention + b = cos_freq.shape[0] + t = cos_freq.shape[1] + + cos_freq = cos_freq.reshape(b, t, num_attention_heads, -1) + sin_freq = sin_freq.reshape(b, t, num_attention_heads, -1) + + cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) + sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) + return cos_freq, sin_freq + + +def interleaved_freqs_cis(freqs: torch.Tensor, pad_size: int) -> tuple[torch.Tensor, torch.Tensor]: + cos_freq = freqs.cos().repeat_interleave(2, dim=-1) + sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(cos_freq[:, :, :pad_size]) + cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) + sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) + return cos_freq, sin_freq + + +def precompute_freqs_cis( + indices_grid: torch.Tensor, + dim: int, + out_dtype: torch.dtype, + theta: float = 10000.0, + max_pos: list[int] | None = None, + use_middle_indices_grid: bool = False, + num_attention_heads: int = 32, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + freq_grid_generator: Callable[[float, int, int, torch.device], torch.Tensor] = generate_freq_grid_pytorch, +) -> tuple[torch.Tensor, torch.Tensor]: + if max_pos is None: + max_pos = [20, 2048, 2048] + + indices = freq_grid_generator(theta, indices_grid.shape[1], dim) + freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid) + + if rope_type == LTXRopeType.SPLIT: + expected_freqs = dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads) + else: + # 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only + n_elem = 2 * indices_grid.shape[1] + cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) + return cos_freq.to(out_dtype), sin_freq.to(out_dtype) diff --git a/packages/ltx-core/src/ltx_core/model/transformer/text_projection.py b/packages/ltx-core/src/ltx_core/model/transformer/text_projection.py new file mode 100644 index 0000000000000000000000000000000000000000..185dbe3bac9d60e9e2a35f44cdbb39afed62b34a --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/transformer/text_projection.py @@ -0,0 +1,31 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + +import torch + + +class PixArtAlphaTextProjection(torch.nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_features: int, hidden_size: int, out_features: int | None = None, act_fn: str = "gelu_tanh"): + super().__init__() + if out_features is None: + out_features = hidden_size + self.linear_1 = torch.nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) + if act_fn == "gelu_tanh": + self.act_1 = torch.nn.GELU(approximate="tanh") + elif act_fn == "silu": + self.act_1 = torch.nn.SiLU() + else: + raise ValueError(f"Unknown activation function: {act_fn}") + self.linear_2 = torch.nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) + + def forward(self, caption: torch.Tensor) -> torch.Tensor: + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states diff --git a/packages/ltx-core/src/ltx_core/model/transformer/timestep_embedding.py b/packages/ltx-core/src/ltx_core/model/transformer/timestep_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..89aa0cf09b1472c151b1dbfd8ca22545cd8140a3 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/transformer/timestep_embedding.py @@ -0,0 +1,148 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + +import math + +import torch + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> torch.Tensor: + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class TimestepEmbedding(torch.nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + out_dim: int | None = None, + post_act_fn: str | None = None, + cond_proj_dim: int | None = None, + sample_proj_bias: bool = True, + ): + super().__init__() + + self.linear_1 = torch.nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + + if cond_proj_dim is not None: + self.cond_proj = torch.nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = torch.nn.SiLU() + time_embed_dim_out = out_dim if out_dim is not None else time_embed_dim + + self.linear_2 = torch.nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) + + if post_act_fn is None: + self.post_act = None + + def forward(self, sample: torch.Tensor, condition: torch.Tensor | None = None) -> torch.Tensor: + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Timesteps(torch.nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class PixArtAlphaCombinedTimestepSizeEmbeddings(torch.nn.Module): + """ + For PixArt-Alpha. + + Reference: + https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 + """ + + def __init__( + self, + embedding_dim: int, + size_emb_dim: int, + ): + super().__init__() + + self.outdim = size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward( + self, + timestep: torch.Tensor, + hidden_dtype: torch.dtype, + ) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + return timesteps_emb diff --git a/packages/ltx-core/src/ltx_core/model/transformer/transformer.py b/packages/ltx-core/src/ltx_core/model/transformer/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..8b4650a446d255ef2df8c80c09fc74863361ef4d --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/transformer/transformer.py @@ -0,0 +1,258 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + +from dataclasses import dataclass, replace + +import torch + +from ltx_core.guidance.perturbations import BatchedPerturbationConfig, PerturbationType +from ltx_core.model.transformer.attention import Attention, AttentionCallable, AttentionFunction +from ltx_core.model.transformer.feed_forward import FeedForward +from ltx_core.model.transformer.rope import LTXRopeType +from ltx_core.model.transformer.transformer_args import TransformerArgs +from ltx_core.utils import rms_norm + + +@dataclass +class TransformerConfig: + dim: int + heads: int + d_head: int + context_dim: int + + +class BasicAVTransformerBlock(torch.nn.Module): + def __init__( + self, + idx: int, + video: TransformerConfig | None = None, + audio: TransformerConfig | None = None, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + norm_eps: float = 1e-6, + attention_function: AttentionFunction | AttentionCallable = AttentionFunction.DEFAULT, + ): + super().__init__() + + self.idx = idx + if video is not None: + self.attn1 = Attention( + query_dim=video.dim, + heads=video.heads, + dim_head=video.d_head, + context_dim=None, + rope_type=rope_type, + norm_eps=norm_eps, + attention_function=attention_function, + ) + self.attn2 = Attention( + query_dim=video.dim, + context_dim=video.context_dim, + heads=video.heads, + dim_head=video.d_head, + rope_type=rope_type, + norm_eps=norm_eps, + attention_function=attention_function, + ) + self.ff = FeedForward(video.dim, dim_out=video.dim) + self.scale_shift_table = torch.nn.Parameter(torch.empty(6, video.dim)) + + if audio is not None: + self.audio_attn1 = Attention( + query_dim=audio.dim, + heads=audio.heads, + dim_head=audio.d_head, + context_dim=None, + rope_type=rope_type, + norm_eps=norm_eps, + attention_function=attention_function, + ) + self.audio_attn2 = Attention( + query_dim=audio.dim, + context_dim=audio.context_dim, + heads=audio.heads, + dim_head=audio.d_head, + rope_type=rope_type, + norm_eps=norm_eps, + attention_function=attention_function, + ) + self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim) + self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(6, audio.dim)) + + if audio is not None and video is not None: + # Q: Video, K,V: Audio + self.audio_to_video_attn = Attention( + query_dim=video.dim, + context_dim=audio.dim, + heads=audio.heads, + dim_head=audio.d_head, + rope_type=rope_type, + norm_eps=norm_eps, + attention_function=attention_function, + ) + + # Q: Audio, K,V: Video + self.video_to_audio_attn = Attention( + query_dim=audio.dim, + context_dim=video.dim, + heads=audio.heads, + dim_head=audio.d_head, + rope_type=rope_type, + norm_eps=norm_eps, + attention_function=attention_function, + ) + + self.scale_shift_table_a2v_ca_audio = torch.nn.Parameter(torch.empty(5, audio.dim)) + self.scale_shift_table_a2v_ca_video = torch.nn.Parameter(torch.empty(5, video.dim)) + + self.norm_eps = norm_eps + + def get_ada_values( + self, + scale_shift_table: torch.Tensor, + batch_size: int, + timestep: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + num_ada_params = scale_shift_table.shape[0] + + ada_values = ( + scale_shift_table.unsqueeze(0).unsqueeze(0).to(timestep.dtype) + + timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1) + ).unbind(dim=2) + return ada_values + + def get_av_ca_ada_values( + self, + scale_shift_table: torch.Tensor, + batch_size: int, + scale_shift_timestep: torch.Tensor, + gate_timestep: torch.Tensor, + num_scale_shift_values: int = 4, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + scale_shift_ada_values = self.get_ada_values( + scale_shift_table[:num_scale_shift_values, :], + batch_size, + scale_shift_timestep, + ) + gate_ada_values = self.get_ada_values( + scale_shift_table[num_scale_shift_values:, :], + batch_size, + gate_timestep, + ) + + scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values] + gate_ada_values = [t.squeeze(2) for t in gate_ada_values] + + return (*scale_shift_chunks, *gate_ada_values) + + def forward( + self, + video: TransformerArgs | None, + audio: TransformerArgs | None, + perturbations: BatchedPerturbationConfig | None = None, + ) -> tuple[TransformerArgs | None, TransformerArgs | None]: + batch_size = video.x.shape[0] + if perturbations is None: + perturbations = BatchedPerturbationConfig.empty(batch_size) + + vx = video.x if video is not None else None + ax = audio.x if audio is not None else None + + run_vx = video is not None and video.enabled and vx.numel() > 0 + run_ax = audio is not None and audio.enabled and ax.numel() > 0 + + run_a2v = run_vx and (audio is not None and audio.enabled and ax.numel() > 0) + run_v2a = run_ax and (video is not None and video.enabled and vx.numel() > 0) + + if run_vx: + vshift_msa, vscale_msa, vgate_msa, vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values( + self.scale_shift_table, vx.shape[0], video.timesteps + ) + if not perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx): + norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa + v_mask = perturbations.mask_like(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx, vx) + vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa * v_mask + + vx = vx + self.attn2(rms_norm(vx, eps=self.norm_eps), context=video.context, mask=video.context_mask) + + if run_ax: + ashift_msa, ascale_msa, agate_msa, ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values( + self.audio_scale_shift_table, ax.shape[0], audio.timesteps + ) + + if not perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx): + norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa + a_mask = perturbations.mask_like(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx, ax) + ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa * a_mask + + ax = ax + self.audio_attn2(rms_norm(ax, eps=self.norm_eps), context=audio.context, mask=audio.context_mask) + + # Audio - Video cross attention. + if run_a2v or run_v2a: + vx_norm3 = rms_norm(vx, eps=self.norm_eps) + ax_norm3 = rms_norm(ax, eps=self.norm_eps) + + ( + scale_ca_audio_hidden_states_a2v, + shift_ca_audio_hidden_states_a2v, + scale_ca_audio_hidden_states_v2a, + shift_ca_audio_hidden_states_v2a, + gate_out_v2a, + ) = self.get_av_ca_ada_values( + self.scale_shift_table_a2v_ca_audio, + ax.shape[0], + audio.cross_scale_shift_timestep, + audio.cross_gate_timestep, + ) + + ( + scale_ca_video_hidden_states_a2v, + shift_ca_video_hidden_states_a2v, + scale_ca_video_hidden_states_v2a, + shift_ca_video_hidden_states_v2a, + gate_out_a2v, + ) = self.get_av_ca_ada_values( + self.scale_shift_table_a2v_ca_video, + vx.shape[0], + video.cross_scale_shift_timestep, + video.cross_gate_timestep, + ) + + if run_a2v: + vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v) + shift_ca_video_hidden_states_a2v + ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v + a2v_mask = perturbations.mask_like(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx, vx) + vx = vx + ( + self.audio_to_video_attn( + vx_scaled, + context=ax_scaled, + pe=video.cross_positional_embeddings, + k_pe=audio.cross_positional_embeddings, + ) + * gate_out_a2v + * a2v_mask + ) + + if run_v2a: + ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a + vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a + v2a_mask = perturbations.mask_like(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx, ax) + ax = ax + ( + self.video_to_audio_attn( + ax_scaled, + context=vx_scaled, + pe=audio.cross_positional_embeddings, + k_pe=video.cross_positional_embeddings, + ) + * gate_out_v2a + * v2a_mask + ) + + if run_vx: + vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp + vx = vx + self.ff(vx_scaled) * vgate_mlp + + if run_ax: + ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp + ax = ax + self.audio_ff(ax_scaled) * agate_mlp + + return replace(video, x=vx) if video is not None else None, replace(audio, x=ax) if audio is not None else None diff --git a/packages/ltx-core/src/ltx_core/model/transformer/transformer_args.py b/packages/ltx-core/src/ltx_core/model/transformer/transformer_args.py new file mode 100644 index 0000000000000000000000000000000000000000..58dffa2a6ad46d371045b59ef3601cc10fd2b955 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/transformer/transformer_args.py @@ -0,0 +1,242 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + +from dataclasses import dataclass, replace + +import torch + +from ltx_core.model.transformer.adaln import AdaLayerNormSingle +from ltx_core.model.transformer.modality import Modality +from ltx_core.model.transformer.rope import ( + LTXRopeType, + generate_freq_grid_np, + generate_freq_grid_pytorch, + precompute_freqs_cis, +) +from ltx_core.model.transformer.text_projection import PixArtAlphaTextProjection + + +@dataclass(frozen=True) +class TransformerArgs: + x: torch.Tensor + context: torch.Tensor + context_mask: torch.Tensor + timesteps: torch.Tensor + embedded_timestep: torch.Tensor + positional_embeddings: torch.Tensor + cross_positional_embeddings: torch.Tensor | None + cross_scale_shift_timestep: torch.Tensor | None + cross_gate_timestep: torch.Tensor | None + enabled: bool + + +class TransformerArgsPreprocessor: + def __init__( # noqa: PLR0913 + self, + patchify_proj: torch.nn.Linear, + adaln: AdaLayerNormSingle, + caption_projection: PixArtAlphaTextProjection, + inner_dim: int, + max_pos: list[int], + num_attention_heads: int, + use_middle_indices_grid: bool, + timestep_scale_multiplier: int, + double_precision_rope: bool, + positional_embedding_theta: float, + rope_type: LTXRopeType, + ) -> None: + self.patchify_proj = patchify_proj + self.adaln = adaln + self.caption_projection = caption_projection + self.inner_dim = inner_dim + self.max_pos = max_pos + self.num_attention_heads = num_attention_heads + self.use_middle_indices_grid = use_middle_indices_grid + self.timestep_scale_multiplier = timestep_scale_multiplier + self.double_precision_rope = double_precision_rope + self.positional_embedding_theta = positional_embedding_theta + self.rope_type = rope_type + + def _prepare_timestep( + self, timestep: torch.Tensor, batch_size: int, hidden_dtype: torch.dtype + ) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare timestep embeddings.""" + + timestep = timestep * self.timestep_scale_multiplier + timestep, embedded_timestep = self.adaln( + timestep.flatten(), + hidden_dtype=hidden_dtype, + ) + + # Second dimension is 1 or number of tokens (if timestep_per_token) + timestep = timestep.view(batch_size, -1, timestep.shape[-1]) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1]) + return timestep, embedded_timestep + + def _prepare_context( + self, + context: torch.Tensor, + x: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Prepare context for transformer blocks.""" + batch_size = x.shape[0] + context = self.caption_projection(context) + context = context.view(batch_size, -1, x.shape[-1]) + + return context, attention_mask + + def _prepare_attention_mask(self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype) -> torch.Tensor | None: + """Prepare attention mask.""" + if attention_mask is None or torch.is_floating_point(attention_mask): + return attention_mask + + return (attention_mask - 1).to(x_dtype).reshape( + (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + ) * torch.finfo(x_dtype).max + + def _prepare_positional_embeddings( + self, + positions: torch.Tensor, + inner_dim: int, + max_pos: list[int], + use_middle_indices_grid: bool, + num_attention_heads: int, + x_dtype: torch.dtype, + ) -> torch.Tensor: + """Prepare positional embeddings.""" + freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch + pe = precompute_freqs_cis( + positions, + dim=inner_dim, + out_dtype=x_dtype, + theta=self.positional_embedding_theta, + max_pos=max_pos, + use_middle_indices_grid=use_middle_indices_grid, + num_attention_heads=num_attention_heads, + rope_type=self.rope_type, + freq_grid_generator=freq_grid_generator, + ) + return pe + + def prepare( + self, + modality: Modality, + ) -> TransformerArgs: + x = self.patchify_proj(modality.latent) + timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], modality.latent.dtype) + context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask) + attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype) + pe = self._prepare_positional_embeddings( + positions=modality.positions, + inner_dim=self.inner_dim, + max_pos=self.max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + num_attention_heads=self.num_attention_heads, + x_dtype=modality.latent.dtype, + ) + return TransformerArgs( + x=x, + context=context, + context_mask=attention_mask, + timesteps=timestep, + embedded_timestep=embedded_timestep, + positional_embeddings=pe, + cross_positional_embeddings=None, + cross_scale_shift_timestep=None, + cross_gate_timestep=None, + enabled=modality.enabled, + ) + + +class MultiModalTransformerArgsPreprocessor: + def __init__( # noqa: PLR0913 + self, + patchify_proj: torch.nn.Linear, + adaln: AdaLayerNormSingle, + caption_projection: PixArtAlphaTextProjection, + cross_scale_shift_adaln: AdaLayerNormSingle, + cross_gate_adaln: AdaLayerNormSingle, + inner_dim: int, + max_pos: list[int], + num_attention_heads: int, + cross_pe_max_pos: int, + use_middle_indices_grid: bool, + audio_cross_attention_dim: int, + timestep_scale_multiplier: int, + double_precision_rope: bool, + positional_embedding_theta: float, + rope_type: LTXRopeType, + av_ca_timestep_scale_multiplier: int, + ) -> None: + self.simple_preprocessor = TransformerArgsPreprocessor( + patchify_proj=patchify_proj, + adaln=adaln, + caption_projection=caption_projection, + inner_dim=inner_dim, + max_pos=max_pos, + num_attention_heads=num_attention_heads, + use_middle_indices_grid=use_middle_indices_grid, + timestep_scale_multiplier=timestep_scale_multiplier, + double_precision_rope=double_precision_rope, + positional_embedding_theta=positional_embedding_theta, + rope_type=rope_type, + ) + self.cross_scale_shift_adaln = cross_scale_shift_adaln + self.cross_gate_adaln = cross_gate_adaln + self.cross_pe_max_pos = cross_pe_max_pos + self.audio_cross_attention_dim = audio_cross_attention_dim + self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier + + def prepare( + self, + modality: Modality, + ) -> TransformerArgs: + transformer_args = self.simple_preprocessor.prepare(modality) + cross_pe = self.simple_preprocessor._prepare_positional_embeddings( + positions=modality.positions[:, 0:1, :], + inner_dim=self.audio_cross_attention_dim, + max_pos=[self.cross_pe_max_pos], + use_middle_indices_grid=True, + num_attention_heads=self.simple_preprocessor.num_attention_heads, + x_dtype=modality.latent.dtype, + ) + + cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep( + timestep=modality.timesteps, + timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier, + batch_size=transformer_args.x.shape[0], + hidden_dtype=modality.latent.dtype, + ) + + return replace( + transformer_args, + cross_positional_embeddings=cross_pe, + cross_scale_shift_timestep=cross_scale_shift_timestep, + cross_gate_timestep=cross_gate_timestep, + ) + + def _prepare_cross_attention_timestep( + self, + timestep: torch.Tensor, + timestep_scale_multiplier: int, + batch_size: int, + hidden_dtype: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare cross attention timestep embeddings.""" + timestep = timestep * timestep_scale_multiplier + + av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier + + scale_shift_timestep, _ = self.cross_scale_shift_adaln( + timestep.flatten(), + hidden_dtype=hidden_dtype, + ) + scale_shift_timestep = scale_shift_timestep.view(batch_size, -1, scale_shift_timestep.shape[-1]) + gate_noise_timestep, _ = self.cross_gate_adaln( + timestep.flatten() * av_ca_factor, + hidden_dtype=hidden_dtype, + ) + gate_noise_timestep = gate_noise_timestep.view(batch_size, -1, gate_noise_timestep.shape[-1]) + + return scale_shift_timestep, gate_noise_timestep diff --git a/packages/ltx-core/src/ltx_core/model/upsampler/__init__.py b/packages/ltx-core/src/ltx_core/model/upsampler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/packages/ltx-core/src/ltx_core/model/upsampler/__pycache__/__init__.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/upsampler/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42855c437d68bf95ee439796becb0968ae942af7 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/upsampler/__pycache__/__init__.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/upsampler/__pycache__/blur_downsample.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/upsampler/__pycache__/blur_downsample.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3875295dc925ff86ccbd7bcba99394634903561 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/upsampler/__pycache__/blur_downsample.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/upsampler/__pycache__/model.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/upsampler/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..817e5a0b42e61b6728b56f9eed8f4210c988adf6 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/upsampler/__pycache__/model.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/upsampler/__pycache__/model_configurator.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/upsampler/__pycache__/model_configurator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a17afbaeb26fed91868ecfad9b7442c57809799 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/upsampler/__pycache__/model_configurator.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/upsampler/__pycache__/pixel_shuffle.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/upsampler/__pycache__/pixel_shuffle.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd1cd750d1f21b61c995a01540643c51b57c2f9e Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/upsampler/__pycache__/pixel_shuffle.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/upsampler/__pycache__/res_block.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/upsampler/__pycache__/res_block.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c444e573831c11e1e2d7b7c7f109445f1895a50a Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/upsampler/__pycache__/res_block.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/upsampler/__pycache__/spatial_rational_resampler.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/upsampler/__pycache__/spatial_rational_resampler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c49b885c09ed8d81aa5686677bd1743d3d807b2 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/upsampler/__pycache__/spatial_rational_resampler.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/upsampler/blur_downsample.py b/packages/ltx-core/src/ltx_core/model/upsampler/blur_downsample.py new file mode 100644 index 0000000000000000000000000000000000000000..2e540cda255c9fc0dd2769c94da5bc6871727f88 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/upsampler/blur_downsample.py @@ -0,0 +1,56 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + +import math + +import torch +import torch.nn.functional as F +from einops import rearrange + + +class BlurDownsample(torch.nn.Module): + """ + Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. + Applies only on H,W. Works for dims=2 or dims=3 (per-frame). + """ + + def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None: + super().__init__() + assert dims in (2, 3) + assert isinstance(stride, int) + assert stride >= 1 + assert kernel_size >= 3 + assert kernel_size % 2 == 1 + self.dims = dims + self.stride = stride + self.kernel_size = kernel_size + + # 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from + # the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and + # provides a smooth approximation of a Gaussian filter (often called a "binomial filter"). + # The 2D kernel is constructed as the outer product and normalized. + k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)]) + k2d = k[:, None] @ k[None, :] + k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size) + self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.stride == 1: + return x + + if self.dims == 2: + return self._apply_2d(x) + else: + # dims == 3: apply per-frame on H,W + b, _, f, _, _ = x.shape + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self._apply_2d(x) + h2, w2 = x.shape[-2:] + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2) + return x + + def _apply_2d(self, x2d: torch.Tensor) -> torch.Tensor: + c = x2d.shape[1] + weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise + x2d = F.conv2d(x2d, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c) + return x2d diff --git a/packages/ltx-core/src/ltx_core/model/upsampler/model.py b/packages/ltx-core/src/ltx_core/model/upsampler/model.py new file mode 100644 index 0000000000000000000000000000000000000000..7de29e95a09bd45fad38b7d44c49221ae76119c3 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/upsampler/model.py @@ -0,0 +1,129 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + + +import torch +from einops import rearrange + +from ltx_core.model.upsampler.pixel_shuffle import PixelShuffleND +from ltx_core.model.upsampler.res_block import ResBlock +from ltx_core.model.upsampler.spatial_rational_resampler import SpatialRationalResampler + + +class LatentUpsampler(torch.nn.Module): + """ + Model to spatially upsample VAE latents. + + Args: + in_channels (`int`): Number of channels in the input latent + mid_channels (`int`): Number of channels in the middle layers + num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`): Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`): Whether to spatially upsample the latent + temporal_upsample (`bool`): Whether to temporally upsample the latent + spatial_scale (`float`): Scale factor for spatial upsampling + rational_resampler (`bool`): Whether to use a rational resampler for spatial upsampling + """ + + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 512, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + spatial_scale: float = 2.0, + rational_resampler: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + self.spatial_scale = float(spatial_scale) + self.rational_resampler = rational_resampler + + conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.initial_conv = conv(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = torch.nn.GroupNorm(32, mid_channels) + self.initial_activation = torch.nn.SiLU() + + self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]) + + if spatial_upsample and temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + if rational_resampler: + self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=self.spatial_scale) + else: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError("Either spatial_upsample or temporal_upsample must be True") + + self.post_upsample_res_blocks = torch.nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + self.final_conv = conv(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, latent: torch.Tensor) -> torch.Tensor: + b, _, f, _, _ = latent.shape + + if self.dims == 2: + x = rearrange(latent, "b c f h w -> (b f) c h w") + x = self.initial_conv(x) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + x = self.upsampler(x) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + else: + x = self.initial_conv(latent) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + if self.temporal_upsample: + x = self.upsampler(x) + # remove the first frame after upsampling. + # This is done because the first frame encodes one pixel frame. + x = x[:, :, 1:, :, :] + elif isinstance(self.upsampler, SpatialRationalResampler): + x = self.upsampler(x) + else: + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.upsampler(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + + return x diff --git a/packages/ltx-core/src/ltx_core/model/upsampler/model_configurator.py b/packages/ltx-core/src/ltx_core/model/upsampler/model_configurator.py new file mode 100644 index 0000000000000000000000000000000000000000..2adbbe6bb104a7d0f555ed18ed5078eceef000e4 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/upsampler/model_configurator.py @@ -0,0 +1,25 @@ +from ltx_core.model.model_protocol import ModelConfigurator +from ltx_core.model.upsampler.model import LatentUpsampler + + +class LatentUpsamplerConfigurator(ModelConfigurator[LatentUpsampler]): + @classmethod + def from_config(cls: type[LatentUpsampler], config: dict) -> LatentUpsampler: + in_channels = config.get("in_channels", 128) + mid_channels = config.get("mid_channels", 512) + num_blocks_per_stage = config.get("num_blocks_per_stage", 4) + dims = config.get("dims", 3) + spatial_upsample = config.get("spatial_upsample", True) + temporal_upsample = config.get("temporal_upsample", False) + spatial_scale = config.get("spatial_scale", 2.0) + rational_resampler = config.get("rational_resampler", False) + return LatentUpsampler( + in_channels=in_channels, + mid_channels=mid_channels, + num_blocks_per_stage=num_blocks_per_stage, + dims=dims, + spatial_upsample=spatial_upsample, + temporal_upsample=temporal_upsample, + spatial_scale=spatial_scale, + rational_resampler=rational_resampler, + ) diff --git a/packages/ltx-core/src/ltx_core/model/upsampler/pixel_shuffle.py b/packages/ltx-core/src/ltx_core/model/upsampler/pixel_shuffle.py new file mode 100644 index 0000000000000000000000000000000000000000..53ccd780b3ffc009dd97215694e95f87df1b8f64 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/upsampler/pixel_shuffle.py @@ -0,0 +1,60 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + +import torch +from einops import rearrange + + +class PixelShuffleND(torch.nn.Module): + """ + N-dimensional pixel shuffle operation for upsampling tensors. + + Args: + dims (int): Number of dimensions to apply pixel shuffle to. + - 1: Temporal (e.g., frames) + - 2: Spatial (e.g., height and width) + - 3: Spatiotemporal (e.g., depth, height, width) + upscale_factors (tuple[int, int, int], optional): Upscaling factors for each dimension. + For dims=1, only the first value is used. + For dims=2, the first two values are used. + For dims=3, all three values are used. + + The input tensor is rearranged so that the channel dimension is split into + smaller channels and upscaling factors, and the upscaling factors are moved + into the corresponding spatial/temporal dimensions. + + Note: + This operation is equivalent to the patchifier operation in for the models. Consider + using this class instead. + """ + + def __init__(self, dims: int, upscale_factors: tuple[int, int, int] = (2, 2, 2)): + super().__init__() + assert dims in [1, 2, 3], "dims must be 1, 2, or 3" + self.dims = dims + self.upscale_factors = upscale_factors + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.dims == 3: + return rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + p3=self.upscale_factors[2], + ) + elif self.dims == 2: + return rearrange( + x, + "b (c p1 p2) h w -> b c (h p1) (w p2)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + ) + elif self.dims == 1: + return rearrange( + x, + "b (c p1) f h w -> b c (f p1) h w", + p1=self.upscale_factors[0], + ) + else: + raise ValueError(f"Unsupported dims: {self.dims}") diff --git a/packages/ltx-core/src/ltx_core/model/upsampler/res_block.py b/packages/ltx-core/src/ltx_core/model/upsampler/res_block.py new file mode 100644 index 0000000000000000000000000000000000000000..9600056e3d20ed5765c965ee0bb0d8a08f34937e --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/upsampler/res_block.py @@ -0,0 +1,40 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + +from typing import Optional + +import torch + + +class ResBlock(torch.nn.Module): + """ + Residual block with two convolutional layers, group normalization, and SiLU activation. + Args: + channels (int): Number of input and output channels. + mid_channels (Optional[int]): Number of channels in the intermediate convolution layer. Defaults to `channels` + if not specified. + dims (int): Dimensionality of the convolution (2 for Conv2d, 3 for Conv3d). Defaults to 3. + """ + + def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3): + super().__init__() + if mid_channels is None: + mid_channels = channels + + conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.conv1 = conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = torch.nn.GroupNorm(32, mid_channels) + self.conv2 = conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = torch.nn.GroupNorm(32, channels) + self.activation = torch.nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.conv1(x) + x = self.norm1(x) + x = self.activation(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.activation(x + residual) + return x diff --git a/packages/ltx-core/src/ltx_core/model/upsampler/spatial_rational_resampler.py b/packages/ltx-core/src/ltx_core/model/upsampler/spatial_rational_resampler.py new file mode 100644 index 0000000000000000000000000000000000000000..73e92b75838db509e956377efc8f5ec589bc4ff3 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/upsampler/spatial_rational_resampler.py @@ -0,0 +1,52 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + +from typing import Tuple + +import torch +from einops import rearrange + +from ltx_core.model.upsampler.blur_downsample import BlurDownsample +from ltx_core.model.upsampler.pixel_shuffle import PixelShuffleND + + +def _rational_for_scale(scale: float) -> Tuple[int, int]: + mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)} + if float(scale) not in mapping: + raise ValueError(f"Unsupported scale {scale}. Choose from {list(mapping.keys())}") + return mapping[float(scale)] + + +class SpatialRationalResampler(torch.nn.Module): + """ + Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased + downsample by 'den' using fixed blur + stride. Operates on H,W only. + + For dims==3, work per-frame for spatial scaling (temporal axis untouched). + + Args: + mid_channels (`int`): Number of intermediate channels for the convolution layer + scale (`float`): Spatial scaling factor. Supported values are: + - 0.75: Downsample by 3/4 (reduce spatial size) + - 1.5: Upsample by 3/2 (increase spatial size) + - 2.0: Upsample by 2x (double spatial size) + - 4.0: Upsample by 4x (quadruple spatial size) + Any other value will raise a ValueError. + """ + + def __init__(self, mid_channels: int, scale: float): + super().__init__() + self.scale = float(scale) + self.num, self.den = _rational_for_scale(self.scale) + self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1) + self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num)) + self.blur_down = BlurDownsample(dims=2, stride=self.den) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, _, f, _, _ = x.shape + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.conv(x) + x = self.pixel_shuffle(x) + x = self.blur_down(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + return x diff --git a/packages/ltx-core/src/ltx_core/model/video_vae/__init__.py b/packages/ltx-core/src/ltx_core/model/video_vae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb2e8abfdd86bdbb9c064794d50e68ef93f5e626 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/video_vae/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Ivan Zorin + +"""Video VAE package.""" + +__all__: list[str] = [] diff --git a/packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/__init__.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0de9f72c060e8563dec5571e30e54ae17599078 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/__init__.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/convolution.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/convolution.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bb174a02e1a301700b5b8424ae2e33b6246127f Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/convolution.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/enums.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/enums.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15fa3d706281ef57031ae282ffc4644c437c4c6c Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/enums.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/model_configurator.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/model_configurator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cffdb9705f2ec9ab24546d6cdfcc85a319daf88 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/model_configurator.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/ops.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a3891e58d6a1473ce1768d5dc2311a2a47b547e Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/ops.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/resnet.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/resnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad29d94caa5f7bbda1b9ab0220ccac68e6fdaf76 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/resnet.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/sampling.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/sampling.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd23f5c10d06b82da93c1aca52d4c54c9da1b38a Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/sampling.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/video_vae.cpython-310.pyc b/packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/video_vae.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..182c64f3e65f38a95be5e3eaed2076733e2a4ca0 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/video_vae.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/model/video_vae/convolution.py b/packages/ltx-core/src/ltx_core/model/video_vae/convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..8e28dec68fce9008ffa3a1b8f6f995aa98883877 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/video_vae/convolution.py @@ -0,0 +1,320 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Ivan Zorin + +from typing import Tuple, Union + +import torch +from einops import rearrange +from torch import nn +from torch.nn import functional as F + +from ltx_core.model.video_vae.enums import PaddingModeType + + +def make_conv_nd( # noqa: PLR0913 + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causal: bool = False, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + temporal_padding_mode: PaddingModeType = PaddingModeType.ZEROS, +) -> nn.Module: + if not (spatial_padding_mode == temporal_padding_mode or causal): + raise NotImplementedError("spatial and temporal padding modes must be equal") + if dims == 2: + return nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=spatial_padding_mode.value, + ) + elif dims == 3: + if causal: + return CausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + bias=bias, + spatial_padding_mode=spatial_padding_mode, + ) + return nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=spatial_padding_mode.value, + ) + elif dims == (2, 1): + return DualConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + padding_mode=spatial_padding_mode.value, + ) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +def make_linear_nd( + dims: int, + in_channels: int, + out_channels: int, + bias: bool = True, +) -> nn.Module: + if dims == 2: + return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias) + elif dims in (3, (2, 1)): + return nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +class DualConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + ) -> None: + super(DualConv3d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.padding_mode = padding_mode + # Ensure kernel_size, stride, padding, and dilation are tuples of length 3 + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + if kernel_size == (1, 1, 1): + raise ValueError("kernel_size must be greater than 1. Use make_linear_nd instead.") + if isinstance(stride, int): + stride = (stride, stride, stride) + if isinstance(padding, int): + padding = (padding, padding, padding) + if isinstance(dilation, int): + dilation = (dilation, dilation, dilation) + + # Set parameters for convolutions + self.groups = groups + self.bias = bias + + # Define the size of the channels after the first convolution + intermediate_channels = out_channels if in_channels < out_channels else in_channels + + # Define parameters for the first convolution + self.weight1 = nn.Parameter( + torch.Tensor( + intermediate_channels, + in_channels // groups, + 1, + kernel_size[1], + kernel_size[2], + ) + ) + self.stride1 = (1, stride[1], stride[2]) + self.padding1 = (0, padding[1], padding[2]) + self.dilation1 = (1, dilation[1], dilation[2]) + if bias: + self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels)) + else: + self.register_parameter("bias1", None) + + # Define parameters for the second convolution + self.weight2 = nn.Parameter(torch.Tensor(out_channels, intermediate_channels // groups, kernel_size[0], 1, 1)) + self.stride2 = (stride[0], 1, 1) + self.padding2 = (padding[0], 0, 0) + self.dilation2 = (dilation[0], 1, 1) + if bias: + self.bias2 = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter("bias2", None) + + # Initialize weights and biases + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.kaiming_uniform_(self.weight1, a=torch.sqrt(5)) + nn.init.kaiming_uniform_(self.weight2, a=torch.sqrt(5)) + if self.bias: + fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1) + bound1 = 1 / torch.sqrt(fan_in1) + nn.init.uniform_(self.bias1, -bound1, bound1) + fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2) + bound2 = 1 / torch.sqrt(fan_in2) + nn.init.uniform_(self.bias2, -bound2, bound2) + + def forward( + self, + x: torch.Tensor, + use_conv3d: bool = False, + skip_time_conv: bool = False, + ) -> torch.Tensor: + if use_conv3d: + return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv) + else: + return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv) + + def forward_with_3d(self, x: torch.Tensor, skip_time_conv: bool = False) -> torch.Tensor: + # First convolution + x = F.conv3d( + x, + self.weight1, + self.bias1, + self.stride1, + self.padding1, + self.dilation1, + self.groups, + padding_mode=self.padding_mode, + ) + + if skip_time_conv: + return x + + # Second convolution + x = F.conv3d( + x, + self.weight2, + self.bias2, + self.stride2, + self.padding2, + self.dilation2, + self.groups, + padding_mode=self.padding_mode, + ) + + return x + + def forward_with_2d(self, x: torch.Tensor, skip_time_conv: bool = False) -> torch.Tensor: + b, _, _, h, w = x.shape + + # First 2D convolution + x = rearrange(x, "b c d h w -> (b d) c h w") + # Squeeze the depth dimension out of weight1 since it's 1 + weight1 = self.weight1.squeeze(2) + # Select stride, padding, and dilation for the 2D convolution + stride1 = (self.stride1[1], self.stride1[2]) + padding1 = (self.padding1[1], self.padding1[2]) + dilation1 = (self.dilation1[1], self.dilation1[2]) + x = F.conv2d( + x, + weight1, + self.bias1, + stride1, + padding1, + dilation1, + self.groups, + padding_mode=self.padding_mode, + ) + + _, _, h, w = x.shape + + if skip_time_conv: + x = rearrange(x, "(b d) c h w -> b c d h w", b=b) + return x + + # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension + x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b) + + # Reshape weight2 to match the expected dimensions for conv1d + weight2 = self.weight2.squeeze(-1).squeeze(-1) + # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution + stride2 = self.stride2[0] + padding2 = self.padding2[0] + dilation2 = self.dilation2[0] + x = F.conv1d( + x, + weight2, + self.bias2, + stride2, + padding2, + dilation2, + self.groups, + padding_mode=self.padding_mode, + ) + x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) + + return x + + @property + def weight(self) -> torch.Tensor: + return self.weight2 + + +class CausalConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: Union[int, Tuple[int]] = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ) -> None: + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + kernel_size = (kernel_size, kernel_size, kernel_size) + self.time_kernel_size = kernel_size[0] + + dilation = (dilation, 1, 1) + + height_pad = kernel_size[1] // 2 + width_pad = kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + padding_mode=spatial_padding_mode.value, + groups=groups, + bias=bias, + ) + + def forward(self, x: torch.Tensor, causal: bool = True) -> torch.Tensor: + if causal: + first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) + x = torch.concatenate((first_frame_pad, x), dim=2) + else: + first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)) + last_frame_pad = x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)) + x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) + x = self.conv(x) + return x + + @property + def weight(self) -> torch.Tensor: + return self.conv.weight diff --git a/packages/ltx-core/src/ltx_core/model/video_vae/enums.py b/packages/ltx-core/src/ltx_core/model/video_vae/enums.py new file mode 100644 index 0000000000000000000000000000000000000000..c0ba593dca049de7a71c7ca3cbced122dff0dd13 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/video_vae/enums.py @@ -0,0 +1,23 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Ivan Zorin + +from enum import Enum + + +class NormLayerType(Enum): + GROUP_NORM = "group_norm" + PIXEL_NORM = "pixel_norm" + + +class LogVarianceType(Enum): + PER_CHANNEL = "per_channel" + UNIFORM = "uniform" + CONSTANT = "constant" + NONE = "none" + + +class PaddingModeType(Enum): + ZEROS = "zeros" + REFLECT = "reflect" + REPLICATE = "replicate" + CIRCULAR = "circular" diff --git a/packages/ltx-core/src/ltx_core/model/video_vae/model_configurator.py b/packages/ltx-core/src/ltx_core/model/video_vae/model_configurator.py new file mode 100644 index 0000000000000000000000000000000000000000..b498629c28a773ea71e3f1f314ac86b2c57db241 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/video_vae/model_configurator.py @@ -0,0 +1,76 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Ivan Zorin + +from ltx_core.loader.sd_ops import SDOps +from ltx_core.model.model_protocol import ModelConfigurator +from ltx_core.model.video_vae.enums import LogVarianceType, NormLayerType, PaddingModeType +from ltx_core.model.video_vae.video_vae import Decoder, Encoder + + +class VAEEncoderConfigurator(ModelConfigurator[Encoder]): + @classmethod + def from_config(cls: type[Encoder], config: dict) -> Encoder: + config = config.get("vae", {}) + convolution_dimensions = config.get("dims", 3) + in_channels = config.get("in_channels", 3) + latent_channels = config.get("latent_channels", 128) + encoder_spatial_padding_mode = PaddingModeType(config.get("encoder_spatial_padding_mode", "zeros")) + encoder_blocks = config.get("encoder_blocks", []) + patch_size = config.get("patch_size", 4) + norm_layer_str = config.get("norm_layer", "pixel_norm") + latent_log_var_str = config.get("latent_log_var", "uniform") + + return Encoder( + convolution_dimensions=convolution_dimensions, + in_channels=in_channels, + out_channels=latent_channels, + encoder_blocks=encoder_blocks, + patch_size=patch_size, + norm_layer=NormLayerType(norm_layer_str), + latent_log_var=LogVarianceType(latent_log_var_str), + encoder_spatial_padding_mode=encoder_spatial_padding_mode, + ) + + +class VAEDecoderConfigurator(ModelConfigurator[Decoder]): + @classmethod + def from_config(cls: type[Decoder], config: dict) -> Decoder: + config = config.get("vae", {}) + convolution_dimensions = config.get("dims", 3) + latent_channels = config.get("latent_channels", 128) + decoder_spatial_padding_mode = PaddingModeType(config.get("decoder_spatial_padding_mode", "reflect")) + out_channels = config.get("out_channels", 3) + decoder_blocks = config.get("decoder_blocks", []) + patch_size = config.get("patch_size", 4) + norm_layer_str = config.get("norm_layer", "pixel_norm") + causal = config.get("causal_decoder", False) + timestep_conditioning = config.get("timestep_conditioning", True) + + return Decoder( + convolution_dimensions=convolution_dimensions, + in_channels=latent_channels, + out_channels=out_channels, + decoder_blocks=decoder_blocks, + patch_size=patch_size, + norm_layer=NormLayerType(norm_layer_str), + causal=causal, + timestep_conditioning=timestep_conditioning, + decoder_spatial_padding_mode=decoder_spatial_padding_mode, + ) + + +VAE_DECODER_COMFY_KEYS_FILTER = ( + SDOps("VAE_DECODER_COMFY_KEYS_FILTER") + .with_matching(prefix="vae.decoder.") + .with_matching(prefix="vae.per_channel_statistics.") + .with_replacement("vae.decoder.", "") + .with_replacement("vae.per_channel_statistics.", "per_channel_statistics.") +) + +VAE_ENCODER_COMFY_KEYS_FILTER = ( + SDOps("VAE_ENCODER_COMFY_KEYS_FILTER") + .with_matching(prefix="vae.encoder.") + .with_matching(prefix="vae.per_channel_statistics.") + .with_replacement("vae.encoder.", "") + .with_replacement("vae.per_channel_statistics.", "per_channel_statistics.") +) diff --git a/packages/ltx-core/src/ltx_core/model/video_vae/normalization.py b/packages/ltx-core/src/ltx_core/model/video_vae/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..ad907198c9b12d2b52784c18fe1cef1960938092 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/video_vae/normalization.py @@ -0,0 +1,6 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Ivan Zorin + +from ltx_core.model.common.normalization import PixelNorm, build_normalization_layer + +__all__ = ["PixelNorm", "build_normalization_layer"] diff --git a/packages/ltx-core/src/ltx_core/model/video_vae/ops.py b/packages/ltx-core/src/ltx_core/model/video_vae/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d664e0828bb58625de4d91dcc09bcaa12a71ff40 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/video_vae/ops.py @@ -0,0 +1,92 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Ivan Zorin + +import torch +from einops import rearrange +from torch import nn + + +def patchify(x: torch.Tensor, patch_size_hw: int, patch_size_t: int = 1) -> torch.Tensor: + """ + Rearrange spatial dimensions into channels. Divides image into patch_size x patch_size blocks + and moves pixels from each block into separate channels (space-to-depth). + + Args: + x: Input tensor (4D or 5D) + patch_size_hw: Spatial patch size for height and width. With patch_size_hw=4, divides HxW into 4x4 blocks. + patch_size_t: Temporal patch size for frames. Default=1 (no temporal patching). + + For 5D: (B, C, F, H, W) -> (B, Cx(patch_size_hw^2)x(patch_size_t), F/patch_size_t, H/patch_size_hw, W/patch_size_hw) + Example: (B, 3, 33, 512, 512) with patch_size_hw=4, patch_size_t=1 -> (B, 48, 33, 128, 128) + """ + if patch_size_hw == 1 and patch_size_t == 1: + return x + if x.dim() == 4: + x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw) + elif x.dim() == 5: + x = rearrange( + x, + "b c (f p) (h q) (w r) -> b (c p r q) f h w", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x: torch.Tensor, patch_size_hw: int, patch_size_t: int = 1) -> torch.Tensor: + """ + Rearrange channels back into spatial dimensions. Inverse of patchify - moves pixels from + channels back into patch_size x patch_size blocks (depth-to-space). + + Args: + x: Input tensor (4D or 5D) + patch_size_hw: Spatial patch size for height and width. With patch_size_hw=4, expands HxW by 4x. + patch_size_t: Temporal patch size for frames. Default=1 (no temporal expansion). + + For 5D: (B, Cx(patch_size_hw^2)x(patch_size_t), F, H, W) -> (B, C, Fxpatch_size_t, Hxpatch_size_hw, Wxpatch_size_hw) + Example: (B, 48, 33, 128, 128) with patch_size_hw=4, patch_size_t=1 -> (B, 3, 33, 512, 512) + """ + if patch_size_hw == 1 and patch_size_t == 1: + return x + + if x.dim() == 4: + x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw) + elif x.dim() == 5: + x = rearrange( + x, + "b (c p r q) f h w -> b c (f p) (h q) (w r)", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + + return x + + +class PerChannelStatistics(nn.Module): + """ + Per-channel statistics for normalizing and denormalizing the latent representation. + This statics is computed over the entire dataset and stored in model's checkpoint under VAE state_dict. + """ + + def __init__(self, latent_channels: int = 128): + super().__init__() + self.register_buffer("std-of-means", torch.empty(latent_channels)) + self.register_buffer("mean-of-means", torch.empty(latent_channels)) + self.register_buffer("mean-of-stds", torch.empty(latent_channels)) + self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(latent_channels)) + self.register_buffer("channel", torch.empty(latent_channels)) + + def un_normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view( + 1, -1, 1, 1, 1 + ).to(x) + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view( + 1, -1, 1, 1, 1 + ).to(x) diff --git a/packages/ltx-core/src/ltx_core/model/video_vae/resnet.py b/packages/ltx-core/src/ltx_core/model/video_vae/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9dfb757745fd6bc75cb006c64cbd1055c1d54304 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/video_vae/resnet.py @@ -0,0 +1,284 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Ivan Zorin + +from typing import Optional, Tuple, Union + +import torch +from torch import nn + +from ltx_core.model.common.normalization import PixelNorm +from ltx_core.model.transformer.timestep_embedding import PixArtAlphaCombinedTimestepSizeEmbeddings +from ltx_core.model.video_vae.convolution import make_conv_nd, make_linear_nd +from ltx_core.model.video_vae.enums import NormLayerType, PaddingModeType + + +class ResnetBlock3D(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + norm_layer: NormLayerType = NormLayerType.PIXEL_NORM, + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.inject_noise = inject_noise + + if norm_layer == NormLayerType.GROUP_NORM: + self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + elif norm_layer == NormLayerType.PIXEL_NORM: + self.norm1 = PixelNorm() + + self.non_linearity = nn.SiLU() + + self.conv1 = make_conv_nd( + dims, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + if inject_noise: + self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + + if norm_layer == NormLayerType.GROUP_NORM: + self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) + elif norm_layer == NormLayerType.PIXEL_NORM: + self.norm2 = PixelNorm() + + self.dropout = torch.nn.Dropout(dropout) + + self.conv2 = make_conv_nd( + dims, + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + if inject_noise: + self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + + self.conv_shortcut = ( + make_linear_nd(dims=dims, in_channels=in_channels, out_channels=out_channels) + if in_channels != out_channels + else nn.Identity() + ) + + # Using GroupNorm with 1 group is equivalent to LayerNorm but works with (B, C, ...) layout + # avoiding the need for dimension rearrangement used in standard nn.LayerNorm + self.norm3 = ( + nn.GroupNorm(num_groups=1, num_channels=in_channels, eps=eps, affine=True) + if in_channels != out_channels + else nn.Identity() + ) + + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5) + + def _feed_spatial_noise( + self, + hidden_states: torch.Tensor, + per_channel_scale: torch.Tensor, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + spatial_shape = hidden_states.shape[-2:] + device = hidden_states.device + dtype = hidden_states.dtype + + # similar to the "explicit noise inputs" method in style-gan + spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype, generator=generator)[None] + scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...] + hidden_states = hidden_states + scaled_noise + + return hidden_states + + def forward( + self, + input_tensor: torch.Tensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + hidden_states = input_tensor + batch_size = hidden_states.shape[0] + + hidden_states = self.norm1(hidden_states) + if self.timestep_conditioning: + if timestep is None: + raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True") + ada_values = self.scale_shift_table[None, ..., None, None, None].to( + device=hidden_states.device, dtype=hidden_states.dtype + ) + timestep.reshape( + batch_size, + 4, + -1, + timestep.shape[-3], + timestep.shape[-2], + timestep.shape[-1], + ) + shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1) + + hidden_states = hidden_states * (1 + scale1) + shift1 + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.conv1(hidden_states, causal=causal) + + if self.inject_noise: + hidden_states = self._feed_spatial_noise( + hidden_states, + self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype), + generator=generator, + ) + + hidden_states = self.norm2(hidden_states) + + if self.timestep_conditioning: + hidden_states = hidden_states * (1 + scale2) + shift2 + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + + hidden_states = self.conv2(hidden_states, causal=causal) + + if self.inject_noise: + hidden_states = self._feed_spatial_noise( + hidden_states, + self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype), + generator=generator, + ) + + input_tensor = self.norm3(input_tensor) + + batch_size = input_tensor.shape[0] + + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +class UNetMidBlock3D(nn.Module): + """ + A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. + + Args: + in_channels (`int`): The number of input channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + inject_noise (`bool`, *optional*, defaults to `False`): + Whether to inject noise into the hidden states. + timestep_conditioning (`bool`, *optional*, defaults to `False`): + Whether to condition the hidden states on the timestep. + + Returns: + `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + norm_layer: NormLayerType = NormLayerType.GROUP_NORM, + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings( + embedding_dim=in_channels * 4, size_emb_dim=0 + ) + + self.res_blocks = nn.ModuleList( + [ + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + timestep_embed = None + if self.timestep_conditioning: + if timestep is None: + raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True") + batch_size = hidden_states.shape[0] + timestep_embed = self.time_embedder( + timestep=timestep.flatten(), + hidden_dtype=hidden_states.dtype, + ) + timestep_embed = timestep_embed.view(batch_size, timestep_embed.shape[-1], 1, 1, 1) + + for resnet in self.res_blocks: + hidden_states = resnet( + hidden_states, + causal=causal, + timestep=timestep_embed, + generator=generator, + ) + + return hidden_states diff --git a/packages/ltx-core/src/ltx_core/model/video_vae/sampling.py b/packages/ltx-core/src/ltx_core/model/video_vae/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..c975710267853f3d75ef3ef83f0b1574400a1c07 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/video_vae/sampling.py @@ -0,0 +1,126 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Ivan Zorin + +import math +from typing import Tuple, Union + +import torch +from einops import rearrange +from torch import nn + +from .convolution import make_conv_nd +from .enums import PaddingModeType + + +class SpaceToDepthDownsample(nn.Module): + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + stride: Tuple[int, int, int], + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + self.stride = stride + self.group_size = in_channels * math.prod(stride) // out_channels + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=out_channels // math.prod(stride), + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + def forward( + self, + x: torch.Tensor, + causal: bool = True, + ) -> torch.Tensor: + if self.stride[0] == 2: + x = torch.cat([x[:, :, :1, :, :], x], dim=2) # duplicate first frames for padding + + # skip connection + x_in = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size) + x_in = x_in.mean(dim=2) + + # conv + x = self.conv(x, causal=causal) + x = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + + x = x + x_in + + return x + + +class DepthToSpaceUpsample(nn.Module): + def __init__( + self, + dims: int | Tuple[int, int], + in_channels: int, + stride: Tuple[int, int, int], + residual: bool = False, + out_channels_reduction_factor: int = 1, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + self.stride = stride + self.out_channels = math.prod(stride) * in_channels // out_channels_reduction_factor + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + self.residual = residual + self.out_channels_reduction_factor = out_channels_reduction_factor + + def forward( + self, + x: torch.Tensor, + causal: bool = True, + ) -> torch.Tensor: + if self.residual: + # Reshape and duplicate the input to match the output shape + x_in = rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor + x_in = x_in.repeat(1, num_repeat, 1, 1, 1) + if self.stride[0] == 2: + x_in = x_in[:, :, 1:, :, :] + x = self.conv(x, causal=causal) + x = rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + if self.stride[0] == 2: + x = x[:, :, 1:, :, :] + if self.residual: + x = x + x_in + return x diff --git a/packages/ltx-core/src/ltx_core/model/video_vae/video_vae.py b/packages/ltx-core/src/ltx_core/model/video_vae/video_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..b21ca226b6aed762f2095164a29289026c52dca5 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/model/video_vae/video_vae.py @@ -0,0 +1,848 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Ivan Zorin + +from typing import Any, Generator, List, Optional, Tuple + +import torch +from torch import nn + +from ltx_core.model.common.normalization import PixelNorm +from ltx_core.model.transformer.timestep_embedding import PixArtAlphaCombinedTimestepSizeEmbeddings +from ltx_core.model.video_vae.convolution import make_conv_nd +from ltx_core.model.video_vae.enums import LogVarianceType, NormLayerType, PaddingModeType +from ltx_core.model.video_vae.ops import PerChannelStatistics, patchify, unpatchify +from ltx_core.model.video_vae.resnet import ResnetBlock3D, UNetMidBlock3D +from ltx_core.model.video_vae.sampling import DepthToSpaceUpsample, SpaceToDepthDownsample +from ltx_core.pipeline.components.protocols import SpatioTemporalScaleFactors, VideoLatentShape +from ltx_core.tiling import ( + Tile, + TilingConfig, + compute_trapezoidal_mask_1d, + create_tiles_from_tile_sizes, +) + + +def _make_encoder_block( + block_name: str, + block_config: dict[str, Any], + in_channels: int, + convolution_dimensions: int, + norm_layer: NormLayerType, + norm_num_groups: int, + spatial_padding_mode: PaddingModeType, +) -> Tuple[nn.Module, int]: + out_channels = in_channels + + if block_name == "res_x": + block = UNetMidBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + num_layers=block_config["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "res_x_y": + out_channels = in_channels * block_config.get("multiplier", 2) + block = ResnetBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time": + block = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=(2, 1, 1), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space": + block = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=(1, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all": + block = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all_x_y": + out_channels = in_channels * block_config.get("multiplier", 2) + block = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all_res": + out_channels = in_channels * block_config.get("multiplier", 2) + block = SpaceToDepthDownsample( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space_res": + out_channels = in_channels * block_config.get("multiplier", 2) + block = SpaceToDepthDownsample( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time_res": + out_channels = in_channels * block_config.get("multiplier", 2) + block = SpaceToDepthDownsample( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unknown block: {block_name}") + + return block, out_channels + + +class Encoder(nn.Module): + _DEFAULT_NORM_NUM_GROUPS = 32 + """ + Variational Autoencoder Encoder. Encodes video frames into a latent representation. + + The encoder compresses the input video through a series of downsampling operations controlled by + patch_size and encoder_blocks. The output is a normalized latent tensor with shape (B, 128, F', H', W'). + + Compression Behavior: + The total compression is determined by: + 1. Initial spatial compression via patchify: H -> H/4, W -> W/4 (patch_size=4) + 2. Sequential compression through encoder_blocks based on their stride patterns + + Compression blocks apply 2x compression in specified dimensions: + - "compress_time" / "compress_time_res": temporal only + - "compress_space" / "compress_space_res": spatial only (H and W) + - "compress_all" / "compress_all_res": all dimensions (F, H, W) + - "res_x" / "res_x_y": no compression + + Standard LTX Video configuration: + - patch_size=4 + - encoder_blocks: 1x compress_space_res, 1x compress_time_res, 2x compress_all_res + - Final dimensions: F' = 1 + (F-1)/8, H' = H/32, W' = W/32 + - Example: (B, 3, 33, 512, 512) -> (B, 128, 5, 16, 16) + - Note: Input must have 1 + 8*k frames (e.g., 1, 9, 17, 25, 33...) + + Args: + convolution_dimensions: The number of dimensions to use in convolutions (2D or 3D). + in_channels: The number of input channels. For RGB images, this is 3. + out_channels: The number of output channels (latent channels). For latent channels, this is 128. + encoder_blocks: The list of blocks to construct the encoder. Each block is a tuple of (block_name, params) + where params is either an int (num_layers) or a dict with configuration. + patch_size: The patch size for initial spatial compression. Should be a power of 2. + norm_layer: The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + latent_log_var: The log variance mode. Can be either `per_channel`, `uniform`, `constant` or `none`. + """ + + def __init__( + self, + convolution_dimensions: int = 3, + in_channels: int = 3, + out_channels: int = 128, + encoder_blocks: List[Tuple[str, int]] | List[Tuple[str, dict[str, Any]]] = [], # noqa: B006 + patch_size: int = 4, + norm_layer: NormLayerType = NormLayerType.PIXEL_NORM, + latent_log_var: LogVarianceType = LogVarianceType.UNIFORM, + encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + + self.patch_size = patch_size + self.norm_layer = norm_layer + self.latent_channels = out_channels + self.latent_log_var = latent_log_var + self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS + + # Per-channel statistics for normalizing latents + self.per_channel_statistics = PerChannelStatistics(latent_channels=out_channels) + + in_channels = in_channels * patch_size**2 + feature_channels = out_channels + + self.conv_in = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=feature_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=encoder_spatial_padding_mode, + ) + + self.down_blocks = nn.ModuleList([]) + + for block_name, block_params in encoder_blocks: + # Convert int to dict format for uniform handling + block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params + + block, feature_channels = _make_encoder_block( + block_name=block_name, + block_config=block_config, + in_channels=feature_channels, + convolution_dimensions=convolution_dimensions, + norm_layer=norm_layer, + norm_num_groups=self._norm_num_groups, + spatial_padding_mode=encoder_spatial_padding_mode, + ) + + self.down_blocks.append(block) + + # out + if norm_layer == NormLayerType.GROUP_NORM: + self.conv_norm_out = nn.GroupNorm(num_channels=feature_channels, num_groups=self._norm_num_groups, eps=1e-6) + elif norm_layer == NormLayerType.PIXEL_NORM: + self.conv_norm_out = PixelNorm() + + self.conv_act = nn.SiLU() + + conv_out_channels = out_channels + if latent_log_var == LogVarianceType.PER_CHANNEL: + conv_out_channels *= 2 + elif latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}: + conv_out_channels += 1 + elif latent_log_var != LogVarianceType.NONE: + raise ValueError(f"Invalid latent_log_var: {latent_log_var}") + + self.conv_out = make_conv_nd( + dims=convolution_dimensions, + in_channels=feature_channels, + out_channels=conv_out_channels, + kernel_size=3, + padding=1, + causal=True, + spatial_padding_mode=encoder_spatial_padding_mode, + ) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + r""" + Encode video frames into normalized latent representation. + + Args: + sample: Input video (B, C, F, H, W). F must be 1 + 8*k (e.g., 1, 9, 17, 25, 33...). + + Returns: + Normalized latent means (B, 128, F', H', W') where F' = 1+(F-1)/8, H' = H/32, W' = W/32. + Example: (B, 3, 33, 512, 512) -> (B, 128, 5, 16, 16). + """ + # Validate frame count + frames_count = sample.shape[2] + if ((frames_count - 1) % 8) != 0: + raise ValueError( + "Invalid number of frames: Encode input must have 1 + 8 * x frames " + "(e.g., 1, 9, 17, ...). Please check your input." + ) + + # Initial spatial compression: trade spatial resolution for channel depth + # This reduces H,W by patch_size and increases channels, making convolutions more efficient + # Example: (B, 3, F, 512, 512) -> (B, 48, F, 128, 128) with patch_size=4 + sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + sample = self.conv_in(sample) + + for down_block in self.down_blocks: + sample = down_block(sample) + + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if self.latent_log_var == LogVarianceType.UNIFORM: + # Uniform Variance: model outputs N means and 1 shared log-variance channel. + # We need to expand the single logvar to match the number of means channels + # to create a format compatible with PER_CHANNEL (means + logvar, each with N channels). + # Sample shape: (B, N+1, ...) where N = latent_channels (e.g., 128 means + 1 logvar = 129) + # Target shape: (B, 2*N, ...) where first N are means, last N are logvar + + if sample.shape[1] < 2: + raise ValueError( + f"Invalid channel count for UNIFORM mode: expected at least 2 channels " + f"(N means + 1 logvar), got {sample.shape[1]}" + ) + + # Extract means (first N channels) and logvar (last 1 channel) + means = sample[:, :-1, ...] # (B, N, ...) + logvar = sample[:, -1:, ...] # (B, 1, ...) + + # Repeat logvar N times to match means channels + # Use expand/repeat pattern that works for both 4D and 5D tensors + num_channels = means.shape[1] + repeat_shape = [1, num_channels] + [1] * (sample.ndim - 2) + repeated_logvar = logvar.repeat(*repeat_shape) # (B, N, ...) + + # Concatenate to create (B, 2*N, ...) format: [means, repeated_logvar] + sample = torch.cat([means, repeated_logvar], dim=1) + elif self.latent_log_var == LogVarianceType.CONSTANT: + sample = sample[:, :-1, ...] + approx_ln_0 = -30 # this is the minimal clamp value in DiagonalGaussianDistribution objects + sample = torch.cat( + [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0], + dim=1, + ) + + # Split into means and logvar, then normalize means + means, _ = torch.chunk(sample, 2, dim=1) + return self.per_channel_statistics.normalize(means) + + +def _make_decoder_block( + block_name: str, + block_config: dict[str, Any], + in_channels: int, + convolution_dimensions: int, + norm_layer: NormLayerType, + timestep_conditioning: bool, + norm_num_groups: int, + spatial_padding_mode: PaddingModeType, +) -> Tuple[nn.Module, int]: + out_channels = in_channels + if block_name == "res_x": + block = UNetMidBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + num_layers=block_config["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_config.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "attn_res_x": + block = UNetMidBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + num_layers=block_config["num_layers"], + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_config.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + attention_head_dim=block_config["attention_head_dim"], + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "res_x_y": + out_channels = in_channels // block_config.get("multiplier", 2) + block = ResnetBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_config.get("inject_noise", False), + timestep_conditioning=False, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time": + block = DepthToSpaceUpsample( + dims=convolution_dimensions, + in_channels=in_channels, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space": + block = DepthToSpaceUpsample( + dims=convolution_dimensions, + in_channels=in_channels, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all": + out_channels = in_channels // block_config.get("multiplier", 1) + block = DepthToSpaceUpsample( + dims=convolution_dimensions, + in_channels=in_channels, + stride=(2, 2, 2), + residual=block_config.get("residual", False), + out_channels_reduction_factor=block_config.get("multiplier", 1), + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unknown layer: {block_name}") + + return block, out_channels + + +class Decoder(nn.Module): + _DEFAULT_NORM_NUM_GROUPS = 32 + """ + Variational Autoencoder Decoder. Decodes latent representation into video frames. + + The decoder upsamples latents through a series of upsampling operations (inverse of encoder). + Output dimensions: F = 8x(F'-1) + 1, H = 32xH', W = 32xW' for standard LTX Video configuration. + + Upsampling blocks expand dimensions by 2x in specified dimensions: + - "compress_time": temporal only + - "compress_space": spatial only (H and W) + - "compress_all": all dimensions (F, H, W) + - "res_x" / "res_x_y" / "attn_res_x": no upsampling + + Causal Mode: + causal=False (standard): Symmetric padding, allows future frame dependencies. + causal=True: Causal padding, each frame depends only on past/current frames. + First frame removed after temporal upsampling in both modes. Output shape unchanged. + Example: (B, 128, 5, 16, 16) -> (B, 3, 33, 512, 512) for both modes. + + Args: + convolution_dimensions: The number of dimensions to use in convolutions (2D or 3D). + in_channels: The number of input channels (latent channels). Default is 128. + out_channels: The number of output channels. For RGB images, this is 3. + decoder_blocks: The list of blocks to construct the decoder. Each block is a tuple of (block_name, params) + where params is either an int (num_layers) or a dict with configuration. + patch_size: Final spatial expansion factor. For standard LTX Video, use 4 for 4x spatial expansion: + H -> Hx4, W -> Wx4. Should be a power of 2. + norm_layer: The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + causal: Whether to use causal convolutions. For standard LTX Video, use False for symmetric padding. + When True, uses causal padding (past/current frames only). + timestep_conditioning: Whether to condition the decoder on timestep for denoising. + """ + + def __init__( + self, + convolution_dimensions: int = 3, + in_channels: int = 128, + out_channels: int = 3, + decoder_blocks: List[Tuple[str, int | dict]] = [], # noqa: B006 + patch_size: int = 4, + norm_layer: NormLayerType = NormLayerType.PIXEL_NORM, + causal: bool = False, + timestep_conditioning: bool = False, + decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT, + ): + super().__init__() + + # Spatiotemporal downscaling between decoded video space and VAE latents. + # According to the LTXV paper, the standard configuration downsamples + # video inputs by a factor of 8 in the temporal dimension and 32 in + # each spatial dimension (height and width). This parameter determines how + # many video frames and pixels correspond to a single latent cell. + self.video_downscale_factors = SpatioTemporalScaleFactors( + time=8, + width=32, + height=32, + ) + + self.patch_size = patch_size + out_channels = out_channels * patch_size**2 + self.causal = causal + self.timestep_conditioning = timestep_conditioning + self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS + + # Per-channel statistics for denormalizing latents + self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels) + + # Noise and timestep parameters for decoder conditioning + self.decode_noise_scale = 0.025 + self.decode_timestep = 0.05 + + # Compute initial feature_channels by going through blocks in reverse + # This determines the channel width at the start of the decoder + feature_channels = in_channels + for block_name, block_params in list(reversed(decoder_blocks)): + block_config = block_params if isinstance(block_params, dict) else {} + if block_name == "res_x_y": + feature_channels = feature_channels * block_config.get("multiplier", 2) + if block_name == "compress_all": + feature_channels = feature_channels * block_config.get("multiplier", 1) + + self.conv_in = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=feature_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=decoder_spatial_padding_mode, + ) + + self.up_blocks = nn.ModuleList([]) + + for block_name, block_params in list(reversed(decoder_blocks)): + # Convert int to dict format for uniform handling + block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params + + block, feature_channels = _make_decoder_block( + block_name=block_name, + block_config=block_config, + in_channels=feature_channels, + convolution_dimensions=convolution_dimensions, + norm_layer=norm_layer, + timestep_conditioning=timestep_conditioning, + norm_num_groups=self._norm_num_groups, + spatial_padding_mode=decoder_spatial_padding_mode, + ) + + self.up_blocks.append(block) + + if norm_layer == NormLayerType.GROUP_NORM: + self.conv_norm_out = nn.GroupNorm(num_channels=feature_channels, num_groups=self._norm_num_groups, eps=1e-6) + elif norm_layer == NormLayerType.PIXEL_NORM: + self.conv_norm_out = PixelNorm() + + self.conv_act = nn.SiLU() + self.conv_out = make_conv_nd( + dims=convolution_dimensions, + in_channels=feature_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + causal=True, + spatial_padding_mode=decoder_spatial_padding_mode, + ) + + if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0)) + self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings( + embedding_dim=feature_channels * 2, size_emb_dim=0 + ) + self.last_scale_shift_table = nn.Parameter(torch.empty(2, feature_channels)) + + # def forward(self, sample: torch.Tensor, target_shape) -> torch.Tensor: + def forward( + self, + sample: torch.Tensor, + timestep: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + r""" + Decode latent representation into video frames. + + Args: + sample: Latent tensor (B, 128, F', H', W'). + timestep: Timestep for conditioning (if timestep_conditioning=True). Uses default 0.05 if None. + generator: Random generator for deterministic noise injection (if inject_noise=True in blocks). + + Returns: + Decoded video (B, 3, F, H, W) where F = 8x(F'-1) + 1, H = 32xH', W = 32xW'. + Example: (B, 128, 5, 16, 16) -> (B, 3, 33, 512, 512). + Note: First frame is removed after temporal upsampling regardless of causal mode. + When causal=False, allows future frame dependencies in convolutions but maintains same output shape. + """ + batch_size = sample.shape[0] + + # Add noise if timestep conditioning is enabled + if self.timestep_conditioning: + noise = ( + torch.randn( + sample.size(), + generator=generator, + dtype=sample.dtype, + device=sample.device, + ) + * self.decode_noise_scale + ) + + sample = noise + (1.0 - self.decode_noise_scale) * sample + + # Denormalize latents + sample = self.per_channel_statistics.un_normalize(sample) + + # Use default decode_timestep if timestep not provided + if timestep is None and self.timestep_conditioning: + timestep = torch.full((batch_size,), self.decode_timestep, device=sample.device, dtype=sample.dtype) + + sample = self.conv_in(sample, causal=self.causal) + + scaled_timestep = None + if self.timestep_conditioning: + if timestep is None: + raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True") + scaled_timestep = timestep * self.timestep_scale_multiplier.to(sample) + + for up_block in self.up_blocks: + if isinstance(up_block, UNetMidBlock3D): + block_kwargs = { + "causal": self.causal, + "timestep": scaled_timestep if self.timestep_conditioning else None, + "generator": generator, + } + sample = up_block(sample, **block_kwargs) + elif isinstance(up_block, ResnetBlock3D): + sample = up_block(sample, causal=self.causal, generator=generator) + else: + sample = up_block(sample, causal=self.causal) + + sample = self.conv_norm_out(sample) + + if self.timestep_conditioning: + embedded_timestep = self.last_time_embedder( + timestep=scaled_timestep.flatten(), + hidden_dtype=sample.dtype, + ) + embedded_timestep = embedded_timestep.view(batch_size, embedded_timestep.shape[-1], 1, 1, 1) + ada_values = self.last_scale_shift_table[None, ..., None, None, None].to( + device=sample.device, dtype=sample.dtype + ) + embedded_timestep.reshape( + batch_size, + 2, + -1, + embedded_timestep.shape[-3], + embedded_timestep.shape[-2], + embedded_timestep.shape[-1], + ) + shift, scale = ada_values.unbind(dim=1) + sample = sample * (1 + scale) + shift + + sample = self.conv_act(sample) + sample = self.conv_out(sample, causal=self.causal) + + # Final spatial expansion: reverse the initial patchify from encoder + # Moves pixels from channels back to spatial dimensions + # Example: (B, 48, F, 128, 128) -> (B, 3, F, 512, 512) with patch_size=4 + sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + + return sample + + def map_temporal_slice(self, begin: int, end: int, left_ramp: int, right_ramp: int) -> Tuple[slice, torch.Tensor]: + scale = self.video_downscale_factors.time + start = begin * scale + stop = 1 + (end - 1) * scale + left_ramp = 1 + (left_ramp - 1) * scale + right_ramp = right_ramp * scale + + return slice(start, stop), compute_trapezoidal_mask_1d(stop - start, left_ramp, right_ramp, True) + + def map_spatial_slice(self, begin: int, end: int, left_ramp: int, right_ramp: int) -> Tuple[slice, torch.Tensor]: + scale = self.video_downscale_factors.height + start = begin * scale + stop = end * scale + left_ramp = left_ramp * scale + right_ramp = right_ramp * scale + + return slice(start, stop), compute_trapezoidal_mask_1d(stop - start, left_ramp, right_ramp, False) + + def _prepare_tiles( + self, + latent: torch.Tensor, + tiling_config: TilingConfig | None = None, + ) -> List[Tile]: + spatial_axes_indices = () + spatial_tile_size = 1 + spatial_overlap = 0 + temporal_axes_indices = () + temporal_tile_size = 1 + temporal_overlap = 0 + if tiling_config is not None: + spatial_config = tiling_config.spatial_config + if spatial_config is not None: + spatial_axes_indices = (3, 4) + spatial_tile_size = spatial_config.tile_size_in_pixels // self.video_downscale_factors.width + spatial_overlap = spatial_config.tile_overlap_in_pixels // self.video_downscale_factors.width + + temporal_config = tiling_config.temporal_config + if temporal_config is not None: + temporal_axes_indices = (2,) + temporal_tile_size = temporal_config.tile_size_in_frames // self.video_downscale_factors.time + temporal_overlap = temporal_config.tile_overlap_in_frames // self.video_downscale_factors.time + + return create_tiles_from_tile_sizes( + self, + latent.shape, + spatial_tile_size, + temporal_tile_size, + spatial_overlap, + temporal_overlap, + spatial_axes_indices, + temporal_axes_indices, + ) + + def tiled_decode( + self, + latent: torch.Tensor, + tiling_config: TilingConfig | None = None, + timestep: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> Generator[tuple[torch.Tensor, int], None, None]: + """ + Decode a latent tensor into video frames using tiled processing. + + Splits the latent tensor into tiles, decodes each tile individually, + and yields video chunks as they become available. + + Args: + latent: Input latent tensor (B, C, F', H', W'). + tiling_config: Tiling configuration for the latent tensor. + timestep: Optional timestep for decoder conditioning. + generator: Optional random generator for deterministic decoding. + + Yields: + Video chunks (B, C, T, H, W) by temporal slices; + Total number of chunks. + """ + + # Calculate full video shape from latent shape to get spatial dimensions + full_video_shape = VideoLatentShape.from_torch_shape(latent.shape).upscale(self.video_downscale_factors) + tiles = self._prepare_tiles(latent, tiling_config) + + temporal_groups = self._group_tiles_by_temporal_slice(tiles) + total_number_of_chunks = len(temporal_groups) + + # State for temporal overlap handling + previous_chunk = None + previous_weights = None + previous_temporal_slice = None + + for temporal_group_tiles in temporal_groups: + curr_temporal_slice = temporal_group_tiles[0].out_coords[2] + + # Calculate the shape of the temporal buffer for this group of tiles. + # The temporal length depends on whether this is the first tile (starts at 0) or not. + # - First tile: (frames - 1) * scale + 1 + # - Subsequent tiles: frames * scale + # This logic is handled by TemporalAxisMapping and reflected in out_coords. + temporal_tile_buffer_shape = full_video_shape._replace( + frames=curr_temporal_slice.stop - curr_temporal_slice.start, + ) + + buffer = torch.zeros( + temporal_tile_buffer_shape.to_torch_shape(), + device=latent.device, + dtype=latent.dtype, + ) + + curr_weights = self._accumulate_temporal_group_into_buffer( + group_tiles=temporal_group_tiles, + buffer=buffer, + latent=latent, + timestep=timestep, + generator=generator, + ) + + # Blend with previous temporal chunk if it exists + if previous_chunk is not None: + # Check if current temporal slice overlaps with previous temporal slice + if previous_temporal_slice.stop > curr_temporal_slice.start: + overlap_len = previous_temporal_slice.stop - curr_temporal_slice.start + temporal_overlap_slice = slice(curr_temporal_slice.start - previous_temporal_slice.start, None) + + # The overlap is already masked before it reaches this step. Each tile is accumulated into buffer + # with its trapezoidal mask, and curr_weights accumulates the same mask. In the overlap blend we add + # the masked values (buffer[...]) and the corresponding weights (curr_weights[...]) into the + # previous buffers, then later normalize by weights. + previous_chunk[:, :, temporal_overlap_slice, :, :] += buffer[:, :, slice(0, overlap_len), :, :] + previous_weights[:, :, temporal_overlap_slice, :, :] += curr_weights[ + :, :, slice(0, overlap_len), :, : + ] + + buffer[:, :, slice(0, overlap_len), :, :] = previous_chunk[:, :, temporal_overlap_slice, :, :] + curr_weights[:, :, slice(0, overlap_len), :, :] = previous_weights[ + :, :, temporal_overlap_slice, :, : + ] + + # Yield the non-overlapping part of the previous chunk + previous_weights = previous_weights.clamp(min=1e-8) + yield_len = curr_temporal_slice.start - previous_temporal_slice.start + yield (previous_chunk / previous_weights)[:, :, :yield_len, :, :], total_number_of_chunks + + # Update state for next iteration + previous_chunk = buffer + previous_weights = curr_weights + previous_temporal_slice = curr_temporal_slice + + # Yield any remaining chunk + if previous_chunk is not None: + previous_weights = previous_weights.clamp(min=1e-8) + yield previous_chunk / previous_weights, total_number_of_chunks + + def _group_tiles_by_temporal_slice(self, tiles: List[Tile]) -> List[List[Tile]]: + """Group tiles by their temporal output slice.""" + if not tiles: + return [] + + groups = [] + current_slice = tiles[0].out_coords[2] + current_group = [] + + for tile in tiles: + tile_slice = tile.out_coords[2] + if tile_slice == current_slice: + current_group.append(tile) + else: + groups.append(current_group) + current_slice = tile_slice + current_group = [tile] + + # Add the final group + if current_group: + groups.append(current_group) + + return groups + + def _accumulate_temporal_group_into_buffer( + self, + group_tiles: List[Tile], + buffer: torch.Tensor, + latent: torch.Tensor, + timestep: Optional[torch.Tensor], + generator: Optional[torch.Generator], + ) -> torch.Tensor: + """ + Decode and accumulate all tiles of a temporal group into a local buffer. + + The buffer is local to the group and always starts at time 0; temporal coordinates + are rebased by subtracting temporal_slice.start. + """ + temporal_slice = group_tiles[0].out_coords[2] + + weights = torch.zeros_like(buffer) + + for tile in group_tiles: + decoded_tile = self.forward(latent[tile.in_coords], timestep, generator) + mask = tile.blend_mask.to(device=buffer.device, dtype=buffer.dtype) + # LT_INTERNAL: useless, always zero! + temporal_offset = tile.out_coords[2].start - temporal_slice.start + # Use the tile's output coordinate length, not the decoded tile's length, + # as the decoder may produce a different number of frames than expected + expected_temporal_len = tile.out_coords[2].stop - tile.out_coords[2].start + decoded_temporal_len = decoded_tile.shape[2] + + # Ensure we don't exceed the buffer or decoded tile bounds + actual_temporal_len = min(expected_temporal_len, decoded_temporal_len, buffer.shape[2] - temporal_offset) + + chunk_coords = ( + slice(None), # batch + slice(None), # channels + slice(temporal_offset, temporal_offset + actual_temporal_len), + tile.out_coords[3], # height + tile.out_coords[4], # width + ) + + # Slice decoded_tile and mask to match the actual length we're writing + decoded_slice = decoded_tile[:, :, :actual_temporal_len, :, :] + mask_slice = mask[:, :, :actual_temporal_len, :, :] if mask.shape[2] > 1 else mask + + buffer[chunk_coords] += decoded_slice * mask_slice + weights[chunk_coords] += mask_slice + + return weights diff --git a/packages/ltx-core/src/ltx_core/pipeline/__init__.py b/packages/ltx-core/src/ltx_core/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/packages/ltx-core/src/ltx_core/pipeline/__pycache__/__init__.cpython-310.pyc b/packages/ltx-core/src/ltx_core/pipeline/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adeeee9b0c436418e34cb208df6374c7c5b0c8f8 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/pipeline/__pycache__/__init__.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/pipeline/components/__init__.py b/packages/ltx-core/src/ltx_core/pipeline/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/packages/ltx-core/src/ltx_core/pipeline/components/__pycache__/__init__.cpython-310.pyc b/packages/ltx-core/src/ltx_core/pipeline/components/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07d8962dcfba2460e1b14475e786b3ad4e383e0a Binary files /dev/null and b/packages/ltx-core/src/ltx_core/pipeline/components/__pycache__/__init__.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/pipeline/components/__pycache__/diffusion_steps.cpython-310.pyc b/packages/ltx-core/src/ltx_core/pipeline/components/__pycache__/diffusion_steps.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a491c192528c2ed92827db754f082b0666d71f47 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/pipeline/components/__pycache__/diffusion_steps.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/pipeline/components/__pycache__/guiders.cpython-310.pyc b/packages/ltx-core/src/ltx_core/pipeline/components/__pycache__/guiders.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb3c3ac6b8e5b7172a889d7ba6e34d6825f9100e Binary files /dev/null and b/packages/ltx-core/src/ltx_core/pipeline/components/__pycache__/guiders.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/pipeline/components/__pycache__/noisers.cpython-310.pyc b/packages/ltx-core/src/ltx_core/pipeline/components/__pycache__/noisers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93223b76fb86b17f03a604273798c796598239b3 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/pipeline/components/__pycache__/noisers.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/pipeline/components/__pycache__/patchifiers.cpython-310.pyc b/packages/ltx-core/src/ltx_core/pipeline/components/__pycache__/patchifiers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..203fda05acb9aed32257bf3e71e841a6e1740de7 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/pipeline/components/__pycache__/patchifiers.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/pipeline/components/__pycache__/protocols.cpython-310.pyc b/packages/ltx-core/src/ltx_core/pipeline/components/__pycache__/protocols.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0c79cb4f89ad413c5b0286b94e5d593dba90710 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/pipeline/components/__pycache__/protocols.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/pipeline/components/__pycache__/schedulers.cpython-310.pyc b/packages/ltx-core/src/ltx_core/pipeline/components/__pycache__/schedulers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b0274a9d1d4aca51f0267ecdaf58d013bbd2091 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/pipeline/components/__pycache__/schedulers.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/pipeline/components/diffusion_steps.py b/packages/ltx-core/src/ltx_core/pipeline/components/diffusion_steps.py new file mode 100644 index 0000000000000000000000000000000000000000..c295575c59cb1aec66c2d3f4f34cb75261c62374 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/pipeline/components/diffusion_steps.py @@ -0,0 +1,19 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Amit Pintz. + +import torch + +from ltx_core.pipeline.components.protocols import DiffusionStepProtocol +from ltx_core.utils import to_velocity + + +class EulerDiffusionStep(DiffusionStepProtocol): + def step( + self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int + ) -> torch.Tensor: + sigma = sigmas[step_index] + sigma_next = sigmas[step_index + 1] + dt = sigma_next - sigma + velocity = to_velocity(sample, sigma, denoised_sample) + + return (sample.to(torch.float32) + velocity.to(torch.float32) * dt).to(sample.dtype) diff --git a/packages/ltx-core/src/ltx_core/pipeline/components/guiders.py b/packages/ltx-core/src/ltx_core/pipeline/components/guiders.py new file mode 100644 index 0000000000000000000000000000000000000000..3c14ceee32f233b1eb75215a7ba614bf950c955f --- /dev/null +++ b/packages/ltx-core/src/ltx_core/pipeline/components/guiders.py @@ -0,0 +1,230 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Amit Pintz. + +from dataclasses import dataclass + +import torch + +from ltx_core.pipeline.components.protocols import GuiderProtocol + + +@dataclass(frozen=True) +class CFGGuider(GuiderProtocol): + scale: float + + def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor: + return (self.scale - 1) * (cond - uncond) + + def enabled(self) -> bool: + return self.scale != 1.0 + + +@dataclass(frozen=True) +class CFGStarRescalingGuider(GuiderProtocol): + """ + Calculates the CFG delta between conditioned and unconditioned samples. + + To minimize offset in the denoising direction and move mostly along the + conditioning axis within the distribution, the unconditioned sample is + rescaled in accordance with the norm of the conditioned sample. + + Attributes: + scale (float): + Global guidance strength. A value of 1.0 corresponds to no extra + guidance beyond the base model prediction. Values > 1.0 increase + the influence of the conditioned sample relative to the + unconditioned one. + """ + + scale: float + + def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor: + # LT_INTERNAL_BEGIN + # https://github.com/LightricksResearch/ComfyUI-LTXVideo-Internal/blob/96de7501458e6d83ec849e9feb62dd0cb33a6d14/stg.py#L375-L385 + # LT_INTERNAL_END + rescaled_neg = projection_coef(cond, uncond) * uncond + return (self.scale - 1) * (cond - rescaled_neg) + + def enabled(self) -> bool: + return self.scale != 1.0 + + +@dataclass(frozen=True) +class STGGuider(GuiderProtocol): + """ + Calculates the STG delta between conditioned and perturbed denoised samples. + + Perturbed samples are the result of the denoising process with perturbations, + e.g. attentions acting as passthrough for certain layers and modalities. + + Attributes: + scale (float): + Global strength of the STG guidance. A value of 0.0 disables the + guidance. Larger values increase the correction applied in the + direction of (pos_denoised - perturbed_denoised). + """ + + scale: float + + def delta(self, pos_denoised: torch.Tensor, perturbed_denoised: torch.Tensor) -> torch.Tensor: + return self.scale * (pos_denoised - perturbed_denoised) + + def enabled(self) -> bool: + return self.scale != 0.0 + + +@dataclass(frozen=True) +class LtxAPGGuider(GuiderProtocol): + """ + Calculates the APG (adaptive projected guidance) delta between conditioned + and unconditioned samples. + + To minimize offset in the denoising direction and move mostly along the + conditioning axis within the distribution, the (cond - uncond) delta is + decomposed into components parallel and orthogonal to the conditioned + sample. The `eta` parameter weights the parallel component, while `scale` + is applied to the orthogonal component. Optionally, a norm threshold can + be used to suppress guidance when the magnitude of the correction is small. + + Attributes: + scale (float): + Strength applied to the component of the guidance that is orthogonal + to the conditioned sample. Controls how aggressively we move in + directions that change semantics but stay consistent with the + conditioning manifold. + + eta (float): + Weight of the component of the guidance that is parallel to the + conditioned sample. A value of 1.0 keeps the full parallel + component; values in [0, 1] attenuate it, and values > 1.0 amplify + motion along the conditioning direction. + + norm_threshold (float): + Minimum L2 norm of the guidance delta below which the guidance + can be reduced or ignored (depending on implementation). + This is useful for avoiding noisy or unstable updates when the + guidance signal is very small. + """ + + scale: float + eta: float = 1.0 + norm_threshold: float = 0.0 + + def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor: + """ + LT_INTERNAL_BEGIN + according to + https://github.com/LightricksResearch/ComfyUI-LTXVideo-Internal/blob/e214b4dff8fc0647d45f629ddd50f75ac7beb6d4/stg.py#L55-L71 + LT_INTERNAL_END + """ + guidance = cond - uncond + if self.norm_threshold > 0: + ones = torch.ones_like(guidance) + guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True) + scale_factor = torch.minimum(ones, self.norm_threshold / guidance_norm) + guidance = guidance * scale_factor + proj_coeff = projection_coef(guidance, cond) + g_parallel = proj_coeff * cond + g_orth = guidance - g_parallel + g_apg = g_parallel * self.eta + g_orth + + return g_apg * (self.scale - 1) + + def enabled(self) -> bool: + return self.scale != 1.0 + + +@dataclass(frozen=False) +class LegacyStatefulAPGGuider(GuiderProtocol): + """ + Calculates the APG (adaptive projected guidance) delta between conditioned + and unconditioned samples. + + LT_INTERNAL_BEGIN + according to + https://github.com/LightricksResearch/ComfyUI/blob/f897754ce318b4184e8e4f3807a7c76293947738/comfy_extras/nodes_apg.py#L13 + in comfy APG replaces CFG by substituting cond with a modified one + it is users responsibility to use only one of them, not both in the same pipeline + LT_INTERNAL_END + + To minimize offset in the denoising direction and move mostly along the + conditioning axis within the distribution, the (cond - uncond) delta is + decomposed into components parallel and orthogonal to the conditioned + sample. The `eta` parameter weights the parallel component, while `scale` + is applied to the orthogonal component. Optionally, a norm threshold can + be used to suppress guidance when the magnitude of the correction is small. + + Attributes: + scale (float): + Strength applied to the component of the guidance that is orthogonal + to the conditioned sample. Controls how aggressively we move in + directions that change semantics but stay consistent with the + conditioning manifold. + + eta (float): + Weight of the component of the guidance that is parallel to the + conditioned sample. A value of 1.0 keeps the full parallel + component; values in [0, 1] attenuate it, and values > 1.0 amplify + motion along the conditioning direction. + + norm_threshold (float): + Minimum L2 norm of the guidance delta below which the guidance + can be reduced or ignored (depending on implementation). + This is useful for avoiding noisy or unstable updates when the + guidance signal is very small. + + momentum (float): + Exponential moving-average coefficient for accumulating guidance + over time. running_avg = momentum * running_avg + guidance + """ + + scale: float + eta: float + norm_threshold: float = 5.0 + momentum: float = 0.0 + # it is user's responsibility not to use same APGGuider for several denoisings or different modalities + # in order not to share accumulated average across different denoisings or modalities + running_avg: torch.Tensor | None = None + + def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor: + """ + LT_INTERNAL_BEGIN + combining https://github.com/LightricksResearch/ComfyUI/blob/f897754ce318b4184e8e4f3807a7c76293947738/comfy_extras/nodes_apg.py#L89 + with https://github.com/LightricksResearch/ComfyUI/blob/f897754ce318b4184e8e4f3807a7c76293947738/comfy/samplers.py#L359 + we get cfg_result = cond + modified_guidance * cond_scale + our convention is for guiders to return the delta which is to be added to the cond, + so we return g_apg * self.guidance_scale + LT_INTERNAL_END + """ + guidance = cond - uncond + if self.momentum != 0: + if self.running_avg is None: + self.running_avg = guidance.clone() + else: + self.running_avg = self.momentum * self.running_avg + guidance + guidance = self.running_avg + + if self.norm_threshold > 0: + ones = torch.ones_like(guidance) + guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True) + scale_factor = torch.minimum(ones, self.norm_threshold / guidance_norm) + guidance = guidance * scale_factor + + proj_coeff = projection_coef(guidance, cond) + g_parallel = proj_coeff * cond + g_orth = guidance - g_parallel + g_apg = g_parallel * self.eta + g_orth + + return g_apg * self.scale + + def enabled(self) -> bool: + return self.scale != 0.0 + + +def projection_coef(to_project: torch.Tensor, project_onto: torch.Tensor) -> torch.Tensor: + batch_size = to_project.shape[0] + positive_flat = to_project.reshape(batch_size, -1) + negative_flat = project_onto.reshape(batch_size, -1) + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8 + return dot_product / squared_norm diff --git a/packages/ltx-core/src/ltx_core/pipeline/components/noisers.py b/packages/ltx-core/src/ltx_core/pipeline/components/noisers.py new file mode 100644 index 0000000000000000000000000000000000000000..897d702f23013e39b25ff6dcd17865f1ca7b196b --- /dev/null +++ b/packages/ltx-core/src/ltx_core/pipeline/components/noisers.py @@ -0,0 +1,31 @@ +from dataclasses import replace +from typing import Protocol + +import torch + +from ltx_core.pipeline.conditioning.tools import LatentState + + +class Noiser(Protocol): + def __call__(self, latent_state: LatentState, noise_scale: float) -> LatentState: ... + + +class GaussianNoiser(Noiser): + def __init__(self, generator: torch.Generator): + super().__init__() + + self.generator = generator + + def __call__(self, latent_state: LatentState, noise_scale: float = 1.0) -> LatentState: + noise = torch.randn( + *latent_state.latent.shape, + device=latent_state.latent.device, + dtype=latent_state.latent.dtype, + generator=self.generator, + ) + scaled_mask = latent_state.denoise_mask * noise_scale + latent = noise * scaled_mask + latent_state.latent * (1 - scaled_mask) + return replace( + latent_state, + latent=latent.to(latent_state.latent.dtype), + ) diff --git a/packages/ltx-core/src/ltx_core/pipeline/components/patchifiers.py b/packages/ltx-core/src/ltx_core/pipeline/components/patchifiers.py new file mode 100644 index 0000000000000000000000000000000000000000..752fa957dfa1d6a6e2167fca92ebdf8fe8d6ffe2 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/pipeline/components/patchifiers.py @@ -0,0 +1,359 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Ivan Zorin + +import math +from typing import Optional, Tuple + +import einops +import torch + +from ltx_core.pipeline.components.protocols import AudioLatentShape, Patchifier, VideoLatentShape + + +class VideoLatentPatchifier(Patchifier): + def __init__(self, patch_size: int): + # Patch sizes for video latents. + self._patch_size = ( + 1, # temporal dimension + patch_size, # height dimension + patch_size, # width dimension + ) + + @property + def patch_size(self) -> Tuple[int, int, int]: + return self._patch_size + + def get_token_count(self, tgt_shape: VideoLatentShape) -> int: + return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size) + + def patchify( + self, + latents: torch.Tensor, + ) -> torch.Tensor: + latents = einops.rearrange( + latents, + "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", + p1=self._patch_size[0], + p2=self._patch_size[1], + p3=self._patch_size[2], + ) + + return latents + + def unpatchify( + self, + latents: torch.Tensor, + output_shape: VideoLatentShape, + ) -> torch.Tensor: + assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier" + + patch_grid_frames = output_shape.frames // self._patch_size[0] + patch_grid_height = output_shape.height // self._patch_size[1] + patch_grid_width = output_shape.width // self._patch_size[2] + + latents = einops.rearrange( + latents, + "b (f h w) (c p q) -> b c f (h p) (w q)", + f=patch_grid_frames, + h=patch_grid_height, + w=patch_grid_width, + p=self._patch_size[1], + q=self._patch_size[2], + ) + + return latents + + def get_patch_grid_bounds( + self, + output_shape: AudioLatentShape | VideoLatentShape, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """ + Return the per-dimension bounds [inclusive start, exclusive end) for every + patch produced by `patchify`. The bounds are expressed in the original + video grid coordinates: frame/time, height, and width. + + The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where: + - axis 1 (size 3) enumerates (frame/time, height, width) dimensions + - axis 3 (size 2) stores `[start, end)` indices within each dimension + + Args: + output_shape: Video grid description containing frames, height, and width. + device: Device of the latent tensor. + """ + if not isinstance(output_shape, VideoLatentShape): + raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates") + + frames = output_shape.frames + height = output_shape.height + width = output_shape.width + batch_size = output_shape.batch + + # Validate inputs to ensure positive dimensions + assert frames > 0, f"frames must be positive, got {frames}" + assert height > 0, f"height must be positive, got {height}" + assert width > 0, f"width must be positive, got {width}" + assert batch_size > 0, f"batch_size must be positive, got {batch_size}" + + # Generate grid coordinates for each dimension (frame, height, width) + # We use torch.arange to create the starting coordinates for each patch. + # indexing='ij' ensures the dimensions are in the order (frame, height, width). + grid_coords = torch.meshgrid( + torch.arange(start=0, end=frames, step=self._patch_size[0], device=device), + torch.arange(start=0, end=height, step=self._patch_size[1], device=device), + torch.arange(start=0, end=width, step=self._patch_size[2], device=device), + indexing="ij", + ) + + # Stack the grid coordinates to create the start coordinates tensor. + # Shape becomes (3, grid_f, grid_h, grid_w) + patch_starts = torch.stack(grid_coords, dim=0) + + # Create a tensor containing the size of a single patch: + # (frame_patch_size, height_patch_size, width_patch_size). + # Reshape to (3, 1, 1, 1) to enable broadcasting when adding to the start coordinates. + patch_size_delta = torch.tensor( + self._patch_size, + device=patch_starts.device, + dtype=patch_starts.dtype, + ).view(3, 1, 1, 1) + + # Calculate end coordinates: start + patch_size + # Shape becomes (3, grid_f, grid_h, grid_w) + patch_ends = patch_starts + patch_size_delta + + # Stack start and end coordinates together along the last dimension + # Shape becomes (3, grid_f, grid_h, grid_w, 2), where the last dimension is [start, end] + latent_coords = torch.stack((patch_starts, patch_ends), dim=-1) + + # Broadcast to batch size and flatten all spatial/temporal dimensions into one sequence. + # Final Shape: (batch_size, 3, num_patches, 2) + latent_coords = einops.repeat( + latent_coords, + "c f h w bounds -> b c (f h w) bounds", + b=batch_size, + bounds=2, + ) + + return latent_coords + + +def get_pixel_coords( + latent_coords: torch.Tensor, + scale_factors: Tuple[int, int, int], + causal_fix: bool = False, +) -> torch.Tensor: + """ + Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling + each axis (frame/time, height, width) with the corresponding VAE downsampling factors. + Optionally compensate for causal encoding that keeps the first frame at unit temporal scale. + + Args: + latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`. + scale_factors: `(temporal, height, width)` integer scale factors applied per axis. + causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs + that treat frame zero differently still yield non-negative timestamps. + """ + # Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout. + broadcast_shape = [1] * latent_coords.ndim + broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width) + scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape) + + # Apply per-axis scaling to convert latent bounds into pixel-space coordinates. + pixel_coords = latent_coords * scale_tensor + + if causal_fix: + # VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`. + # Shift and clamp to keep the first-frame timestamps causal and non-negative. + pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0) + + return pixel_coords + + +class AudioPatchifier(Patchifier): + def __init__( + self, + patch_size: int, + sample_rate: int = 16000, + hop_length: int = 160, + audio_latent_downsample_factor: int = 4, + is_causal: bool = True, + shift: int = 0, + ): + """ + Patchifier tailored for spectrogram/audio latents. + + Args: + patch_size: Number of mel bins combined into a single patch. This + controls the resolution along the frequency axis. + sample_rate: Original waveform sampling rate. Used to map latent + indices back to seconds so downstream consumers can align audio + and video cues. + hop_length: Window hop length used for the spectrogram. Determines + how many real-time samples separate two consecutive latent frames. + audio_latent_downsample_factor: Ratio between spectrogram frames and + latent frames; compensates for additional downsampling inside the + VAE encoder. + is_causal: When True, timing is shifted to account for causal + receptive fields so timestamps do not peek into the future. + shift: Integer offset applied to the latent indices. Enables + constructing overlapping windows from the same latent sequence. + """ + self.hop_length = hop_length + self.sample_rate = sample_rate + self.audio_latent_downsample_factor = audio_latent_downsample_factor + self.is_causal = is_causal + self.shift = shift + self._patch_size = (1, patch_size, patch_size) + + @property + def patch_size(self) -> Tuple[int, int, int]: + return self._patch_size + + def get_token_count(self, tgt_shape: AudioLatentShape) -> int: + return tgt_shape.frames + + def _get_audio_latent_time_in_sec( + self, + start_latent: int, + end_latent: int, + dtype: torch.dtype, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """ + Converts latent indices into real-time seconds while honoring causal + offsets and the configured hop length. + + Args: + start_latent: Inclusive start index inside the latent sequence. This + sets the first timestamp returned. + end_latent: Exclusive end index. Determines how many timestamps get + generated. + dtype: Floating-point dtype used for the returned tensor, allowing + callers to control precision. + device: Target device for the timestamp tensor. When omitted the + computation occurs on CPU to avoid surprising GPU allocations. + """ + if device is None: + device = torch.device("cpu") + + audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device) + + audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor + + if self.is_causal: + # Frame offset for causal alignment. + # The "+1" ensures the timestamp corresponds to the first sample that is fully available. + causal_offset = 1 + audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0) + + return audio_mel_frame * self.hop_length / self.sample_rate + + def _compute_audio_timings( + self, + batch_size: int, + num_steps: int, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """ + Builds a `(B, 1, T, 2)` tensor containing timestamps for each latent frame. + This helper method underpins `get_patch_grid_bounds` for the audio patchifier. + + Args: + batch_size: Number of sequences to broadcast the timings over. + num_steps: Number of latent frames (time steps) to convert into timestamps. + device: Device on which the resulting tensor should reside. + """ + resolved_device = device + if resolved_device is None: + resolved_device = torch.device("cpu") + + start_timings = self._get_audio_latent_time_in_sec( + self.shift, + num_steps + self.shift, + torch.float32, + resolved_device, + ) + start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) + + end_timings = self._get_audio_latent_time_in_sec( + self.shift + 1, + num_steps + self.shift + 1, + torch.float32, + resolved_device, + ) + end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) + + return torch.stack([start_timings, end_timings], dim=-1) + + def patchify( + self, + audio_latents: torch.Tensor, + ) -> torch.Tensor: + """ + Flattens the audio latent tensor along time. Use `get_patch_grid_bounds` + to derive timestamps for each latent frame based on the configured hop + length and downsampling. + + Args: + audio_latents: Latent tensor to patchify. + Returns: + Flattened patch tokens tensor. Use `get_patch_grid_bounds` to compute the + corresponding timing metadata when needed. + """ + audio_latents = einops.rearrange( + audio_latents, + "b c t f -> b t (c f)", + ) + + return audio_latents + + def unpatchify( + self, + audio_latents: torch.Tensor, + output_shape: AudioLatentShape, + ) -> torch.Tensor: + """ + Restores the `(B, C, T, F)` spectrogram tensor from flattened patches. + Use `get_patch_grid_bounds` to recompute the timestamps that describe each + frame's position in real time. + + Args: + audio_latents: Latent tensor to unpatchify. + output_shape: Shape of the unpatched output tensor. + Returns: + Unpatched latent tensor. Use `get_patch_grid_bounds` to compute the timing + metadata associated with the restored latents. + """ + # audio_latents shape: (batch, time, freq * channels) + audio_latents = einops.rearrange( + audio_latents, + "b t (c f) -> b c t f", + c=output_shape.channels, + f=output_shape.mel_bins, + ) + + return audio_latents + + def get_patch_grid_bounds( + self, + output_shape: AudioLatentShape | VideoLatentShape, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """ + Return the temporal bounds `[inclusive start, exclusive end)` for every + patch emitted by `patchify`. For audio this corresponds to timestamps in + seconds aligned with the original spectrogram grid. + + The returned tensor has shape `[batch_size, 1, time_steps, 2]`, where: + - axis 1 (size 1) represents the temporal dimension + - axis 3 (size 2) stores the `[start, end)` timestamps per patch + + Args: + output_shape: Audio grid specification describing the number of time steps. + device: Target device for the returned tensor. + """ + if not isinstance(output_shape, AudioLatentShape): + raise ValueError("AudioPatchifier expects AudioLatentShape when computing coordinates") + + return self._compute_audio_timings(output_shape.batch, output_shape.frames, device) diff --git a/packages/ltx-core/src/ltx_core/pipeline/components/protocols.py b/packages/ltx-core/src/ltx_core/pipeline/components/protocols.py new file mode 100644 index 0000000000000000000000000000000000000000..132de2f4cc6c4f06b9fe2cb953bc6a29d97313ac --- /dev/null +++ b/packages/ltx-core/src/ltx_core/pipeline/components/protocols.py @@ -0,0 +1,241 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Amit Pintz. + +from typing import NamedTuple, Protocol, Tuple + +import torch + + +class VideoPixelShape(NamedTuple): + """ + Shape of the tensor representing the video pixel array. Assumes BGR channel format. + """ + + batch: int + frames: int + height: int + width: int + fps: float + + +class SpatioTemporalScaleFactors(NamedTuple): + """ + Describes the spatiotemporal downscaling between decoded video space and + the corresponding VAE latent grid. + """ + + time: int + width: int + height: int + + +class VideoLatentShape(NamedTuple): + batch: int + channels: int + frames: int + height: int + width: int + + def to_torch_shape(self) -> torch.Size: + return torch.Size([self.batch, self.channels, self.frames, self.height, self.width]) + + @staticmethod + def from_torch_shape(shape: torch.Size) -> "VideoLatentShape": + return VideoLatentShape( + batch=shape[0], + channels=shape[1], + frames=shape[2], + height=shape[3], + width=shape[4], + ) + + def mask_shape(self) -> "VideoLatentShape": + return self._replace(channels=1) + + @staticmethod + def from_pixel_shape( + shape: VideoPixelShape, + latent_channels: int = 128, + scale_factors: tuple[int, int, int] = (8, 32, 32), + ) -> "VideoLatentShape": + frames = (shape.frames - 1) // scale_factors[0] + 1 + height = shape.height // scale_factors[1] + width = shape.width // scale_factors[2] + + return VideoLatentShape( + batch=shape.batch, + channels=latent_channels, + frames=frames, + height=height, + width=width, + ) + + def upscale(self, scale_factors: SpatioTemporalScaleFactors = (8, 32, 32)) -> "VideoLatentShape": + return self._replace( + channels=3, + frames=(self.frames - 1) * scale_factors.time + 1, + height=self.height * scale_factors.height, + width=self.width * scale_factors.width, + ) + + +class AudioLatentShape(NamedTuple): + batch: int + channels: int + frames: int + mel_bins: int + + def to_torch_shape(self) -> torch.Size: + return torch.Size([self.batch, self.channels, self.frames, self.mel_bins]) + + def mask_shape(self) -> "AudioLatentShape": + return self._replace(channels=1, mel_bins=1) + + @staticmethod + def from_torch_shape(shape: torch.Size) -> "AudioLatentShape": + return AudioLatentShape( + batch=shape[0], + channels=shape[1], + frames=shape[2], + mel_bins=shape[3], + ) + + @staticmethod + def from_duration( + batch: int, + duration: float, + channels: int = 8, + mel_bins: int = 16, + sample_rate: int = 16000, + hop_length: int = 160, + audio_latent_downsample_factor: int = 4, + ) -> "AudioLatentShape": + latents_per_second = float(sample_rate) / float(hop_length) / float(audio_latent_downsample_factor) + + return AudioLatentShape( + batch=batch, + channels=channels, + frames=round(duration * latents_per_second), + mel_bins=mel_bins, + ) + + @staticmethod + def from_video_pixel_shape( + shape: VideoPixelShape, + channels: int = 8, + mel_bins: int = 16, + sample_rate: int = 16000, + hop_length: int = 160, + audio_latent_downsample_factor: int = 4, + ) -> "AudioLatentShape": + return AudioLatentShape.from_duration( + batch=shape.batch, + duration=float(shape.frames) / float(shape.fps), + channels=channels, + mel_bins=mel_bins, + sample_rate=sample_rate, + hop_length=hop_length, + audio_latent_downsample_factor=audio_latent_downsample_factor, + ) + + +class Patchifier(Protocol): + """ + Protocol for patchifiers that convert latent tensors into patches and assemble them back. + """ + + def patchify( + self, + latents: torch.Tensor, + ) -> torch.Tensor: + ... + """ + Convert latent tensors into flattened patch tokens. + + Args: + latents: Latent tensor to patchify. + + Returns: + Flattened patch tokens tensor. + """ + + def unpatchify( + self, + latents: torch.Tensor, + output_shape: AudioLatentShape | VideoLatentShape, + ) -> torch.Tensor: + """ + Converts latent tensors between spatio-temporal formats and flattened sequence representations. + + Args: + latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`. + output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or + VideoLatentShape. + + Returns: + Dense latent tensor restored from the flattened representation. + """ + + @property + def patch_size(self) -> Tuple[int, int, int]: + ... + """ + Returns the patch size as a tuple of (temporal, height, width) dimensions + """ + + def get_patch_grid_bounds( + self, + output_shape: AudioLatentShape | VideoLatentShape, + device: torch.device | None = None, + ) -> torch.Tensor: + ... + """ + Compute metadata describing where each latent patch resides within the + grid specified by `output_shape`. + + Args: + output_shape: Target grid layout for the patches. + device: Target device for the returned tensor. + + Returns: + Tensor containing patch coordinate metadata such as spatial or temporal intervals. + """ + + +class SchedulerProtocol(Protocol): + """ + Protocol for schedulers that provide a sigmas schedule tensor for a + given number of steps. Device is cpu. + """ + + def execute(self, steps: int, **kwargs) -> torch.FloatTensor: ... + + +class GuiderProtocol(Protocol): + """ + Protocol for guiders that compute a delta tensor given conditioning inputs. + The returned delta should be added to the conditional output (cond), enabling + multiple guiders to be chained together by accumulating their deltas. + """ + + scale: float + + def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor: ... + + def enabled(self) -> bool: + """ + Returns whether the corresponding perturbation is enabled. E.g. for CFG, this should return False if the scale + is 1.0. + """ + ... + + +class DiffusionStepProtocol(Protocol): + """ + Protocol for diffusion steps that provide a next sample tensor for a given current sample tensor, + current denoised sample tensor, and sigmas tensor. + """ + + def step( + self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int + ) -> torch.Tensor: ... diff --git a/packages/ltx-core/src/ltx_core/pipeline/components/schedulers.py b/packages/ltx-core/src/ltx_core/pipeline/components/schedulers.py new file mode 100644 index 0000000000000000000000000000000000000000..5b3068c7a6f6f1c8f3a6061a92a65680a420da60 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/pipeline/components/schedulers.py @@ -0,0 +1,136 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Amit Pintz. + +import math +from functools import lru_cache + +import numpy +import scipy +import torch + +from ltx_core.pipeline.components.protocols import SchedulerProtocol + +BASE_SHIFT_ANCHOR = 1024 +MAX_SHIFT_ANCHOR = 4096 + + +class LTX2Scheduler(SchedulerProtocol): + def execute( + self, + steps: int, + latent: torch.Tensor | None = None, + max_shift: float = 2.05, + base_shift: float = 0.95, + stretch: bool = True, + terminal: float = 0.1, + **_kwargs, + ) -> torch.FloatTensor: + tokens = math.prod(latent.shape[2:]) if latent is not None else MAX_SHIFT_ANCHOR + sigmas = torch.linspace(1.0, 0.0, steps + 1) + + x1 = BASE_SHIFT_ANCHOR + x2 = MAX_SHIFT_ANCHOR + mm = (max_shift - base_shift) / (x2 - x1) + b = base_shift - mm * x1 + sigma_shift = (tokens) * mm + b + + power = 1 + sigmas = torch.where( + sigmas != 0, + math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power), + 0, + ) + + # Stretch sigmas so that its final value matches the given terminal value. + if stretch: + non_zero_mask = sigmas != 0 + non_zero_sigmas = sigmas[non_zero_mask] + one_minus_z = 1.0 - non_zero_sigmas + scale_factor = one_minus_z[-1] / (1.0 - terminal) + stretched = 1.0 - (one_minus_z / scale_factor) + sigmas[non_zero_mask] = stretched + + return sigmas.to(torch.float32) + + +class LinearQuadraticScheduler(SchedulerProtocol): + """ + LT_INTERNAL_BEGIN + Default value for linear_steps in ClownSampler is steps // 2, we produce the same for linear_steps=None + https://github.com/ClownsharkBatwing/RES4LYF/blob/7750bf7800b6ad9d670308a09989fc0c04c40cec/sigmas.py#L1397 + LT_INTERNAL_END + """ + + def execute( + self, steps: int, threshold_noise: float = 0.025, linear_steps: int | None = None, **_kwargs + ) -> torch.FloatTensor: + if steps == 1: + return torch.FloatTensor([1.0, 0.0]) + + if linear_steps is None: + linear_steps = steps // 2 + linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] + threshold_noise_step_diff = linear_steps - threshold_noise * steps + quadratic_steps = steps - linear_steps + quadratic_sigma_schedule = [] + if quadratic_steps > 0: + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) + const = quadratic_coef * (linear_steps**2) + quadratic_sigma_schedule = [ + quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, steps) + ] + sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0] + sigma_schedule = [1.0 - x for x in sigma_schedule] + # LT_INTERNAL: in comfy it's multiplied by model.get_model_object("model_sampling").sigma_max, + # LT_INTERNAL: which is 1 for ltxv, so we don't _precalculate_model_sampling_sigmas just to get 1.0 + return torch.FloatTensor(sigma_schedule) + + +class BetaScheduler(SchedulerProtocol): + # Implemented based on: https://arxiv.org/abs/2407.12173 + # LT_INTERNAL: https://github.com/LightricksResearch/ComfyUI/blob/8ea56795b5b8da48b018756373caa8893e0bf907/comfy/supported_models.py#L813 + shift = 2.37 + # LT_INTERNAL: default value for timesteps_length in comfy + timesteps_length = 10000 + + # LT_INTERNAL: ClownSampler uses alpha=0.5, beta=0.7 in beta57 scheduler + def execute(self, steps: int, alpha: float = 0.6, beta: float = 0.6) -> torch.FloatTensor: + """ + Execute the beta scheduler. + + Args: + steps: The number of steps to execute the scheduler for. + alpha: The alpha parameter for the beta distribution. + beta: The beta parameter for the beta distribution. + + Warnings: + The number of steps within `sigmas` theoretically might be less than `steps+1`, + because of the deduplication of the identical timesteps + + Returns: + A tensor of sigmas. + """ + model_sampling_sigmas = _precalculate_model_sampling_sigmas(self.shift, self.timesteps_length) + total_timesteps = len(model_sampling_sigmas) - 1 + ts = 1 - numpy.linspace(0, 1, steps, endpoint=False) + ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps).tolist() + ts = list(dict.fromkeys(ts)) + + sigmas = [float(model_sampling_sigmas[int(t)]) for t in ts] + [0.0] + return torch.FloatTensor(sigmas) + + +@lru_cache(maxsize=5) +def _precalculate_model_sampling_sigmas(shift: float, timesteps_length: int) -> torch.Tensor: + # LT_INTERNAL: https://github.com/LightricksResearch/ComfyUI/blob/8ea56795b5b8da48b018756373caa8893e0bf907/comfy/model_sampling.py#L353 + timesteps = torch.arange(1, timesteps_length + 1, 1) / timesteps_length + return torch.Tensor([flux_time_shift(shift, 1.0, t) for t in timesteps]) + + +def flux_time_shift(mu: float, sigma: float, t: float) -> float: + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +beta_scheduler = BetaScheduler() +sigmas = beta_scheduler.execute(steps=5, alpha=0.5, beta=0.7) diff --git a/packages/ltx-core/src/ltx_core/pipeline/conditioning/__init__.py b/packages/ltx-core/src/ltx_core/pipeline/conditioning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/packages/ltx-core/src/ltx_core/pipeline/conditioning/__pycache__/__init__.cpython-310.pyc b/packages/ltx-core/src/ltx_core/pipeline/conditioning/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91abb6e8d6993cd47272b62e0b817bc63f371099 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/pipeline/conditioning/__pycache__/__init__.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/pipeline/conditioning/__pycache__/exceptions.cpython-310.pyc b/packages/ltx-core/src/ltx_core/pipeline/conditioning/__pycache__/exceptions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf40717caac033f9299d17f74da764b6ad389bea Binary files /dev/null and b/packages/ltx-core/src/ltx_core/pipeline/conditioning/__pycache__/exceptions.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/pipeline/conditioning/__pycache__/item.cpython-310.pyc b/packages/ltx-core/src/ltx_core/pipeline/conditioning/__pycache__/item.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb95fa3b4a3abf0061261e370acadbc24d2de27f Binary files /dev/null and b/packages/ltx-core/src/ltx_core/pipeline/conditioning/__pycache__/item.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/pipeline/conditioning/__pycache__/tools.cpython-310.pyc b/packages/ltx-core/src/ltx_core/pipeline/conditioning/__pycache__/tools.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae8972bed091cefec0a67dca74ff4580856a3623 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/pipeline/conditioning/__pycache__/tools.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/pipeline/conditioning/exceptions.py b/packages/ltx-core/src/ltx_core/pipeline/conditioning/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..d2b0f9beaff90175ea7cc1930f781af0978c4596 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/pipeline/conditioning/exceptions.py @@ -0,0 +1,8 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + + +class ConditioningError(Exception): + """ + Class for conditioning-related errors. + """ diff --git a/packages/ltx-core/src/ltx_core/pipeline/conditioning/item.py b/packages/ltx-core/src/ltx_core/pipeline/conditioning/item.py new file mode 100644 index 0000000000000000000000000000000000000000..5975d115c9d6e12fa444de139328f165de0935d4 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/pipeline/conditioning/item.py @@ -0,0 +1,23 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + +from typing import Protocol + +from ltx_core.pipeline.conditioning.tools import LatentState, LatentTools + + +class ConditioningItem(Protocol): + def apply_to(self, latent_state: LatentState, latent_tools: LatentTools) -> LatentState: + """ + Apply the conditioning to the latent state. + + Args: + latent_state: The latent state to apply the conditioning to. This is state always patchified. + + Returns: + The latent state after the conditioning has been applied. + + IMPORTANT: If the conditioning needs to add extra tokens to the latent, it should add them to the end of the + latent. + """ + ... diff --git a/packages/ltx-core/src/ltx_core/pipeline/conditioning/tools.py b/packages/ltx-core/src/ltx_core/pipeline/conditioning/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..502ebc4d7b0cb5fb1a1f253640f349c1f6c31dac --- /dev/null +++ b/packages/ltx-core/src/ltx_core/pipeline/conditioning/tools.py @@ -0,0 +1,200 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + +from dataclasses import dataclass, replace +from typing import Protocol + +import torch +from torch._prims_common import DeviceLikeType + +from ltx_core.pipeline.components.patchifiers import ( + AudioLatentShape, + AudioPatchifier, + VideoLatentPatchifier, + VideoLatentShape, + get_pixel_coords, +) +from ltx_core.pipeline.components.protocols import Patchifier + + +@dataclass(frozen=True) +class LatentState: + latent: torch.Tensor + denoise_mask: torch.Tensor + positions: torch.Tensor + clean_latent: torch.Tensor + + def clone(self) -> "LatentState": + return LatentState( + latent=self.latent.clone(), + denoise_mask=self.denoise_mask.clone(), + positions=self.positions.clone(), + clean_latent=self.clean_latent.clone(), + ) + + +class LatentTools(Protocol): + """ + Tools for building latent states. + """ + + patchifier: Patchifier + target_shape: VideoLatentShape | AudioLatentShape + + def create_initial_state( + self, + device: DeviceLikeType, + dtype: torch.dtype, + initial_latent: torch.Tensor | None = None, + ) -> LatentState: + """ + Create an initial latent state. If initial_latent is provided, it will be used to create the latent state. + """ + ... + + def patchify(self, latent_state: LatentState) -> LatentState: + """ + Patchify the latent state. + """ + if latent_state.latent.shape != self.target_shape.to_torch_shape(): + raise ValueError( + f"Latent state has shape {latent_state.latent.shape}, expected shape is " + f"{self.target_shape.to_torch_shape()}" + ) + latent_state = latent_state.clone() + latent = self.patchifier.patchify(latent_state.latent) + clean_latent = self.patchifier.patchify(latent_state.clean_latent) + denoise_mask = self.patchifier.patchify(latent_state.denoise_mask) + return replace(latent_state, latent=latent, denoise_mask=denoise_mask, clean_latent=clean_latent) + + def unpatchify(self, latent_state: LatentState) -> LatentState: + """ + Unpatchify the latent state. + """ + latent_state = latent_state.clone() + latent = self.patchifier.unpatchify(latent_state.latent, output_shape=self.target_shape) + clean_latent = self.patchifier.unpatchify(latent_state.clean_latent, output_shape=self.target_shape) + denoise_mask = self.patchifier.unpatchify( + latent_state.denoise_mask, output_shape=self.target_shape.mask_shape() + ) + return replace(latent_state, latent=latent, denoise_mask=denoise_mask, clean_latent=clean_latent) + + def clear_conditioning(self, latent_state: LatentState) -> LatentState: + """ + Clear the conditioning from the latent state. This method removes extra tokens from the end of the latent. + Therefore, conditioning items should add extra tokens ONLY to the end of the latent. + """ + latent_state = latent_state.clone() + + num_tokens = self.patchifier.get_token_count(self.target_shape) + latent = latent_state.latent[:, :num_tokens] + clean_latent = latent_state.clean_latent[:, :num_tokens] + denoise_mask = torch.ones_like(latent_state.denoise_mask)[:, :num_tokens] + positions = latent_state.positions[:, :, :num_tokens] + + return LatentState(latent=latent, denoise_mask=denoise_mask, positions=positions, clean_latent=clean_latent) + + +@dataclass(frozen=True) +class VideoLatentTools(LatentTools): + """ + Tools for building video latent states. + """ + + patchifier: VideoLatentPatchifier + target_shape: VideoLatentShape + fps: float + scale_factors: tuple[int, int, int] = (8, 32, 32) + causal_fix: bool = True + + def create_initial_state( + self, + device: DeviceLikeType, + dtype: torch.dtype, + initial_latent: torch.Tensor | None = None, + ) -> LatentState: + if initial_latent is not None: + assert initial_latent.shape == self.target_shape.to_torch_shape(), ( + f"Latent shape {initial_latent.shape} does not match target shape {self.target_shape.to_torch_shape()}" + ) + else: + initial_latent = torch.zeros( + *self.target_shape.to_torch_shape(), + device=device, + dtype=dtype, + ) + + clean_latent = initial_latent.clone() + + denoise_mask = torch.ones( + *self.target_shape.mask_shape().to_torch_shape(), + device=device, + dtype=torch.float32, + ) + + latent_coords = self.patchifier.get_patch_grid_bounds( + output_shape=self.target_shape, + device=device, + ) + + positions = get_pixel_coords( + latent_coords=latent_coords, + scale_factors=self.scale_factors, + causal_fix=self.causal_fix, + ).float() + positions[:, 0, ...] = positions[:, 0, ...] / self.fps + + return self.patchify( + LatentState( + latent=initial_latent, + denoise_mask=denoise_mask, + positions=positions.to(dtype), + clean_latent=clean_latent, + ) + ) + + +@dataclass(frozen=True) +class AudioLatentTools(LatentTools): + """ + Tools for building audio latent states. + """ + + patchifier: AudioPatchifier + target_shape: AudioLatentShape + + def create_initial_state( + self, + device: DeviceLikeType, + dtype: torch.dtype, + initial_latent: torch.Tensor | None = None, + ) -> LatentState: + if initial_latent is not None: + assert initial_latent.shape == self.target_shape.to_torch_shape(), ( + f"Latent shape {initial_latent.shape} does not match target shape {self.target_shape.to_torch_shape()}" + ) + else: + initial_latent = torch.zeros( + *self.target_shape.to_torch_shape(), + device=device, + dtype=dtype, + ) + + clean_latent = initial_latent.clone() + + denoise_mask = torch.ones( + *self.target_shape.mask_shape().to_torch_shape(), + device=device, + dtype=torch.float32, + ) + + latent_coords = self.patchifier.get_patch_grid_bounds( + output_shape=self.target_shape, + device=device, + ) + + return self.patchify( + LatentState( + latent=initial_latent, denoise_mask=denoise_mask, positions=latent_coords, clean_latent=clean_latent + ) + ) diff --git a/packages/ltx-core/src/ltx_core/pipeline/conditioning/types/__init__.py b/packages/ltx-core/src/ltx_core/pipeline/conditioning/types/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/packages/ltx-core/src/ltx_core/pipeline/conditioning/types/__pycache__/__init__.cpython-310.pyc b/packages/ltx-core/src/ltx_core/pipeline/conditioning/types/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f972ebc9516e6a0e89a5fafdf922a77fe4928528 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/pipeline/conditioning/types/__pycache__/__init__.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/pipeline/conditioning/types/__pycache__/keyframe_cond.cpython-310.pyc b/packages/ltx-core/src/ltx_core/pipeline/conditioning/types/__pycache__/keyframe_cond.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..506baac36805aaaef76f9b98f451983116927c1e Binary files /dev/null and b/packages/ltx-core/src/ltx_core/pipeline/conditioning/types/__pycache__/keyframe_cond.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/pipeline/conditioning/types/__pycache__/latent_cond.cpython-310.pyc b/packages/ltx-core/src/ltx_core/pipeline/conditioning/types/__pycache__/latent_cond.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6c30aac7aa3d97ab2a03a9bedb1512dc6cdaac5 Binary files /dev/null and b/packages/ltx-core/src/ltx_core/pipeline/conditioning/types/__pycache__/latent_cond.cpython-310.pyc differ diff --git a/packages/ltx-core/src/ltx_core/pipeline/conditioning/types/keyframe_cond.py b/packages/ltx-core/src/ltx_core/pipeline/conditioning/types/keyframe_cond.py new file mode 100644 index 0000000000000000000000000000000000000000..38df193785c6976443ec267652d49a342f0c7ce8 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/pipeline/conditioning/types/keyframe_cond.py @@ -0,0 +1,47 @@ +import torch + +from ltx_core.pipeline.components.patchifiers import get_pixel_coords +from ltx_core.pipeline.components.protocols import VideoLatentShape +from ltx_core.pipeline.conditioning.item import ConditioningItem, LatentState +from ltx_core.pipeline.conditioning.tools import VideoLatentTools + + +class VideoConditionByKeyframeIndex(ConditioningItem): + def __init__(self, keyframes: torch.Tensor, frame_idx: int, strength: float): + self.keyframes = keyframes + self.frame_idx = frame_idx + self.strength = strength + + def apply_to( + self, + latent_state: LatentState, + latent_tools: VideoLatentTools, + ) -> LatentState: + tokens = latent_tools.patchifier.patchify(self.keyframes) + latent_coords = latent_tools.patchifier.get_patch_grid_bounds( + output_shape=VideoLatentShape.from_torch_shape(self.keyframes.shape), + device=self.keyframes.device, + ) + positions = get_pixel_coords( + latent_coords=latent_coords, + scale_factors=latent_tools.scale_factors, + causal_fix=latent_tools.causal_fix if self.frame_idx == 0 else False, + ) + + positions[:, 0, ...] += self.frame_idx + positions = positions.to(dtype=torch.float32) + positions[:, 0, ...] /= latent_tools.fps + + denoise_mask = torch.full( + size=(*tokens.shape[:2], 1), + fill_value=1.0 - self.strength, + device=self.keyframes.device, + dtype=self.keyframes.dtype, + ) + + return LatentState( + latent=torch.cat([latent_state.latent, tokens], dim=1), + denoise_mask=torch.cat([latent_state.denoise_mask, denoise_mask], dim=1), + positions=torch.cat([latent_state.positions, positions], dim=2), + clean_latent=torch.cat([latent_state.clean_latent, tokens], dim=1), + ) diff --git a/packages/ltx-core/src/ltx_core/pipeline/conditioning/types/latent_cond.py b/packages/ltx-core/src/ltx_core/pipeline/conditioning/types/latent_cond.py new file mode 100644 index 0000000000000000000000000000000000000000..5a6d8afac314eb1c030dd2cd2ea987298188acf1 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/pipeline/conditioning/types/latent_cond.py @@ -0,0 +1,40 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + +import torch + +from ltx_core.pipeline.conditioning.exceptions import ConditioningError +from ltx_core.pipeline.conditioning.item import ConditioningItem, LatentState +from ltx_core.pipeline.conditioning.tools import LatentTools + + +class VideoConditionByLatentIndex(ConditioningItem): + def __init__(self, latent: torch.Tensor, strength: float, latent_idx: int): + self.latent = latent + self.strength = strength + self.latent_idx = latent_idx + + def apply_to(self, latent_state: LatentState, latent_tools: LatentTools) -> LatentState: + cond_batch, cond_channels, _, cond_height, cond_width = self.latent.shape + tgt_batch, tgt_channels, tgt_frames, tgt_height, tgt_width = latent_tools.target_shape.to_torch_shape() + + if (cond_batch, cond_channels, cond_height, cond_width) != (tgt_batch, tgt_channels, tgt_height, tgt_width): + raise ConditioningError( + f"Can't apply image conditioning item to latent with shape {latent_tools.target_shape}, expected " + f"shape is ({tgt_batch}, {tgt_channels}, {tgt_frames}, {tgt_height}, {tgt_width}). Make sure " + "the image and latent have the same spatial shape." + ) + + tokens = latent_tools.patchifier.patchify(self.latent) + start_token = latent_tools.patchifier.get_token_count( + latent_tools.target_shape._replace(frames=self.latent_idx) + ) + stop_token = start_token + tokens.shape[1] + + latent_state = latent_state.clone() + + latent_state.latent[:, start_token:stop_token] = tokens + latent_state.clean_latent[:, start_token:stop_token] = tokens + latent_state.denoise_mask[:, start_token:stop_token] = 1.0 - self.strength + + return latent_state diff --git a/packages/ltx-core/src/ltx_core/tiling.py b/packages/ltx-core/src/ltx_core/tiling.py new file mode 100644 index 0000000000000000000000000000000000000000..d2a303197d6e982bdd2c273b62207e58c810b4e8 --- /dev/null +++ b/packages/ltx-core/src/ltx_core/tiling.py @@ -0,0 +1,363 @@ +import itertools +from dataclasses import dataclass +from typing import Any, List, NamedTuple, Tuple + +import torch + + +def compute_trapezoidal_mask_1d( + length: int, + ramp_left: int, + ramp_right: int, + left_starts_from_0: bool = False, +) -> torch.Tensor: + """ + Generate a 1D trapezoidal blending mask with linear ramps. + + Args: + length: Output length of the mask. + ramp_left: Fade-in length on the left. + ramp_right: Fade-out length on the right. + left_starts_from_0: Whether the ramp starts from 0 or first non-zero value. + Useful for temporal tiles where the first tile is causal. + Returns: + A 1D tensor of shape `(length,)` with values in [0, 1]. + """ + if length <= 0: + raise ValueError("Mask length must be positive.") + + ramp_left = max(0, min(ramp_left, length)) + ramp_right = max(0, min(ramp_right, length)) + + mask = torch.ones(length) + + if ramp_left > 0: + interval_length = ramp_left + 1 if left_starts_from_0 else ramp_left + 2 + fade_in = torch.linspace(0.0, 1.0, interval_length)[:-1] + if not left_starts_from_0: + fade_in = fade_in[1:] + mask[:ramp_left] *= fade_in + + if ramp_right > 0: + fade_out = torch.linspace(1.0, 0.0, steps=ramp_right + 2)[1:-1] + mask[-ramp_right:] *= fade_out + + return mask.clamp_(0, 1) + + +@dataclass(frozen=True) +class SpatialTilingConfig: + """Configuration for dividing each frame into spatial tiles with optional overlap. + + Args: + tile_size_in_pixels (int): Size of each tile in pixels. Must be at least 64 and divisible by 32. + tile_overlap_in_pixels (int, optional): Overlap between tiles in pixels. Must be divisible by 32. Defaults to 0. + """ + + tile_size_in_pixels: int + tile_overlap_in_pixels: int = 0 + + def __post_init__(self) -> None: + if self.tile_size_in_pixels < 64: + raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}") + if self.tile_size_in_pixels % 32 != 0: + raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}") + if self.tile_overlap_in_pixels % 32 != 0: + raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}") + + +@dataclass(frozen=True) +class TemporalTilingConfig: + """Configuration for dividing a video into temporal tiles (chunks of frames) with optional overlap. + + Args: + tile_size_in_frames (int): Number of frames in each tile. Must be at least 16 and divisible by 8. + tile_overlap_in_frames (int, optional): Number of overlapping frames between consecutive tiles. + Must be divisible by 8. Defaults to 0. + """ + + tile_size_in_frames: int + tile_overlap_in_frames: int = 0 + + def __post_init__(self) -> None: + if self.tile_size_in_frames < 16: + raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}") + if self.tile_size_in_frames % 8 != 0: + raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}") + if self.tile_overlap_in_frames % 8 != 0: + raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}") + + +@dataclass(frozen=True) +class TilingConfig: + """Configuration for splitting video into tiles with optional overlap. + + Attributes: + spatial_config: Configuration for splitting spatial dimensions into tiles. + temporal_config: Configuration for splitting temporal dimension into tiles. + """ + + spatial_config: SpatialTilingConfig | None = None + temporal_config: TemporalTilingConfig | None = None + + @classmethod + def default(cls) -> "TilingConfig": + return cls( + spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64), + temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24), + ) + + +@dataclass(frozen=True) +class LatentIntervals: + original_shape: torch.Size + starts_per_dimension: Tuple[List[int], ...] + ends_per_dimension: Tuple[List[int], ...] + left_ramps_per_dimension: Tuple[List[int], ...] + right_ramps_per_dimension: Tuple[List[int], ...] + + +class Tile(NamedTuple): + """ + Represents a single tile. + + Attributes: + in_coords: + Tuple of slices specifying where to cut the tile from the INPUT tensor. + + out_coords: + Tuple of slices specifying where this tile's OUTPUT should be placed in the reconstructed OUTPUT tensor. + + masks_1d: + Per-dimension masks in OUTPUT units. + These are used to create all-dimensional blending mask. + + Methods: + blend_mask: + Create a single N-D mask from the per-dimension masks. + """ + + in_coords: Tuple[slice, ...] + out_coords: Tuple[slice, ...] + masks_1d: Tuple[Tuple[torch.Tensor, ...]] + + @property + def blend_mask(self) -> torch.Tensor: + num_dims = len(self.out_coords) + per_dimension_masks: List[torch.Tensor] = [] + + for dim_idx in range(num_dims): + mask_1d = self.masks_1d[dim_idx] + view_shape = [1] * num_dims + if mask_1d is None: + # Broadcast mask along this dimension (length 1). + one = torch.ones(1) + + view_shape[dim_idx] = 1 + per_dimension_masks.append(one.view(*view_shape)) + continue + + # Reshape (L,) -> (1, ..., L, ..., 1) so masks across dimensions broadcast-multiply. + view_shape[dim_idx] = mask_1d.shape[0] + per_dimension_masks.append(mask_1d.view(*view_shape)) + + # Multiply per-dimension masks to form the full N-D mask (separable blending window). + combined_mask = per_dimension_masks[0] + for mask in per_dimension_masks[1:]: + combined_mask = combined_mask * mask + + return combined_mask + + +def create_tiles_from_tile_sizes( + # LT_INTERNAL: make vae conform to smth like SpatialTiler | TemporalTiler | SpatialTemporalTiler + vae: Any, # noqa: ANN401 + latent_shape: torch.Size, + spatial_tile_size: int, + temporal_tile_size: int, + spatial_overlap: int = 0, + temporal_overlap: int = 0, + spatial_axes_indices: Tuple[int, ...] = (3, 4), + temporal_axes_indices: Tuple[int] = (2,), +) -> List[Tile]: + latent_intervals = _create_intervals_from_tile_sizes( + latent_shape=latent_shape, + spatial_tile_size=spatial_tile_size, + spatial_overlap=spatial_overlap, + temporal_tile_size=temporal_tile_size, + temporal_overlap=temporal_overlap, + spatial_axes_indices=spatial_axes_indices, + temporal_axes_indices=temporal_axes_indices, + ) + return create_tiles_from_latent_intervals(vae, latent_intervals, temporal_axes_indices, spatial_axes_indices) + + +def create_tiles_from_tiles_amount( + # LT_INTERNAL: make vae conform to smth like SpatialTiler | TemporalTiler | SpatialTemporalTiler + vae: Any, # noqa: ANN401 + latent_shape: torch.Size, + spatial_tiles_amount: int, + temporal_tile_size: int, + spatial_overlap: int = 0, + temporal_overlap: int = 0, + # LT_INTERNAL: vae.temporal_axes_indices if isinstance(vae, TemporalTiler) + temporal_axes_indices: Tuple[int] = (2,), + # LT_INTERNAL: vae.spatial_axes_indices if isinstance(vae, SpatialTiler) + spatial_axes_indices: Tuple[int, ...] = (3, 4), +) -> List[Tile]: + latent_intervals = _create_intervals_from_tiles_amount( + latent_shape, + spatial_tiles_amount, + temporal_tile_size, + temporal_overlap, + spatial_overlap, + temporal_axes_indices, + spatial_axes_indices, + ) + return create_tiles_from_latent_intervals(vae, latent_intervals, temporal_axes_indices, spatial_axes_indices) + + +def _create_intervals_from_tile_sizes( + latent_shape: torch.Size, + spatial_tile_size: int, + temporal_tile_size: int, + spatial_overlap: int = 0, + temporal_overlap: int = 0, + spatial_axes_indices: Tuple[int, ...] = (3, 4), + temporal_axes_indices: Tuple[int] = (2,), +) -> LatentIntervals: + starts_per_dimension = [] + ends_per_dimension = [] + left_ramps_per_dimension = [] + right_ramps_per_dimension = [] + for axis_index in range(len(latent_shape)): + dimension_size = latent_shape[axis_index] + size = dimension_size + overlap = 0 + amount = 1 + if axis_index in temporal_axes_indices or axis_index in spatial_axes_indices: + size = temporal_tile_size if axis_index in temporal_axes_indices else spatial_tile_size + overlap = temporal_overlap if axis_index in temporal_axes_indices else spatial_overlap + amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap) + starts = [i * (size - overlap) for i in range(amount)] + ends = [start + size for start in starts] + ends[-1] = dimension_size + left_ramps = [0] + [overlap] * (amount - 1) + right_ramps = [overlap] * (amount - 1) + [0] + if axis_index in temporal_axes_indices: + # each temporal tile is causal / grabs a latent frame before the start + starts[1:] = [s - 1 for s in starts[1:]] + left_ramps[1:] = [r + 1 for r in left_ramps[1:]] + starts_per_dimension.append(starts) + ends_per_dimension.append(ends) + left_ramps_per_dimension.append(left_ramps) + right_ramps_per_dimension.append(right_ramps) + + return LatentIntervals( + original_shape=latent_shape, + starts_per_dimension=tuple(starts_per_dimension), + ends_per_dimension=tuple(ends_per_dimension), + left_ramps_per_dimension=tuple(left_ramps_per_dimension), + right_ramps_per_dimension=tuple(right_ramps_per_dimension), + ) + + +def _create_intervals_from_tiles_amount( + latent_shape: torch.Size, + spatial_tiles_amount: int, + temporal_tile_size: int, + temporal_overlap: int = 0, + spatial_overlap: int = 0, + temporal_axes_indices: Tuple[int] = (2,), + spatial_axes_indices: Tuple[int, ...] = (3, 4), +) -> LatentIntervals: + starts_per_dimension = [] + ends_per_dimension = [] + left_ramps_per_dimension = [] + right_ramps_per_dimension = [] + for axis_index in range(len(latent_shape)): + dimension_size = latent_shape[axis_index] + size = dimension_size + overlap = 0 + amount = 1 + if axis_index in temporal_axes_indices: + size = temporal_tile_size + overlap = temporal_overlap + amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap) + elif axis_index in spatial_axes_indices: + amount = spatial_tiles_amount + overlap = spatial_overlap + size = (dimension_size + (spatial_tiles_amount - 1) * overlap) // spatial_tiles_amount + starts = [i * (size - overlap) for i in range(amount)] + ends = [start + size for start in starts] + ends[-1] = dimension_size + left_ramps = [0] + [overlap] * (amount - 1) + right_ramps = [overlap] * (amount - 1) + [0] + if axis_index in temporal_axes_indices: + # each temporal tile is causal / grabs a latent frame before the start + starts[1:] = [s - 1 for s in starts[1:]] + left_ramps[1:] = [r + 1 for r in left_ramps[1:]] + starts_per_dimension.append(starts) + ends_per_dimension.append(ends) + left_ramps_per_dimension.append(left_ramps) + right_ramps_per_dimension.append(right_ramps) + + return LatentIntervals( + original_shape=latent_shape, + starts_per_dimension=tuple(starts_per_dimension), + ends_per_dimension=tuple(ends_per_dimension), + left_ramps_per_dimension=tuple(left_ramps_per_dimension), + right_ramps_per_dimension=tuple(right_ramps_per_dimension), + ) + + +def create_tiles_from_latent_intervals( + vae: Any, # noqa: ANN401 + latent_intervals: LatentIntervals, + temporal_axes_indices: Tuple[int] = (2,), + spatial_axes_indices: Tuple[int, ...] = (3, 4), +) -> List[Tile]: + full_dim_input_slices = [] + full_dim_output_slices = [] + full_dim_masks_1d = [] + for axis_index in range(len(latent_intervals.original_shape)): + starts = latent_intervals.starts_per_dimension[axis_index] + ends = latent_intervals.ends_per_dimension[axis_index] + left_ramps = latent_intervals.left_ramps_per_dimension[axis_index] + right_ramps = latent_intervals.right_ramps_per_dimension[axis_index] + input_slices = [slice(s, e) for s, e in zip(starts, ends, strict=True)] + output_slices = [slice(0, None) for _ in input_slices] + masks_1d = [None] * len(input_slices) + # LT_INTERNAL: and isinstance(vae, TemporalTiler)? + if axis_index in temporal_axes_indices: + output_slices = [] + masks_1d = [] + for s, e, lr, rr in zip(starts, ends, left_ramps, right_ramps, strict=True): + output_slice, mask_1d = vae.map_temporal_slice(s, e, lr, rr) + output_slices.append(output_slice) + masks_1d.append(mask_1d) + # LT_INTERNAL: and isinstance(vae, SpatialTiler)? + elif axis_index in spatial_axes_indices: + output_slices = [] + masks_1d = [] + for s, e, lr, rr in zip(starts, ends, left_ramps, right_ramps, strict=True): + output_slice, mask_1d = vae.map_spatial_slice(s, e, lr, rr) + output_slices.append(output_slice) + masks_1d.append(mask_1d) + full_dim_input_slices.append(input_slices) + full_dim_output_slices.append(output_slices) + full_dim_masks_1d.append(masks_1d) + + tiles = [] + tile_in_coords = list(itertools.product(*full_dim_input_slices)) + tile_out_coords = list(itertools.product(*full_dim_output_slices)) + tile_mask_1ds = list(itertools.product(*full_dim_masks_1d)) + for in_coord, out_coord, mask_1d in zip(tile_in_coords, tile_out_coords, tile_mask_1ds, strict=True): + tiles.append( + Tile( + in_coords=in_coord, + out_coords=out_coord, + masks_1d=mask_1d, + ) + ) + return tiles diff --git a/packages/ltx-core/src/ltx_core/utils.py b/packages/ltx-core/src/ltx_core/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9ed38bf822bce122d7ba5774ed1778ce3491468e --- /dev/null +++ b/packages/ltx-core/src/ltx_core/utils.py @@ -0,0 +1,57 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Amit Pintz. + +from typing import Any + +import torch + + +def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor: + """Root-mean-square (RMS) normalize `x` over its last dimension. + + Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized + shape and forwards `weight` and `eps`. + """ + return torch.nn.functional.rms_norm(x, (x.shape[-1],), weight=weight, eps=eps) + + +def check_config_value(config: dict, key: str, expected: Any) -> None: # noqa: ANN401 + actual = config.get(key) + if actual != expected: + raise ValueError(f"Config value {key} is {actual}, expected {expected}") + + +def to_velocity( + sample: torch.Tensor, + sigma: float | torch.Tensor, + denoised_sample: torch.Tensor, + calc_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Convert the sample and its denoised version to velocity. + + Returns: + Velocity + """ + if isinstance(sigma, torch.Tensor): + sigma = sigma.to(calc_dtype).item() + if sigma == 0: + raise ValueError("Sigma can't be 0.0") + return ((sample.to(calc_dtype) - denoised_sample.to(calc_dtype)) / sigma).to(sample.dtype) + + +def to_denoised( + sample: torch.Tensor, + velocity: torch.Tensor, + sigma: float | torch.Tensor, + calc_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Convert the sample and its denoising velocity to denoised sample. + + Returns: + Denoised sample + """ + if isinstance(sigma, torch.Tensor): + sigma = sigma.to(calc_dtype) + return (sample.to(calc_dtype) - velocity.to(calc_dtype) * sigma).to(sample.dtype) diff --git a/packages/ltx-core/tests/__init__.py b/packages/ltx-core/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/packages/ltx-core/tests/ltx_core/__init__.py b/packages/ltx-core/tests/ltx_core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/packages/ltx-core/tests/ltx_core/conftest.py b/packages/ltx-core/tests/ltx_core/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..154d7e1c070aa9455c18fd16d6d0ccf4f9398237 --- /dev/null +++ b/packages/ltx-core/tests/ltx_core/conftest.py @@ -0,0 +1,106 @@ +import os +from pathlib import Path +from typing import Callable + +import av +import pytest +import torch +import torch.nn.functional as F +from torch._prims_common import DeviceLikeType + +torch.use_deterministic_algorithms(True) +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + +MODELS_PATH = Path(os.getenv("MODELS_PATH", "/models")) +CHECKPOINTS_DIR = MODELS_PATH / "comfyui_models" / "checkpoints" +LORAS_DIR = MODELS_PATH / "comfyui_models" / "loras" + +GEMMA_ROOT = MODELS_PATH / "comfyui_models" / "text_encoders" / "gemma-3-12b-it-qat-q4_0-unquantized_readout_proj" +DISTILLED_CHECKPOINT_PATH = CHECKPOINTS_DIR / "ltx-av-distilled-teacher-1933500-step-25700-ema.safetensors" +AV_CHECKPOINT_SPLIT_PATH = CHECKPOINTS_DIR / "ltx-av-step-1933500-split-new-vae.safetensors" +SPATIAL_UPSAMPLER_PATH = CHECKPOINTS_DIR / "ltx2-spatial-upscaler-x2-1.0.bf16.safetensors" +DISTILLED_LORA_PATH = ( + LORAS_DIR / "ltxv" / "ltx2" / "ltx-av-distilled-teacher-1933500-step-25700-ema-lora-384_comfy.safetensors" +) +ROOT_DIR = Path(__file__).parent +OUTPUT_DIR = ROOT_DIR / "output" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + + +def _psnr(pred: torch.Tensor, target: torch.Tensor, max_val: float = 1.0, eps: float = 1e-8) -> torch.Tensor: + """ + Compute Peak Signal-to-Noise Ratio (PSNR) between two images (or batches of images). + + Args: + pred: Predicted image tensor, shape (..., H, W) or (..., C, H, W) + target: Ground truth image tensor, same shape as `pred` + max_val: Maximum possible pixel value of the images. + For images in [0, 1] use 1.0, for [0, 255] use 255.0, etc. + eps: Small value to avoid log of zero. + + Returns: + psnr: PSNR value (in dB). + """ + # Ensure same shape + if pred.shape != target.shape: + raise ValueError(f"Shape mismatch: pred {pred.shape}, target {target.shape}") + + # Convert to float for safety + pred = pred.float() + target = target.float() + + # Mean squared error per sample + # Flatten over all dims + if pred.dim() > 1: + mse = F.mse_loss(pred, target, reduction="none") + # Reduce over spatial (and channel) dims + dims = list(range(mse.dim())) + mse = mse.mean(dim=dims) + else: + # 1D case + mse = F.mse_loss(pred, target, reduction="mean") + + # PSNR computation + psnr_val = 10.0 * torch.log10((max_val**2) / (mse + eps)) + + return psnr_val + + +@pytest.fixture +def psnr() -> Callable[[torch.Tensor, torch.Tensor, float, float], float]: + """Fixture that returns the PSNR function.""" + return _psnr + + +def _decode_video_from_file(path: str, device: DeviceLikeType) -> tuple[torch.Tensor, torch.Tensor | None]: + container = av.open(path) + try: + video_stream = next(s for s in container.streams if s.type == "video") + audio_stream = next((s for s in container.streams if s.type == "audio"), None) + + frames = [] + audio = [] if audio_stream else None + + streams_to_decode = [video_stream] + if audio_stream: + streams_to_decode.append(audio_stream) + + for frame in container.decode(*streams_to_decode): + if isinstance(frame, av.VideoFrame): + tensor = torch.tensor(frame.to_rgb().to_ndarray(), dtype=torch.uint8, device=device).unsqueeze(0) + frames.append(tensor) + elif isinstance(frame, av.AudioFrame): + audio.append(torch.tensor(frame.to_ndarray(), dtype=torch.float32, device=device).unsqueeze(0)) + + if audio: + audio = torch.cat(audio) + finally: + container.close() + + return torch.cat(frames), audio + + +@pytest.fixture +def decode_video_from_file() -> Callable[[str], tuple[torch.Tensor, torch.Tensor | None]]: + """Fixture that returns the function to decode a video from a file.""" + return _decode_video_from_file diff --git a/packages/ltx-core/tests/ltx_core/guidance/test_perturbations.py b/packages/ltx-core/tests/ltx_core/guidance/test_perturbations.py new file mode 100644 index 0000000000000000000000000000000000000000..fab67d5983b87b1c854721bf632381274ff56acf --- /dev/null +++ b/packages/ltx-core/tests/ltx_core/guidance/test_perturbations.py @@ -0,0 +1,89 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + +import pytest +import torch + +import ltx_core.guidance.perturbations as ptb + + +@pytest.fixture +def positive_config() -> ptb.PerturbationConfig: + return ptb.PerturbationConfig(None) + + +@pytest.fixture +def modality_config() -> ptb.PerturbationConfig: + return ptb.PerturbationConfig( + [ + ptb.Perturbation(ptb.PerturbationType.SKIP_A2V_CROSS_ATTN, None), + ptb.Perturbation(ptb.PerturbationType.SKIP_V2A_CROSS_ATTN, None), + ] + ) + + +@pytest.fixture +def stg_config() -> ptb.PerturbationConfig: + return ptb.PerturbationConfig( + [ + ptb.Perturbation(ptb.PerturbationType.SKIP_VIDEO_SELF_ATTN, [0, 1, 2]), + ptb.Perturbation(ptb.PerturbationType.SKIP_AUDIO_SELF_ATTN, [0, 1]), + ] + ) + + +def test_perturbation() -> None: + perturbation = ptb.Perturbation(ptb.PerturbationType.SKIP_A2V_CROSS_ATTN, [0, 1, 2]) + assert perturbation.is_perturbed(ptb.PerturbationType.SKIP_A2V_CROSS_ATTN, 0) + assert not perturbation.is_perturbed(ptb.PerturbationType.SKIP_A2V_CROSS_ATTN, 3) + assert not perturbation.is_perturbed(ptb.PerturbationType.SKIP_V2A_CROSS_ATTN, 0) + + all_blocks = ptb.Perturbation(ptb.PerturbationType.SKIP_A2V_CROSS_ATTN, None) + assert all_blocks.is_perturbed(ptb.PerturbationType.SKIP_A2V_CROSS_ATTN, 0) + + +def test_perturbation_config( + positive_config: ptb.PerturbationConfig, modality_config: ptb.PerturbationConfig, stg_config: ptb.PerturbationConfig +) -> None: + assert modality_config.is_perturbed(ptb.PerturbationType.SKIP_A2V_CROSS_ATTN, 0) + assert modality_config.is_perturbed(ptb.PerturbationType.SKIP_V2A_CROSS_ATTN, 0) + assert stg_config.is_perturbed(ptb.PerturbationType.SKIP_AUDIO_SELF_ATTN, 0) + assert stg_config.is_perturbed(ptb.PerturbationType.SKIP_AUDIO_SELF_ATTN, 1) + assert stg_config.is_perturbed(ptb.PerturbationType.SKIP_VIDEO_SELF_ATTN, 0) + assert not stg_config.is_perturbed(ptb.PerturbationType.SKIP_VIDEO_SELF_ATTN, 3) + assert not stg_config.is_perturbed(ptb.PerturbationType.SKIP_A2V_CROSS_ATTN, 3) + assert not stg_config.is_perturbed(ptb.PerturbationType.SKIP_V2A_CROSS_ATTN, 3) + + assert not positive_config.is_perturbed(ptb.PerturbationType.SKIP_A2V_CROSS_ATTN, 0) + assert not positive_config.is_perturbed(ptb.PerturbationType.SKIP_V2A_CROSS_ATTN, 0) + assert not positive_config.is_perturbed(ptb.PerturbationType.SKIP_VIDEO_SELF_ATTN, 0) + assert not positive_config.is_perturbed(ptb.PerturbationType.SKIP_AUDIO_SELF_ATTN, 0) + + +def test_batched_perturbation_config( + positive_config: ptb.PerturbationConfig, modality_config: ptb.PerturbationConfig, stg_config: ptb.PerturbationConfig +) -> None: + batched_ptb = ptb.BatchedPerturbationConfig([positive_config, modality_config, stg_config]) + assert batched_ptb.any_in_batch(ptb.PerturbationType.SKIP_A2V_CROSS_ATTN, 0) + assert batched_ptb.any_in_batch(ptb.PerturbationType.SKIP_V2A_CROSS_ATTN, 0) + assert batched_ptb.any_in_batch(ptb.PerturbationType.SKIP_VIDEO_SELF_ATTN, 0) + assert batched_ptb.any_in_batch(ptb.PerturbationType.SKIP_AUDIO_SELF_ATTN, 0) + assert not batched_ptb.any_in_batch(ptb.PerturbationType.SKIP_VIDEO_SELF_ATTN, 3) + assert not batched_ptb.any_in_batch(ptb.PerturbationType.SKIP_AUDIO_SELF_ATTN, 2) + assert not batched_ptb.all_in_batch(ptb.PerturbationType.SKIP_A2V_CROSS_ATTN, 0) + assert not batched_ptb.all_in_batch(ptb.PerturbationType.SKIP_V2A_CROSS_ATTN, 0) + + mask = batched_ptb.mask(ptb.PerturbationType.SKIP_A2V_CROSS_ATTN, 0, torch.device("cpu"), torch.float32) + assert mask.tolist() == [1, 0, 1] + + mask = batched_ptb.mask_like(ptb.PerturbationType.SKIP_AUDIO_SELF_ATTN, 2, mask) + assert mask.tolist() == [1, 1, 1] + + mask = batched_ptb.mask_like(ptb.PerturbationType.SKIP_AUDIO_SELF_ATTN, 1, mask) + assert mask.tolist() == [1, 1, 0] + + +def test_empty_batched_perturbation_config() -> None: + batched_ptb = ptb.BatchedPerturbationConfig.empty(2) + mask = batched_ptb.mask(ptb.PerturbationType.SKIP_A2V_CROSS_ATTN, 0, torch.device("cpu"), torch.float32) + assert mask.tolist() == [1, 1] diff --git a/packages/ltx-core/tests/ltx_core/loader/test_sd_keys_ops.py b/packages/ltx-core/tests/ltx_core/loader/test_sd_keys_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..bc7221c1f5be9226c466d4bc095e0bbb553b73dd --- /dev/null +++ b/packages/ltx-core/tests/ltx_core/loader/test_sd_keys_ops.py @@ -0,0 +1,141 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. + +from ltx_core.loader.sd_ops import ( + ContentMatching, + ContentReplacement, + SDOps, +) + + +class TestSDOps: + def test_creation_minimal(self) -> None: + ops = SDOps("test_ops") + assert ops.name == "test_ops" + assert ops.mapping == () + + def test_with_replacement_creates_new_instance(self) -> None: + ops = SDOps("test") + new_ops = ops.with_replacement("old", "new") + assert ops.mapping == () + assert len(new_ops.mapping) == 1 + assert isinstance(new_ops.mapping[0], ContentReplacement) + assert new_ops.mapping[0].content == "old" + assert new_ops.mapping[0].replacement == "new" + + def test_with_replacement_chaining(self) -> None: + ops = SDOps("test").with_replacement("a", "b").with_replacement("c", "d") + assert len(ops.mapping) == 2 + assert ops.mapping[0] == ContentReplacement("a", "b") + assert ops.mapping[1] == ContentReplacement("c", "d") + + def test_with_matching_creates_new_instance(self) -> None: + ops = SDOps("test") + new_ops = ops.with_matching(prefix="model.", suffix=".weight") + assert ops.mapping == () + assert len(new_ops.mapping) == 1 + assert isinstance(new_ops.mapping[0], ContentMatching) + assert new_ops.mapping[0].prefix == "model." + assert new_ops.mapping[0].suffix == ".weight" + + def test_with_matching_chaining(self) -> None: + ops = SDOps("test").with_matching(prefix="a.", suffix="b.").with_matching(prefix="c.", suffix="d.") + assert len(ops.mapping) == 2 + assert ops.mapping[0] == ContentMatching(prefix="a.", suffix="b.") + assert ops.mapping[1] == ContentMatching(prefix="c.", suffix="d.") + + def test_mixed_chaining(self) -> None: + ops = SDOps("test").with_matching(prefix="model.").with_replacement("old", "new") + assert len(ops.mapping) == 2 + assert isinstance(ops.mapping[0], ContentMatching) + assert isinstance(ops.mapping[1], ContentReplacement) + + +class TestSDOpsApply: + def test_apply_without_matching_returns_none(self) -> None: + ops = SDOps("test") + result = ops.apply_to_key("any.key.name") + assert result is None + + def test_apply_with_passthrough_matcher(self) -> None: + ops = SDOps("test").with_matching() + result = ops.apply_to_key("any.key.name") + assert result == "any.key.name" + + def test_apply_single_replacement(self) -> None: + ops = SDOps("test").with_matching().with_replacement("old", "new") + assert ops.apply_to_key("old.key") == "new.key" + assert ops.apply_to_key("my.old.key") == "my.new.key" + assert ops.apply_to_key("no_match") == "no_match" + + def test_apply_multiple_replacements(self) -> None: + ops = SDOps("test").with_matching().with_replacement("model.", "").with_replacement(".weight", ".bias") + result = ops.apply_to_key("model.layer.weight") + assert result == "layer.bias" + + def test_apply_with_matching_prefix_passes(self) -> None: + ops = SDOps("test").with_matching(prefix="model.").with_replacement("model.", "") + result = ops.apply_to_key("model.layer.weight") + assert result == "layer.weight" + + def test_apply_with_matching_prefix_fails(self) -> None: + ops = SDOps("test").with_matching(prefix="model.").with_replacement("layer", "block") + result = ops.apply_to_key("other.layer.weight") + assert result is None + + def test_apply_with_matching_suffix_passes(self) -> None: + ops = SDOps("test").with_matching(suffix=".weight") + result = ops.apply_to_key("model.layer.weight") + assert result == "model.layer.weight" + + def test_apply_with_matching_suffix_fails(self) -> None: + ops = SDOps("test").with_matching(suffix=".weight") + result = ops.apply_to_key("model.layer.bias") + assert result is None + + def test_apply_with_matching_prefix_and_suffix_passes(self) -> None: + ops = SDOps("test").with_matching(prefix="model.", suffix=".weight") + result = ops.apply_to_key("model.layer.weight") + assert result == "model.layer.weight" + + def test_apply_with_matching_prefix_and_suffix_fails_prefix(self) -> None: + ops = SDOps("test").with_matching(prefix="model.", suffix=".weight") + result = ops.apply_to_key("other.layer.weight") + assert result is None + + def test_apply_with_matching_prefix_and_suffix_fails_suffix(self) -> None: + ops = SDOps("test").with_matching(prefix="model.", suffix=".weight") + result = ops.apply_to_key("model.layer.bias") + assert result is None + + def test_apply_with_multiple_matchers_any_match_passes(self) -> None: + ops = SDOps("test").with_matching(prefix="model.").with_matching(prefix="other.") + assert ops.apply_to_key("model.layer") == "model.layer" + assert ops.apply_to_key("other.layer") == "other.layer" + assert ops.apply_to_key("unknown.layer") is None + + def test_apply_replacement_all_occurrences(self) -> None: + ops = SDOps("test").with_matching().with_replacement("block", "layer") + result = ops.apply_to_key("block.sub_block.block") + assert result == "layer.sub_layer.layer" + + +class TestSDOpsImmutability: + def test_original_unchanged_after_with_replacement(self) -> None: + original = SDOps("test") + _ = original.with_replacement("a", "b") + assert original.mapping == () + + def test_original_unchanged_after_with_matching(self) -> None: + original = SDOps("test") + _ = original.with_matching(prefix="model.") + assert original.mapping == () + + def test_chained_ops_are_independent(self) -> None: + base = SDOps("test") + ops1 = base.with_replacement("a", "b") + ops2 = base.with_replacement("c", "d") + assert ops1.mapping != ops2.mapping + assert len(ops1.mapping) == 1 + assert len(ops2.mapping) == 1 + assert ops1.mapping[0].content == "a" + assert ops2.mapping[0].content == "c" diff --git a/packages/ltx-core/tests/ltx_core/loader/test_sgpu_sft_builder.py b/packages/ltx-core/tests/ltx_core/loader/test_sgpu_sft_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..6e784c475e8d56c3c17423711f50b65e2ae06e64 --- /dev/null +++ b/packages/ltx-core/tests/ltx_core/loader/test_sgpu_sft_builder.py @@ -0,0 +1,329 @@ +from pathlib import Path +from tempfile import TemporaryDirectory + +import pytest +import torch + +from ltx_core.loader.registry import StateDictRegistry +from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP +from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder +from ltx_core.model.transformer.model import LTXModel +from ltx_core.model.transformer.model_configurator import ( + LTXV_MODEL_COMFY_RENAMING_MAP, + LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP, + UPCAST_DURING_INFERENCE, + LTXModelConfigurator, +) + +MODEL_PATH = "packages/ltx-core/tests/ltx_core/loader/assets/model-transformer_block.7.attn1.to_v.sft" +LORA_PATH = "packages/ltx-core/tests/ltx_core/loader/assets/lora-transformer_block.7.attn1.to_v.sft" + + +# Fixtures for float values +@pytest.fixture +def tolerance() -> float: + """Tolerance value for tensor comparisons.""" + return 1e-3 + + +@pytest.fixture +def expected_tensor_values() -> dict[str, float]: + """Expected tensor values at specific indices.""" + return { + "vanilla_0_0": 0.0275, + "vanilla_2048_2048": -0.0830, + "single_lora_0_0": 0.0225, + "single_lora_2048_2048": -0.0688, + "double_lora_0_0": 0.0175, + "double_lora_2048_2048": -0.0546, + } + + +@pytest.fixture +def model_dimensions() -> dict[str, int]: + """Model dimension constants.""" + return { + "num_attention_heads": 32, + "attention_head_dim": 128, + "weight_shape_0": 4096, + "weight_shape_1": 4096, + "inner_dim": 32 * 128, + } + + +@pytest.fixture +def lora_scale() -> float: + """LoRA scale value.""" + return 1.0 + + +# Fixtures for string values +@pytest.fixture +def device() -> torch.device: + """Device for model operations.""" + return torch.device("cpu") + + +@pytest.fixture +def state_dict_keys() -> dict[str, str]: + """State dict key names.""" + return { + "vanilla_weight": "model.diffusion_model.transformer_blocks.7.attn1.to_v.weight", + "renamed_weight": "transformer_blocks.7.attn1.to_v.weight", + "model_weight": "transformer.blocks.7.attn1.to_v.weight", + "lora_A": "transformer_blocks.7.attn1.to_v.lora_A.weight", + "lora_B": "transformer_blocks.7.attn1.to_v.lora_B.weight", + } + + +@pytest.fixture +def metadata_keys() -> dict[str, str]: + """Metadata dictionary keys.""" + return { + "transformer": "transformer", + "num_attention_heads": "num_attention_heads", + "attention_head_dim": "attention_head_dim", + } + + +@pytest.fixture +def tensor_indices() -> dict[str, tuple[int, int]]: + """Tensor index tuples for assertions.""" + return { + "top_left": (0, 0), + "center": (2048, 2048), + } + + +def test_sft_metadata_loading(metadata_keys: dict[str, str], model_dimensions: dict[str, int]) -> None: + builder = Builder( + model_path=MODEL_PATH, + model_class_configurator=LTXModelConfigurator, + ) + metadata = builder.model_loader.metadata(MODEL_PATH) + assert metadata_keys["transformer"] in metadata + transformer_config = metadata[metadata_keys["transformer"]] + assert transformer_config[metadata_keys["num_attention_heads"]] == model_dimensions["num_attention_heads"] + assert transformer_config[metadata_keys["attention_head_dim"]] == model_dimensions["attention_head_dim"] + + +def test_metamodel_creation(model_dimensions: dict[str, int]) -> None: + builder = Builder( + model_path=MODEL_PATH, + model_class_configurator=LTXModelConfigurator, + ) + transformer_cfg = builder.model_config() + meta_transformer = builder.meta_model(transformer_cfg, ()) + assert isinstance(meta_transformer, LTXModel) + assert meta_transformer.inner_dim == model_dimensions["inner_dim"] + + +def test_model_state_dict_loading( + device: torch.device, + state_dict_keys: dict[str, str], + model_dimensions: dict[str, int], + tensor_indices: dict[str, tuple[int, int]], + expected_tensor_values: dict[str, float], + tolerance: float, +) -> None: + registry = StateDictRegistry() + builder = Builder( + model_path=MODEL_PATH, + model_class_configurator=LTXModelConfigurator, + registry=registry, + ) + model_sd = builder.load_sd( + [MODEL_PATH], + registry=builder.registry, + device=device, + ) + model_renamed_sd = builder.load_sd( + [MODEL_PATH], + sd_ops=LTXV_MODEL_COMFY_RENAMING_MAP, + registry=builder.registry, + device=device, + ) + assert model_sd.sd is not None + assert len(model_sd.sd) == 2 + assert state_dict_keys["vanilla_weight"] in model_sd.sd + to_v = model_sd.sd[state_dict_keys["vanilla_weight"]] + assert to_v.shape == (model_dimensions["weight_shape_0"], model_dimensions["weight_shape_1"]) + assert torch.isclose( + to_v[tensor_indices["top_left"]].float(), + torch.tensor(expected_tensor_values["vanilla_0_0"]), + atol=tolerance, + ) + assert torch.isclose( + to_v[tensor_indices["center"]].float(), + torch.tensor(expected_tensor_values["vanilla_2048_2048"]), + atol=tolerance, + ) + assert model_renamed_sd.sd is not None + assert len(model_renamed_sd.sd) == 2 + assert state_dict_keys["renamed_weight"] in model_renamed_sd.sd + to_v_renamed = model_renamed_sd.sd[state_dict_keys["renamed_weight"]] + assert torch.allclose(to_v, to_v_renamed, atol=tolerance) + vanilla_id = registry.get([MODEL_PATH], None) + assert vanilla_id is not None + renamed_id = registry.get([MODEL_PATH], LTXV_MODEL_COMFY_RENAMING_MAP) + assert renamed_id is not None + assert renamed_id != vanilla_id + + +def test_lora_state_dict_loading(device: torch.device, state_dict_keys: dict[str, str]) -> None: + builder = Builder( + model_path=MODEL_PATH, + model_class_configurator=LTXModelConfigurator, + ) + lora_sd = builder.load_sd( + [LORA_PATH], + registry=builder.registry, + device=device, + sd_ops=LTXV_LORA_COMFY_RENAMING_MAP, + ) + assert lora_sd.sd is not None + assert len(lora_sd.sd) == 2 + assert state_dict_keys["lora_A"] in lora_sd.sd + assert state_dict_keys["lora_B"] in lora_sd.sd + + +def test_vanilla_model_building( + device: torch.device, state_dict_keys: dict[str, str], model_dimensions: dict[str, int] +) -> None: + model = Builder( + model_path=MODEL_PATH, + model_class_configurator=LTXModelConfigurator, + ).build(device=device) + assert isinstance(model, LTXModel) + assert model.inner_dim == model_dimensions["inner_dim"] + for name, param in model.named_parameters(): + if name != state_dict_keys["model_weight"]: + continue + assert str(param.device) == str(device) + + +def test_model_with_single_lora_building( + device: torch.device, + state_dict_keys: dict[str, str], + tensor_indices: dict[str, tuple[int, int]], + expected_tensor_values: dict[str, float], + tolerance: float, + lora_scale: float, +) -> None: + builder = Builder( + model_path=MODEL_PATH, + model_sd_ops=LTXV_MODEL_COMFY_RENAMING_MAP, + model_class_configurator=LTXModelConfigurator, + ) + model_with_lora = builder.lora(LORA_PATH, lora_scale, LTXV_LORA_COMFY_RENAMING_MAP).build(device=device) + for name, param in model_with_lora.named_parameters(): + if name != state_dict_keys["renamed_weight"]: + continue + assert torch.isclose( + param[tensor_indices["top_left"]].float(), + torch.tensor(expected_tensor_values["single_lora_0_0"]), + atol=tolerance, + ) + assert torch.isclose( + param[tensor_indices["center"]].float(), + torch.tensor(expected_tensor_values["single_lora_2048_2048"]), + atol=tolerance, + ) + + +def test_model_and_registry_with_multiple_loras_building( + device: torch.device, + state_dict_keys: dict[str, str], + tensor_indices: dict[str, tuple[int, int]], + expected_tensor_values: dict[str, float], + tolerance: float, + lora_scale: float, +) -> None: + registry = StateDictRegistry() + builder = Builder( + model_path=MODEL_PATH, + model_sd_ops=LTXV_MODEL_COMFY_RENAMING_MAP, + model_class_configurator=LTXModelConfigurator, + registry=registry, + ) + with TemporaryDirectory() as temp_dir: + lora_path = str(Path(temp_dir) / "lora.sft") + Path(lora_path).write_bytes(Path(LORA_PATH).read_bytes()) + single_lora_builder = builder.lora(LORA_PATH, lora_scale, LTXV_LORA_COMFY_RENAMING_MAP) + dbl_lora_builder = builder.lora(LORA_PATH, lora_scale, LTXV_LORA_COMFY_RENAMING_MAP).lora( + lora_path, lora_scale, LTXV_LORA_COMFY_RENAMING_MAP + ) + single_lora_model = single_lora_builder.build(device=device) + dbl_lora_model = dbl_lora_builder.build(device=device) + for name, param in single_lora_model.named_parameters(): + if name != state_dict_keys["renamed_weight"]: + continue + assert torch.isclose( + param[tensor_indices["top_left"]].float(), + torch.tensor(expected_tensor_values["single_lora_0_0"]), + atol=tolerance, + ) + assert torch.isclose( + param[tensor_indices["center"]].float(), + torch.tensor(expected_tensor_values["single_lora_2048_2048"]), + atol=tolerance, + ) + for name, param in dbl_lora_model.named_parameters(): + if name != state_dict_keys["renamed_weight"]: + continue + assert torch.isclose( + param[tensor_indices["top_left"]].float(), + torch.tensor(expected_tensor_values["double_lora_0_0"]), + atol=tolerance, + ) + assert torch.isclose( + param[tensor_indices["center"]].float(), + torch.tensor(expected_tensor_values["double_lora_2048_2048"]), + atol=tolerance, + ) + assert len(registry._state_dicts) == 3 + + +def test_fp8_model_building_and_linear_forward( + device: torch.device, state_dict_keys: dict[str, str], model_dimensions: dict[str, int] +) -> None: + model = Builder( + model_path=MODEL_PATH, + model_class_configurator=LTXModelConfigurator, + model_sd_ops=LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP, + module_ops=(UPCAST_DURING_INFERENCE,), + ).build(device=device) + assert isinstance(model, LTXModel) + assert model.inner_dim == model_dimensions["inner_dim"] + torch.manual_seed(42) + x = torch.randn(1, model.inner_dim, dtype=torch.bfloat16) + + for name, param in model.named_parameters(): + if name != state_dict_keys["model_weight"]: + continue + assert str(param.device) == str(device) + assert param.dtype == torch.float8_e4m3fn + + func = list(model.transformer_blocks)[7].attn1.to_v + output = func(x)[0] + fixture = torch.Tensor( + [ + -1.4531, + 1.5625, + 1.7969, + -2.1875, + -0.3652, + 0.5312, + -0.9258, + 0.2617, + -0.7734, + -3.3281, + -0.6914, + 1.3906, + 0.2412, + 0.5430, + 1.5547, + -3.2656, + ] + ) + assert torch.allclose(output[:16], fixture.to(dtype=output.dtype), atol=1e-4, rtol=1e-4) diff --git a/packages/ltx-core/tests/ltx_core/model/audio_vae/test_audio_vae.py b/packages/ltx-core/tests/ltx_core/model/audio_vae/test_audio_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..22743ed2789fa5ec8861308ab70197c4528d15c3 --- /dev/null +++ b/packages/ltx-core/tests/ltx_core/model/audio_vae/test_audio_vae.py @@ -0,0 +1,250 @@ +import pytest +import torch +from tests.ltx_core.utils import resolve_model_path + +from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder +from ltx_core.model.audio_vae.model_configurator import ( + AUDIO_VAE_DECODER_COMFY_KEYS_FILTER, + AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER, + VOCODER_COMFY_KEYS_FILTER, + VAEDecoderConfigurator, + VAEEncoderConfigurator, + VocoderConfigurator, +) +from ltx_core.model.audio_vae.ops import AudioProcessor + + +@pytest.fixture(scope="module") +def model_path() -> str: + return resolve_model_path() + + +@pytest.fixture(scope="module") +def encoder_builder(model_path: str) -> Builder: + return Builder( + model_path=model_path, + model_class_configurator=VAEEncoderConfigurator, + model_sd_ops=AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER, + ) + + +@pytest.fixture(scope="module") +def decoder_builder(model_path: str) -> Builder: + return Builder( + model_path=model_path, + model_class_configurator=VAEDecoderConfigurator, + model_sd_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER, + ) + + +@pytest.fixture(scope="module") +def vocoder_builder(model_path: str) -> Builder: + return Builder( + model_path=model_path, + model_class_configurator=VocoderConfigurator, + model_sd_ops=VOCODER_COMFY_KEYS_FILTER, + ) + + +def test_audio_vae_decoder_instantiation(decoder_builder: Builder) -> None: + vae_decoder = decoder_builder.build() + assert vae_decoder is not None + assert not any(param.device.type == "meta" for param in vae_decoder.parameters()) + + +def generate_test_waveform( + duration_seconds: float, + sample_rate: int, + num_channels: int = 2, + frequency: float = 440.0, +) -> torch.Tensor: + """Generate a multi-channel sine wave test signal.""" + num_samples = int(duration_seconds * sample_rate) + t = torch.linspace(0, duration_seconds, num_samples) + + # Generate slightly different frequencies for left/right channels + waveform = torch.stack( + [torch.sin(2 * torch.pi * (frequency + i * 10) * t) for i in range(num_channels)], + dim=0, + ) + # Add batch dimension: (channels, time) -> (batch, channels, time) + return waveform.unsqueeze(0) * 0.5 # Scale to avoid clipping + + +def test_spectrogram_reconstruction_quality(encoder_builder: Builder, decoder_builder: Builder) -> None: + """Test that input and output spectrograms are similar after VAE encode/decode.""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Load encoder and decoder + encoder = encoder_builder.build() + decoder = decoder_builder.build() + + encoder = encoder.to(device).eval() + decoder = decoder.to(device).eval() + + # Pipeline parameters + sample_rate = encoder.sample_rate + n_mels = encoder.mel_bins + hop_length = encoder.mel_hop_length + n_fft = encoder.n_fft + + # Generate test waveform + input_waveform = generate_test_waveform( + duration_seconds=1.0, + sample_rate=sample_rate, + num_channels=2, + frequency=440.0, + ).to(device) + + # Get model dtype for input conversion + model_dtype = next(encoder.parameters()).dtype + + with torch.no_grad(): + # Convert to spectrogram + audio_processor = AudioProcessor( + sample_rate=sample_rate, + mel_bins=n_mels, + mel_hop_length=hop_length, + n_fft=n_fft, + ).to(device) + + input_spectrogram = audio_processor.waveform_to_mel( + input_waveform, + waveform_sample_rate=sample_rate, + ).to(dtype=model_dtype) + + # Encode and decode + latent = encoder(input_spectrogram) + reconstructed_spectrogram = decoder(latent) + + # Compare spectrograms (allow for some reconstruction loss) + # Align shapes for comparison (decoder may have slightly different output shape) + min_time = min(input_spectrogram.shape[2], reconstructed_spectrogram.shape[2]) + min_freq = min(input_spectrogram.shape[3], reconstructed_spectrogram.shape[3]) + + input_cropped = input_spectrogram[:, :, :min_time, :min_freq] + reconstructed_cropped = reconstructed_spectrogram[:, :, :min_time, :min_freq] + + # Calculate reconstruction error + mse = torch.nn.functional.mse_loss(input_cropped, reconstructed_cropped) + correlation = torch.corrcoef(torch.stack([input_cropped.flatten(), reconstructed_cropped.flatten()]))[0, 1] + + # Assert reasonable reconstruction quality + # These thresholds may need adjustment based on model quality + assert mse < 10.0, f"Spectrogram MSE too high: {mse.item():.4f}" + assert correlation > 0.5, f"Spectrogram correlation too low: {correlation.item():.4f}" + + +def test_waveform_roundtrip_similarity( + encoder_builder: Builder, decoder_builder: Builder, vocoder_builder: Builder +) -> None: + """ + Test full waveform roundtrip: waveform -> spectrogram -> VAE encode/decode -> vocoder -> waveform. + + Compares input and output waveforms, as well as their spectrograms. + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Load models + encoder = encoder_builder.build() + decoder = decoder_builder.build() + vocoder = vocoder_builder.build() + + encoder = encoder.to(device).eval() + decoder = decoder.to(device).eval() + vocoder = vocoder.to(device).eval() + + # Pipeline parameters + input_sample_rate = encoder.sample_rate # 16kHz + output_sample_rate = vocoder.output_sample_rate # 24kHz + n_mels = encoder.mel_bins + hop_length = encoder.mel_hop_length + n_fft = encoder.n_fft + + # Generate test waveform (stereo, 1 second) + duration = 1.0 + input_waveform = generate_test_waveform( + duration_seconds=duration, + sample_rate=input_sample_rate, + num_channels=2, + frequency=440.0, + ).to(device) + + # Get model dtype for input conversion + model_dtype = next(encoder.parameters()).dtype + + with torch.no_grad(): + audio_processor = AudioProcessor( + sample_rate=input_sample_rate, + mel_bins=n_mels, + mel_hop_length=hop_length, + n_fft=n_fft, + ).to(device) + + input_spectrogram = audio_processor.waveform_to_mel( + input_waveform, + waveform_sample_rate=input_sample_rate, + ).to(dtype=model_dtype) + + latent = encoder(input_spectrogram) + reconstructed_spectrogram = decoder(latent) + + output_waveform = vocoder(reconstructed_spectrogram) + + output_waveform_resampled = audio_processor.resample_waveform( + output_waveform, + source_rate=output_sample_rate, + target_rate=input_sample_rate, + ) + + output_spectrogram = audio_processor.waveform_to_mel( + output_waveform_resampled, + waveform_sample_rate=input_sample_rate, + ) + + # Waveform comparison + # Align waveform lengths for comparison + min_samples = min(input_waveform.shape[2], output_waveform_resampled.shape[2]) + input_waveform_aligned = input_waveform[:, :, :min_samples] + output_waveform_aligned = output_waveform_resampled[:, :, :min_samples] + + # Calculate waveform correlation (more meaningful than MSE for audio) + waveform_correlation = torch.corrcoef( + torch.stack([input_waveform_aligned.flatten(), output_waveform_aligned.flatten()]) + )[0, 1] + + # Spectrogram comparison + # Align spectrogram shapes for comparison + min_time = min(input_spectrogram.shape[2], output_spectrogram.shape[2]) + min_freq = min(input_spectrogram.shape[3], output_spectrogram.shape[3]) + + input_spec_aligned = input_spectrogram[:, :, :min_time, :min_freq] + output_spec_aligned = output_spectrogram[:, :, :min_time, :min_freq] + + spectrogram_mse = torch.nn.functional.mse_loss(input_spec_aligned, output_spec_aligned) + spectrogram_correlation = torch.corrcoef( + torch.stack([input_spec_aligned.flatten(), output_spec_aligned.flatten()]) + )[0, 1] + + # Assertions + # Waveform correlation can be low since VAE+vocoder don't preserve phase. + # A low but positive correlation is acceptable. + assert waveform_correlation > 0.0, ( + f"Waveform correlation is negative: {waveform_correlation.item():.4f}. " + "Input and output waveforms should have at least positive correlation." + ) + + # Spectrograms should be well correlated + assert spectrogram_correlation > 0.5, ( + f"Spectrogram correlation too low: {spectrogram_correlation.item():.4f}. " + "Input and output spectrograms are not similar enough." + ) + + # Spectrogram MSE should be reasonable + assert spectrogram_mse < 10.0, ( + f"Spectrogram MSE too high: {spectrogram_mse.item():.4f}. Reconstruction quality is poor." + ) + + # Output waveform should be in valid range + assert output_waveform_resampled.min() >= -1.0, f"Output waveform min {output_waveform_resampled.min()} below -1.0" + assert output_waveform_resampled.max() <= 1.0, f"Output waveform max {output_waveform_resampled.max()} above 1.0" diff --git a/packages/ltx-core/tests/ltx_core/model/audio_vae/test_vocoder.py b/packages/ltx-core/tests/ltx_core/model/audio_vae/test_vocoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a7e371b689f798df236c80b4f47649b98ee94a21 --- /dev/null +++ b/packages/ltx-core/tests/ltx_core/model/audio_vae/test_vocoder.py @@ -0,0 +1,18 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Ivan Zorin + + +from tests.ltx_core.utils import resolve_model_path + +from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder +from ltx_core.model.audio_vae.model_configurator import VOCODER_COMFY_KEYS_FILTER, VocoderConfigurator + + +def test_vocoder() -> None: + builder = Builder( + model_path=resolve_model_path(), + model_class_configurator=VocoderConfigurator, + model_sd_ops=VOCODER_COMFY_KEYS_FILTER, + ) + model = builder.build() + assert model is not None diff --git a/packages/ltx-core/tests/ltx_core/model/clip/test_enhancing.py b/packages/ltx-core/tests/ltx_core/model/clip/test_enhancing.py new file mode 100644 index 0000000000000000000000000000000000000000..c1aa4f47a5f8f94f3e728a3e51ba08d34fe13541 --- /dev/null +++ b/packages/ltx-core/tests/ltx_core/model/clip/test_enhancing.py @@ -0,0 +1,62 @@ +import json +import os +from pathlib import Path +from typing import Generator + +import pytest +import torch +from PIL import Image as PILImage + +from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder +from ltx_core.model.clip.gemma.encoders.av_encoder import ( + AV_GEMMA_TEXT_ENCODER_KEY_OPS, + AVGemmaTextEncoderModel, + AVGemmaTextEncoderModelConfigurator, +) +from ltx_core.model.clip.gemma.encoders.base_encoder import module_ops_from_gemma_root + +MODELS_PATH = Path(os.getenv("MODELS_PATH", "/models")) +GEMMA_ROOT_PATH = MODELS_PATH / "comfyui_models" / "text_encoders" / "gemma-3-12b-it-qat-q4_0-unquantized_readout_proj" +CHECKPOINT_PATH = MODELS_PATH / "comfyui_models" / "checkpoints" / "ltx-av-step-1933500-split-new-vae.safetensors" +TEXT_PROMPT = "A knight with a red cape faces fire breathing dragon." +with open("packages/ltx-core/tests/ltx_core/model/clip/assets/enhanced_prompts.json", "r") as f: + ENHANCED_PROMPTS = json.load(f) +I2V_ENHANCED_TEXT_PROMPT = ENHANCED_PROMPTS["I2V_ENHANCED_TEXT_PROMPT"] +T2V_ENHANCED_TEXT_PROMPT = ENHANCED_PROMPTS["T2V_ENHANCED_TEXT_PROMPT"] +IMG = PILImage.open("packages/ltx-core/tests/ltx_core/model/clip/assets/dragon_1.png").convert("RGB") + + +@pytest.fixture(scope="session") +def text_encoder() -> Generator[AVGemmaTextEncoderModel, None, None]: + if not torch.cuda.is_available(): + pytest.skip("This test runs too slow on CPU") + if not CHECKPOINT_PATH.exists() or not GEMMA_ROOT_PATH.exists(): + pytest.skip("Checkpoints inaccessible") + + model = Builder( + model_path=CHECKPOINT_PATH, + model_class_configurator=AVGemmaTextEncoderModelConfigurator, + model_sd_ops=AV_GEMMA_TEXT_ENCODER_KEY_OPS, + module_ops=module_ops_from_gemma_root(GEMMA_ROOT_PATH), + ).build(device=torch.device("cuda")) + yield model + + # optional cleanup + del model + torch.cuda.empty_cache() + + +def test_model_loading_with_img_processor(text_encoder: AVGemmaTextEncoderModel) -> None: + assert text_encoder is not None + + +def test_enhance_i2v(text_encoder: AVGemmaTextEncoderModel) -> None: + enhanced_text_prompt = text_encoder.enhance_i2v(TEXT_PROMPT, IMG) + assert enhanced_text_prompt is not None + assert enhanced_text_prompt[:64] == I2V_ENHANCED_TEXT_PROMPT[:64] + + +def test_enhance_t2v(text_encoder: AVGemmaTextEncoderModel) -> None: + enhanced_text_prompt = text_encoder.enhance_t2v(TEXT_PROMPT) + assert enhanced_text_prompt is not None + assert enhanced_text_prompt[:128] == T2V_ENHANCED_TEXT_PROMPT[:128] diff --git a/packages/ltx-core/tests/ltx_core/model/transformer/test_attention.py b/packages/ltx-core/tests/ltx_core/model/transformer/test_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..891ca2efd71eff81445f09c7f85611d8cbce5f43 --- /dev/null +++ b/packages/ltx-core/tests/ltx_core/model/transformer/test_attention.py @@ -0,0 +1,105 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + +import pytest +import torch + +from ltx_core.model.transformer.attention import AttentionFunction + +FIXTURE = torch.Tensor( + [ + [ + [ + 2.5469, + -0.7148, + -0.4941, + 0.1270, + 0.1016, + -0.4043, + 0.9023, + 0.8086, + -0.6875, + 0.1377, + 1.0391, + 0.0928, + -0.3750, + -0.0908, + 2.0625, + -1.8125, + -0.2715, + 0.2812, + -1.0391, + 0.7773, + 0.8828, + 0.0444, + -1.4844, + 1.1328, + 1.3281, + -1.2578, + 0.9492, + -0.6562, + 0.9102, + -0.6289, + -0.6602, + 2.0781, + ] + ] + ] +) + + +def _xformers_available() -> bool: + """Check if xformers can be imported.""" + try: + from xformers.ops import memory_efficient_attention # noqa: F401, PLC0415 + + return True + except ImportError: + return False + + +def test_attention_function_pytorch() -> None: + attention_function = AttentionFunction.PYTORCH + assert attention_function( + torch.tensor([[[1, 2, 3]]], dtype=torch.float32), + torch.tensor([[[4, 5, 6]]], dtype=torch.float32), + torch.tensor([[[7, 8, 9]]], dtype=torch.float32), + 1, + ).tolist() == [[[7, 8, 9]]] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="XFormersAttention requires CUDA") +def test_xformers_attention() -> None: + try: + from xformers.ops import memory_efficient_attention # noqa: F401, PLC0415 + except ImportError: + pytest.skip("XFormersAttention requires xformers to be installed") + attention_function = AttentionFunction.XFORMERS + shape = FIXTURE.shape + torch.manual_seed(0) + q = torch.randn(shape, dtype=torch.bfloat16, device=torch.device("cuda")) + k = torch.randn(shape, dtype=torch.bfloat16, device=torch.device("cuda")) + v = torch.randn(shape, dtype=torch.bfloat16, device=torch.device("cuda")) + fixture = FIXTURE.to(device=torch.device("cuda"), dtype=torch.bfloat16) + result = attention_function(q, k, v, 1) + assert torch.allclose(result, fixture, atol=1e-4, rtol=1e-4) + + +@pytest.mark.skipif( + not torch.cuda.is_available() or _xformers_available(), + reason="FlashAttention3 requires CUDA and should only run if xformers is not available", +) +def test_flash_attention_3() -> None: + try: + import flash_attn_interface # noqa: F401, PLC0415 + except ImportError: + pytest.skip("FlashAttention3 requires FlashAttention3 to be installed") + attention_function = AttentionFunction.FLASH_ATTENTION_3 + shape = FIXTURE.shape + torch.manual_seed(0) + q = torch.randn(shape, dtype=torch.bfloat16, device=torch.device("cuda")) + k = torch.randn(shape, dtype=torch.bfloat16, device=torch.device("cuda")) + v = torch.randn(shape, dtype=torch.bfloat16, device=torch.device("cuda")) + fixture = FIXTURE.to(device=torch.device("cuda"), dtype=torch.bfloat16) + result = attention_function(q, k, v, 1) + assert torch.allclose(result, fixture, atol=1e-4, rtol=1e-4) diff --git a/packages/ltx-core/tests/ltx_core/model/transformer/test_gelu_approx.py b/packages/ltx-core/tests/ltx_core/model/transformer/test_gelu_approx.py new file mode 100644 index 0000000000000000000000000000000000000000..bc776c700b0ff2b3c6914ac1fd44497fa0032820 --- /dev/null +++ b/packages/ltx-core/tests/ltx_core/model/transformer/test_gelu_approx.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + +import torch + +from ltx_core.model.transformer.gelu_approx import GELUApprox + + +def test_gelu_approx() -> None: + gelu_approx = GELUApprox(1, 1) + gelu_approx.load_state_dict({"proj.weight": torch.ones(1, 1), "proj.bias": torch.zeros(1)}) + x = torch.tensor([[2.0]]) + output = gelu_approx(x) + assert output.allclose(torch.tensor([[2.0]]), atol=0.05) diff --git a/packages/ltx-core/tests/ltx_core/model/transformer/test_model.py b/packages/ltx-core/tests/ltx_core/model/transformer/test_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3d68fa3cd6cf137161e9682723a24fda049cb2a1 --- /dev/null +++ b/packages/ltx-core/tests/ltx_core/model/transformer/test_model.py @@ -0,0 +1,63 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + +import pytest +import torch + +from ltx_core.model.transformer.model import LTXModelType +from ltx_core.model.transformer.model_configurator import LTXModelConfigurator, LTXVideoOnlyModelConfigurator + +VIDEO_ONLY_CONFIG = { + "transformer": { + "dropout": 0.0, + "norm_num_groups": 32, + "attention_bias": True, + "num_vector_embeds": None, + "activation_fn": "gelu-approximate", + "num_embeds_ada_norm": 1000, + "use_linear_projection": False, + "only_cross_attention": False, + "cross_attention_norm": True, + "double_self_attention": False, + "upcast_attention": False, + "standardization_norm": "rms_norm", + "norm_elementwise_affine": False, + "qk_norm": "rms_norm", + "positional_embedding_type": "rope", + "causal_temporal_positioning": True, + "use_middle_indices_grid": True, + } +} + +AUDIO_VIDEO_TRANSFORMER_CONFIG_DELTA = { + "use_audio_video_cross_attention": True, + "share_ff": False, + "av_cross_ada_norm": True, + "audio_num_attention_heads": 32, + "audio_attention_head_dim": 64, + "audio_in_channels": 128, + "audio_out_channels": 128, + "audio_cross_attention_dim": 2048, + "audio_positional_embedding_max_pos": [20], + "av_ca_timestep_scale_multiplier": 1, +} + + +def test_audio_video_model() -> None: + transformer_config = VIDEO_ONLY_CONFIG.copy() + transformer_config["transformer"].update(AUDIO_VIDEO_TRANSFORMER_CONFIG_DELTA) + with torch.device("meta"): + with pytest.raises(ValueError, match="Config value"): + LTXModelConfigurator.from_config({}) + model = LTXModelConfigurator.from_config(transformer_config) + assert model is not None + assert model.model_type == LTXModelType.AudioVideo + + +def test_video_only_model() -> None: + with torch.device("meta"): + with pytest.raises(ValueError, match="Config value"): + LTXVideoOnlyModelConfigurator.from_config({}) + model = LTXVideoOnlyModelConfigurator.from_config(VIDEO_ONLY_CONFIG) + assert model is not None + assert model.model_type == LTXModelType.VideoOnly diff --git a/packages/ltx-core/tests/ltx_core/model/upsampler/test_model.py b/packages/ltx-core/tests/ltx_core/model/upsampler/test_model.py new file mode 100644 index 0000000000000000000000000000000000000000..8b77a8374832189fa1c6e35a7b50874a4ae9bd0e --- /dev/null +++ b/packages/ltx-core/tests/ltx_core/model/upsampler/test_model.py @@ -0,0 +1,26 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + +import torch + +from ltx_core.model.upsampler.model_configurator import LatentUpsamplerConfigurator + + +def test_model() -> None: + model = LatentUpsamplerConfigurator.from_config( + { + "in_channels": 1, + "mid_channels": 32, + "num_blocks_per_stage": 1, + "dims": 3, + "spatial_upsample": True, + "temporal_upsample": True, + "spatial_scale": 2.0, + "rational_resampler": False, + } + ) + assert model is not None + latent = torch.randn(1, 1, 2, 2, 2) + with torch.inference_mode(): + output = model(latent) + assert output.shape == (1, 1, 3, 4, 4) diff --git a/packages/ltx-core/tests/ltx_core/model/video_vae/test_video_vae.py b/packages/ltx-core/tests/ltx_core/model/video_vae/test_video_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..a246d1c3197288b0932657ff6f3c1ffbcbeea6b5 --- /dev/null +++ b/packages/ltx-core/tests/ltx_core/model/video_vae/test_video_vae.py @@ -0,0 +1,224 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Ivan Zorin + +from typing import Callable + +import numpy as np +import pytest +import torch +from skimage import data, io +from skimage.transform import resize +from tests.ltx_core.conftest import OUTPUT_DIR +from tests.ltx_core.utils import resolve_model_path + +from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder +from ltx_core.model.video_vae.model_configurator import ( + VAE_DECODER_COMFY_KEYS_FILTER, + VAE_ENCODER_COMFY_KEYS_FILTER, + VAEDecoderConfigurator, + VAEEncoderConfigurator, +) +from ltx_core.tiling import SpatialTilingConfig, TemporalTilingConfig, TilingConfig +from ltx_pipelines.media_io import encode_video + +SAVE_IMAGES = False + + +VAE_CONFIG = { + "vae": { + "dims": 3, + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "encoder_spatial_padding_mode": "zeros", + "decoder_spatial_padding_mode": "reflect", + "encoder_blocks": [ + ["res_x", {"num_layers": 4}], + ["compress_space_res", {"multiplier": 2}], + ["res_x", {"num_layers": 6}], + ["compress_time_res", {"multiplier": 2}], + ["res_x", {"num_layers": 6}], + ["compress_all_res", {"multiplier": 2}], + ["res_x", {"num_layers": 2}], + ["compress_all_res", {"multiplier": 2}], + ["res_x", {"num_layers": 2}], + ], + "decoder_blocks": [ + ["res_x", {"num_layers": 5, "inject_noise": False}], + ["compress_all", {"residual": True, "multiplier": 2}], + ["res_x", {"num_layers": 5, "inject_noise": False}], + ["compress_all", {"residual": True, "multiplier": 2}], + ["res_x", {"num_layers": 5, "inject_noise": False}], + ["compress_all", {"residual": True, "multiplier": 2}], + ["res_x", {"num_layers": 5, "inject_noise": False}], + ], + "scaling_factor": 1.0, + "norm_layer": "pixel_norm", + "patch_size": 4, + "latent_log_var": "uniform", + "use_quant_conv": False, + "causal_decoder": False, + "timestep_conditioning": True, + "normalize_latent_channels": False, + } +} + + +def test_encoder_instantiation() -> None: + vae_encoder = VAEEncoderConfigurator.from_config(VAE_CONFIG) + assert vae_encoder is not None + + +def test_decoder_instantiation() -> None: + vae_decoder = VAEDecoderConfigurator.from_config(VAE_CONFIG) + assert vae_decoder is not None + + +@pytest.mark.e2e +def test_encoder_loading() -> None: + model_path = resolve_model_path() + + vae_encoder = Builder( + model_path=model_path, + model_class_configurator=VAEEncoderConfigurator, + model_sd_ops=VAE_ENCODER_COMFY_KEYS_FILTER, + ).build() + assert vae_encoder is not None + + +@pytest.mark.e2e +def test_decoder_loading() -> None: + model_path = resolve_model_path() + + vae_decoder = Builder( + model_path=model_path, + model_class_configurator=VAEDecoderConfigurator, + model_sd_ops=VAE_DECODER_COMFY_KEYS_FILTER, + ).build() + assert vae_decoder is not None + + +@pytest.mark.e2e +@pytest.mark.parametrize( + ("image_name", "image_func"), + [ + ("astronaut", data.astronaut), + ("chelsea", data.chelsea), + ("coffee", data.coffee), + ], +) +def test_encode_decode_cycle(image_name: str, image_func: Callable) -> None: + # Load weights from $MODEL_PATH or fall back to the default checkpoint + model_path = resolve_model_path() + + dtype = torch.bfloat16 + device = torch.device("cuda") + + vae_decoder = Builder( + model_path=model_path, + model_class_configurator=VAEDecoderConfigurator, + model_sd_ops=VAE_DECODER_COMFY_KEYS_FILTER, + ).build() + + vae_encoder = Builder( + model_path=model_path, + model_class_configurator=VAEEncoderConfigurator, + model_sd_ops=VAE_ENCODER_COMFY_KEYS_FILTER, + ).build() + + vae_encoder.to(dtype=dtype, device=device) + vae_decoder.to(dtype=dtype, device=device) + + # Prepare Image + image = image_func() + + # Resize if needed to match target shape + target_shape = (512, 512) + if image.shape[:2] != target_shape: + # resize returns float 0-1 + image = resize(image, target_shape, anti_aliasing=True) + # Convert to 0-255 uint8 range to match original pipeline assumption + image = (image * 255).astype(np.uint8) + + # Normalize to [-1, 1] + image = np.array(image).astype(np.float32) / 127.5 - 1.0 + + # Convert to tensor (B, C, F, H, W) + # Replicate the image 33 times to create a video + image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).unsqueeze(2) + sample_video = image_tensor.repeat(1, 1, 33, 1, 1).to(device=device, dtype=dtype) + + # Run VAE + with torch.autocast(device_type="cuda", dtype=dtype): + encoded = vae_encoder(sample_video) + assert not torch.isnan(encoded).any(), f"Encoded tensor contains NaNs for {image_name}" + decoded = vae_decoder(encoded) + assert not torch.isnan(decoded).any(), f"Decoded tensor contains NaNs for {image_name}" + + # Verify reconstruction shape + assert decoded.shape == sample_video.shape, f"Shape mismatch for {image_name}" + + # Verify reconstruction error + diff = (sample_video - decoded).float() + mse = diff.pow(2).mean().item() + + # Assert MSE is reasonable + # MSE threshold < 0.05 is conservative but safe for diverse images. + assert mse < 0.02, f"MSE too high for {image_name}: {mse:.4f}" + + if SAVE_IMAGES: + img_out = decoded[0, :, 0].detach().float().cpu().numpy() + img_out = (img_out + 1.0) * 127.5 + img_out = np.clip(img_out, 0, 255).astype(np.uint8) + img_out = np.transpose(img_out, (1, 2, 0)) # (H, W, C) + io.imsave(f"test_output_{image_name}.png", img_out) + + # Cleanup + del encoded, decoded, sample_video, image_tensor, diff + torch.cuda.empty_cache() + + +@pytest.mark.e2e +def test_tiled_compare_video( + psnr: Callable[[torch.Tensor, torch.Tensor, float, float], float], + decode_video_from_file: Callable[[str], tuple[torch.Tensor, torch.Tensor | None]], +) -> None: + """Test that compares tiled and non-tiled video decoding.""" + model_path = resolve_model_path() + decoder = Builder( + model_path=model_path, + model_class_configurator=VAEDecoderConfigurator, + model_sd_ops=VAE_DECODER_COMFY_KEYS_FILTER, + ).build() + encoder = Builder( + model_path=model_path, + model_class_configurator=VAEEncoderConfigurator, + model_sd_ops=VAE_ENCODER_COMFY_KEYS_FILTER, + ).build() + video, _ = decode_video_from_file(path="packages/ltx-pipelines/tests/assets/expected_keyframes.mp4", device="cpu") + sample_video = video.permute(3, 0, 1, 2).unsqueeze(0) / 127.5 - 1.0 + tiling_config = TilingConfig( + spatial_config=SpatialTilingConfig(tile_size_in_pixels=192, tile_overlap_in_pixels=64), + temporal_config=TemporalTilingConfig(tile_size_in_frames=48, tile_overlap_in_frames=24), + ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16), torch.no_grad(): + encoded_video = encoder(sample_video.cuda().bfloat16()) + decoded_video = decoder(encoded_video).to(device="cpu") + chunks = [] + for frames, _ in decoder.tiled_decode(encoded_video, tiling_config): + chunks.append(frames.cpu()) + tiled_video = torch.cat(chunks, dim=2) + decoded_video = torch.clamp((decoded_video[0].permute(1, 2, 3, 0) + 1.0) / 2.0, 0.0, 1.0) + tiled_video = torch.clamp((tiled_video[0].permute(1, 2, 3, 0) + 1.0) / 2.0, 0.0, 1.0) + + psnr_tiled_non_tiled = psnr(tiled_video, decoded_video) + psnr_tiled_original = psnr(tiled_video, video / 255.0) + psnr_non_tiled_original = psnr(decoded_video, video / 255.0) + encode_video((tiled_video * 255.0).to(torch.uint8), 25, None, None, str(OUTPUT_DIR / "tiled_video.mp4")) + encode_video((decoded_video * 255.0).to(torch.uint8), 25, None, None, str(OUTPUT_DIR / "decoded_video.mp4")) + + assert psnr_tiled_non_tiled > 35.0, f"Decoding in tiles is different from non-tiled: {psnr_tiled_non_tiled}" + assert psnr_tiled_original > 30.0, f"Decoding in tiles is too different from original: {psnr_tiled_original}" + assert psnr_non_tiled_original > 30.0, ( + f"Decoding non-tiled is too different from original: {psnr_non_tiled_original}" + ) diff --git a/packages/ltx-core/tests/ltx_core/pipeline/components/test_diffusion_steps.py b/packages/ltx-core/tests/ltx_core/pipeline/components/test_diffusion_steps.py new file mode 100644 index 0000000000000000000000000000000000000000..4112ba3bf396c16b7ec967a9bc5c570cda4ec626 --- /dev/null +++ b/packages/ltx-core/tests/ltx_core/pipeline/components/test_diffusion_steps.py @@ -0,0 +1,18 @@ +import torch + +from ltx_core.pipeline.components.diffusion_steps import EulerDiffusionStep + + +def test_euler_diffusion_step_simple_update() -> None: + step = EulerDiffusionStep() + sample = torch.tensor([1.0, 2.0]) + denoised_sample = torch.tensor([0.5, 1.5]) + sigmas = torch.tensor([1.0, 0.5]) + + out = step.step(sample=sample, denoised_sample=denoised_sample, sigmas=sigmas, step_index=0) + + # sigma = 1.0, sigma_next = 0.5, dt = -0.5 + # v = (x - x0) / sigma = [(1.0 - 0.5) / 1.0, (2.0 - 1.5) / 1.0] = [0.5, 0.5] + # x + v * dt = [1.0, 2.0] + [0.5, 0.5] * -0.5 = [0.75, 1.75] + expected = [0.75, 1.75] + assert out.tolist() == expected diff --git a/packages/ltx-core/tests/ltx_core/pipeline/components/test_guiders.py b/packages/ltx-core/tests/ltx_core/pipeline/components/test_guiders.py new file mode 100644 index 0000000000000000000000000000000000000000..9974d1569f1fb428ec0e5806a40c4388b482abc0 --- /dev/null +++ b/packages/ltx-core/tests/ltx_core/pipeline/components/test_guiders.py @@ -0,0 +1,148 @@ +import torch + +from ltx_core.pipeline.components.guiders import ( + CFGGuider, + CFGStarRescalingGuider, + LegacyStatefulAPGGuider, + LtxAPGGuider, + STGGuider, + projection_coef, +) + + +def test_cfg_guider_delta_scales_difference() -> None: + guider = CFGGuider(scale=2.0) + cond = torch.tensor([2.0, 4.0]) + uncond = torch.tensor([1.0, 1.0]) + + # (scale - 1) * (cond - uncond) = 1.0 * [1.0, 3.0] + delta = guider.delta(cond=cond, uncond=uncond) + expected = [1.0, 3.0] + + assert delta.tolist() == expected + + +def test_cfg_star_rescaling_guider_delta() -> None: + guider = CFGStarRescalingGuider(scale=2.0) + cond = torch.tensor([[2.0, 4.0]]) + uncond = torch.tensor([[1.0, 1.0]]) + + # projection_coef(cond, uncond) = (2*1 + 4*1) / (1^2 + 1^2 + 1e-8) = 6 / 2 = 3.0 + # rescaled_neg = 3.0 * [1.0, 1.0] = [3.0, 3.0] + # (scale - 1) * (cond - rescaled_neg) = 1.0 * ([2.0, 4.0] - [3.0, 3.0]) = [-1.0, 1.0] + delta = guider.delta(cond=cond, uncond=uncond) + expected = [[-1.0, 1.0]] + + assert torch.allclose(delta, torch.tensor(expected)) + + +def test_stg_guider_delta() -> None: + guider = STGGuider(scale=2.0) + pos_denoised = torch.tensor([2.0, 4.0]) + perturbed_denoised = torch.tensor([1.0, 1.0]) + + # scale * (pos_denoised - perturbed_denoised) = 2.0 * [1.0, 3.0] = [2.0, 6.0] + delta = guider.delta(pos_denoised=pos_denoised, perturbed_denoised=perturbed_denoised) + expected = [2.0, 6.0] + + assert delta.tolist() == expected + + +def test_ltx_apg_guider_delta() -> None: + guider = LtxAPGGuider(scale=2.0, eta=0.5) + cond = torch.tensor([[[[2.0, 4.0]]]]) + uncond = torch.tensor([[[[1.0, 1.0]]]]) + + # guidance = cond - uncond = [1.0, 3.0] + # proj_coeff = projection_coef(guidance, cond) = (1*2 + 3*4) / (2^2 + 4^2 + 1e-8) = 14 / 20 = 0.7 + # g_parallel = 0.7 * [2.0, 4.0] = [1.4, 2.8] + # g_orth = [1.0, 3.0] - [1.4, 2.8] = [-0.4, 0.2] + # g_apg = [1.4, 2.8] * 0.5 + [-0.4, 0.2] = [0.7, 1.4] + [-0.4, 0.2] = [0.3, 1.6] + # delta = g_apg * (guidance_scale - 1) = [0.3, 1.6] * 1.0 = [0.3, 1.6] + delta = guider.delta(cond=cond, uncond=uncond) + expected = [[0.3, 1.6]] + + assert torch.allclose(delta, torch.tensor(expected)) + + +def test_ltx_apg_guider_delta_with_norm_threshold() -> None: + guider = LtxAPGGuider(scale=2.0, eta=0.5, norm_threshold=1.0) + cond = torch.tensor([[[[2.0, 4.0]]]]) + uncond = torch.tensor([[[[1.0, 1.0]]]]) + + # guidance = cond - uncond = [1.0, 3.0] + # guidance_norm = sqrt(1^2 + 3^2) = sqrt(10) ≈ 3.16 + # Since norm_threshold (1.0) < guidance_norm, scale_factor = 1.0 / 3.16 ≈ 0.316 + # guidance = [1.0, 3.0] * 0.316 ≈ [0.316, 0.949] + # Then proceed with APG calculation + delta = guider.delta(cond=cond, uncond=uncond) + + # Verify the shape and that it's not all zeros + assert delta.shape == cond.shape + assert not torch.allclose(delta, torch.zeros_like(delta)) + + +def test_legacy_stateful_apg_guider_delta() -> None: + guider = LegacyStatefulAPGGuider(scale=2.0, eta=0.5, momentum=0.0) + cond = torch.tensor([[[[2.0, 4.0]]]]) + uncond = torch.tensor([[[[1.0, 1.0]]]]) + + # guidance = cond - uncond = [1.0, 3.0] + # Since momentum=0.0, no running average is used + # Since norm_threshold=5.0 (default) > guidance_norm, no scaling + # proj_coeff = projection_coef(guidance, cond) = (1*2 + 3*4) / (2^2 + 4^2 + 1e-8) = 14 / 20 = 0.7 + # g_parallel = 0.7 * [2.0, 4.0] = [1.4, 2.8] + # g_orth = [1.0, 3.0] - [1.4, 2.8] = [-0.4, 0.2] + # g_apg = [1.4, 2.8] * 0.5 + [-0.4, 0.2] = [0.3, 1.6] + # delta = g_apg * guidance_scale = [0.3, 1.6] * 2.0 = [0.6, 3.2] + delta = guider.delta(cond=cond, uncond=uncond) + expected = [[0.6, 3.2]] + + assert torch.allclose(delta, torch.tensor(expected)) + + +def test_legacy_stateful_apg_guider_delta_with_momentum() -> None: + guider = LegacyStatefulAPGGuider(scale=2.0, eta=0.5, momentum=0.5, norm_threshold=0.0) + cond = torch.tensor([[[[2.0, 4.0]]]]) + uncond = torch.tensor([[[[1.0, 1.0]]]]) + + # First call: running_avg = guidance = [1.0, 3.0] + delta1 = guider.delta(cond=cond, uncond=uncond) + + # Second call: running_avg = 0.5 * [1.0, 3.0] + [1.0, 3.0] = [1.5, 4.5] + # and guidance = [1.5, 4.5] + delta2 = guider.delta(cond=cond, uncond=uncond) + + # Verify the shape and that deltas are different (momentum affects result) + assert delta1.shape == cond.shape + assert delta2.shape == cond.shape + assert not torch.allclose(delta1, delta2) + + +def test_projection_coef() -> None: + to_project = torch.tensor([[2.0, 4.0]]) + project_onto = torch.tensor([[1.0, 1.0]]) + + # dot_product = 2*1 + 4*1 = 6 + # squared_norm = 1^2 + 1^2 + 1e-8 = 2 + 1e-8 ≈ 2 + # projection_coef = 6 / 2 = 3.0 + coef = projection_coef(to_project=to_project, project_onto=project_onto) + expected = [[3.0]] + + assert torch.allclose(coef, torch.tensor(expected)) + + +def test_projection_coef_orthogonal() -> None: + to_project = torch.tensor([[1.0, 0.0]]) + project_onto = torch.tensor([[0.0, 1.0]]) + + # dot_product = 1*0 + 0*1 = 0 + # squared_norm = 0^2 + 1^2 + 1e-8 = 1 + 1e-8 + # projection_coef = 0 / (1 + 1e-8) ≈ 0 + coef = projection_coef(to_project=to_project, project_onto=project_onto) + expected = [[0.0]] + + assert torch.allclose(coef, torch.tensor(expected), atol=1e-6) + + +test_ltx_apg_guider_delta() diff --git a/packages/ltx-core/tests/ltx_core/pipeline/components/test_patchifiers.py b/packages/ltx-core/tests/ltx_core/pipeline/components/test_patchifiers.py new file mode 100644 index 0000000000000000000000000000000000000000..032087d9d4af8c60ac4738e5bba41f3ad51a2dca --- /dev/null +++ b/packages/ltx-core/tests/ltx_core/pipeline/components/test_patchifiers.py @@ -0,0 +1,95 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Ivan Zorin + +import torch + +from ltx_core.pipeline.components.patchifiers import AudioPatchifier, VideoLatentPatchifier +from ltx_core.pipeline.components.protocols import AudioLatentShape, VideoLatentShape + + +def test_video_latent_patchifier() -> None: + # Setup + batch_size = 2 + channels = 128 + frames = 8 + height = 32 + width = 32 + patch_size = 4 + + # Create patchifier + patchifier = VideoLatentPatchifier(patch_size=patch_size) + assert patchifier is not None + assert patchifier.patch_size == (1, patch_size, patch_size) + + output_shape = VideoLatentShape( + batch=batch_size, + channels=channels, + frames=frames, + height=height, + width=width, + ) + # Create random latents + latents = torch.randn(batch_size, channels, frames, height, width) + + # Test patchify + patches = patchifier.patchify(latents) + coords = patchifier.get_patch_grid_bounds( + output_shape=output_shape, + device=latents.device, + ) + + expected_num_patches = frames * (height // patch_size) * (width // patch_size) + expected_features = channels * patch_size * patch_size + + assert patches.shape == (batch_size, expected_num_patches, expected_features) + assert coords.shape == (batch_size, 3, expected_num_patches, 2) + + # Test unpatchify + reconstructed_latents = patchifier.unpatchify(patches, output_shape) + reconstructed_coords = patchifier.get_patch_grid_bounds( + output_shape=output_shape, + device=reconstructed_latents.device, + ) + + # Verify roundtrip + assert torch.allclose(latents, reconstructed_latents, atol=1e-6), ( + f"Unpatchified latents do not match original latents: {latents.shape} != {reconstructed_latents.shape}" + ) + assert torch.allclose(coords, reconstructed_coords, atol=1e-6), ( + f"Coordinates of unpatchified latents do not match: {coords} != {reconstructed_coords}" + ) + + +def test_audio_patchifier() -> None: + batch_size = 2 + channels = 4 + frames = 12 + freq_bins = 16 + patchifier = AudioPatchifier(patch_size=16) + + latents = torch.randn(batch_size, channels, frames, freq_bins) + patches = patchifier.patchify(latents) + + expected_features = channels * freq_bins + assert patches.shape == (batch_size, frames, expected_features) + + output_shape = AudioLatentShape( + batch=batch_size, + channels=channels, + frames=frames, + mel_bins=freq_bins, + ) + coords = patchifier.get_patch_grid_bounds( + output_shape=output_shape, + device=latents.device, + ) + assert coords.shape == (batch_size, 1, frames, 2) + + reconstructed_latents = patchifier.unpatchify(patches, output_shape) + assert torch.allclose(latents, reconstructed_latents, atol=1e-6) + + reconstructed_coords = patchifier.get_patch_grid_bounds( + output_shape=output_shape, + device=reconstructed_latents.device, + ) + assert torch.allclose(coords, reconstructed_coords, atol=1e-6) diff --git a/packages/ltx-core/tests/ltx_core/pipeline/components/test_schedulers.py b/packages/ltx-core/tests/ltx_core/pipeline/components/test_schedulers.py new file mode 100644 index 0000000000000000000000000000000000000000..234db4ec3519a011d6417fdd9032f468c114b8b6 --- /dev/null +++ b/packages/ltx-core/tests/ltx_core/pipeline/components/test_schedulers.py @@ -0,0 +1,42 @@ +import torch + +from ltx_core.pipeline.components.schedulers import BetaScheduler, LinearQuadraticScheduler, LTX2Scheduler + + +def test_ltx2_scheduler_basic_properties() -> None: + scheduler = LTX2Scheduler() + + steps = 4 + latent = torch.zeros(1, 4, 8, 8) # non-None latent to exercise token-based shift + + sigmas = scheduler.execute(steps=steps, latent=latent) + + # We expect `steps + 1` sigma values. + assert isinstance(sigmas, torch.Tensor) + assert sigmas.shape == (steps + 1,) + + # All sigmas should be in [0, 1] and non-negative. + assert torch.all(sigmas >= 0.0) + assert torch.all(sigmas <= 1.0) + + +def test_linear_quadratic_scheduler_basic_properties() -> None: + scheduler = LinearQuadraticScheduler() + + steps = 5 + sigmas = scheduler.execute(steps=steps) + fixture = torch.Tensor([1.0000, 0.9875, 0.9750, 0.8583, 0.5333, 0.0000]) + assert isinstance(sigmas, torch.Tensor) + assert sigmas.shape == (steps + 1,) + assert torch.allclose(sigmas, fixture, atol=1e-4, rtol=1e-5) + + +def test_beta_scheduler_basic_properties() -> None: + scheduler = BetaScheduler() + + steps = 5 + sigmas = scheduler.execute(steps=steps, alpha=0.5, beta=0.7) + fixture = torch.Tensor([1.0000, 0.9758, 0.9144, 0.7701, 0.4146, 0.0000]) + assert isinstance(sigmas, torch.Tensor) + assert sigmas.shape == (steps + 1,) + assert torch.allclose(sigmas, fixture, atol=1e-4, rtol=1e-5) diff --git a/packages/ltx-core/tests/ltx_core/pipeline/conditioning/test_conditioning.py b/packages/ltx-core/tests/ltx_core/pipeline/conditioning/test_conditioning.py new file mode 100644 index 0000000000000000000000000000000000000000..aadcf1332c23bec0487968cbe56e4d652eb0dcfc --- /dev/null +++ b/packages/ltx-core/tests/ltx_core/pipeline/conditioning/test_conditioning.py @@ -0,0 +1,263 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + +import pytest +import torch + +from ltx_core.pipeline.components.patchifiers import AudioPatchifier, VideoLatentPatchifier +from ltx_core.pipeline.components.protocols import AudioLatentShape, VideoLatentShape, VideoPixelShape +from ltx_core.pipeline.conditioning.item import LatentState +from ltx_core.pipeline.conditioning.tools import AudioLatentTools, VideoLatentTools +from ltx_core.pipeline.conditioning.types.keyframe_cond import VideoConditionByKeyframeIndex +from ltx_core.pipeline.conditioning.types.latent_cond import VideoConditionByLatentIndex + + +def test_video_conditioning_tools_initialization() -> None: + """Test VideoLatentTools initialization with valid parameters.""" + patchifier = VideoLatentPatchifier(patch_size=4) + + tools = VideoLatentTools( + patchifier=patchifier, + target_shape=VideoLatentShape.from_pixel_shape( + shape=VideoPixelShape(batch=2, frames=9, height=128, width=128, fps=30.0) + ), + fps=30.0, + ) + + assert tools.fps == 30.0 + assert tools.target_shape.batch == 2 + assert tools.target_shape.channels == 128 + assert tools.target_shape.height == 128 // 32 + assert tools.target_shape.width == 128 // 32 + assert tools.target_shape.frames == (9 - 1) // 8 + 1 + + +def test_video_conditioning_builder_initialization_non_causal() -> None: + """Test VideoLatentBuilder initialization with causal_fix=False.""" + patchifier = VideoLatentPatchifier(patch_size=4) + + tools = VideoLatentTools( + patchifier=patchifier, + target_shape=VideoLatentShape.from_pixel_shape( + shape=VideoPixelShape(batch=1, frames=16, height=64, width=64, fps=24.0) + ), + fps=24.0, + causal_fix=False, + ) + assert tools.causal_fix is False + assert tools.target_shape.frames == 16 // 8 + + +def test_video_conditioning_tools_build_empty_state() -> None: + """Test VideoLatentTools.build_empty_state() method.""" + patchifier = VideoLatentPatchifier(patch_size=1) + tools = VideoLatentTools( + patchifier=patchifier, + target_shape=VideoLatentShape.from_pixel_shape( + shape=VideoPixelShape(batch=2, frames=9, height=128, width=128, fps=30.0) + ), + fps=30.0, + ) + latent_state = tools.create_initial_state(device=torch.device("cpu"), dtype=torch.float32) + + assert latent_state.latent.shape == (2, 32, 128) + assert latent_state.denoise_mask.shape == (2, 32, 1) + assert latent_state.positions.shape == (2, 3, 32, 2) + assert latent_state.positions.dtype == torch.float32 + + +def test_video_conditioning_tools_build_with_latent_conditioning() -> None: + """Test VideoLatentTools.build_empty_state() method with LatentConditionByFrame.""" + patchifier = VideoLatentPatchifier(patch_size=1) + batch = 2 + in_channels = 128 + latent_height = 4 + latent_width = 4 + + # Create a conditioning latent (single frame image) + conditioning = VideoConditionByLatentIndex( + latent=torch.randn(batch, in_channels, 1, latent_height, latent_width), + strength=0.5, + latent_idx=0, + ) + + tools = VideoLatentTools( + patchifier=patchifier, + target_shape=VideoLatentShape.from_pixel_shape( + shape=VideoPixelShape(batch=batch, frames=9, height=128, width=128, fps=30.0) + ), + fps=30.0, + ) + empty_state = tools.create_initial_state(device=torch.device("cpu"), dtype=torch.float32) + latent_state = conditioning.apply_to(latent_state=empty_state, latent_tools=tools) + + # Verify latent state structure + assert isinstance(latent_state, LatentState) + assert latent_state.latent.shape[0] == batch + assert empty_state.latent.shape == latent_state.latent.shape + + +def test_video_conditioning_builder_apply_to() -> None: + """Test VideoLatentBuilder.apply_to() method.""" + patchifier = VideoLatentPatchifier(patch_size=4) + batch = 1 + in_channels = 128 + latent_height = 4 + latent_width = 4 + + # Create conditioning latents + conditioning1 = VideoConditionByLatentIndex( + latent=torch.randn(batch, in_channels, 1, latent_height, latent_width), + strength=0.5, + latent_idx=0, + ) + conditioning2 = VideoConditionByLatentIndex( + latent=torch.randn(batch, in_channels, 1, latent_height, latent_width), + strength=0.7, + latent_idx=1, + ) + + tools = VideoLatentTools( + patchifier=patchifier, + target_shape=VideoLatentShape.from_pixel_shape( + shape=VideoPixelShape(batch=batch, frames=9, height=128, width=128, fps=30.0) + ), + fps=30.0, + ) + + device = torch.device("cpu") + dtype = torch.float32 + + # Build to create conditioning items + built_state = tools.create_initial_state(device=device, dtype=dtype) + + # Create a new state and apply conditioning manually + test_latent = torch.randn_like(built_state.latent) + test_denoise_mask = torch.ones_like(built_state.denoise_mask) + test_positions = built_state.positions.clone() + test_state = LatentState( + latent=test_latent, + denoise_mask=test_denoise_mask, + positions=test_positions, + clean_latent=test_latent.clone(), + ) + + result = conditioning1.apply_to(latent_state=test_state, latent_tools=tools) + result = conditioning2.apply_to(latent_state=result, latent_tools=tools) + + # Verify result is different from input (conditioning was applied) + assert not torch.allclose(result.latent, test_latent) + assert not torch.allclose(result.denoise_mask, test_denoise_mask) + + +def test_video_conditioning_builder_roundtrip() -> None: + """Test VideoLatentBuilder build -> revert roundtrip.""" + patchifier = VideoLatentPatchifier(patch_size=1) + batch = 1 + in_channels = 128 + latent_height = 4 + latent_width = 4 + + conditioning = VideoConditionByKeyframeIndex( + keyframes=torch.randn(batch, in_channels, 1, latent_height, latent_width), + frame_idx=0, + strength=0.5, + ) + + tools = VideoLatentTools( + patchifier=patchifier, + target_shape=VideoLatentShape.from_pixel_shape( + shape=VideoPixelShape(batch=batch, frames=9, height=128, width=128, fps=30.0) + ), + fps=30.0, + ) + + device = torch.device("cpu") + dtype = torch.float32 + + empty_state = tools.create_initial_state(device=device, dtype=dtype) + latent_state = conditioning.apply_to(latent_state=empty_state, latent_tools=tools) + unconditioned_state = tools.clear_conditioning(latent_state) + assert torch.allclose(unconditioned_state.latent, empty_state.latent) + + +def test_audio_conditioning_builder_initialization() -> None: + """Test AudioLatentBuilder initialization with valid parameters.""" + patchifier = AudioPatchifier( + patch_size=16, + sample_rate=16000, + hop_length=160, + audio_latent_downsample_factor=4, + ) + + tools = AudioLatentTools( + patchifier=patchifier, + target_shape=AudioLatentShape.from_duration(batch=2, duration=2.0, channels=8, mel_bins=16), + ) + + assert tools.target_shape.batch == 2 + assert tools.target_shape.channels == 8 + assert tools.target_shape.mel_bins == 16 + assert tools.target_shape.frames == int(2.0 * 16000.0 / 160.0 / 4.0) + + +def test_audio_conditioning_builder_build() -> None: + """Test AudioLatentBuilder.build() method.""" + patchifier = AudioPatchifier( + patch_size=16, + sample_rate=16000, + hop_length=160, + audio_latent_downsample_factor=4, + ) + + tools = AudioLatentTools( + patchifier=patchifier, + target_shape=AudioLatentShape.from_duration(batch=2, duration=1.0, channels=8, mel_bins=16), + ) + + device = torch.device("cpu") + dtype = torch.float32 + + latent_state = tools.create_initial_state(device=device, dtype=dtype) + + # Verify latent state structure + assert isinstance(latent_state, LatentState) + assert latent_state.latent.shape[0] == 2 # batch + assert latent_state.denoise_mask.shape[0] == 2 # batch + assert latent_state.positions.shape[0] == 2 # batch + + # Verify positions shape for audio (1D time dimension) + assert latent_state.positions.shape[1] == 1 # time dimension only + assert latent_state.positions.dtype == torch.float32 + + +def test_audio_conditioning_builder_roundtrip() -> None: + """Test AudioLatentBuilder build -> clear conditioning roundtrip.""" + patchifier = AudioPatchifier( + patch_size=16, + sample_rate=16000, + hop_length=160, + audio_latent_downsample_factor=4, + ) + + tools = AudioLatentTools( + patchifier=patchifier, + target_shape=AudioLatentShape.from_duration(batch=1, duration=1.0, channels=8, mel_bins=16), + ) + + device = torch.device("cpu") + dtype = torch.float32 + + # Build to get patchified state + built_state = tools.create_initial_state(device=device, dtype=dtype) + + # Revert + reverted_state = tools.clear_conditioning(built_state) + + # Verify result has unpatchified shape + assert reverted_state.latent.shape[0] == 1 # batch + assert reverted_state.latent.shape[2] == 128 # channels * mel_bins + + +if __name__ == "__main__": + pytest.main() diff --git a/packages/ltx-core/tests/ltx_core/test_placeholder.py b/packages/ltx-core/tests/ltx_core/test_placeholder.py new file mode 100644 index 0000000000000000000000000000000000000000..bf2603dedfc34dde5eec76c9d069627e17fdaa15 --- /dev/null +++ b/packages/ltx-core/tests/ltx_core/test_placeholder.py @@ -0,0 +1,6 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Andrew Kvochko + + +def test_placeholder() -> None: + assert True diff --git a/packages/ltx-core/tests/ltx_core/utils.py b/packages/ltx-core/tests/ltx_core/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..270a33a505a93473526db6c69aad2030eb031edd --- /dev/null +++ b/packages/ltx-core/tests/ltx_core/utils.py @@ -0,0 +1,27 @@ +"""Shared helpers for model-based tests.""" + +import os +from pathlib import Path + +import pytest + +DEFAULT_MODEL_PATH = Path( + "/models/comfyui_models/checkpoints/ltx-av-step-1932500-interleaved-new-vae.safetensors", +) + + +def resolve_model_path() -> str: + """Return the checkpoint path, preferring $MODEL_PATH when provided.""" + model_path = os.getenv("MODEL_PATH") + if model_path: + env_path = Path(model_path) + if not env_path.is_file(): + raise FileNotFoundError(f"MODEL_PATH points to a missing file: {model_path}") + return str(env_path) + + if DEFAULT_MODEL_PATH.is_file(): + return str(DEFAULT_MODEL_PATH) + + pytest.skip( + "MODEL_PATH is not set and the default checkpoint is unavailable; skipping test.", allow_module_level=True + ) diff --git a/packages/ltx-pipelines/README.md b/packages/ltx-pipelines/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a53fbe0af49abce26d921ee9112dcaeb9eb5af63 --- /dev/null +++ b/packages/ltx-pipelines/README.md @@ -0,0 +1 @@ +# LTX-2 Core diff --git a/packages/ltx-pipelines/pyproject.toml b/packages/ltx-pipelines/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..3db0b319b8eae486124bc28b85da39a7d777416a --- /dev/null +++ b/packages/ltx-pipelines/pyproject.toml @@ -0,0 +1,11 @@ +[project] +name = "ltx-pipelines" +version = "0.1.0" +description = "Pipelines implementation for Lightricks' LTX-2 model" +readme = "README.md" +requires-python = ">=3.12" +dependencies = ["ltx-core", "av", "tqdm", "pillow"] + +[build-system] +requires = ["uv_build>=0.9.8,<0.10.0"] +build-backend = "uv_build" diff --git a/packages/ltx-pipelines/src/ltx_pipelines/.ipynb_checkpoints/ti2vid_two_stages-checkpoint.py b/packages/ltx-pipelines/src/ltx_pipelines/.ipynb_checkpoints/ti2vid_two_stages-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..1f3ebd1c1b14dc9a7e6d2737dcab5301ae7f555e --- /dev/null +++ b/packages/ltx-pipelines/src/ltx_pipelines/.ipynb_checkpoints/ti2vid_two_stages-checkpoint.py @@ -0,0 +1,266 @@ +import torch + +from ltx_core.loader.primitives import LoraPathStrengthAndSDOps +from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP +from ltx_core.model.model_ledger import ModelLedger +from ltx_core.pipeline.components.diffusion_steps import EulerDiffusionStep +from ltx_core.pipeline.components.guiders import CFGGuider +from ltx_core.pipeline.components.noisers import GaussianNoiser +from ltx_core.pipeline.components.protocols import DiffusionStepProtocol, VideoPixelShape +from ltx_core.pipeline.components.schedulers import LTX2Scheduler +from ltx_core.pipeline.conditioning.item import LatentState +from ltx_core.tiling import TilingConfig +from ltx_pipelines import utils +from ltx_pipelines.constants import ( + AUDIO_SAMPLE_RATE, + DEFAULT_LORA_STRENGTH, + STAGE_2_DISTILLED_SIGMA_VALUES, +) +from ltx_pipelines.media_io import encode_video +from ltx_pipelines.pipeline_utils import ( + PipelineComponents, + denoise_audio_video, + encode_text, + euler_denoising_loop, + guider_denoising_func, + simple_denoising_func, +) +from ltx_pipelines.pipeline_utils import ( + decode_audio as vae_decode_audio, +) +from ltx_pipelines.pipeline_utils import ( + decode_video as vae_decode_video, +) +from ltx_pipelines.utils import image_conditionings_by_replacing_latent + + +class TI2VidTwoStagesPipeline: + def __init__( + self, + checkpoint_path: str, + distilled_lora_path: str, + distilled_lora_strength: float, + spatial_upsampler_path: str, + gemma_root: str, + loras: list[LoraPathStrengthAndSDOps], + device: str = utils.get_device(), + fp8transformer: bool = False, + local_files_only: bool = True, + ): + self.device = device + self.dtype = torch.bfloat16 + self.stage_1_model_ledger = ModelLedger( + dtype=self.dtype, + device=device, + checkpoint_path=checkpoint_path, + gemma_root_path=gemma_root, + spatial_upsampler_path=spatial_upsampler_path, + loras=loras, + fp8transformer=fp8transformer, + local_files_only=local_files_only + ) + + self.stage_2_model_ledger = self.stage_1_model_ledger.with_loras( + loras=[ + LoraPathStrengthAndSDOps( + path=distilled_lora_path, + strength=distilled_lora_strength, + sd_ops=LTXV_LORA_COMFY_RENAMING_MAP, + ) + ], + ) + + self.pipeline_components = PipelineComponents( + dtype=self.dtype, + device=device, + ) + + @torch.inference_mode() + def __call__( # noqa: PLR0913 + self, + prompt: str, + output_path: str, + negative_prompt: str, + seed: int, + height: int, + width: int, + num_frames: int, + frame_rate: float, + num_inference_steps: int, + cfg_guidance_scale: float, + images: list[tuple[str, int, float]], + tiling_config: TilingConfig | None = None, + ) -> None: + generator = torch.Generator(device=self.device).manual_seed(seed) + noiser = GaussianNoiser(generator=generator) + stepper = EulerDiffusionStep() + cfg_guider = CFGGuider(cfg_guidance_scale) + dtype = torch.bfloat16 + + text_encoder = self.stage_1_model_ledger.text_encoder() + context_p, context_n = encode_text(text_encoder, prompts=[prompt, negative_prompt]) + v_context_p, a_context_p = context_p + v_context_n, a_context_n = context_n + + torch.cuda.synchronize() + del text_encoder + utils.cleanup_memory() + + # Stage 1: Initial low resolution video generation. + video_encoder = self.stage_1_model_ledger.video_encoder() + transformer = self.stage_1_model_ledger.transformer() + sigmas = LTX2Scheduler().execute(steps=num_inference_steps).to(dtype=torch.float32, device=self.device) + + def first_stage_denoising_loop( + sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol + ) -> tuple[LatentState, LatentState]: + return euler_denoising_loop( + sigmas=sigmas, + video_state=video_state, + audio_state=audio_state, + stepper=stepper, + denoise_fn=guider_denoising_func( + cfg_guider, + v_context_p, + v_context_n, + a_context_p, + a_context_n, + transformer=transformer, # noqa: F821 + ), + ) + + stage_1_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate) + stage_1_conditionings = image_conditionings_by_replacing_latent( + images=images, + height=stage_1_output_shape.height, + width=stage_1_output_shape.width, + video_encoder=video_encoder, + dtype=dtype, + device=self.device, + ) + video_state, audio_state = denoise_audio_video( + output_shape=stage_1_output_shape, + conditionings=stage_1_conditionings, + noiser=noiser, + sigmas=sigmas, + stepper=stepper, + denoising_loop_fn=first_stage_denoising_loop, + components=self.pipeline_components, + dtype=dtype, + device=self.device, + ) + + torch.cuda.synchronize() + del transformer + utils.cleanup_memory() + + # Stage 2: Upsample and refine the video at higher resolution with distilled LORA. + upscaled_video_latent = utils.upsample_video( + latent=video_state.latent[:1], + video_encoder=video_encoder, + upsampler=self.stage_2_model_ledger.spatial_upsampler(), + ) + + torch.cuda.synchronize() + utils.cleanup_memory() + + transformer = self.stage_2_model_ledger.transformer() + distilled_sigmas = torch.Tensor(STAGE_2_DISTILLED_SIGMA_VALUES).to(self.device) + + def second_stage_denoising_loop( + sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol + ) -> tuple[LatentState, LatentState]: + return euler_denoising_loop( + sigmas=sigmas, + video_state=video_state, + audio_state=audio_state, + stepper=stepper, + denoise_fn=simple_denoising_func( + video_context=v_context_p, + audio_context=a_context_p, + transformer=transformer, # noqa: F821 + ), + ) + + stage_2_output_shape = VideoPixelShape( + batch=1, frames=num_frames, width=width * 2, height=height * 2, fps=frame_rate + ) + stage_2_conditionings = image_conditionings_by_replacing_latent( + images=images, + height=stage_2_output_shape.height, + width=stage_2_output_shape.width, + video_encoder=video_encoder, + dtype=dtype, + device=self.device, + ) + video_state, audio_state = denoise_audio_video( + output_shape=stage_2_output_shape, + conditionings=stage_2_conditionings, + noiser=noiser, + sigmas=distilled_sigmas, + stepper=stepper, + denoising_loop_fn=second_stage_denoising_loop, + components=self.pipeline_components, + dtype=dtype, + device=self.device, + noise_scale=distilled_sigmas[0], + initial_video_latent=upscaled_video_latent, + initial_audio_latent=audio_state.latent, + ) + + torch.cuda.synchronize() + del transformer + del video_encoder + utils.cleanup_memory() + + decoded_video = vae_decode_video(video_state, self.stage_2_model_ledger.video_decoder(), tiling_config) + + decoded_audio = vae_decode_audio( + audio_state, self.stage_2_model_ledger.audio_decoder(), self.stage_2_model_ledger.vocoder() + ) + + encode_video( + video=decoded_video, + fps=frame_rate, + audio=decoded_audio, + audio_sample_rate=AUDIO_SAMPLE_RATE, + output_path=output_path, + ) + + +def main() -> None: + parser = utils.default_2_stage_arg_parser() + args = parser.parse_args() + lora_strengths = (args.lora_strength + [DEFAULT_LORA_STRENGTH] * len(args.lora))[: len(args.lora)] + loras = [ + LoraPathStrengthAndSDOps(lora, strength, LTXV_LORA_COMFY_RENAMING_MAP) + for lora, strength in zip(args.lora, lora_strengths, strict=True) + ] + pipeline = TI2VidTwoStagesPipeline( + checkpoint_path=args.checkpoint_path, + distilled_lora_path=args.distilled_lora_path, + distilled_lora_strength=args.distilled_lora_strength, + spatial_upsampler_path=args.spatial_upsampler_path, + gemma_root=args.gemma_root, + loras=loras, + fp8transformer=args.enable_fp8, + ) + + pipeline( + prompt=args.prompt, + negative_prompt=args.negative_prompt, + output_path=args.output_path, + seed=args.seed, + height=args.height, + width=args.width, + num_frames=args.num_frames, + frame_rate=args.frame_rate, + num_inference_steps=args.num_inference_steps, + cfg_guidance_scale=args.cfg_guidance_scale, + images=args.images, + tiling_config=TilingConfig.default(), + ) + + +if __name__ == "__main__": + main() diff --git a/packages/ltx-pipelines/src/ltx_pipelines/__init__.py b/packages/ltx-pipelines/src/ltx_pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/packages/ltx-pipelines/src/ltx_pipelines/__pycache__/__init__.cpython-310.pyc b/packages/ltx-pipelines/src/ltx_pipelines/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f8c67c75ecd68c2f9a0da2d241fdf83ac9e7f5a Binary files /dev/null and b/packages/ltx-pipelines/src/ltx_pipelines/__pycache__/__init__.cpython-310.pyc differ diff --git a/packages/ltx-pipelines/src/ltx_pipelines/__pycache__/constants.cpython-310.pyc b/packages/ltx-pipelines/src/ltx_pipelines/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ded4e5a7ef5cd0fff2fe70bb090fba5f8579361 Binary files /dev/null and b/packages/ltx-pipelines/src/ltx_pipelines/__pycache__/constants.cpython-310.pyc differ diff --git a/packages/ltx-pipelines/src/ltx_pipelines/__pycache__/media_io.cpython-310.pyc b/packages/ltx-pipelines/src/ltx_pipelines/__pycache__/media_io.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0dbab8ce87fd63b150f41f1f2bb18d1f963f2c33 Binary files /dev/null and b/packages/ltx-pipelines/src/ltx_pipelines/__pycache__/media_io.cpython-310.pyc differ diff --git a/packages/ltx-pipelines/src/ltx_pipelines/__pycache__/pipeline_utils.cpython-310.pyc b/packages/ltx-pipelines/src/ltx_pipelines/__pycache__/pipeline_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38804bb75493d9e83bef7433a878ba8264d76da5 Binary files /dev/null and b/packages/ltx-pipelines/src/ltx_pipelines/__pycache__/pipeline_utils.cpython-310.pyc differ diff --git a/packages/ltx-pipelines/src/ltx_pipelines/__pycache__/ti2vid_one_stage.cpython-310.pyc b/packages/ltx-pipelines/src/ltx_pipelines/__pycache__/ti2vid_one_stage.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..359988842debf1932e99a04a59f2c191f36b62c0 Binary files /dev/null and b/packages/ltx-pipelines/src/ltx_pipelines/__pycache__/ti2vid_one_stage.cpython-310.pyc differ diff --git a/packages/ltx-pipelines/src/ltx_pipelines/__pycache__/ti2vid_two_stages.cpython-310.pyc b/packages/ltx-pipelines/src/ltx_pipelines/__pycache__/ti2vid_two_stages.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..356363224dc43cf57020c0b660f8bfe174968561 Binary files /dev/null and b/packages/ltx-pipelines/src/ltx_pipelines/__pycache__/ti2vid_two_stages.cpython-310.pyc differ diff --git a/packages/ltx-pipelines/src/ltx_pipelines/__pycache__/utils.cpython-310.pyc b/packages/ltx-pipelines/src/ltx_pipelines/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7baef5a7ae48ce4c4327617dc7fcab7bb72ff6c9 Binary files /dev/null and b/packages/ltx-pipelines/src/ltx_pipelines/__pycache__/utils.cpython-310.pyc differ diff --git a/packages/ltx-pipelines/src/ltx_pipelines/constants.py b/packages/ltx-pipelines/src/ltx_pipelines/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..0ddfe22a823d8fd98b64ccdaeba1b286a8965a48 --- /dev/null +++ b/packages/ltx-pipelines/src/ltx_pipelines/constants.py @@ -0,0 +1,73 @@ +# ============================================================================= +# Diffusion Schedule +# ============================================================================= + +# Noise schedule for the distilled pipeline. These sigma values control noise +# levels at each denoising step and were tuned to match the distillation process. +DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0] + +# Reduced schedule for super-resolution stage 2 (subset of distilled values) +STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0] + + +# ============================================================================= +# Video Generation Defaults +# ============================================================================= + +DEFAULT_SEED = 42 +DEFAULT_HEIGHT = 512 +DEFAULT_WIDTH = 768 +DEFAULT_NUM_FRAMES = 121 +DEFAULT_FRAME_RATE = 24.0 +DEFAULT_NUM_INFERENCE_STEPS = 40 +DEFAULT_CFG_GUIDANCE_SCALE = 3.0 + + +# ============================================================================= +# Audio +# ============================================================================= + +AUDIO_SAMPLE_RATE = 24000 + + +# ============================================================================= +# LoRA +# ============================================================================= + +DEFAULT_LORA_STRENGTH = 1.0 + + +# ============================================================================= +# Video VAE Architecture +# ============================================================================= + +VIDEO_SCALE_FACTORS = (8, 32, 32) # (temporal, height, width) +VIDEO_LATENT_CHANNELS = 128 + + +# ============================================================================= +# Image Preprocessing +# ============================================================================= + +# CRF (Constant Rate Factor) for H.264 encoding used in image conditioning. +# Lower = higher quality, 0 = lossless. This mimics compression artifacts. +DEFAULT_IMAGE_CRF = 29 + + +# ============================================================================= +# Prompts +# ============================================================================= + +DEFAULT_NEGATIVE_PROMPT = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) diff --git a/packages/ltx-pipelines/src/ltx_pipelines/distilled.py b/packages/ltx-pipelines/src/ltx_pipelines/distilled.py new file mode 100644 index 0000000000000000000000000000000000000000..bb7eb3fd8d13cbf972fa0b6bdd1ecdf88fd924f7 --- /dev/null +++ b/packages/ltx-pipelines/src/ltx_pipelines/distilled.py @@ -0,0 +1,215 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Amit Pintz. + + +import torch + +from ltx_core.loader.primitives import LoraPathStrengthAndSDOps +from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP +from ltx_core.model.model_ledger import ModelLedger +from ltx_core.pipeline.components.diffusion_steps import EulerDiffusionStep +from ltx_core.pipeline.components.noisers import GaussianNoiser +from ltx_core.pipeline.components.protocols import DiffusionStepProtocol, VideoPixelShape +from ltx_core.pipeline.conditioning.item import LatentState +from ltx_core.tiling import TilingConfig +from ltx_pipelines import utils +from ltx_pipelines.constants import ( + AUDIO_SAMPLE_RATE, + DEFAULT_LORA_STRENGTH, + DISTILLED_SIGMA_VALUES, + STAGE_2_DISTILLED_SIGMA_VALUES, +) +from ltx_pipelines.media_io import encode_video +from ltx_pipelines.pipeline_utils import ( + PipelineComponents, + denoise_audio_video, + encode_text, + euler_denoising_loop, + simple_denoising_func, +) +from ltx_pipelines.pipeline_utils import decode_audio as vae_decode_audio +from ltx_pipelines.pipeline_utils import decode_video as vae_decode_video +from ltx_pipelines.utils import image_conditionings_by_replacing_latent + +device = utils.get_device() + + +class DistilledPipeline: + def __init__( + self, + checkpoint_path: str, + gemma_root: str, + spatial_upsampler_path: str, + loras: list[LoraPathStrengthAndSDOps], + device: torch.device = device, + fp8transformer: bool = False, + ): + self.device = device + self.dtype = torch.bfloat16 + + self.model_ledger = ModelLedger( + dtype=self.dtype, + device=device, + checkpoint_path=checkpoint_path, + spatial_upsampler_path=spatial_upsampler_path, + gemma_root_path=gemma_root, + loras=loras, + fp8transformer=fp8transformer, + ) + + self.pipeline_components = PipelineComponents( + dtype=self.dtype, + device=device, + ) + + @torch.inference_mode() + def __call__( + self, + prompt: str, + output_path: str, + seed: int, + height: int, + width: int, + num_frames: int, + frame_rate: float, + images: list[tuple[str, int, float]], + tiling_config: TilingConfig | None = None, + ) -> None: + generator = torch.Generator(device=self.device).manual_seed(seed) + noiser = GaussianNoiser(generator=generator) + stepper = EulerDiffusionStep() + dtype = torch.bfloat16 + + text_encoder = self.model_ledger.text_encoder() + context_p = encode_text(text_encoder, prompts=[prompt])[0] + video_context, audio_context = context_p + + torch.cuda.synchronize() + del text_encoder + utils.cleanup_memory() + + # Stage 1: Initial low resolution video generation. + video_encoder = self.model_ledger.video_encoder() + transformer = self.model_ledger.transformer() + stage_1_sigmas = torch.Tensor(DISTILLED_SIGMA_VALUES).to(self.device) + + def denoising_loop( + sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol + ) -> tuple[LatentState, LatentState]: + return euler_denoising_loop( + sigmas=sigmas, + video_state=video_state, + audio_state=audio_state, + stepper=stepper, + denoise_fn=simple_denoising_func( + video_context=video_context, + audio_context=audio_context, + transformer=transformer, # noqa: F821 + ), + ) + + stage_1_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate) + stage_1_conditionings = image_conditionings_by_replacing_latent( + images=images, + height=stage_1_output_shape.height, + width=stage_1_output_shape.width, + video_encoder=video_encoder, + dtype=dtype, + device=self.device, + ) + + video_state, audio_state = denoise_audio_video( + output_shape=stage_1_output_shape, + conditionings=stage_1_conditionings, + noiser=noiser, + sigmas=stage_1_sigmas, + stepper=stepper, + denoising_loop_fn=denoising_loop, + components=self.pipeline_components, + dtype=dtype, + device=self.device, + ) + + # Stage 2: Upsample and refine the video at higher resolution with distilled LORA. + upscaled_video_latent = utils.upsample_video( + latent=video_state.latent[:1], video_encoder=video_encoder, upsampler=self.model_ledger.spatial_upsampler() + ) + + torch.cuda.synchronize() + utils.cleanup_memory() + + stage_2_sigmas = torch.Tensor(STAGE_2_DISTILLED_SIGMA_VALUES).to(self.device) + stage_2_output_shape = VideoPixelShape( + batch=1, frames=num_frames, width=width * 2, height=height * 2, fps=frame_rate + ) + stage_2_conditionings = image_conditionings_by_replacing_latent( + images=images, + height=stage_2_output_shape.height, + width=stage_2_output_shape.width, + video_encoder=video_encoder, + dtype=dtype, + device=self.device, + ) + video_state, audio_state = denoise_audio_video( + output_shape=stage_2_output_shape, + conditionings=stage_2_conditionings, + noiser=noiser, + sigmas=stage_2_sigmas, + stepper=stepper, + denoising_loop_fn=denoising_loop, + components=self.pipeline_components, + dtype=dtype, + device=self.device, + noise_scale=stage_2_sigmas[0], + initial_video_latent=upscaled_video_latent, + initial_audio_latent=audio_state.latent, + ) + + torch.cuda.synchronize() + del transformer + del video_encoder + utils.cleanup_memory() + + decoded_video = vae_decode_video(video_state, self.model_ledger.video_decoder(), tiling_config) + + decoded_audio = vae_decode_audio(audio_state, self.model_ledger.audio_decoder(), self.model_ledger.vocoder()) + + encode_video( + video=decoded_video, + fps=frame_rate, + audio=decoded_audio, + audio_sample_rate=AUDIO_SAMPLE_RATE, + output_path=output_path, + ) + + +def main() -> None: + parser = utils.default_2_stage_distilled_arg_parser() + args = parser.parse_args() + lora_strengths = (args.lora_strength + [DEFAULT_LORA_STRENGTH] * len(args.lora))[: len(args.lora)] + loras = [ + LoraPathStrengthAndSDOps(lora, strength, LTXV_LORA_COMFY_RENAMING_MAP) + for lora, strength in zip(args.lora, lora_strengths, strict=True) + ] + pipeline = DistilledPipeline( + checkpoint_path=args.checkpoint_path, + spatial_upsampler_path=args.spatial_upsampler_path, + gemma_root=args.gemma_root, + loras=loras, + fp8transformer=args.enable_fp8, + ) + pipeline( + prompt=args.prompt, + output_path=args.output_path, + seed=args.seed, + height=args.height, + width=args.width, + num_frames=args.num_frames, + frame_rate=args.frame_rate, + images=args.images, + tiling_config=TilingConfig.default(), + ) + + +if __name__ == "__main__": + main() diff --git a/packages/ltx-pipelines/src/ltx_pipelines/ic_lora.py b/packages/ltx-pipelines/src/ltx_pipelines/ic_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..cbf152775dfb47517f5817abb8076644ffefef9a --- /dev/null +++ b/packages/ltx-pipelines/src/ltx_pipelines/ic_lora.py @@ -0,0 +1,304 @@ +import torch + +from ltx_core.loader.primitives import LoraPathStrengthAndSDOps +from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP +from ltx_core.model.model_ledger import ModelLedger +from ltx_core.model.video_vae.video_vae import Encoder as VideoEncoder +from ltx_core.pipeline.components.diffusion_steps import EulerDiffusionStep +from ltx_core.pipeline.components.guiders import CFGGuider +from ltx_core.pipeline.components.noisers import GaussianNoiser +from ltx_core.pipeline.components.protocols import DiffusionStepProtocol, VideoPixelShape +from ltx_core.pipeline.components.schedulers import LTX2Scheduler +from ltx_core.pipeline.conditioning.item import ConditioningItem, LatentState +from ltx_core.pipeline.conditioning.types.keyframe_cond import VideoConditionByKeyframeIndex +from ltx_core.tiling import TilingConfig +from ltx_pipelines import utils +from ltx_pipelines.constants import ( + AUDIO_SAMPLE_RATE, + DEFAULT_LORA_STRENGTH, + STAGE_2_DISTILLED_SIGMA_VALUES, +) +from ltx_pipelines.media_io import encode_video, load_video_conditioning +from ltx_pipelines.pipeline_utils import ( + PipelineComponents, + denoise_audio_video, + encode_text, + euler_denoising_loop, + guider_denoising_func, + simple_denoising_func, +) +from ltx_pipelines.pipeline_utils import decode_audio as vae_decode_audio +from ltx_pipelines.pipeline_utils import decode_video as vae_decode_video + +device = utils.get_device() + + +class ICLoraPipeline: + def __init__( + self, + checkpoint_path: str, + distilled_lora_path: str, + distilled_lora_strength: float, + spatial_upsampler_path: str, + gemma_root: str, + loras: list[LoraPathStrengthAndSDOps], + device: torch.device = device, + fp8transformer: bool = False, + ): + self.dtype = torch.bfloat16 + self.stage_1_model_ledger = ModelLedger( + dtype=self.dtype, + device=device, + checkpoint_path=checkpoint_path, + spatial_upsampler_path=spatial_upsampler_path, + gemma_root_path=gemma_root, + loras=loras, + fp8transformer=fp8transformer, + ) + self.stage_2_model_ledger = self.stage_1_model_ledger.with_loras( + loras=[ + LoraPathStrengthAndSDOps( + path=distilled_lora_path, + strength=distilled_lora_strength, + sd_ops=LTXV_LORA_COMFY_RENAMING_MAP, + ) + ], + ) + self.pipeline_components = PipelineComponents( + dtype=self.dtype, + device=device, + ) + self.device = device + + @torch.inference_mode() + def __call__( # noqa: PLR0913 + self, + prompt: str, + output_path: str, + negative_prompt: str, + seed: int, + height: int, + width: int, + num_frames: int, + frame_rate: float, + num_inference_steps: int, + cfg_guidance_scale: float, + images: list[tuple[str, int, float]], + video_conditioning: list[tuple[str, float]], + tiling_config: TilingConfig | None = None, + ) -> None: + generator = torch.Generator(device=self.device).manual_seed(seed) + noiser = GaussianNoiser(generator=generator) + stepper = EulerDiffusionStep() + cfg_guider = CFGGuider(cfg_guidance_scale) + dtype = torch.bfloat16 + + text_encoder = self.stage_1_model_ledger.text_encoder() + context_p, context_n = encode_text(text_encoder, prompts=[prompt, negative_prompt]) + v_context_p, a_context_p = context_p + v_context_n, a_context_n = context_n + + torch.cuda.synchronize() + del text_encoder + utils.cleanup_memory() + + # Stage 1: Initial low resolution video generation. + video_encoder = self.stage_1_model_ledger.video_encoder() + transformer = self.stage_1_model_ledger.transformer() + sigmas = LTX2Scheduler().execute(steps=num_inference_steps).to(dtype=torch.float32, device=self.device) + + def first_stage_denoising_loop( + sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol + ) -> tuple[LatentState, LatentState]: + return euler_denoising_loop( + sigmas=sigmas, + video_state=video_state, + audio_state=audio_state, + stepper=stepper, + denoise_fn=guider_denoising_func( + cfg_guider, + v_context_p, + v_context_n, + a_context_p, + a_context_n, + transformer=transformer, # noqa: F821 + ), + ) + + stage_1_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate) + stage_1_conditionings = self._create_conditionings( + images=images, + video_conditioning=video_conditioning, + height=stage_1_output_shape.height, + width=stage_1_output_shape.width, + video_encoder=video_encoder, + num_frames=num_frames, + ) + video_state, audio_state = denoise_audio_video( + output_shape=stage_1_output_shape, + conditionings=stage_1_conditionings, + noiser=noiser, + sigmas=sigmas, + stepper=stepper, + denoising_loop_fn=first_stage_denoising_loop, + components=self.pipeline_components, + dtype=dtype, + device=self.device, + ) + + torch.cuda.synchronize() + del transformer + utils.cleanup_memory() + + # Stage 2: Upsample and refine the video at higher resolution with distilled LORA. + upscaled_video_latent = utils.upsample_video( + latent=video_state.latent[:1], + video_encoder=video_encoder, + upsampler=self.stage_2_model_ledger.spatial_upsampler(), + ) + + torch.cuda.synchronize() + utils.cleanup_memory() + + transformer = self.stage_2_model_ledger.transformer() + distilled_sigmas = torch.Tensor(STAGE_2_DISTILLED_SIGMA_VALUES).to(self.device) + + def second_stage_denoising_loop( + sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol + ) -> tuple[LatentState, LatentState]: + return euler_denoising_loop( + sigmas=sigmas, + video_state=video_state, + audio_state=audio_state, + stepper=stepper, + denoise_fn=simple_denoising_func( + video_context=v_context_p, + audio_context=a_context_p, + transformer=transformer, # noqa: F821 + ), + ) + + stage_2_output_shape = VideoPixelShape( + batch=1, frames=num_frames, width=width * 2, height=height * 2, fps=frame_rate + ) + stage_2_conditionings = utils.image_conditionings_by_replacing_latent( + images=images, + height=stage_2_output_shape.height, + width=stage_2_output_shape.width, + video_encoder=video_encoder, + dtype=self.dtype, + device=self.device, + ) + + video_state, audio_state = denoise_audio_video( + output_shape=stage_2_output_shape, + conditionings=stage_2_conditionings, + noiser=noiser, + sigmas=distilled_sigmas, + stepper=stepper, + denoising_loop_fn=second_stage_denoising_loop, + components=self.pipeline_components, + dtype=dtype, + device=self.device, + noise_scale=distilled_sigmas[0], + initial_video_latent=upscaled_video_latent, + initial_audio_latent=audio_state.latent, + ) + + torch.cuda.synchronize() + del transformer + del video_encoder + utils.cleanup_memory() + + decoded_video = vae_decode_video(video_state, self.stage_2_model_ledger.video_decoder(), tiling_config) + + decoded_audio = vae_decode_audio( + audio_state, self.stage_2_model_ledger.audio_decoder(), self.stage_2_model_ledger.vocoder() + ) + + encode_video( + video=decoded_video, + fps=frame_rate, + audio=decoded_audio, + audio_sample_rate=AUDIO_SAMPLE_RATE, + output_path=output_path, + ) + + def _create_conditionings( + self, + images: list[tuple[str, int, float]], + video_conditioning: list[tuple[str, float]], + height: int, + width: int, + num_frames: int, + video_encoder: VideoEncoder, + ) -> list[ConditioningItem]: + conditionings = utils.image_conditionings_by_replacing_latent( + images=images, + height=height, + width=width, + video_encoder=video_encoder, + dtype=self.dtype, + device=self.device, + ) + + for video_path, strength in video_conditioning: + video = load_video_conditioning( + video_path=video_path, + height=height, + width=width, + frame_cap=num_frames, + dtype=self.dtype, + device=self.device, + ) + encoded_video = video_encoder(video) + conditionings.append(VideoConditionByKeyframeIndex(keyframes=encoded_video, frame_idx=0, strength=strength)) + + return conditionings + + +def main() -> None: + parser = utils.default_2_stage_arg_parser() + parser.add_argument( + "--video_conditioning", + dest="video_conditioning", + action=utils.VideoConditioningAction, + nargs=2, + metavar=("PATH", "STRENGTH"), + required=True, + ) + args = parser.parse_args() + lora_strengths = (args.lora_strength + [DEFAULT_LORA_STRENGTH] * len(args.lora))[: len(args.lora)] + loras = [ + LoraPathStrengthAndSDOps(lora, strength, LTXV_LORA_COMFY_RENAMING_MAP) + for lora, strength in zip(args.lora, lora_strengths, strict=True) + ] + pipeline = ICLoraPipeline( + checkpoint_path=args.checkpoint_path, + distilled_lora_path=args.distilled_lora_path, + distilled_lora_strength=args.distilled_lora_strength, + spatial_upsampler_path=args.spatial_upsampler_path, + gemma_root=args.gemma_root, + loras=loras, + fp8transformer=args.enable_fp8, + ) + + pipeline( + prompt=args.prompt, + output_path=args.output_path, + negative_prompt=args.negative_prompt, + seed=args.seed, + height=args.height, + width=args.width, + num_frames=args.num_frames, + frame_rate=args.frame_rate, + num_inference_steps=args.num_inference_steps, + cfg_guidance_scale=args.cfg_guidance_scale, + images=args.images, + video_conditioning=args.video_conditioning, + tiling_config=TilingConfig.default(), + ) + + +if __name__ == "__main__": + main() diff --git a/packages/ltx-pipelines/src/ltx_pipelines/keyframe_interpolation.py b/packages/ltx-pipelines/src/ltx_pipelines/keyframe_interpolation.py new file mode 100644 index 0000000000000000000000000000000000000000..8241d14c4dd1b42fa0f148b83415e91e1e63d20d --- /dev/null +++ b/packages/ltx-pipelines/src/ltx_pipelines/keyframe_interpolation.py @@ -0,0 +1,258 @@ +import torch + +from ltx_core.loader.primitives import LoraPathStrengthAndSDOps +from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP +from ltx_core.model.model_ledger import ModelLedger +from ltx_core.pipeline.components.diffusion_steps import EulerDiffusionStep +from ltx_core.pipeline.components.guiders import CFGGuider +from ltx_core.pipeline.components.noisers import GaussianNoiser +from ltx_core.pipeline.components.protocols import DiffusionStepProtocol, VideoPixelShape +from ltx_core.pipeline.components.schedulers import LTX2Scheduler +from ltx_core.pipeline.conditioning.item import LatentState +from ltx_core.tiling import TilingConfig +from ltx_pipelines import utils +from ltx_pipelines.constants import ( + AUDIO_SAMPLE_RATE, + DEFAULT_LORA_STRENGTH, + STAGE_2_DISTILLED_SIGMA_VALUES, +) +from ltx_pipelines.media_io import encode_video +from ltx_pipelines.pipeline_utils import ( + PipelineComponents, + denoise_audio_video, + encode_text, + euler_denoising_loop, + guider_denoising_func, + simple_denoising_func, +) +from ltx_pipelines.pipeline_utils import decode_audio as vae_decode_audio +from ltx_pipelines.pipeline_utils import decode_video as vae_decode_video + +device = utils.get_device() + + +class KeyframeInterpolationPipeline: + def __init__( + self, + checkpoint_path: str, + distilled_lora_path: str, + distilled_lora_strength: float, + spatial_upsampler_path: str, + gemma_root: str, + loras: list[LoraPathStrengthAndSDOps], + device: torch.device = device, + fp8transformer: bool = False, + ): + self.device = device + self.dtype = torch.bfloat16 + self.stage_1_model_ledger = ModelLedger( + dtype=self.dtype, + device=device, + checkpoint_path=checkpoint_path, + spatial_upsampler_path=spatial_upsampler_path, + gemma_root_path=gemma_root, + loras=loras, + fp8transformer=fp8transformer, + ) + self.stage_2_model_ledger = self.stage_1_model_ledger.with_loras( + loras=[ + LoraPathStrengthAndSDOps( + path=distilled_lora_path, + strength=distilled_lora_strength, + sd_ops=LTXV_LORA_COMFY_RENAMING_MAP, + ) + ], + ) + self.pipeline_components = PipelineComponents( + dtype=self.dtype, + device=device, + ) + + @torch.inference_mode() + def __call__( # noqa: PLR0913 + self, + prompt: str, + output_path: str, + negative_prompt: str, + seed: int, + height: int, + width: int, + num_frames: int, + frame_rate: float, + num_inference_steps: int, + cfg_guidance_scale: float, + images: list[tuple[str, int, float]], + tiling_config: TilingConfig | None = None, + ) -> None: + generator = torch.Generator(device=self.device).manual_seed(seed) + noiser = GaussianNoiser(generator=generator) + stepper = EulerDiffusionStep() + cfg_guider = CFGGuider(cfg_guidance_scale) + dtype = torch.bfloat16 + + text_encoder = self.stage_1_model_ledger.text_encoder() + context_p, context_n = encode_text(text_encoder, prompts=[prompt, negative_prompt]) + v_context_p, a_context_p = context_p + v_context_n, a_context_n = context_n + + torch.cuda.synchronize() + del text_encoder + utils.cleanup_memory() + + # Stage 1: Initial low resolution video generation. + video_encoder = self.stage_1_model_ledger.video_encoder() + transformer = self.stage_1_model_ledger.transformer() + sigmas = LTX2Scheduler().execute(steps=num_inference_steps).to(dtype=torch.float32, device=self.device) + + def first_stage_denoising_loop( + sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol + ) -> tuple[LatentState, LatentState]: + return euler_denoising_loop( + sigmas=sigmas, + video_state=video_state, + audio_state=audio_state, + stepper=stepper, + denoise_fn=guider_denoising_func( + cfg_guider, + v_context_p, + v_context_n, + a_context_p, + a_context_n, + transformer=transformer, # noqa: F821 + ), + ) + + stage_1_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate) + stage_1_conditionings = utils.image_conditionings_by_adding_guiding_latent( + images=images, + height=stage_1_output_shape.height, + width=stage_1_output_shape.width, + video_encoder=video_encoder, + dtype=dtype, + device=self.device, + ) + video_state, audio_state = denoise_audio_video( + output_shape=stage_1_output_shape, + conditionings=stage_1_conditionings, + noiser=noiser, + sigmas=sigmas, + stepper=stepper, + denoising_loop_fn=first_stage_denoising_loop, + components=self.pipeline_components, + dtype=dtype, + device=self.device, + ) + + torch.cuda.synchronize() + del transformer + utils.cleanup_memory() + + # Stage 2: Upsample and refine the video at higher resolution with distilled LORA. + upscaled_video_latent = utils.upsample_video( + latent=video_state.latent[:1], + video_encoder=video_encoder, + upsampler=self.stage_2_model_ledger.spatial_upsampler(), + ) + + torch.cuda.synchronize() + utils.cleanup_memory() + + transformer = self.stage_2_model_ledger.transformer() + distilled_sigmas = torch.Tensor(STAGE_2_DISTILLED_SIGMA_VALUES).to(self.device) + + def second_stage_denoising_loop( + sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol + ) -> tuple[LatentState, LatentState]: + return euler_denoising_loop( + sigmas=sigmas, + video_state=video_state, + audio_state=audio_state, + stepper=stepper, + denoise_fn=simple_denoising_func( + video_context=v_context_p, + audio_context=a_context_p, + transformer=transformer, # noqa: F821 + ), + ) + + stage_2_output_shape = VideoPixelShape( + batch=1, frames=num_frames, width=width * 2, height=height * 2, fps=frame_rate + ) + stage_2_conditionings = utils.image_conditionings_by_adding_guiding_latent( + images=images, + height=stage_2_output_shape.height, + width=stage_2_output_shape.width, + video_encoder=video_encoder, + dtype=dtype, + device=self.device, + ) + video_state, audio_state = denoise_audio_video( + output_shape=stage_2_output_shape, + conditionings=stage_2_conditionings, + noiser=noiser, + sigmas=distilled_sigmas, + stepper=stepper, + denoising_loop_fn=second_stage_denoising_loop, + components=self.pipeline_components, + dtype=dtype, + device=self.device, + noise_scale=distilled_sigmas[0], + initial_video_latent=upscaled_video_latent, + initial_audio_latent=audio_state.latent, + ) + + torch.cuda.synchronize() + del transformer + del video_encoder + utils.cleanup_memory() + + decoded_video = vae_decode_video(video_state, self.stage_2_model_ledger.video_decoder(), tiling_config) + + decoded_audio = vae_decode_audio( + audio_state, self.stage_2_model_ledger.audio_decoder(), self.stage_2_model_ledger.vocoder() + ) + + encode_video( + video=decoded_video, + fps=frame_rate, + audio=decoded_audio, + audio_sample_rate=AUDIO_SAMPLE_RATE, + output_path=output_path, + ) + + +def main() -> None: + parser = utils.default_2_stage_arg_parser() + args = parser.parse_args() + lora_strengths = (args.lora_strength + [DEFAULT_LORA_STRENGTH] * len(args.lora))[: len(args.lora)] + loras = [ + LoraPathStrengthAndSDOps(lora, strength, LTXV_LORA_COMFY_RENAMING_MAP) + for lora, strength in zip(args.lora, lora_strengths, strict=True) + ] + pipeline = KeyframeInterpolationPipeline( + checkpoint_path=args.checkpoint_path, + distilled_lora_path=args.distilled_lora_path, + distilled_lora_strength=args.distilled_lora_strength, + spatial_upsampler_path=args.spatial_upsampler_path, + gemma_root=args.gemma_root, + loras=loras, + fp8transformer=args.enable_fp8, + ) + pipeline( + prompt=args.prompt, + output_path=args.output_path, + negative_prompt=args.negative_prompt, + seed=args.seed, + height=args.height, + width=args.width, + num_frames=args.num_frames, + frame_rate=args.frame_rate, + num_inference_steps=args.num_inference_steps, + cfg_guidance_scale=args.cfg_guidance_scale, + images=args.images, + tiling_config=TilingConfig.default(), + ) + + +if __name__ == "__main__": + main() diff --git a/packages/ltx-pipelines/src/ltx_pipelines/media_io.py b/packages/ltx-pipelines/src/ltx_pipelines/media_io.py new file mode 100644 index 0000000000000000000000000000000000000000..e3fcff122022548a78eff9405f32250f4231da5c --- /dev/null +++ b/packages/ltx-pipelines/src/ltx_pipelines/media_io.py @@ -0,0 +1,277 @@ +import math +from collections.abc import Generator +from fractions import Fraction +from io import BytesIO + +import av +import numpy as np +import torch +from einops import rearrange +from PIL import Image +from torch._prims_common import DeviceLikeType +from tqdm import tqdm + +from ltx_pipelines.constants import DEFAULT_IMAGE_CRF + + +def resize_and_center_crop(latent: torch.Tensor, height: int, width: int) -> torch.Tensor: + """Resize image preserving aspect ratio (filling target), then center crop to exact dimensions. + + Args: + latent: Input tensor with shape (H, W, C) or (F, H, W, C) + height: Target height + width: Target width + + Returns: + Tensor with shape (1, C, 1, height, width) for 3D input or (1, C, F, height, width) for 4D input + """ + if latent.ndim == 3: + latent = rearrange(latent, "h w c -> 1 c h w") + elif latent.ndim == 4: + latent = rearrange(latent, "f h w c -> f c h w") + else: + raise ValueError(f"Expected input with 3 or 4 dimensions; got shape {latent.shape}.") + + _, _, src_h, src_w = latent.shape + + scale = max(height / src_h, width / src_w) + # Use ceil to avoid floating-point rounding causing new_h/new_w to be + # slightly smaller than target, which would result in negative crop offsets. + new_h = math.ceil(src_h * scale) + new_w = math.ceil(src_w * scale) + + latent = torch.nn.functional.interpolate(latent, size=(new_h, new_w), mode="bilinear", align_corners=False) + + crop_top = (new_h - height) // 2 + crop_left = (new_w - width) // 2 + latent = latent[:, :, crop_top : crop_top + height, crop_left : crop_left + width] + + latent = rearrange(latent, "f c h w -> 1 c f h w") + return latent + + +def normalize_latent(latent: torch.Tensor, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + return (latent / 127.5 - 1.0).to(device=device, dtype=dtype) + + +def load_image_conditioning( + image_path: str, height: int, width: int, dtype: torch.dtype, device: torch.device +) -> torch.Tensor: + """ + Loads an image from a path and preprocesses it for conditioning. + Note: The image is resized to the nearest multiple of 2 for compatibility with video codecs. + """ + image = decode_image(image_path=image_path) + image = preprocess(image=image) + image = torch.tensor(image, dtype=torch.float32, device=device) + image = resize_and_center_crop(image, height, width) + image = normalize_latent(image, device, dtype) + return image + + +def load_video_conditioning( + video_path: str, height: int, width: int, frame_cap: int, dtype: torch.dtype, device: torch.device +) -> torch.Tensor: + """ + Loads a video from a path and preprocesses it for conditioning. + Note: The video is resized to the nearest multiple of 2 for compatibility with video codecs. + """ + frames = decode_video_from_file(path=video_path, frame_cap=frame_cap, device=device) + result = None + for f in frames: + frame = resize_and_center_crop(f.to(torch.float32), height, width) + frame = normalize_latent(frame, device, dtype) + result = frame if result is None else torch.cat([result, frame], dim=2) + return result + + +def decode_image(image_path: str) -> np.ndarray: + image = Image.open(image_path) + np_array = np.array(image)[..., :3] + return np_array + + +def _write_audio( + container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int +) -> None: + if samples.ndim == 1: + samples = samples[:, None] + + if samples.shape[1] != 2 and samples.shape[0] == 2: + samples = samples.T + + if samples.shape[1] != 2: + raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.") + + # Convert to int16 packed for ingestion; resampler converts to encoder fmt. + if samples.dtype != torch.int16: + samples = torch.clip(samples, -1.0, 1.0) + samples = (samples * 32767.0).to(torch.int16) + + frame_in = av.AudioFrame.from_ndarray( + samples.contiguous().reshape(1, -1).cpu().numpy(), + format="s16", + layout="stereo", + ) + frame_in.sample_rate = audio_sample_rate + + _resample_audio(container, audio_stream, frame_in) + + +def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream: + """ + Prepare the audio stream for writing. + """ + audio_stream = container.add_stream("aac", rate=audio_sample_rate) + audio_stream.codec_context.sample_rate = audio_sample_rate + audio_stream.codec_context.layout = "stereo" + audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate) + return audio_stream + + +def _resample_audio( + container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame +) -> None: + cc = audio_stream.codec_context + + # Use the encoder's format/layout/rate as the *target* + target_format = cc.format or "fltp" # AAC → usually fltp + target_layout = cc.layout or "stereo" + target_rate = cc.sample_rate or frame_in.sample_rate + + audio_resampler = av.audio.resampler.AudioResampler( + format=target_format, + layout=target_layout, + rate=target_rate, + ) + + audio_next_pts = 0 + for rframe in audio_resampler.resample(frame_in): + if rframe.pts is None: + rframe.pts = audio_next_pts + audio_next_pts += rframe.samples + rframe.sample_rate = frame_in.sample_rate + container.mux(audio_stream.encode(rframe)) + + # flush audio encoder + for packet in audio_stream.encode(): + container.mux(packet) + + +def encode_video( + video: torch.Tensor | Generator[tuple[torch.Tensor, int], None, None], + fps: int, + audio: torch.Tensor | None, + audio_sample_rate: int | None, + output_path: str, +) -> None: + if isinstance(video, torch.Tensor): + video = iter([(video, 1)]) + + first_chunk, total = next(video) + + _, height, width, _ = first_chunk.shape + + container = av.open(output_path, mode="w") + stream = container.add_stream("libx264", rate=int(fps)) + stream.width = width + stream.height = height + stream.pix_fmt = "yuv420p" + + if audio is not None: + if audio_sample_rate is None: + raise ValueError("audio_sample_rate is required when audio is provided") + + audio_stream = _prepare_audio_stream(container, audio_sample_rate) + + def all_tiles( + first_chunk: torch.Tensor, tiles_generator: Generator[tuple[torch.Tensor, int], None, None] + ) -> Generator[tuple[torch.Tensor, int], None, None]: + yield first_chunk, total + yield from tiles_generator + + for video_chunk, _ in tqdm(all_tiles(first_chunk, video), total=total): + video_chunk_cpu = video_chunk.to("cpu").numpy() + for frame_array in video_chunk_cpu: + frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") + for packet in stream.encode(frame): + container.mux(packet) + + # Flush encoder + for packet in stream.encode(): + container.mux(packet) + + if audio is not None: + _write_audio(container, audio_stream, audio, audio_sample_rate) + + container.close() + + +def decode_audio_from_file(path: str, device: torch.device) -> torch.Tensor | None: + container = av.open(path) + try: + audio = [] + audio_stream = next(s for s in container.streams if s.type == "audio") + for frame in container.decode(audio_stream): + audio.append(torch.tensor(frame.to_ndarray(), dtype=torch.float32, device=device).unsqueeze(0)) + container.close() + audio = torch.cat(audio) + except StopIteration: + audio = None + finally: + container.close() + + return audio + + +def decode_video_from_file(path: str, frame_cap: int, device: DeviceLikeType) -> Generator[torch.Tensor]: + container = av.open(path) + try: + video_stream = next(s for s in container.streams if s.type == "video") + for frame in container.decode(video_stream): + tensor = torch.tensor(frame.to_rgb().to_ndarray(), dtype=torch.uint8, device=device).unsqueeze(0) + yield tensor + frame_cap = frame_cap - 1 + if frame_cap == 0: + break + finally: + container.close() + + +def encode_single_frame(output_file: str, image_array: np.ndarray, crf: float) -> None: + container = av.open(output_file, "w", format="mp4") + try: + stream = container.add_stream("libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"}) + # Round to nearest multiple of 2 for compatibility with video codecs + height = image_array.shape[0] // 2 * 2 + width = image_array.shape[1] // 2 * 2 + image_array = image_array[:height, :width] + stream.height = height + stream.width = width + av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(format="yuv420p") + container.mux(stream.encode(av_frame)) + container.mux(stream.encode()) + finally: + container.close() + + +def decode_single_frame(video_file: str) -> np.array: + container = av.open(video_file) + try: + stream = next(s for s in container.streams if s.type == "video") + frame = next(container.decode(stream)) + finally: + container.close() + return frame.to_ndarray(format="rgb24") + + +def preprocess(image: np.array, crf: float = DEFAULT_IMAGE_CRF) -> np.array: + if crf == 0: + return image + + with BytesIO() as output_file: + encode_single_frame(output_file, image, crf) + video_bytes = output_file.getvalue() + with BytesIO(video_bytes) as video_file: + image_array = decode_single_frame(video_file) + return image_array diff --git a/packages/ltx-pipelines/src/ltx_pipelines/pipeline_utils.py b/packages/ltx-pipelines/src/ltx_pipelines/pipeline_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..15c7bf0fa662a7077ea330f2af5b97c6e38ac150 --- /dev/null +++ b/packages/ltx-pipelines/src/ltx_pipelines/pipeline_utils.py @@ -0,0 +1,442 @@ +from dataclasses import replace +from typing import Generator, Protocol + +import torch +from einops import rearrange +from tqdm import tqdm + +from ltx_core.model.audio_vae.vocoder import Vocoder +from ltx_core.model.clip.gemma.encoders.base_encoder import GemmaTextEncoderModelBase +from ltx_core.model.model_ledger import AudioDecoder, VideoDecoder +from ltx_core.model.transformer.modality import Modality +from ltx_core.model.transformer.model import X0Model +from ltx_core.pipeline.components.noisers import Noiser +from ltx_core.pipeline.components.patchifiers import AudioPatchifier, VideoLatentPatchifier +from ltx_core.pipeline.components.protocols import ( + AudioLatentShape, + DiffusionStepProtocol, + GuiderProtocol, + VideoLatentShape, + VideoPixelShape, +) +from ltx_core.pipeline.conditioning.item import ConditioningItem +from ltx_core.pipeline.conditioning.tools import AudioLatentTools, LatentState, LatentTools, VideoLatentTools +from ltx_core.tiling import ( + TilingConfig, +) +from ltx_core.utils import to_denoised, to_velocity +from ltx_pipelines.constants import VIDEO_LATENT_CHANNELS, VIDEO_SCALE_FACTORS + + +class PipelineComponents: + def __init__( + self, + dtype: torch.dtype, + device: torch.device, + ): + self.dtype = dtype + self.device = device + + self.video_scale_factors = VIDEO_SCALE_FACTORS + self.video_latent_channels = VIDEO_LATENT_CHANNELS + + self.video_patchifier = VideoLatentPatchifier(patch_size=1) + self.audio_patchifier = AudioPatchifier(patch_size=1) + + +class DenoisingFunc(Protocol): + def __call__( + self, video_state: LatentState, audio_state: LatentState, sigmas: torch.Tensor, step_index: int + ) -> tuple[torch.Tensor, torch.Tensor]: ... + + +def euler_denoising_loop( + sigmas: torch.Tensor, + video_state: LatentState, + audio_state: LatentState, + stepper: DiffusionStepProtocol, + denoise_fn: DenoisingFunc, +) -> tuple[LatentState, LatentState]: + """ + Perform the joint audio-video denoising loop over a diffusion schedule. + This function iterates over all but the final value in ``sigmas`` and, at + each diffusion step, calls ``denoise_fn`` to obtain denoised video and + audio latents. The denoised latents are post-processed with their + respective denoise masks and clean latents, then passed to ``stepper`` to + advance the noisy latents one step along the diffusion schedule. + + ### Parameters + + sigmas: + A 1D tensor of noise levels (diffusion sigmas) defining the sampling + schedule. All steps except the last element are iterated over. + video_state: + The current video :class:`LatentState`, containing the noisy latent, + its clean reference latent, and the denoising mask. + audio_state: + The current audio :class:`LatentState`, analogous to ``video_state`` + but for the audio modality. + stepper: + An implementation of :class:`DiffusionStepProtocol` that updates a + latent given the current latent, its denoised estimate, the full + ``sigmas`` schedule, and the current step index. + denoise_fn: + A callable implementing :class:`DenoisingFunc`. It is invoked as + ``denoise_fn(video_state, audio_state, sigmas, step_index)`` and must + return a tuple ``(denoised_video, denoised_audio)``, where each element + is a tensor with the same shape as the corresponding latent. + + ### Returns + + tuple[LatentState, LatentState] + A pair ``(video_state, audio_state)`` containing the final video and + audio latent states after completing the denoising loop. + """ + for step_idx, _ in enumerate(tqdm(sigmas[:-1])): + denoised_video, denoised_audio = denoise_fn(video_state, audio_state, sigmas, step_idx) + + denoised_video = post_process_latent(denoised_video, video_state.denoise_mask, video_state.clean_latent) + denoised_audio = post_process_latent(denoised_audio, audio_state.denoise_mask, audio_state.clean_latent) + + video_state = replace(video_state, latent=stepper.step(video_state.latent, denoised_video, sigmas, step_idx)) + audio_state = replace(audio_state, latent=stepper.step(audio_state.latent, denoised_audio, sigmas, step_idx)) + + return (video_state, audio_state) + + +def gradient_estimating_euler_denoising_loop( + sigmas: torch.Tensor, + video_state: LatentState, + audio_state: LatentState, + stepper: DiffusionStepProtocol, + denoise_fn: DenoisingFunc, + ge_gamma: float = 2.0, +) -> tuple[LatentState, LatentState]: + """ + Perform the joint audio-video denoising loop using gradient-estimation sampling. + + This function is similar to :func:`euler_denoising_loop`, but applies + gradient estimation to improve the denoised estimates by tracking velocity + changes across steps. See the referenced function for detailed parameter + documentation. + + ### Parameters + + ge_gamma: + Gradient estimation coefficient controlling the velocity correction term. + Default is 2.0. Paper: https://openreview.net/pdf?id=o2ND9v0CeK + + sigmas, video_state, audio_state, stepper, denoise_fn: + See :func:`euler_denoising_loop` for parameter descriptions. + + ### Returns + + tuple[LatentState, LatentState] + See :func:`euler_denoising_loop` for return value description. + """ + + previous_audio_velocity = None + previous_video_velocity = None + + def update_velocity_and_sample( + noisy_sample: torch.Tensor, denoised_sample: torch.Tensor, sigma: float, previous_velocity: torch.Tensor | None + ) -> tuple[torch.Tensor, torch.Tensor]: + current_velocity = to_velocity(noisy_sample, sigma, denoised_sample) + if previous_velocity is not None: + delta_v = current_velocity - previous_velocity + total_velocity = ge_gamma * delta_v + previous_velocity + denoised_sample = to_denoised(noisy_sample, total_velocity, sigma) + return current_velocity, denoised_sample + + for step_idx, _ in enumerate(tqdm(sigmas[:-1])): + denoised_video, denoised_audio = denoise_fn(video_state, audio_state, sigmas, step_idx) + + denoised_video = post_process_latent(denoised_video, video_state.denoise_mask, video_state.clean_latent) + denoised_audio = post_process_latent(denoised_audio, audio_state.denoise_mask, audio_state.clean_latent) + + if sigmas[step_idx + 1] == 0: + return replace(video_state, latent=denoised_video), replace(audio_state, latent=denoised_audio) + + previous_video_velocity, denoised_video = update_velocity_and_sample( + video_state.latent, denoised_video, sigmas[step_idx], previous_video_velocity + ) + previous_audio_velocity, denoised_audio = update_velocity_and_sample( + audio_state.latent, denoised_audio, sigmas[step_idx], previous_audio_velocity + ) + + video_state = replace(video_state, latent=stepper.step(video_state.latent, denoised_video, sigmas, step_idx)) + audio_state = replace(audio_state, latent=stepper.step(audio_state.latent, denoised_audio, sigmas, step_idx)) + + return (video_state, audio_state) + + +def noise_video_state( + output_shape: VideoPixelShape, + noiser: Noiser, + conditionings: list[ConditioningItem], + components: PipelineComponents, + dtype: torch.dtype, + device: torch.device, + noise_scale: float = 1.0, + initial_latent: torch.Tensor | None = None, +) -> tuple[LatentState, VideoLatentTools]: + """Initialize and noise a video latent state for the diffusion pipeline. + + Creates a video latent state from the output shape, applies conditionings, + and adds noise using the provided noiser. Returns the noised state and + video latent tools for further processing. If initial_latent is provided, it will be used to create the initial + state, otherwise an empty initial state will be created. + """ + video_latent_shape = VideoLatentShape.from_pixel_shape( + shape=output_shape, + latent_channels=components.video_latent_channels, + scale_factors=components.video_scale_factors, + ) + video_tools = VideoLatentTools(components.video_patchifier, video_latent_shape, output_shape.fps) + video_state = create_noised_state( + tools=video_tools, + conditionings=conditionings, + noiser=noiser, + dtype=dtype, + device=device, + noise_scale=noise_scale, + initial_latent=initial_latent, + ) + + return video_state, video_tools + + +def noise_audio_state( + output_shape: VideoPixelShape, + noiser: Noiser, + conditionings: list[ConditioningItem], + components: PipelineComponents, + dtype: torch.dtype, + device: torch.device, + noise_scale: float = 1.0, + initial_latent: torch.Tensor | None = None, +) -> tuple[LatentState, AudioLatentTools]: + """Initialize and noise an audio latent state for the diffusion pipeline. + + Creates an audio latent state from the output shape, applies conditionings, + and adds noise using the provided noiser. Returns the noised state and + audio latent tools for further processing. If initial_latent is provided, it will be used to create the initial + state, otherwise an empty initial state will be created. + """ + audio_latent_shape = AudioLatentShape.from_video_pixel_shape(output_shape) + audio_tools = AudioLatentTools(components.audio_patchifier, audio_latent_shape) + audio_state = create_noised_state( + tools=audio_tools, + conditionings=conditionings, + noiser=noiser, + dtype=dtype, + device=device, + noise_scale=noise_scale, + initial_latent=initial_latent, + ) + + return audio_state, audio_tools + + +def create_noised_state( + tools: LatentTools, + conditionings: list[ConditioningItem], + noiser: Noiser, + dtype: torch.dtype, + device: torch.device, + noise_scale: float = 1.0, + initial_latent: torch.Tensor | None = None, +) -> LatentState: + """Create a noised latent state from empty state, conditionings, and noiser. + + Creates an empty latent state, applies conditionings, and then adds noise + using the provided noiser. Returns the final noised state ready for diffusion. + """ + state = tools.create_initial_state(device, dtype, initial_latent) + state = state_with_conditionings(state, conditionings, tools) + state = noiser(state, noise_scale) + + return state + + +def state_with_conditionings( + latent_state: LatentState, conditioning_items: list[ConditioningItem], latent_tools: LatentTools +) -> LatentState: + """Apply a list of conditionings to a latent state. + + Iterates through the conditioning items and applies each one to the latent + state in sequence. Returns the modified state with all conditionings applied. + """ + for conditioning in conditioning_items: + latent_state = conditioning.apply_to(latent_state=latent_state, latent_tools=latent_tools) + + return latent_state + + +def post_process_latent(denoised: torch.Tensor, denoise_mask: torch.Tensor, clean: torch.Tensor) -> torch.Tensor: + """Blend denoised output with clean state based on mask.""" + return (denoised * denoise_mask + clean.float() * (1 - denoise_mask)).to(denoised.dtype) + + +def modality_from_latent_state( + state: LatentState, context: torch.Tensor, sigma: float | torch.Tensor, enabled: bool = True +) -> Modality: + """Create a Modality from a latent state. + + Constructs a Modality object with the latent state's data, timesteps derived + from the denoise mask and sigma, positions, and the provided context. + """ + return Modality( + enabled=enabled, + latent=state.latent, + timesteps=timesteps_from_mask(state.denoise_mask, sigma), + positions=state.positions, + context=context, + context_mask=None, + ) + + +def timesteps_from_mask(denoise_mask: torch.Tensor, sigma: float | torch.Tensor) -> torch.Tensor: + """Compute timesteps from a denoise mask and sigma value. + + Multiplies the denoise mask by sigma to produce timesteps for each position + in the latent state. Areas where the mask is 0 will have zero timesteps. + """ + return denoise_mask * sigma + + +@torch.inference_mode() +def decode_video( + video_state: LatentState, + video_decoder: VideoDecoder, + tiling_config: TilingConfig | None = None, +) -> Generator[torch.Tensor, None, None]: + def convert_to_uint8(frames: torch.Tensor) -> torch.Tensor: + frames = (((frames + 1.0) / 2.0).clamp(0.0, 1.0) * 255.0).to(torch.uint8) + frames = rearrange(frames[0], "c f h w -> f h w c") + return frames + + if tiling_config is not None: + for frames, total in video_decoder.tiled_decode(video_state.latent[:1], tiling_config): + yield convert_to_uint8(frames), total + else: + decoded_video = video_decoder(video_state.latent[:1]) + yield convert_to_uint8(decoded_video), 1 + + +@torch.inference_mode() +def decode_audio(audio_state: LatentState, audio_decoder: AudioDecoder, vocoder: Vocoder) -> torch.Tensor: + decoded_audio = audio_decoder(audio_state.latent[:1]) + decoded_audio = vocoder(decoded_audio).squeeze(0).float() + return decoded_audio + + +@torch.inference_mode() +def encode_text(text_encoder: GemmaTextEncoderModelBase, prompts: list[str]) -> list[tuple[torch.Tensor, torch.Tensor]]: + result = [] + for prompt in prompts: + v_context, a_context, _ = text_encoder(prompt) + result.append((v_context, a_context)) + return result + + +def simple_denoising_func( + video_context: torch.Tensor, audio_context: torch.Tensor, transformer: X0Model +) -> DenoisingFunc: + def simple_denoising_step( + video_state: LatentState, audio_state: LatentState, sigmas: torch.Tensor, step_index: int + ) -> tuple[torch.Tensor, torch.Tensor]: + sigma = sigmas[step_index] + pos_video = modality_from_latent_state(video_state, video_context, sigma) + pos_audio = modality_from_latent_state(audio_state, audio_context, sigma) + + denoised_video, denoised_audio = transformer(video=pos_video, audio=pos_audio, perturbations=None) + return denoised_video, denoised_audio + + return simple_denoising_step + + +def guider_denoising_func( + guider: GuiderProtocol, + v_context_p: torch.Tensor, + v_context_n: torch.Tensor, + a_context_p: torch.Tensor, + a_context_n: torch.Tensor, + transformer: X0Model, +) -> DenoisingFunc: + def guider_denoising_step( + video_state: LatentState, audio_state: LatentState, sigmas: torch.Tensor, step_index: int + ) -> tuple[torch.Tensor, torch.Tensor]: + sigma = sigmas[step_index] + pos_video = modality_from_latent_state(video_state, v_context_p, sigma) + pos_audio = modality_from_latent_state(audio_state, a_context_p, sigma) + + denoised_video, denoised_audio = transformer(video=pos_video, audio=pos_audio, perturbations=None) + if guider.enabled(): + neg_video = modality_from_latent_state(video_state, v_context_n, sigma) + neg_audio = modality_from_latent_state(audio_state, a_context_n, sigma) + + neg_denoised_video, neg_denoised_audio = transformer(video=neg_video, audio=neg_audio, perturbations=None) + + denoised_video = denoised_video + guider.delta(denoised_video, neg_denoised_video) + denoised_audio = denoised_audio + guider.delta(denoised_audio, neg_denoised_audio) + + return denoised_video, denoised_audio + + return guider_denoising_step + + +class DenoisingLoopFunc(Protocol): + def __call__( + self, sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol + ) -> tuple[torch.Tensor, torch.Tensor]: ... + + +def denoise_audio_video( # noqa: PLR0913 + output_shape: VideoPixelShape, + conditionings: list[ConditioningItem], + noiser: Noiser, + sigmas: torch.Tensor, + stepper: DiffusionStepProtocol, + denoising_loop_fn: DenoisingLoopFunc, + components: PipelineComponents, + dtype: torch.dtype, + device: torch.device, + noise_scale: float = 1.0, + initial_video_latent: torch.Tensor | None = None, + initial_audio_latent: torch.Tensor | None = None, +) -> tuple[LatentState, LatentState]: + video_state, video_tools = noise_video_state( + output_shape=output_shape, + noiser=noiser, + conditionings=conditionings, + components=components, + dtype=dtype, + device=device, + noise_scale=noise_scale, + initial_latent=initial_video_latent, + ) + audio_state, audio_tools = noise_audio_state( + output_shape=output_shape, + noiser=noiser, + conditionings=[], + components=components, + dtype=dtype, + device=device, + noise_scale=noise_scale, + initial_latent=initial_audio_latent, + ) + + video_state, audio_state = denoising_loop_fn( + sigmas, + video_state, + audio_state, + stepper, + ) + + video_state = video_tools.clear_conditioning(video_state) + video_state = video_tools.unpatchify(video_state) + audio_state = audio_tools.clear_conditioning(audio_state) + audio_state = audio_tools.unpatchify(audio_state) + + return video_state, audio_state diff --git a/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_one_stage.py b/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_one_stage.py new file mode 100644 index 0000000000000000000000000000000000000000..98ee9e23027424fd6127db2dd2bc124e9aca0bfe --- /dev/null +++ b/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_one_stage.py @@ -0,0 +1,180 @@ +import torch + +from ltx_core.loader.primitives import LoraPathStrengthAndSDOps +from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP +from ltx_core.model.model_ledger import ModelLedger +from ltx_core.pipeline.components.diffusion_steps import EulerDiffusionStep +from ltx_core.pipeline.components.guiders import CFGGuider +from ltx_core.pipeline.components.noisers import GaussianNoiser +from ltx_core.pipeline.components.protocols import DiffusionStepProtocol, VideoPixelShape +from ltx_core.pipeline.components.schedulers import LTX2Scheduler +from ltx_core.pipeline.conditioning.item import LatentState +from ltx_pipelines import utils +from ltx_pipelines.constants import ( + AUDIO_SAMPLE_RATE, + DEFAULT_LORA_STRENGTH, +) +from ltx_pipelines.media_io import encode_video +from ltx_pipelines.pipeline_utils import ( + PipelineComponents, + denoise_audio_video, + encode_text, + euler_denoising_loop, + guider_denoising_func, +) +from ltx_pipelines.pipeline_utils import ( + decode_audio as vae_decode_audio, +) +from ltx_pipelines.pipeline_utils import ( + decode_video as vae_decode_video, +) +from ltx_pipelines.utils import image_conditionings_by_replacing_latent + +device = utils.get_device() + + +class TI2VidOneStagePipeline: + def __init__( + self, + checkpoint_path: str, + gemma_root: str, + loras: list[LoraPathStrengthAndSDOps], + device: torch.device = device, + fp8transformer: bool = False, + local_files_only: bool = True, + ): + self.dtype = torch.bfloat16 + self.device = device + self.model_ledger = ModelLedger( + dtype=self.dtype, + device=device, + checkpoint_path=checkpoint_path, + gemma_root_path=gemma_root, + loras=loras, + fp8transformer=fp8transformer, + local_files_only=local_files_only + ) + self.pipeline_components = PipelineComponents( + dtype=self.dtype, + device=device, + ) + + @torch.inference_mode() + def __call__( # noqa: PLR0913 + self, + prompt: str, + output_path: str, + negative_prompt: str, + seed: int, + height: int, + width: int, + num_frames: int, + frame_rate: float, + num_inference_steps: int, + cfg_guidance_scale: float, + images: list[tuple[str, int, float]], + ) -> None: + generator = torch.Generator(device=self.device).manual_seed(seed) + noiser = GaussianNoiser(generator=generator) + stepper = EulerDiffusionStep() + cfg_guider = CFGGuider(cfg_guidance_scale) + dtype = torch.bfloat16 + + text_encoder = self.model_ledger.text_encoder() + context_p, context_n = encode_text(text_encoder, prompts=[prompt, negative_prompt]) + v_context_p, a_context_p = context_p + v_context_n, a_context_n = context_n + + torch.cuda.synchronize() + del text_encoder + utils.cleanup_memory() + + # Stage 1: Initial low resolution video generation. + video_encoder = self.model_ledger.video_encoder() + transformer = self.model_ledger.transformer() + sigmas = LTX2Scheduler().execute(steps=num_inference_steps).to(dtype=torch.float32, device=self.device) + + def first_stage_denoising_loop( + sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol + ) -> tuple[LatentState, LatentState]: + return euler_denoising_loop( + sigmas=sigmas, + video_state=video_state, + audio_state=audio_state, + stepper=stepper, + denoise_fn=guider_denoising_func( + cfg_guider, + v_context_p, + v_context_n, + a_context_p, + a_context_n, + transformer=transformer, # noqa: F821 + ), + ) + + stage_1_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate) + stage_1_conditionings = image_conditionings_by_replacing_latent( + images=images, + height=stage_1_output_shape.height, + width=stage_1_output_shape.width, + video_encoder=video_encoder, + dtype=dtype, + device=self.device, + ) + + video_state, audio_state = denoise_audio_video( + output_shape=stage_1_output_shape, + conditionings=stage_1_conditionings, + noiser=noiser, + sigmas=sigmas, + stepper=stepper, + denoising_loop_fn=first_stage_denoising_loop, + components=self.pipeline_components, + dtype=dtype, + device=self.device, + ) + + torch.cuda.synchronize() + del transformer + utils.cleanup_memory() + + decoded_video = vae_decode_video(video_state, self.model_ledger.video_decoder()) + decoded_audio = vae_decode_audio(audio_state, self.model_ledger.audio_decoder(), self.model_ledger.vocoder()) + + encode_video( + video=decoded_video, + fps=frame_rate, + audio=decoded_audio, + audio_sample_rate=AUDIO_SAMPLE_RATE, + output_path=output_path, + ) + + +def main() -> None: + parser = utils.default_1_stage_arg_parser() + args = parser.parse_args() + lora_strengths = (args.lora_strength + [DEFAULT_LORA_STRENGTH] * len(args.lora))[: len(args.lora)] + loras = [ + LoraPathStrengthAndSDOps(lora, strength, LTXV_LORA_COMFY_RENAMING_MAP) + for lora, strength in zip(args.lora, lora_strengths, strict=True) + ] + pipeline = TI2VidOneStagePipeline( + checkpoint_path=args.checkpoint_path, gemma_root=args.gemma_root, loras=loras, fp8transformer=args.enable_fp8 + ) + pipeline( + prompt=args.prompt, + output_path=args.output_path, + negative_prompt=args.negative_prompt, + seed=args.seed, + height=args.height, + width=args.width, + num_frames=args.num_frames, + frame_rate=args.frame_rate, + num_inference_steps=args.num_inference_steps, + cfg_guidance_scale=args.cfg_guidance_scale, + images=args.images, + ) + + +if __name__ == "__main__": + main() diff --git a/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py b/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py new file mode 100644 index 0000000000000000000000000000000000000000..1f3ebd1c1b14dc9a7e6d2737dcab5301ae7f555e --- /dev/null +++ b/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py @@ -0,0 +1,266 @@ +import torch + +from ltx_core.loader.primitives import LoraPathStrengthAndSDOps +from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP +from ltx_core.model.model_ledger import ModelLedger +from ltx_core.pipeline.components.diffusion_steps import EulerDiffusionStep +from ltx_core.pipeline.components.guiders import CFGGuider +from ltx_core.pipeline.components.noisers import GaussianNoiser +from ltx_core.pipeline.components.protocols import DiffusionStepProtocol, VideoPixelShape +from ltx_core.pipeline.components.schedulers import LTX2Scheduler +from ltx_core.pipeline.conditioning.item import LatentState +from ltx_core.tiling import TilingConfig +from ltx_pipelines import utils +from ltx_pipelines.constants import ( + AUDIO_SAMPLE_RATE, + DEFAULT_LORA_STRENGTH, + STAGE_2_DISTILLED_SIGMA_VALUES, +) +from ltx_pipelines.media_io import encode_video +from ltx_pipelines.pipeline_utils import ( + PipelineComponents, + denoise_audio_video, + encode_text, + euler_denoising_loop, + guider_denoising_func, + simple_denoising_func, +) +from ltx_pipelines.pipeline_utils import ( + decode_audio as vae_decode_audio, +) +from ltx_pipelines.pipeline_utils import ( + decode_video as vae_decode_video, +) +from ltx_pipelines.utils import image_conditionings_by_replacing_latent + + +class TI2VidTwoStagesPipeline: + def __init__( + self, + checkpoint_path: str, + distilled_lora_path: str, + distilled_lora_strength: float, + spatial_upsampler_path: str, + gemma_root: str, + loras: list[LoraPathStrengthAndSDOps], + device: str = utils.get_device(), + fp8transformer: bool = False, + local_files_only: bool = True, + ): + self.device = device + self.dtype = torch.bfloat16 + self.stage_1_model_ledger = ModelLedger( + dtype=self.dtype, + device=device, + checkpoint_path=checkpoint_path, + gemma_root_path=gemma_root, + spatial_upsampler_path=spatial_upsampler_path, + loras=loras, + fp8transformer=fp8transformer, + local_files_only=local_files_only + ) + + self.stage_2_model_ledger = self.stage_1_model_ledger.with_loras( + loras=[ + LoraPathStrengthAndSDOps( + path=distilled_lora_path, + strength=distilled_lora_strength, + sd_ops=LTXV_LORA_COMFY_RENAMING_MAP, + ) + ], + ) + + self.pipeline_components = PipelineComponents( + dtype=self.dtype, + device=device, + ) + + @torch.inference_mode() + def __call__( # noqa: PLR0913 + self, + prompt: str, + output_path: str, + negative_prompt: str, + seed: int, + height: int, + width: int, + num_frames: int, + frame_rate: float, + num_inference_steps: int, + cfg_guidance_scale: float, + images: list[tuple[str, int, float]], + tiling_config: TilingConfig | None = None, + ) -> None: + generator = torch.Generator(device=self.device).manual_seed(seed) + noiser = GaussianNoiser(generator=generator) + stepper = EulerDiffusionStep() + cfg_guider = CFGGuider(cfg_guidance_scale) + dtype = torch.bfloat16 + + text_encoder = self.stage_1_model_ledger.text_encoder() + context_p, context_n = encode_text(text_encoder, prompts=[prompt, negative_prompt]) + v_context_p, a_context_p = context_p + v_context_n, a_context_n = context_n + + torch.cuda.synchronize() + del text_encoder + utils.cleanup_memory() + + # Stage 1: Initial low resolution video generation. + video_encoder = self.stage_1_model_ledger.video_encoder() + transformer = self.stage_1_model_ledger.transformer() + sigmas = LTX2Scheduler().execute(steps=num_inference_steps).to(dtype=torch.float32, device=self.device) + + def first_stage_denoising_loop( + sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol + ) -> tuple[LatentState, LatentState]: + return euler_denoising_loop( + sigmas=sigmas, + video_state=video_state, + audio_state=audio_state, + stepper=stepper, + denoise_fn=guider_denoising_func( + cfg_guider, + v_context_p, + v_context_n, + a_context_p, + a_context_n, + transformer=transformer, # noqa: F821 + ), + ) + + stage_1_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate) + stage_1_conditionings = image_conditionings_by_replacing_latent( + images=images, + height=stage_1_output_shape.height, + width=stage_1_output_shape.width, + video_encoder=video_encoder, + dtype=dtype, + device=self.device, + ) + video_state, audio_state = denoise_audio_video( + output_shape=stage_1_output_shape, + conditionings=stage_1_conditionings, + noiser=noiser, + sigmas=sigmas, + stepper=stepper, + denoising_loop_fn=first_stage_denoising_loop, + components=self.pipeline_components, + dtype=dtype, + device=self.device, + ) + + torch.cuda.synchronize() + del transformer + utils.cleanup_memory() + + # Stage 2: Upsample and refine the video at higher resolution with distilled LORA. + upscaled_video_latent = utils.upsample_video( + latent=video_state.latent[:1], + video_encoder=video_encoder, + upsampler=self.stage_2_model_ledger.spatial_upsampler(), + ) + + torch.cuda.synchronize() + utils.cleanup_memory() + + transformer = self.stage_2_model_ledger.transformer() + distilled_sigmas = torch.Tensor(STAGE_2_DISTILLED_SIGMA_VALUES).to(self.device) + + def second_stage_denoising_loop( + sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol + ) -> tuple[LatentState, LatentState]: + return euler_denoising_loop( + sigmas=sigmas, + video_state=video_state, + audio_state=audio_state, + stepper=stepper, + denoise_fn=simple_denoising_func( + video_context=v_context_p, + audio_context=a_context_p, + transformer=transformer, # noqa: F821 + ), + ) + + stage_2_output_shape = VideoPixelShape( + batch=1, frames=num_frames, width=width * 2, height=height * 2, fps=frame_rate + ) + stage_2_conditionings = image_conditionings_by_replacing_latent( + images=images, + height=stage_2_output_shape.height, + width=stage_2_output_shape.width, + video_encoder=video_encoder, + dtype=dtype, + device=self.device, + ) + video_state, audio_state = denoise_audio_video( + output_shape=stage_2_output_shape, + conditionings=stage_2_conditionings, + noiser=noiser, + sigmas=distilled_sigmas, + stepper=stepper, + denoising_loop_fn=second_stage_denoising_loop, + components=self.pipeline_components, + dtype=dtype, + device=self.device, + noise_scale=distilled_sigmas[0], + initial_video_latent=upscaled_video_latent, + initial_audio_latent=audio_state.latent, + ) + + torch.cuda.synchronize() + del transformer + del video_encoder + utils.cleanup_memory() + + decoded_video = vae_decode_video(video_state, self.stage_2_model_ledger.video_decoder(), tiling_config) + + decoded_audio = vae_decode_audio( + audio_state, self.stage_2_model_ledger.audio_decoder(), self.stage_2_model_ledger.vocoder() + ) + + encode_video( + video=decoded_video, + fps=frame_rate, + audio=decoded_audio, + audio_sample_rate=AUDIO_SAMPLE_RATE, + output_path=output_path, + ) + + +def main() -> None: + parser = utils.default_2_stage_arg_parser() + args = parser.parse_args() + lora_strengths = (args.lora_strength + [DEFAULT_LORA_STRENGTH] * len(args.lora))[: len(args.lora)] + loras = [ + LoraPathStrengthAndSDOps(lora, strength, LTXV_LORA_COMFY_RENAMING_MAP) + for lora, strength in zip(args.lora, lora_strengths, strict=True) + ] + pipeline = TI2VidTwoStagesPipeline( + checkpoint_path=args.checkpoint_path, + distilled_lora_path=args.distilled_lora_path, + distilled_lora_strength=args.distilled_lora_strength, + spatial_upsampler_path=args.spatial_upsampler_path, + gemma_root=args.gemma_root, + loras=loras, + fp8transformer=args.enable_fp8, + ) + + pipeline( + prompt=args.prompt, + negative_prompt=args.negative_prompt, + output_path=args.output_path, + seed=args.seed, + height=args.height, + width=args.width, + num_frames=args.num_frames, + frame_rate=args.frame_rate, + num_inference_steps=args.num_inference_steps, + cfg_guidance_scale=args.cfg_guidance_scale, + images=args.images, + tiling_config=TilingConfig.default(), + ) + + +if __name__ == "__main__": + main() diff --git a/packages/ltx-pipelines/src/ltx_pipelines/utils.py b/packages/ltx-pipelines/src/ltx_pipelines/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fe0e40f315f079d197ae4770b7fb80d6b17a7a7c --- /dev/null +++ b/packages/ltx-pipelines/src/ltx_pipelines/utils.py @@ -0,0 +1,200 @@ +# Copyright (c) 2025 Lightricks. All rights reserved. +# Created by Amit Pintz. + +import argparse +import gc +import os +from pathlib import Path + +import torch + +from ltx_core.loader.primitives import LoraPathStrengthAndSDOps +from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP +from ltx_core.model.upsampler.model import LatentUpsampler +from ltx_core.model.video_vae.video_vae import Encoder as VideoEncoder +from ltx_core.pipeline.conditioning.item import ConditioningItem +from ltx_core.pipeline.conditioning.types.keyframe_cond import VideoConditionByKeyframeIndex +from ltx_core.pipeline.conditioning.types.latent_cond import VideoConditionByLatentIndex +from ltx_pipelines.constants import ( + DEFAULT_CFG_GUIDANCE_SCALE, + DEFAULT_FRAME_RATE, + DEFAULT_HEIGHT, + DEFAULT_LORA_STRENGTH, + DEFAULT_NEGATIVE_PROMPT, + DEFAULT_NUM_FRAMES, + DEFAULT_NUM_INFERENCE_STEPS, + DEFAULT_SEED, + DEFAULT_WIDTH, +) +from ltx_pipelines.media_io import load_image_conditioning + +DO_EXPAND_PATH = os.getenv("EXPAND_PATH", False) + +def get_device() -> torch.device: + if torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + + +def cleanup_memory() -> None: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + + +def resolve_path(path: str) -> str: + return str(Path(path).expanduser().resolve().as_posix()) if DO_EXPAND_PATH else path + + +class VideoConditioningAction(argparse.Action): + def __call__( + self, + parser: argparse.ArgumentParser, # noqa: ARG002 + namespace: argparse.Namespace, + values: list[str], + option_string: str | None = None, # noqa: ARG002 + ) -> None: + path, strength_str = values + strength = float(strength_str) + current = getattr(namespace, self.dest) or [] + current.append((path, strength)) + setattr(namespace, self.dest, current) + + +class ImageAction(argparse.Action): + def __call__( + self, + parser: argparse.ArgumentParser, # noqa: ARG002 + namespace: argparse.Namespace, + values: list[str], + option_string: str | None = None, # noqa: ARG002 + ) -> None: + path, frame_idx, strength_str = values + frame_idx = int(frame_idx) + strength = float(strength_str) + current = getattr(namespace, self.dest) or [] + current.append((path, frame_idx, strength)) + setattr(namespace, self.dest, current) + + +class LoraAction(argparse.Action): + def __call__( + self, + parser: argparse.ArgumentParser, # noqa: ARG002 + namespace: argparse.Namespace, + values: list[str], + option_string: str | None = None, # noqa: ARG002 + ) -> None: + path, strength_str = values + strength = float(strength_str) + current = getattr(namespace, self.dest) or [] + current.append(LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)) + setattr(namespace, self.dest, current) + + +def image_conditionings_by_replacing_latent( + images: list[tuple[str, int, float]], + height: int, + width: int, + video_encoder: VideoEncoder, + dtype: torch.dtype, + device: torch.device, +) -> list[ConditioningItem]: + conditionings = [] + for image_path, frame_idx, strength in images: + image = load_image_conditioning( + image_path=image_path, + height=height, + width=width, + dtype=dtype, + device=device, + ) + encoded_image = video_encoder(image) + conditionings.append( + VideoConditionByLatentIndex( + latent=encoded_image, + strength=strength, + latent_idx=frame_idx, + ) + ) + + return conditionings + + +def image_conditionings_by_adding_guiding_latent( + images: list[tuple[str, int, float]], + height: int, + width: int, + video_encoder: VideoEncoder, + dtype: torch.dtype, + device: torch.device, +) -> list[ConditioningItem]: + conditionings = [] + for image_path, frame_idx, strength in images: + image = load_image_conditioning( + image_path=image_path, + height=height, + width=width, + dtype=dtype, + device=device, + ) + encoded_image = video_encoder(image) + conditionings.append( + VideoConditionByKeyframeIndex(keyframes=encoded_image, frame_idx=frame_idx, strength=strength) + ) + return conditionings + + +def upsample_video(latent: torch.Tensor, video_encoder: VideoEncoder, upsampler: LatentUpsampler) -> torch.Tensor: + latent = video_encoder.per_channel_statistics.un_normalize(latent) + latent = upsampler(latent) + latent = video_encoder.per_channel_statistics.normalize(latent) + return latent + + +def basic_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint_path", type=resolve_path, required=True) + parser.add_argument("--gemma_root", type=resolve_path, required=True) + parser.add_argument("--prompt", type=str, required=True) + parser.add_argument("--output_path", type=resolve_path, required=True) + parser.add_argument("--seed", type=int, default=DEFAULT_SEED) + parser.add_argument("--height", type=int, default=DEFAULT_HEIGHT) + parser.add_argument("--width", type=int, default=DEFAULT_WIDTH) + parser.add_argument("--num_frames", type=int, default=DEFAULT_NUM_FRAMES) + parser.add_argument("--frame_rate", type=float, default=DEFAULT_FRAME_RATE) + parser.add_argument("--num_inference_steps", type=int, default=DEFAULT_NUM_INFERENCE_STEPS) + parser.add_argument( + "--image", + dest="images", + action=ImageAction, + nargs=3, + metavar=("PATH", "FRAME_IDX", "STRENGTH"), + default=[], + ) + parser.add_argument("--lora", type=resolve_path, action="append", default=[]) + parser.add_argument("--lora_strength", type=float, action="append", default=[]) + parser.add_argument("--enable_fp8", action="store_true") + return parser + + +def default_1_stage_arg_parser() -> argparse.ArgumentParser: + parser = basic_arg_parser() + parser.add_argument("--cfg_guidance_scale", type=float, default=DEFAULT_CFG_GUIDANCE_SCALE) + parser.add_argument("--negative_prompt", type=str, default=DEFAULT_NEGATIVE_PROMPT) + + return parser + + +def default_2_stage_arg_parser() -> argparse.ArgumentParser: + parser = default_1_stage_arg_parser() + parser.add_argument("--distilled_lora_path", type=resolve_path, required=True) + parser.add_argument("--distilled_lora_strength", type=float, default=DEFAULT_LORA_STRENGTH) + parser.add_argument("--spatial_upsampler_path", type=resolve_path, required=True) + return parser + + +def default_2_stage_distilled_arg_parser() -> argparse.ArgumentParser: + parser = basic_arg_parser() + parser.add_argument("--spatial_upsampler_path", type=resolve_path, required=True) + return parser diff --git a/packages/ltx-pipelines/tests/__init__.py b/packages/ltx-pipelines/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/packages/ltx-pipelines/tests/conftest.py b/packages/ltx-pipelines/tests/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..961f234f26ce0f236390c55ae302c71eed8bf501 --- /dev/null +++ b/packages/ltx-pipelines/tests/conftest.py @@ -0,0 +1,155 @@ +import gc +import os +from pathlib import Path +from typing import Callable + +import av +import pytest +import torch +import torch.nn.functional as F +from torch._prims_common import DeviceLikeType + +# ============================================================================= +# Model Paths +# ============================================================================= + +MODELS_PATH = Path(os.getenv("MODELS_PATH", "/models")) +CHECKPOINTS_DIR = MODELS_PATH / "comfyui_models" / "checkpoints" +LORAS_DIR = MODELS_PATH / "comfyui_models" / "loras" + +GEMMA_ROOT = MODELS_PATH / "comfyui_models" / "text_encoders" / "gemma-3-12b-it-qat-q4_0-unquantized_readout_proj" +DISTILLED_CHECKPOINT_PATH = CHECKPOINTS_DIR / "ltx-2-19b-distilled.safetensors" +AV_CHECKPOINT_SPLIT_PATH = CHECKPOINTS_DIR / "ltx-2-19b-dev.safetensors" +SPATIAL_UPSAMPLER_PATH = CHECKPOINTS_DIR / "ltx2-spatial-upscaler-x2-1.0.bf16.safetensors" +DISTILLED_LORA_PATH = LORAS_DIR / "ltxv" / "ltx2" / "ltx-av-distilled-from-42500-lora-384_comfy.safetensors" + +# ============================================================================= +# Prompts +# ============================================================================= + +IMG2VID_PROMPT = ( + "A medium close-up shot features a Caucasian man with a beard, wearing a green and white baseball cap " + "without any letters on the front, and a light blue shirt over a white t-shirt. He is positioned in the " + "center of the frame, looking intently directly at the camera, his eyes focused on camera. His facial " + "expression is one of deep concentration, with his brow slightly raised. As he looks straight at the " + "camera, a quick sniff sound is heard, and then he speaks with a deep male voice and a satisfied tone, " + "saying, 'I think it's so good.' The camera remains static throughout, maintaining a shallow depth of " + "field, which keeps the man in sharp focus while the background is softly blurred, showing a beige wall " + "behind him. After a brief pause, another short, audible sniff is heard. The man then continues to speak, " + "his voice maintaining the same quality, as he states, 'So good. So good.' He elaborates further, " + "emphasizing his point with a final statement, 'This got to be, it's got to be the best tool I've ever " + "seen.'" +) + +DEFAULT_NEGATIVE_PROMPT = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) + +torch.use_deterministic_algorithms(True) +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + +ROOT_DIR = Path(__file__).parent +OUTPUT_DIR = ROOT_DIR / "output" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) +ASSETS_DIR = ROOT_DIR / "assets" + + +@pytest.fixture(autouse=True) +def pre_post_test() -> None: + """Fixture that runs before and after each test.""" + gc.collect() + torch.cuda.empty_cache() + + yield + + gc.collect() + torch.cuda.empty_cache() + + +def _psnr(pred: torch.Tensor, target: torch.Tensor, max_val: float = 1.0, eps: float = 1e-8) -> torch.Tensor: + """ + Compute Peak Signal-to-Noise Ratio (PSNR) between two images (or batches of images). + + Args: + pred: Predicted image tensor, shape (..., H, W) or (..., C, H, W) + target: Ground truth image tensor, same shape as `pred` + max_val: Maximum possible pixel value of the images. + For images in [0, 1] use 1.0, for [0, 255] use 255.0, etc. + eps: Small value to avoid log of zero. + + Returns: + psnr: PSNR value (in dB). + """ + # Ensure same shape + if pred.shape != target.shape: + raise ValueError(f"Shape mismatch: pred {pred.shape}, target {target.shape}") + + # Convert to float for safety + pred = pred.float() + target = target.float() + + # Mean squared error per sample + # Flatten over all dims + if pred.dim() > 1: + mse = F.mse_loss(pred, target, reduction="none") + # Reduce over spatial (and channel) dims + dims = list(range(mse.dim())) + mse = mse.mean(dim=dims) + else: + # 1D case + mse = F.mse_loss(pred, target, reduction="mean") + + # PSNR computation + psnr_val = 10.0 * torch.log10((max_val**2) / (mse + eps)) + + return psnr_val + + +@pytest.fixture +def psnr() -> Callable[[torch.Tensor, torch.Tensor, float, float], float]: + """Fixture that returns the PSNR function.""" + return _psnr + + +def _decode_video_from_file(path: str, device: DeviceLikeType) -> tuple[torch.Tensor, torch.Tensor | None]: + container = av.open(path) + try: + video_stream = next(s for s in container.streams if s.type == "video") + audio_stream = next((s for s in container.streams if s.type == "audio"), None) + + frames = [] + audio = [] if audio_stream else None + + streams_to_decode = [video_stream] + if audio_stream: + streams_to_decode.append(audio_stream) + + for frame in container.decode(*streams_to_decode): + if isinstance(frame, av.VideoFrame): + tensor = torch.tensor(frame.to_rgb().to_ndarray(), dtype=torch.uint8, device=device).unsqueeze(0) + frames.append(tensor) + elif isinstance(frame, av.AudioFrame): + audio.append(torch.tensor(frame.to_ndarray(), dtype=torch.float32, device=device).unsqueeze(0)) + + if audio: + audio = torch.cat(audio) + finally: + container.close() + + return torch.cat(frames), audio + + +@pytest.fixture +def decode_video_from_file() -> Callable[[str], tuple[torch.Tensor, torch.Tensor | None]]: + """Fixture that returns the function to decode a video from a file.""" + return _decode_video_from_file diff --git a/packages/ltx-pipelines/tests/ltx_pipelines/test_distilled.py b/packages/ltx-pipelines/tests/ltx_pipelines/test_distilled.py new file mode 100644 index 0000000000000000000000000000000000000000..8b0c2f4dbec32a747fa1a9c03132b4eb33619b82 --- /dev/null +++ b/packages/ltx-pipelines/tests/ltx_pipelines/test_distilled.py @@ -0,0 +1,99 @@ +from typing import Callable + +import pytest +import torch +from tests.conftest import ( + ASSETS_DIR, + DISTILLED_CHECKPOINT_PATH, + GEMMA_ROOT, + IMG2VID_PROMPT, + OUTPUT_DIR, + SPATIAL_UPSAMPLER_PATH, +) + +from ltx_core.tiling import TilingConfig +from ltx_pipelines.constants import DEFAULT_FRAME_RATE, DEFAULT_HEIGHT, DEFAULT_NUM_FRAMES, DEFAULT_SEED, DEFAULT_WIDTH +from ltx_pipelines.distilled import DistilledPipeline +from ltx_pipelines.utils import get_device + +device = get_device() + + +@pytest.mark.e2e +def test_img2vid_distilled( + psnr: Callable[[torch.Tensor, torch.Tensor, float, float], float], + decode_video_from_file: Callable[[str], tuple[torch.Tensor, torch.Tensor | None]], +) -> None: + """Run img2vid distilled pipeline and verify output matches expected.""" + + output_path = OUTPUT_DIR / "img2vid_distilled_output.mp4" + image_path = ASSETS_DIR / "hat.png" + + pipeline = DistilledPipeline( + checkpoint_path=DISTILLED_CHECKPOINT_PATH.resolve().as_posix(), + spatial_upsampler_path=SPATIAL_UPSAMPLER_PATH.resolve().as_posix(), + gemma_root=GEMMA_ROOT.resolve().as_posix(), + loras=[], + ) + + pipeline( + prompt=IMG2VID_PROMPT, + output_path=output_path.as_posix(), + seed=DEFAULT_SEED, + height=DEFAULT_HEIGHT, + width=DEFAULT_WIDTH, + num_frames=DEFAULT_NUM_FRAMES, + frame_rate=DEFAULT_FRAME_RATE, + images=[(image_path.as_posix(), 0, 1.0)], + tiling_config=TilingConfig.default(), + ) + + # Compare to expected output + decoded_video, waveform = decode_video_from_file(path=output_path, device=device) + expected_video, expected_waveform = decode_video_from_file( + path=ASSETS_DIR / "expected_img2vid_distilled.mp4", device=device + ) + + assert psnr(decoded_video, expected_video, 255.0, 1e-8).item() >= 100.0 + assert psnr(waveform, expected_waveform, 1.0, 1e-8).item() >= 80.0 + + output_path.unlink() + + +@pytest.mark.e2e +def test_txt2vid_distilled( + psnr: Callable[[torch.Tensor, torch.Tensor, float, float], float], + decode_video_from_file: Callable[[str], tuple[torch.Tensor, torch.Tensor | None]], +) -> None: + """Run txt2vid distilled pipeline (no image conditioning) and verify output matches expected.""" + output_path = OUTPUT_DIR / "txt2vid_distilled_output.mp4" + + pipeline = DistilledPipeline( + checkpoint_path=DISTILLED_CHECKPOINT_PATH.resolve().as_posix(), + spatial_upsampler_path=SPATIAL_UPSAMPLER_PATH.resolve().as_posix(), + gemma_root=GEMMA_ROOT.resolve().as_posix(), + loras=[], + ) + + pipeline( + prompt=IMG2VID_PROMPT, + output_path=output_path.as_posix(), + seed=DEFAULT_SEED, + height=DEFAULT_HEIGHT, + width=DEFAULT_WIDTH, + num_frames=DEFAULT_NUM_FRAMES, + frame_rate=DEFAULT_FRAME_RATE, + images=[], + tiling_config=TilingConfig.default(), + ) + + # Compare to expected output + decoded_video, waveform = decode_video_from_file(path=output_path, device=device) + expected_video, expected_waveform = decode_video_from_file( + path=ASSETS_DIR / "expected_txt2vid_distilled.mp4", device=device + ) + + assert psnr(decoded_video, expected_video, 255.0, 1e-8).item() >= 100.0 + assert psnr(waveform, expected_waveform, 1.0, 1e-8).item() >= 80.0 + + output_path.unlink() diff --git a/packages/ltx-pipelines/tests/ltx_pipelines/test_ic_lora.py b/packages/ltx-pipelines/tests/ltx_pipelines/test_ic_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..f743b53231c200ffbae64081785e2aebb481ddc5 --- /dev/null +++ b/packages/ltx-pipelines/tests/ltx_pipelines/test_ic_lora.py @@ -0,0 +1,144 @@ +from typing import Callable + +import pytest +import torch +from tests.conftest import ( + ASSETS_DIR, + AV_CHECKPOINT_SPLIT_PATH, + DEFAULT_NEGATIVE_PROMPT, + DISTILLED_LORA_PATH, + GEMMA_ROOT, + LORAS_DIR, + OUTPUT_DIR, + SPATIAL_UPSAMPLER_PATH, +) + +from ltx_core.loader.primitives import LoraPathStrengthAndSDOps +from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP +from ltx_core.tiling import TilingConfig +from ltx_pipelines.constants import ( + DEFAULT_CFG_GUIDANCE_SCALE, + DEFAULT_FRAME_RATE, + DEFAULT_HEIGHT, + DEFAULT_NUM_FRAMES, + DEFAULT_NUM_INFERENCE_STEPS, + DEFAULT_SEED, + DEFAULT_WIDTH, +) +from ltx_pipelines.ic_lora import ICLoraPipeline +from ltx_pipelines.utils import get_device + +device = get_device() + +LORA_PATH = LORAS_DIR / "Internal" / "ltxv_apps" / "icloras_dev" / "depth_64_2k_split.safetensors" + + +@pytest.mark.e2e +def test_ic_lora_t2v( + psnr: Callable[[torch.Tensor, torch.Tensor, float, float], float], + decode_video_from_file: Callable[[str], tuple[torch.Tensor, torch.Tensor | None]], +) -> None: + """Run txt2vid IC-LoRA pipeline and verify output matches expected.""" + + output_path = OUTPUT_DIR / "ic_lora_t2v_output.mp4" + control_path = ASSETS_DIR / "depth_00001.mp4" + + pipeline = ICLoraPipeline( + checkpoint_path=AV_CHECKPOINT_SPLIT_PATH.resolve().as_posix(), + distilled_lora_path=DISTILLED_LORA_PATH.resolve().as_posix(), + distilled_lora_strength=1.0, + spatial_upsampler_path=SPATIAL_UPSAMPLER_PATH.resolve().as_posix(), + gemma_root=GEMMA_ROOT.resolve().as_posix(), + loras=[LoraPathStrengthAndSDOps(LORA_PATH.resolve().as_posix(), 2.0, LTXV_LORA_COMFY_RENAMING_MAP)], + ) + + pipeline( + prompt=( + "Two humanoid fish walk upright along the sandy bottom of the ocean, their finned legs moving with a " + "slow, deliberate rhythm. Their bodies are covered in textured scales that catch the filtered sunlight " + "drifting down from the surface above. Around them, coral formations, rocks, and swaying sea plants " + "create a quiet underwater landscape, while small schools of fish pass in the distance. Soft beams of " + "blue light cut through the water, and tiny particles float in the current, giving the scene a calm, " + "otherworldly atmosphere." + ), + negative_prompt=DEFAULT_NEGATIVE_PROMPT, + output_path=output_path.as_posix(), + seed=DEFAULT_SEED, + height=DEFAULT_HEIGHT, + width=DEFAULT_WIDTH, + num_frames=DEFAULT_NUM_FRAMES, + frame_rate=DEFAULT_FRAME_RATE, + num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS, + cfg_guidance_scale=DEFAULT_CFG_GUIDANCE_SCALE, + images=[], + video_conditioning=[(control_path.as_posix(), 0.8)], + tiling_config=TilingConfig.default(), + ) + + # Compare to expected output + decoded_video, waveform = decode_video_from_file(path=output_path, device=device) + expected_video, expected_waveform = decode_video_from_file( + path=ASSETS_DIR / "expected_ic_lora_t2v.mp4", device=device + ) + + assert psnr(decoded_video, expected_video, 255.0, 1e-8).item() >= 100.0 + assert psnr(waveform, expected_waveform, 1.0, 1e-8).item() >= 80.0 + + output_path.unlink() + + +@pytest.mark.e2e +def test_ic_lora_i2v( + psnr: Callable[[torch.Tensor, torch.Tensor, float, float], float], + decode_video_from_file: Callable[[str], tuple[torch.Tensor, torch.Tensor | None]], +) -> None: + """Run img2vid IC-LoRA pipeline and verify output matches expected.""" + + output_path = OUTPUT_DIR / "ic_lora_i2v_output.mp4" + control_path = ASSETS_DIR / "depth_00001.mp4" + image_path = ASSETS_DIR / "astronauts.jpeg" + + pipeline = ICLoraPipeline( + checkpoint_path=AV_CHECKPOINT_SPLIT_PATH.resolve().as_posix(), + distilled_lora_path=DISTILLED_LORA_PATH.resolve().as_posix(), + distilled_lora_strength=1.3, + spatial_upsampler_path=SPATIAL_UPSAMPLER_PATH.resolve().as_posix(), + gemma_root=GEMMA_ROOT.resolve().as_posix(), + loras=[LoraPathStrengthAndSDOps(LORA_PATH.resolve().as_posix(), 2.0, LTXV_LORA_COMFY_RENAMING_MAP)], + ) + + pipeline( + prompt=( + "Cinematic tracking shot of two astronauts in detailed white NASA-style EVA space suits walking slowly " + "forward towards the camera through a dense, mysterious alien landscape. The terrain is rugged and " + "textured, covered in gray lunar dust and patches of frost. Thick, volumetric fog and swirling mist " + "envelop the scene, partially obscuring the jagged rock formations in the background. The scene is " + "backlit by a warm, hazy sun low on the horizon, creating dramatic rim lighting and lens flares on their " + "helmet visors. The atmosphere is ethereal and moody. The movement shows the astronauts stepping heavily " + "over the uneven ground, with the mist reacting to their motion. Photorealistic, 8k, unreal engine 5 " + "render style, sci-fi movie aesthetics, high contrast, volumetric lighting." + ), + negative_prompt=DEFAULT_NEGATIVE_PROMPT, + output_path=output_path.as_posix(), + seed=DEFAULT_SEED, + height=DEFAULT_HEIGHT, + width=DEFAULT_WIDTH, + num_frames=DEFAULT_NUM_FRAMES, + frame_rate=DEFAULT_FRAME_RATE, + num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS, + cfg_guidance_scale=DEFAULT_CFG_GUIDANCE_SCALE, + video_conditioning=[(control_path.as_posix(), 0.8)], + images=[(image_path.as_posix(), 0, 1.0)], + tiling_config=TilingConfig.default(), + ) + + # Compare to expected output + decoded_video, waveform = decode_video_from_file(path=output_path, device=device) + expected_video, expected_waveform = decode_video_from_file( + path=ASSETS_DIR / "expected_ic_lora_i2v.mp4", device=device + ) + + assert psnr(decoded_video, expected_video, 255.0, 1e-8).item() >= 100.0 + assert psnr(waveform, expected_waveform, 1.0, 1e-8).item() >= 80.0 + + output_path.unlink() diff --git a/packages/ltx-pipelines/tests/ltx_pipelines/test_img2vid.py b/packages/ltx-pipelines/tests/ltx_pipelines/test_img2vid.py new file mode 100644 index 0000000000000000000000000000000000000000..eba48b6e1890abeb4ab7814b7b33dba0c7c7bb8a --- /dev/null +++ b/packages/ltx-pipelines/tests/ltx_pipelines/test_img2vid.py @@ -0,0 +1,147 @@ +from dataclasses import replace +from typing import Callable + +import pytest +import torch +from tests.conftest import ASSETS_DIR, CHECKPOINTS_DIR, GEMMA_ROOT + +from ltx_core.loader.primitives import LoraPathStrengthAndSDOps +from ltx_core.model.model_ledger import ModelLedger +from ltx_core.pipeline.components.diffusion_steps import EulerDiffusionStep +from ltx_core.pipeline.components.guiders import CFGGuider +from ltx_core.pipeline.components.protocols import AudioLatentShape, VideoLatentShape +from ltx_core.pipeline.conditioning.tools import AudioLatentTools, LatentState, VideoLatentTools +from ltx_pipelines.constants import AUDIO_SAMPLE_RATE, DEFAULT_FRAME_RATE +from ltx_pipelines.media_io import encode_video +from ltx_pipelines.pipeline_utils import ( + PipelineComponents, + euler_denoising_loop, + guider_denoising_func, +) +from ltx_pipelines.pipeline_utils import decode_audio as vae_decode_audio +from ltx_pipelines.pipeline_utils import decode_video as vae_decode_video +from ltx_pipelines.utils import get_device + +device = get_device() + + +class Img2VidTestPipeline: + def __init__( + self, + checkpoint_path: str, + gemma_root: str, + loras: list[LoraPathStrengthAndSDOps], + device: torch.device = device, + ): + self.model_ledger = ModelLedger( + dtype=torch.bfloat16, + device=device, + checkpoint_path=checkpoint_path, + gemma_root_path=gemma_root, + loras=loras, + ) + self.pipeline_components = PipelineComponents( + dtype=torch.bfloat16, + device=device, + ) + + self.device = device + + @torch.inference_mode() + def __call__( + self, + cfg_guidance_scale: float, + ) -> None: + v_context_p = torch.load(ASSETS_DIR / "v_context_p.pt").to(self.device) + v_context_n = torch.load(ASSETS_DIR / "v_context_n.pt").to(self.device) + a_context_p = torch.load(ASSETS_DIR / "a_context_p.pt").to(self.device) + a_context_n = torch.load(ASSETS_DIR / "a_context_n.pt").to(self.device) + + video_clean_state = LatentState( + latent=torch.load(ASSETS_DIR / "v_latent_image.pt").to(self.device), + denoise_mask=torch.load(ASSETS_DIR / "v_denoise_mask.pt").to(self.device), + positions=torch.load(ASSETS_DIR / "video_positions.pt").to(self.device), + clean_latent=torch.load(ASSETS_DIR / "v_latent_image.pt").to(self.device), + ) + video_latent_tools = VideoLatentTools( + patchifier=self.pipeline_components.video_patchifier, + target_shape=VideoLatentShape.from_torch_shape(video_clean_state.latent.shape), + fps=25, + ) + video_clean_state = video_latent_tools.patchify(video_clean_state) + + audio_clean_state = LatentState( + latent=torch.load(ASSETS_DIR / "a_latent_image.pt").to(self.device), + denoise_mask=torch.load(ASSETS_DIR / "a_denoise_mask.pt").to(self.device), + positions=torch.load(ASSETS_DIR / "audio_positions.pt").to(self.device), + clean_latent=torch.load(ASSETS_DIR / "a_latent_image.pt").to(self.device), + ) + audio_latent_tools = AudioLatentTools( + patchifier=self.pipeline_components.audio_patchifier, + target_shape=AudioLatentShape.from_torch_shape(audio_clean_state.latent.shape), + ) + audio_clean_state = audio_latent_tools.patchify(audio_clean_state) + + video_state = replace( + video_clean_state, + latent=self.pipeline_components.video_patchifier.patchify(torch.load(ASSETS_DIR / "v_noised_latent.pt")).to( + self.device + ), + ) + audio_state = replace( + audio_clean_state, + latent=self.pipeline_components.audio_patchifier.patchify(torch.load(ASSETS_DIR / "a_noised_latent.pt")).to( + self.device + ), + ) + + sigmas = torch.load(ASSETS_DIR / "sigmas.pt").to(self.device) + stepper = EulerDiffusionStep() + cfg_guider = CFGGuider(cfg_guidance_scale) + + video_state, audio_state = euler_denoising_loop( + sigmas, + video_state, + audio_state, + stepper, + guider_denoising_func( + cfg_guider, v_context_p, v_context_n, a_context_p, a_context_n, self.model_ledger.transformer() + ), + ) + + video_state = video_latent_tools.clear_conditioning(video_state) + video_state = video_latent_tools.unpatchify(video_state) + audio_state = audio_latent_tools.clear_conditioning(audio_state) + audio_state = audio_latent_tools.unpatchify(audio_state) + decoded_video = vae_decode_video(video_state, self.model_ledger.video_decoder()) + waveform = vae_decode_audio(audio_state, self.model_ledger.audio_decoder(), self.model_ledger.vocoder()) + + encode_video( + video=decoded_video, + fps=DEFAULT_FRAME_RATE, + audio=waveform, + audio_sample_rate=AUDIO_SAMPLE_RATE, + output_path=ASSETS_DIR / "test_comfy.mp4", + ) + + +@pytest.mark.e2e +def test_comfy_inputs( + psnr: Callable[[torch.Tensor, torch.Tensor, float, float], float], + decode_video_from_file: Callable[[str], tuple[torch.Tensor, torch.Tensor | None]], +) -> None: + pipeline = Img2VidTestPipeline( + checkpoint_path=(CHECKPOINTS_DIR / "ltx-av-step-1932500-interleaved-new-vae.safetensors").resolve().as_posix(), + gemma_root=GEMMA_ROOT.resolve().as_posix(), + loras=[], + ) + + pipeline(cfg_guidance_scale=3.0) + + decoded_video, waveform = decode_video_from_file(path=ASSETS_DIR / "test_comfy.mp4", device=pipeline.device) + expected_video, expected_waveform = decode_video_from_file( + path=ASSETS_DIR / "expected_comfy.mp4", device=pipeline.device + ) + + assert psnr(decoded_video, expected_video, 255.0, 1e-8).item() > 35.0 + assert psnr(waveform[: expected_waveform.shape[0]], expected_waveform[: waveform.shape[0]], 1.0, 1e-8).item() > 20.0 diff --git a/packages/ltx-pipelines/tests/ltx_pipelines/test_keyframes.py b/packages/ltx-pipelines/tests/ltx_pipelines/test_keyframes.py new file mode 100644 index 0000000000000000000000000000000000000000..3f2948dde655b3eb05c1b844f518e17346844f70 --- /dev/null +++ b/packages/ltx-pipelines/tests/ltx_pipelines/test_keyframes.py @@ -0,0 +1,93 @@ +from typing import Callable + +import pytest +import torch +from tests.conftest import ( + ASSETS_DIR, + AV_CHECKPOINT_SPLIT_PATH, + DEFAULT_NEGATIVE_PROMPT, + DISTILLED_LORA_PATH, + GEMMA_ROOT, + LORAS_DIR, + OUTPUT_DIR, + SPATIAL_UPSAMPLER_PATH, +) + +from ltx_core.tiling import TilingConfig +from ltx_pipelines.keyframe_interpolation import KeyframeInterpolationPipeline +from ltx_pipelines.utils import DEFAULT_CFG_GUIDANCE_SCALE, DEFAULT_NUM_INFERENCE_STEPS, DEFAULT_SEED, get_device + +device = get_device() + +LORA_PATH = LORAS_DIR / "Internal" / "ltxv_apps" / "playground" / "depth_control.safetensors" + + +@pytest.mark.e2e +def test_key_frames( + psnr: Callable[[torch.Tensor, torch.Tensor, float, float], float], + decode_video_from_file: Callable[[str], tuple[torch.Tensor, torch.Tensor | None]], +) -> None: + """Run keyframes interpolation pipeline and verify output matches expected.""" + + output_path = OUTPUT_DIR / "keyframes_output.mp4" + image1_path = ASSETS_DIR / "dragon_1.png" + image2_path = ASSETS_DIR / "dragon_2.png" + + pipeline = KeyframeInterpolationPipeline( + checkpoint_path=AV_CHECKPOINT_SPLIT_PATH.resolve().as_posix(), + distilled_lora_path=DISTILLED_LORA_PATH.resolve().as_posix(), + distilled_lora_strength=1.0, + spatial_upsampler_path=SPATIAL_UPSAMPLER_PATH.resolve().as_posix(), + gemma_root=GEMMA_ROOT.resolve().as_posix(), + loras=[], + ) + + pipeline( + prompt=( + "A single continuous shot, cinematic epic fantasy battle, 10-12 seconds, 24 fps, 16:9. Start in a " + "hellish volcanic battlefield: ash-filled air, embers drifting, jagged black rocks and small fires " + "licking the ground. Camera begins behind a lone armored knight with a tattered crimson cape, sword " + "in right hand and shield in left, standing in a wide stance facing a colossal horned dragon. The " + "dragon dominates the frame ahead: charcoal-black scales with glowing orange fissures, massive wings " + "spread, molten sparks shedding from its body. It rears back and unleashes a torrent of fire—bright, " + "turbulent flame with heat distortion and swirling smoke—blasting toward the knight. The knight " + "braces, cape whipping violently in the hot wind, shield raised as the fire washes past, scattering " + "burning debris and kicking up dust. Camera slowly dollies forward and slightly upward, circling a " + "little to the knight's left, keeping both knight and dragon in view while the environment shakes " + "subtly from the force of the roar. The knight surges forward through the smoke and embers, closing " + "distance with determined, heavy steps. The dragon lunges down, claws scraping rock; the knight ducks " + "under a sweeping wing, sparks bursting where metal meets scale. The camera continues its smooth move, " + "drifting closer and tighter, following the knight's advance while maintaining the dragon's looming " + "head and wings in frame. The fire and sparks fade, the color temperature cools, smoke thins into a " + "gray overcast. The camera glides to the right and lowers, revealing the dragon's enormous head " + "collapsed on the rocky ground, eyes dim, blood dark against stone. The knight is now seated in " + "exhaustion beside the dragon's head, armor battered and smeared, cape pooled around him. His sword is " + "embedded upright in the dragon's skull/neck area, trembling slightly before settling. The knight's " + "helmeted head hangs, chest rising slowly; ash drifts like snow in the quiet. End on a lingering " + "close-medium composition: knight slumped against the dragon, ruined landscape and distant broken " + "silhouette of a fortress in the background under a heavy gray sky. No cuts, only one continuous camera " + "move and natural lighting shift from fiery chaos to cold silence." + ), + negative_prompt=DEFAULT_NEGATIVE_PROMPT, + output_path=output_path.as_posix(), + seed=DEFAULT_SEED, + height=512, + width=384, + num_frames=161, + frame_rate=12.5, + num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS, + cfg_guidance_scale=DEFAULT_CFG_GUIDANCE_SCALE, + images=[(image1_path.as_posix(), 0, 1.0), (image2_path.as_posix(), 160, 1.0)], + tiling_config=TilingConfig.default(), + ) + + # Compare to expected output + decoded_video, waveform = decode_video_from_file(path=output_path, device=device) + expected_video, expected_waveform = decode_video_from_file( + path=ASSETS_DIR / "expected_keyframes.mp4", device=device + ) + + assert psnr(decoded_video, expected_video, 255.0, 1e-8).item() >= 100.0 + assert psnr(waveform, expected_waveform, 1.0, 1e-8).item() >= 80.0 + + output_path.unlink() diff --git a/packages/ltx-pipelines/tests/ltx_pipelines/test_media_io.py b/packages/ltx-pipelines/tests/ltx_pipelines/test_media_io.py new file mode 100644 index 0000000000000000000000000000000000000000..d8b1fbb81c6a1d6425b0898834f514b52b58a951 --- /dev/null +++ b/packages/ltx-pipelines/tests/ltx_pipelines/test_media_io.py @@ -0,0 +1,77 @@ +import torch + +from ltx_pipelines.media_io import resize_and_center_crop + + +def test_resize_and_center_crop_centers_content_vertical() -> None: + """Verify content is actually cropped from center.""" + h, w, c = 100, 201, 3 + image = torch.zeros(h, w, c) + + # Paint left third red, center third green, right third blue + third = w // 3 # 67 pixels per third + image[:, :third, 0] = 1.0 # Left: red + image[:, third : 2 * third, 1] = 1.0 # Center: green + image[:, 2 * third :, 2] = 1.0 # Right: blue + + # Crop to same height but narrower width - should keep center (green) region + # Source: 100x201, Target: 100x67 (same height, 1/3 width) + # Scale factor: max(100/100, 67/201) = max(1.0, 0.33) = 1.0 + # After scale: 100x201, crop width from center: (201-67)//2 = 67 pixels from each side + result = resize_and_center_crop(image, height=100, width=67) + + # Result shape: (1, C, 1, H, W) -> extract the image + cropped = result[0, :, 0, :, :] # Shape: (C, H, W) + + # Center of cropped image should be predominantly green (from center third) + center_h, center_w = cropped.shape[1] // 2, cropped.shape[2] // 2 + center_pixel = cropped[:, center_h, center_w] + + # Green channel should be highest at center (center was green in original) + assert center_pixel[1] > center_pixel[0], "Center should have more green than red" + assert center_pixel[1] > center_pixel[2], "Center should have more green than blue" + + # Verify the left and right edges are NOT red/blue (they were cropped away) + left_edge = cropped[:, center_h, 0] + right_edge = cropped[:, center_h, -1] + + # Both edges should still be from the green center region, not red/blue edges + assert left_edge[1] >= left_edge[0], "Left edge should not be from red region" + assert right_edge[1] >= right_edge[2], "Right edge should not be from blue region" + + +def test_resize_and_center_crop_centers_content_horizontal() -> None: + """Verify content is actually cropped from center vertically (horizontal stripes).""" + h, w, c = 201, 100, 3 + image = torch.zeros(h, w, c) + + # Paint top third red, center third green, bottom third blue + third = h // 3 # 67 pixels per third + image[:third, :, 0] = 1.0 # Top: red + image[third : 2 * third, :, 1] = 1.0 # Center: green + image[2 * third :, :, 2] = 1.0 # Bottom: blue + + # Crop to same width but shorter height - should keep center (green) region + # Source: 201x100, Target: 67x100 (1/3 height, same width) + # Scale factor: max(67/201, 100/100) = max(0.33, 1.0) = 1.0 + # After scale: 201x100, crop height from center: (201-67)//2 = 67 pixels from top/bottom + result = resize_and_center_crop(image, height=67, width=100) + + # Result shape: (1, C, 1, H, W) -> extract the image + cropped = result[0, :, 0, :, :] # Shape: (C, H, W) + + # Center of cropped image should be predominantly green (from center third) + center_h, center_w = cropped.shape[1] // 2, cropped.shape[2] // 2 + center_pixel = cropped[:, center_h, center_w] + + # Green channel should be highest at center (center was green in original) + assert center_pixel[1] > center_pixel[0], "Center should have more green than red" + assert center_pixel[1] > center_pixel[2], "Center should have more green than blue" + + # Verify the top and bottom edges are NOT red/blue (they were cropped away) + top_edge = cropped[:, 0, center_w] + bottom_edge = cropped[:, -1, center_w] + + # Both edges should still be from the green center region, not red/blue edges + assert top_edge[1] >= top_edge[0], "Top edge should not be from red region" + assert bottom_edge[1] >= bottom_edge[2], "Bottom edge should not be from blue region" diff --git a/packages/ltx-pipelines/tests/ltx_pipelines/test_ti2vid_one_stage.py b/packages/ltx-pipelines/tests/ltx_pipelines/test_ti2vid_one_stage.py new file mode 100644 index 0000000000000000000000000000000000000000..e559d107f15863e2af105505ebf6633ac87f567b --- /dev/null +++ b/packages/ltx-pipelines/tests/ltx_pipelines/test_ti2vid_one_stage.py @@ -0,0 +1,109 @@ +from typing import Callable + +import pytest +import torch +from tests.conftest import ( + ASSETS_DIR, + AV_CHECKPOINT_SPLIT_PATH, + DEFAULT_NEGATIVE_PROMPT, + GEMMA_ROOT, + IMG2VID_PROMPT, + OUTPUT_DIR, +) + +from ltx_pipelines.constants import ( + DEFAULT_CFG_GUIDANCE_SCALE, + DEFAULT_FRAME_RATE, + DEFAULT_HEIGHT, + DEFAULT_NUM_FRAMES, + DEFAULT_NUM_INFERENCE_STEPS, + DEFAULT_SEED, + DEFAULT_WIDTH, +) +from ltx_pipelines.ti2vid_one_stage import TI2VidOneStagePipeline +from ltx_pipelines.utils import get_device + +device = get_device() + + +@pytest.mark.e2e +def test_img2vid_one_stage( + psnr: Callable[[torch.Tensor, torch.Tensor, float, float], float], + decode_video_from_file: Callable[[str], tuple[torch.Tensor, torch.Tensor | None]], +) -> None: + """Run img2vid one-stage pipeline (with image conditioning) and verify output matches expected.""" + + output_path = OUTPUT_DIR / "img2vid_one_stage_output.mp4" + image_path = ASSETS_DIR / "hat.png" + + pipeline = TI2VidOneStagePipeline( + checkpoint_path=AV_CHECKPOINT_SPLIT_PATH.resolve().as_posix(), + gemma_root=GEMMA_ROOT.resolve().as_posix(), + loras=[], + ) + + pipeline( + prompt=IMG2VID_PROMPT, + output_path=output_path.as_posix(), + negative_prompt=DEFAULT_NEGATIVE_PROMPT, + seed=DEFAULT_SEED, + height=DEFAULT_HEIGHT, + width=DEFAULT_WIDTH, + num_frames=DEFAULT_NUM_FRAMES, + frame_rate=DEFAULT_FRAME_RATE, + num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS, + cfg_guidance_scale=DEFAULT_CFG_GUIDANCE_SCALE, + images=[(image_path.as_posix(), 0, 1.0)], + ) + + # Compare to expected output + decoded_video, waveform = decode_video_from_file(path=output_path, device=device) + expected_video, expected_waveform = decode_video_from_file( + path=ASSETS_DIR / "expected_img2vid_one_stage.mp4", device=device + ) + + assert psnr(decoded_video, expected_video, 255.0, 1e-8).item() >= 100.0 + assert psnr(waveform, expected_waveform, 1.0, 1e-8).item() >= 80.0 + + output_path.unlink() + + +@pytest.mark.e2e +def test_txt2vid_one_stage( + psnr: Callable[[torch.Tensor, torch.Tensor, float, float], float], + decode_video_from_file: Callable[[str], tuple[torch.Tensor, torch.Tensor | None]], +) -> None: + """Run txt2vid one-stage pipeline (no image conditioning) and verify output matches expected.""" + + output_path = OUTPUT_DIR / "txt2vid_one_stage_output.mp4" + + pipeline = TI2VidOneStagePipeline( + checkpoint_path=AV_CHECKPOINT_SPLIT_PATH.resolve().as_posix(), + gemma_root=GEMMA_ROOT.resolve().as_posix(), + loras=[], + ) + + pipeline( + prompt=IMG2VID_PROMPT, + output_path=output_path.as_posix(), + negative_prompt=DEFAULT_NEGATIVE_PROMPT, + seed=DEFAULT_SEED, + height=DEFAULT_HEIGHT, + width=DEFAULT_WIDTH, + num_frames=DEFAULT_NUM_FRAMES, + frame_rate=DEFAULT_FRAME_RATE, + num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS, + cfg_guidance_scale=DEFAULT_CFG_GUIDANCE_SCALE, + images=[], + ) + + # Compare to expected output + decoded_video, waveform = decode_video_from_file(path=output_path, device=device) + expected_video, expected_waveform = decode_video_from_file( + path=ASSETS_DIR / "expected_txt2vid_one_stage.mp4", device=device + ) + + assert psnr(decoded_video, expected_video, 255.0, 1e-8).item() >= 100.0 + assert psnr(waveform, expected_waveform, 1.0, 1e-8).item() >= 80.0 + + output_path.unlink() diff --git a/packages/ltx-pipelines/tests/ltx_pipelines/test_ti2vid_two_stages.py b/packages/ltx-pipelines/tests/ltx_pipelines/test_ti2vid_two_stages.py new file mode 100644 index 0000000000000000000000000000000000000000..2d371c709133517e1d7b3c98434d2cf852f9ea21 --- /dev/null +++ b/packages/ltx-pipelines/tests/ltx_pipelines/test_ti2vid_two_stages.py @@ -0,0 +1,118 @@ +from typing import Callable + +import pytest +import torch +from tests.conftest import ( + ASSETS_DIR, + AV_CHECKPOINT_SPLIT_PATH, + DEFAULT_NEGATIVE_PROMPT, + DISTILLED_LORA_PATH, + GEMMA_ROOT, + IMG2VID_PROMPT, + OUTPUT_DIR, + SPATIAL_UPSAMPLER_PATH, +) + +from ltx_core.tiling import TilingConfig +from ltx_pipelines.constants import ( + DEFAULT_CFG_GUIDANCE_SCALE, + DEFAULT_FRAME_RATE, + DEFAULT_HEIGHT, + DEFAULT_NUM_FRAMES, + DEFAULT_NUM_INFERENCE_STEPS, + DEFAULT_SEED, + DEFAULT_WIDTH, +) +from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline +from ltx_pipelines.utils import get_device + +device = get_device() + + +@pytest.mark.e2e +def test_img2vid_two_stages( + psnr: Callable[[torch.Tensor, torch.Tensor, float, float], float], + decode_video_from_file: Callable[[str], tuple[torch.Tensor, torch.Tensor | None]], +) -> None: + """Run img2vid two-stages pipeline and verify output matches expected.""" + + output_path = OUTPUT_DIR / "img2vid_two_stages_output.mp4" + image_path = ASSETS_DIR / "hat.png" + + pipeline = TI2VidTwoStagesPipeline( + checkpoint_path=AV_CHECKPOINT_SPLIT_PATH.resolve().as_posix(), + distilled_lora_path=DISTILLED_LORA_PATH.resolve().as_posix(), + distilled_lora_strength=0.6, + spatial_upsampler_path=SPATIAL_UPSAMPLER_PATH.resolve().as_posix(), + gemma_root=GEMMA_ROOT.resolve().as_posix(), + loras=[], + ) + + pipeline( + prompt=IMG2VID_PROMPT, + output_path=output_path.as_posix(), + negative_prompt=DEFAULT_NEGATIVE_PROMPT, + seed=DEFAULT_SEED, + height=DEFAULT_HEIGHT, + width=DEFAULT_WIDTH, + num_frames=DEFAULT_NUM_FRAMES, + frame_rate=DEFAULT_FRAME_RATE, + num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS, + cfg_guidance_scale=DEFAULT_CFG_GUIDANCE_SCALE, + images=[(image_path.as_posix(), 0, 1.0)], + tiling_config=TilingConfig.default(), + ) + + decoded_video, waveform = decode_video_from_file(path=output_path, device=device) + expected_video, expected_waveform = decode_video_from_file( + path=ASSETS_DIR / "expected_img2vid_two_stages.mp4", device=device + ) + + assert psnr(decoded_video, expected_video, 255.0, 1e-8).item() >= 100.0 + assert psnr(waveform, expected_waveform, 1.0, 1e-8).item() >= 80.0 + + output_path.unlink() + + +@pytest.mark.e2e +def test_txt2vid_two_stages( + psnr: Callable[[torch.Tensor, torch.Tensor, float, float], float], + decode_video_from_file: Callable[[str], tuple[torch.Tensor, torch.Tensor | None]], +) -> None: + """Run txt2vid two-stages pipeline and verify output matches expected.""" + + output_path = OUTPUT_DIR / "txt2vid_two_stages_output.mp4" + + pipeline = TI2VidTwoStagesPipeline( + checkpoint_path=AV_CHECKPOINT_SPLIT_PATH.resolve().as_posix(), + distilled_lora_path=DISTILLED_LORA_PATH.resolve().as_posix(), + distilled_lora_strength=0.6, + spatial_upsampler_path=SPATIAL_UPSAMPLER_PATH.resolve().as_posix(), + gemma_root=GEMMA_ROOT.resolve().as_posix(), + loras=[], + ) + + pipeline( + prompt=IMG2VID_PROMPT, + output_path=output_path.as_posix(), + negative_prompt=DEFAULT_NEGATIVE_PROMPT, + seed=DEFAULT_SEED, + height=DEFAULT_HEIGHT, + width=DEFAULT_WIDTH, + num_frames=DEFAULT_NUM_FRAMES, + frame_rate=DEFAULT_FRAME_RATE, + num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS, + cfg_guidance_scale=DEFAULT_CFG_GUIDANCE_SCALE, + images=[], + tiling_config=TilingConfig.default(), + ) + + decoded_video, waveform = decode_video_from_file(path=output_path, device=device) + expected_video, expected_waveform = decode_video_from_file( + path=ASSETS_DIR / "expected_txt2vid_two_stages.mp4", device=device + ) + + assert psnr(decoded_video, expected_video, 255.0, 1e-8).item() >= 100.0 + assert psnr(waveform, expected_waveform, 1.0, 1e-8).item() >= 80.0 + + output_path.unlink() diff --git a/packages/ltx-trainer/.gitignore b/packages/ltx-trainer/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..aa42d10c556fae1a3492d59cd27bd9398d938153 --- /dev/null +++ b/packages/ltx-trainer/.gitignore @@ -0,0 +1,6 @@ +configs/*.yaml +!configs/ltx2_av_lora.yaml +!configs/ltx2_v2v_ic_lora.yaml +datasets +outputs +wandb diff --git a/packages/ltx-trainer/AGENTS.md b/packages/ltx-trainer/AGENTS.md new file mode 100644 index 0000000000000000000000000000000000000000..79320bc439d83fe86aa634b48c1d76cfdc9af3ae --- /dev/null +++ b/packages/ltx-trainer/AGENTS.md @@ -0,0 +1,352 @@ +# AGENTS.md + +This file provides guidance to AI coding assistants (Claude, Cursor, etc.) when working with code in this repository. + +## Project Overview + +**LTX-2 Trainer** is a training toolkit for fine-tuning the Lightricks LTX-2 audio-video generation model. It supports: + +- **LoRA training** - Efficient fine-tuning with adapters +- **Full fine-tuning** - Complete model training +- **Audio-video training** - Joint audio and video generation +- **IC-LoRA training** - In-context control adapters for video-to-video transformations + +**Key Dependencies:** + +- **[`ltx-core`](../ltx-core/)** - Core model implementations (transformer, VAE, text encoder) +- **[`ltx-pipelines`](../ltx-pipelines/)** - Inference pipeline components + +> **Important:** This trainer only supports **LTX-2** (the audio-video model). The older LTXV models are not supported. + +## Architecture Overview + +### Package Structure + +``` +packages/ltx-trainer/ +├── src/ltx_trainer/ # Main training module +│ ├── config.py # Pydantic configuration models +│ ├── trainer.py # Main training orchestration with Accelerate +│ ├── model_loader.py # Model loading using ltx-core +│ ├── validation_sampler.py # Inference for validation samples +│ ├── datasets.py # PrecomputedDataset for latent-based training +│ ├── training_strategies/ # Strategy pattern for different training modes +│ │ ├── __init__.py # Factory function: get_training_strategy() +│ │ ├── base_strategy.py # TrainingStrategy ABC, ModelInputs, TrainingStrategyConfigBase +│ │ ├── text_to_video.py # TextToVideoStrategy, TextToVideoConfig +│ │ └── video_to_video.py # VideoToVideoStrategy, VideoToVideoConfig +│ ├── timestep_samplers.py # Flow matching timestep sampling +│ ├── captioning.py # Video captioning utilities +│ ├── video_utils.py # Video processing utilities +│ └── hf_hub_utils.py # HuggingFace Hub integration +├── scripts/ # User-facing CLI tools +│ ├── train.py # Main training script +│ ├── process_dataset.py # Dataset preprocessing +│ ├── process_videos.py # Video latent encoding +│ ├── process_captions.py # Text embedding computation +│ ├── caption_videos.py # Automatic video captioning +│ ├── decode_latents.py # Latent decoding for debugging +│ ├── inference.py # Inference with trained models +│ ├── compute_reference.py # Generate IC-LoRA reference videos +│ └── split_scenes.py # Scene detection and splitting +├── configs/ # Example training configurations +│ ├── ltx2_av_lora.yaml # Audio-video LoRA training +│ ├── ltx2_v2v_ic_lora.yaml # IC-LoRA video-to-video +│ └── accelerate/ # Accelerate configs for distributed training +└── docs/ # Documentation +``` + +### Key Architectural Patterns + +**Model Loading:** + +- `ltx_trainer.model_loader` provides component loaders using `ltx-core` +- Individual loaders: `load_transformer()`, `load_video_vae_encoder()`, `load_video_vae_decoder()`, `load_text_encoder()`, etc. +- Combined loader: `load_model()` returns `LtxModelComponents` dataclass +- Uses `SingleGPUModelBuilder` from ltx-core internally + +**Training Flow:** + +1. Configuration loaded via Pydantic models in `config.py` +2. `Trainer` class orchestrates the training loop +3. Training strategies (`TextToVideoStrategy`, `VideoToVideoStrategy`) prepare inputs and compute loss +4. Accelerate handles distributed training and device placement +5. Data flows as precomputed latents through `PrecomputedDataset` + +**Model Interface (Modality-based):** + +```python +from ltx_core.model.transformer.modality import Modality + +# Create modality objects for video and audio +video = Modality( + enabled=True, + latent=video_latents, # [B, seq_len, 128] + timesteps=video_timesteps, # [B, seq_len] per-token + positions=video_positions, # [B, 3, seq_len, 2] + context=video_embeds, + context_mask=None, +) +audio = Modality( + enabled=True, + latent=audio_latents, + timesteps=audio_timesteps, + positions=audio_positions, # [B, 1, seq_len, 2] + context=audio_embeds, + context_mask=None, +) + +# Forward pass returns predictions for both modalities +video_pred, audio_pred = model(video=video, audio=audio, perturbations=None) +``` + +> **Note:** `Modality` is immutable (frozen dataclass). Use `dataclasses.replace()` to modify. + +**Configuration System:** + +- All config in `src/ltx_trainer/config.py` +- Main class: `LtxTrainerConfig` +- Training strategy configs: `TextToVideoConfig`, `VideoToVideoConfig` +- Uses Pydantic field validators and model validators +- Config files in `configs/` directory + +## Development Commands + +### Setup and Installation + +```bash +# From the repository root +uv sync +cd packages/ltx-trainer +``` + +### Code Quality + +```bash +# Run ruff linting and formatting +uv run ruff check . +uv run ruff format . + +# Run pre-commit checks +uv run pre-commit run --all-files +``` + +### Running Tests + +```bash +cd packages/ltx-trainer +uv run pytest +``` + +### Running Training + +```bash +# Single GPU +uv run python scripts/train.py configs/ltx2_av_lora.yaml + +# Multi-GPU with Accelerate +uv run accelerate launch scripts/train.py configs/ltx2_av_lora.yaml +``` + +## Code Standards + +### Type Hints + +- **Always use type hints** for all function arguments and return values +- Use Python 3.10+ syntax: `list[str]` not `List[str]`, `str | Path` not `Union[str, Path]` +- Use `pathlib.Path` for file operations + +### Class Methods + +- Mark methods as `@staticmethod` if they don't access instance or class state +- Use `@classmethod` for alternative constructors + +### AI/ML Specific + +- Use `@torch.inference_mode()` for inference (prefer over `@torch.no_grad()`) +- Use `accelerator.device` for distributed compatibility +- Support mixed precision (`bfloat16` via dtype parameters) +- Use gradient checkpointing for memory-intensive training + +### Logging + +- Use `from ltx_trainer import logger` for all messages +- Avoid print statements in production code + +## Important Files & Modules + +### Configuration (CRITICAL) + +**`src/ltx_trainer/config.py`** - Master config definitions + +Key classes: +- `LtxTrainerConfig` - Main configuration container +- `ModelConfig` - Model paths and training mode +- `TrainingStrategyConfig` - Union of `TextToVideoConfig` | `VideoToVideoConfig` +- `LoraConfig` - LoRA hyperparameters +- `OptimizationConfig` - Learning rate, batch size, etc. +- `ValidationConfig` - Validation settings +- `WandbConfig` - W&B logging settings + +**⚠️ When modifying config.py:** +1. Update ALL config files in `configs/` +2. Update `docs/configuration-reference.md` +3. Test that all configs remain valid + +### Training Core + +**`src/ltx_trainer/trainer.py`** - Main training loop + +- Implements distributed training with Accelerate +- Handles mixed precision, gradient accumulation, checkpointing +- Uses training strategies for mode-specific logic + +**`src/ltx_trainer/training_strategies/`** - Strategy pattern + +- `base_strategy.py`: `TrainingStrategy` ABC, `ModelInputs` dataclass +- `text_to_video.py`: Standard text-to-video (with optional audio) +- `video_to_video.py`: IC-LoRA video-to-video transformations + +Key methods each strategy implements: +- `get_data_sources()` - Required data directories +- `prepare_training_inputs()` - Convert batch to `ModelInputs` +- `compute_loss()` - Calculate training loss +- `requires_audio` property - Whether audio components needed + +**`src/ltx_trainer/model_loader.py`** - Model loading + +Component loaders: +- `load_transformer()` → `LTXModel` +- `load_video_vae_encoder()` → `VideoVAEEncoder` +- `load_video_vae_decoder()` → `VideoVAEDecoder` +- `load_audio_vae_decoder()` → `AudioVAEDecoder` +- `load_vocoder()` → `Vocoder` +- `load_text_encoder()` → `AVGemmaTextEncoderModel` +- `load_model()` → `LtxModelComponents` (convenience wrapper) + +**`src/ltx_trainer/validation_sampler.py`** - Inference for validation + +Uses ltx-core components for denoising: +- `LTX2Scheduler` for sigma scheduling +- `EulerDiffusionStep` for diffusion steps +- `CFGGuider` for classifier-free guidance + +### Data + +**`src/ltx_trainer/datasets.py`** - Dataset handling + +- `PrecomputedDataset` loads pre-computed VAE latents +- Supports video latents, audio latents, text embeddings, reference latents + +## Common Development Tasks + +### Adding a New Configuration Parameter + +1. Add field to appropriate config class in `src/ltx_trainer/config.py` +2. Add validator if needed +3. Update ALL config files in `configs/` +4. Update `docs/configuration-reference.md` + +### Implementing a New Training Strategy + +1. Create new file in `src/ltx_trainer/training_strategies/` +2. Create config class inheriting `TrainingStrategyConfigBase` +3. Create strategy class inheriting `TrainingStrategy` +4. Implement: `get_data_sources()`, `prepare_training_inputs()`, `compute_loss()` +5. Add to `__init__.py`: import, add to `TrainingStrategyConfig` union, update factory +6. Add discriminator tag to config.py's `TrainingStrategyConfig` +7. Create example config file in `configs/` + +### Working with Modalities + +```python +from dataclasses import replace +from ltx_core.model.transformer.modality import Modality + +# Create modality +video = Modality( + enabled=True, + latent=latents, + timesteps=timesteps, + positions=positions, + context=context, + context_mask=None, +) + +# Update (immutable - must use replace) +video = replace(video, latent=new_latent, timesteps=new_timesteps) + +# Disable a modality +audio = replace(audio, enabled=False) +``` + +## Debugging Tips + +**Training Issues:** + +- Check logs first (rich logger provides context) +- GPU memory: Look for OOM errors, enable `enable_gradient_checkpointing: true` +- Distributed training: Check `accelerator.state` and device placement + +**Model Loading:** + +- Ensure `model_path` points to a local `.safetensors` file +- Ensure `text_encoder_path` points to a Gemma model directory +- URLs are NOT supported for model paths + +**Configuration:** + +- Validation errors: Check validators in `config.py` +- Unknown fields: Config uses `extra="forbid"` - all fields must be defined +- Strategy validation: IC-LoRA requires `reference_videos` in validation config + +## Key Constraints + +### LTX-2 Frame Requirements + +Frames must satisfy `frames % 8 == 1`: +- ✅ Valid: 1, 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, 97, 121 +- ❌ Invalid: 24, 32, 48, 64, 100 + +### Resolution Requirements + +Width and height must be divisible by 32. + +### Model Paths + +- Must be local paths (URLs not supported) +- `model_path`: Path to `.safetensors` checkpoint +- `text_encoder_path`: Path to Gemma model directory + +### Platform Requirements + +- Linux required (uses `triton` which is Linux-only) +- CUDA GPU with 24GB+ VRAM recommended + +## Reference: ltx-core Key Components + +``` +packages/ltx-core/src/ltx_core/ +├── model/ +│ ├── transformer/ +│ │ ├── model.py # LTXModel +│ │ ├── modality.py # Modality dataclass +│ │ └── transformer.py # BasicAVTransformerBlock +│ ├── video_vae/ +│ │ └── video_vae.py # Encoder, Decoder +│ ├── audio_vae/ +│ │ ├── audio_vae.py # Decoder +│ │ └── vocoder.py # Vocoder +│ └── clip/gemma/ +│ └── encoders/av_encoder.py # AVGemmaTextEncoderModel +├── pipeline/ +│ ├── components/ +│ │ ├── schedulers.py # LTX2Scheduler +│ │ ├── diffusion_steps.py # EulerDiffusionStep +│ │ ├── guiders.py # CFGGuider +│ │ └── patchifiers.py # VideoLatentPatchifier, AudioPatchifier +│ └── conditioning/ # VideoLatentTools, AudioLatentTools +└── loader/ + ├── single_gpu_model_builder.py # SingleGPUModelBuilder + └── sd_ops.py # Key remapping (SDOps) +``` diff --git a/packages/ltx-trainer/CLAUDE.md b/packages/ltx-trainer/CLAUDE.md new file mode 100644 index 0000000000000000000000000000000000000000..47dc3e3d863cfb5727b87d785d09abf9743c0a72 --- /dev/null +++ b/packages/ltx-trainer/CLAUDE.md @@ -0,0 +1 @@ +AGENTS.md \ No newline at end of file diff --git a/packages/ltx-trainer/README.md b/packages/ltx-trainer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8ecd090e4a256d0cfabe4762377cbaa4edf6fad7 --- /dev/null +++ b/packages/ltx-trainer/README.md @@ -0,0 +1,61 @@ +# LTX-2 Trainer + +This package provides tools and scripts for training and fine-tuning +Lightricks' **LTX-2** audio-video generation model. It enables LoRA training, full +fine-tuning, and training of video-to-video transformations (IC-LoRA) on custom datasets. + +--- + +## 📖 Documentation + +All detailed guides and technical documentation are in the [docs](./docs/) directory: + +- [⚡ Quick Start Guide](docs/quick-start.md) +- [🎬 Dataset Preparation](docs/dataset-preparation.md) +- [🛠️ Training Modes](docs/training-modes.md) +- [⚙️ Configuration Reference](docs/configuration-reference.md) +- [🚀 Training Guide](docs/training-guide.md) +- [🔧 Utility Scripts](docs/utility-scripts.md) +- [📚 LTX-Core API Guide](docs/ltx-core-api-guide.md) +- [🛡️ Troubleshooting Guide](docs/troubleshooting.md) + +--- + +## 🔧 Requirements + +- **LTX-2 Model Checkpoint** - Local `.safetensors` file +- **Gemma Text Encoder** - Local Gemma model directory (required for LTX-2) +- **Linux with CUDA** - CUDA 13+ recommended for optimal performance +- **Nvidia GPU with 80GB+ VRAM** - Is highly recommended; lower VRAM may work with gradient checkpointing and lower + resolutions + +--- + +## 🤝 Contributing + +We welcome contributions from the community! Here's how you can help: + +- **Share Your Work**: If you've trained interesting LoRAs or achieved cool results, please share them with the + community. +- **Report Issues**: Found a bug or have a suggestion? Open an issue on GitHub. +- **Submit PRs**: Help improve the codebase with bug fixes or general improvements. +- **Feature Requests**: Have ideas for new features? Let us know through GitHub issues. + +--- + +## 💬 Join the Community + +Have questions, want to share your results, or need real-time help? + +Join our [community Discord server](https://discord.gg/2mafsHjJ) to connect with other users and the development team! + +- Get troubleshooting help +- Share your training results and workflows +- Collaborate on new ideas and features +- Stay up to date with announcements and updates + +We look forward to seeing you there! + +--- + +Happy training! 🎉 diff --git a/packages/ltx-trainer/configs/accelerate/ddp.yaml b/packages/ltx-trainer/configs/accelerate/ddp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2edb13252bf19959630a5d1d53186ff156c84c4f --- /dev/null +++ b/packages/ltx-trainer/configs/accelerate/ddp.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/packages/ltx-trainer/configs/accelerate/ddp_compile.yaml b/packages/ltx-trainer/configs/accelerate/ddp_compile.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9398f94928d2e769ca14908011cc3ad6949ca6af --- /dev/null +++ b/packages/ltx-trainer/configs/accelerate/ddp_compile.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +dynamo_config: + dynamo_backend: INDUCTOR + dynamo_mode: default + dynamo_use_fullgraph: false + dynamo_use_dynamic: true +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [ ] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/packages/ltx-trainer/configs/accelerate/fsdp.yaml b/packages/ltx-trainer/configs/accelerate/fsdp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6a8b0028c5d5d0f01766b927577eb91ceeb1b6f5 --- /dev/null +++ b/packages/ltx-trainer/configs/accelerate/fsdp.yaml @@ -0,0 +1,29 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +enable_cpu_affinity: false +fsdp_config: + fsdp_activation_checkpointing: false + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_cpu_ram_efficient_loading: true + fsdp_forward_prefetch: false + fsdp_offload_params: false + fsdp_reshard_after_forward: FULL_SHARD + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_sync_module_states: true + fsdp_transformer_layer_cls_to_wrap: BasicAVTransformerBlock + fsdp_use_orig_params: true + fsdp_version: 1 +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/packages/ltx-trainer/configs/accelerate/fsdp_compile.yaml b/packages/ltx-trainer/configs/accelerate/fsdp_compile.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a6b7089d9a50e5590c5360c47fdbe8b6dd5f9560 --- /dev/null +++ b/packages/ltx-trainer/configs/accelerate/fsdp_compile.yaml @@ -0,0 +1,34 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +dynamo_config: + dynamo_backend: INDUCTOR + dynamo_mode: default + dynamo_use_fullgraph: false + dynamo_use_dynamic: true +enable_cpu_affinity: false +fsdp_config: + fsdp_activation_checkpointing: false + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_cpu_ram_efficient_loading: true + fsdp_forward_prefetch: false + fsdp_offload_params: false + fsdp_reshard_after_forward: FULL_SHARD + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_sync_module_states: true + fsdp_transformer_layer_cls_to_wrap: BasicAVTransformerBlock + fsdp_use_orig_params: true + fsdp_version: 1 +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/packages/ltx-trainer/configs/ltx2_av_lora.yaml b/packages/ltx-trainer/configs/ltx2_av_lora.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e1e19013ccbcb8bce789947288ec1b317399d779 --- /dev/null +++ b/packages/ltx-trainer/configs/ltx2_av_lora.yaml @@ -0,0 +1,307 @@ +# ============================================================================= +# LTX-2 Audio-Video LoRA Training Configuration +# ============================================================================= +# +# This configuration is for training LoRA adapters on the LTX-2 model for +# text-to-video generation. It supports both video-only and joint audio-video +# training modes. +# +# Use this configuration when you want to: +# - Fine-tune LTX-2 on your own video dataset +# - Train with or without audio generation +# - Create custom video generation styles or audiovisual concepts +# +# Dataset structure for text-to-video training: +# preprocessed_data_root/ +# ├── latents/ # Video latents (VAE-encoded videos) +# ├── conditions/ # Text embeddings for each video +# └── audio_latents/ # Audio latents (only if with_audio: true) +# +# ============================================================================= + +# ----------------------------------------------------------------------------- +# Model Configuration +# ----------------------------------------------------------------------------- +# Specifies the base model to fine-tune and the training mode. +model: + # Path to the LTX-2 model checkpoint (.safetensors file) + # This should be a local path to your downloaded model + model_path: "path/to/ltx-2-model.safetensors" + + # Path to the text encoder model directory + # For LTX-2, this is typically the Gemma-based text encoder + text_encoder_path: "path/to/gemma-text-encoder" + + # Training mode: "lora" for efficient adapter training, "full" for full fine-tuning + # LoRA is recommended for most use cases (faster, less memory, prevents overfitting) + training_mode: "lora" + + # Optional: Path to resume training from a checkpoint + # Can be a checkpoint file (.safetensors) or directory (uses latest checkpoint) + load_checkpoint: null + +# ----------------------------------------------------------------------------- +# LoRA Configuration +# ----------------------------------------------------------------------------- +# Controls the Low-Rank Adaptation parameters for efficient fine-tuning. +lora: + # Rank of the LoRA matrices (higher = more capacity but more parameters) + # Typical values: 8, 16, 32, 64. Start with 32 for general fine-tuning. + rank: 32 + + # Alpha scaling factor (usually set equal to rank) + # The effective scaling is alpha/rank, so alpha=rank means scaling of 1.0 + alpha: 32 + + # Dropout probability for LoRA layers (0.0 = no dropout) + # Can help with regularization if overfitting occurs + dropout: 0.0 + + # Which transformer modules to apply LoRA to + # The LTX-2 transformer has separate attention and FFN blocks for video and audio: + # + # VIDEO MODULES: + # - attn1.to_k, attn1.to_q, attn1.to_v, attn1.to_out.0 (video self-attention) + # - attn2.to_k, attn2.to_q, attn2.to_v, attn2.to_out.0 (video cross-attention to text) + # - ff.net.0.proj, ff.net.2 (video feed-forward) + # + # AUDIO MODULES: + # - audio_attn1.to_k, audio_attn1.to_q, audio_attn1.to_v, audio_attn1.to_out.0 (audio self-attention) + # - audio_attn2.to_k, audio_attn2.to_q, audio_attn2.to_v, audio_attn2.to_out.0 (audio cross-attention to text) + # - audio_ff.net.0.proj, audio_ff.net.2 (audio feed-forward) + # + # AUDIO-VIDEO CROSS-ATTENTION MODULES (for cross-modal interaction): + # - audio_to_video_attn.to_k, audio_to_video_attn.to_q, audio_to_video_attn.to_v, audio_to_video_attn.to_out.0 + # (Q from video, K/V from audio - allows video to attend to audio features) + # - video_to_audio_attn.to_k, video_to_audio_attn.to_q, video_to_audio_attn.to_v, video_to_audio_attn.to_out.0 + # (Q from audio, K/V from video - allows audio to attend to video features) + # + # Using short patterns like "to_k" matches ALL attention modules (video, audio, and cross-modal). + # For audio-video training, this is the recommended approach. + target_modules: + # Attention layers (matches both video and audio branches) + - "to_k" + - "to_q" + - "to_v" + - "to_out.0" + # Uncomment below to also train feed-forward layers (can increase the LoRA's capacity): + # - "ff.net.0.proj" + # - "ff.net.2" + # - "audio_ff.net.0.proj" + # - "audio_ff.net.2" + +# ----------------------------------------------------------------------------- +# Training Strategy Configuration +# ----------------------------------------------------------------------------- +# Defines the text-to-video training approach. +training_strategy: + # Strategy name: "text_to_video" for standard text-to-video training + name: "text_to_video" + + # Probability of conditioning on the first frame during training + # Higher values train the model to perform better in image-to-video (I2V) mode, + # where a clean first frame is provided and the model generates the rest of the video + # Increase this value to train the model to perform better in image-to-video (I2V) mode + first_frame_conditioning_p: 0.5 + + # Enable joint audio-video training + # Set to true if your dataset includes audio and you want to train the audio branch + with_audio: true + + # Directory name (within preprocessed_data_root) containing audio latents + # Only used when with_audio is true + audio_latents_dir: "audio_latents" + +# ----------------------------------------------------------------------------- +# Optimization Configuration +# ----------------------------------------------------------------------------- +# Controls the training optimization parameters. +optimization: + # Learning rate for the optimizer + # Typical range for LoRA: 1e-5 to 1e-4 + learning_rate: 1e-4 + + # Total number of training steps + steps: 2000 + + # Batch size per GPU + # Reduce if running out of memory + batch_size: 1 + + # Number of gradient accumulation steps + # Effective batch size = batch_size * gradient_accumulation_steps * num_gpus + gradient_accumulation_steps: 1 + + # Maximum gradient norm for clipping (helps training stability) + max_grad_norm: 1.0 + + # Optimizer type: "adamw" (standard) or "adamw8bit" (memory-efficient) + optimizer_type: "adamw" + + # Learning rate scheduler type + # Options: "constant", "linear", "cosine", "cosine_with_restarts", "polynomial" + scheduler_type: "linear" + + # Additional scheduler parameters (depends on scheduler_type) + scheduler_params: { } + + # Enable gradient checkpointing to reduce memory usage + # Recommended for training with limited GPU memory + enable_gradient_checkpointing: true + +# ----------------------------------------------------------------------------- +# Acceleration Configuration +# ----------------------------------------------------------------------------- +# Hardware acceleration and memory optimization settings. +acceleration: + # Mixed precision training mode + # Options: "no" (fp32), "fp16" (half precision), "bf16" (bfloat16, recommended) + mixed_precision_mode: "bf16" + + # Model quantization for reduced memory usage + # Options: null (none), "int8-quanto", "int4-quanto", "int2-quanto", "fp8-quanto", "fp8uz-quanto" + quantization: null + + # Load text encoder in 8-bit precision to save memory + # Useful when GPU memory is limited + load_text_encoder_in_8bit: false + +# ----------------------------------------------------------------------------- +# Data Configuration +# ----------------------------------------------------------------------------- +# Specifies the training data location and loading parameters. +data: + # Root directory containing preprocessed training data + # Should contain: latents/, conditions/, and optionally audio_latents/ + preprocessed_data_root: "/path/to/preprocessed/data" + + # Number of worker processes for data loading + # Used for parallel data loading to speed up data loading + num_dataloader_workers: 2 + +# ----------------------------------------------------------------------------- +# Validation Configuration +# ----------------------------------------------------------------------------- +# Controls validation video generation during training. +validation: + # Text prompts for validation video generation + # Provide prompts representative of your training data + # LTX-2 prefers longer, detailed prompts that describe both visual content and audio + prompts: + - "A woman with long brown hair sits at a wooden desk in a cozy home office, typing on a laptop while occasionally glancing at notes beside her. Soft natural light streams through a large window, casting warm shadows across the room. She pauses to take a sip from a ceramic mug, then continues working with focused concentration. The audio captures the gentle clicking of keyboard keys, the soft rustle of papers, and ambient room tone with occasional distant bird chirps from outside." + - "A chef in a white uniform stands in a professional kitchen, carefully plating a gourmet dish with precise movements. Steam rises from freshly cooked vegetables as he arranges them with tweezers. The stainless steel surfaces gleam under bright overhead lights, and various pots simmer on the stove behind him. The audio features the sizzling of pans, the clinking of utensils against plates, and the ambient hum of kitchen ventilation." + + # Negative prompt to avoid unwanted artifacts + negative_prompt: "worst quality, inconsistent motion, blurry, jittery, distorted" + + # Optional: First frame images for image-to-video validation + # If provided, must have one image per prompt + images: null + + # Output video dimensions [width, height, frames] + # Width and height must be divisible by 32 + # Frames must satisfy: frames % 8 == 1 (e.g., 1, 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, ...) + video_dims: [ 576, 576, 89 ] + + # Frame rate for generated videos + frame_rate: 25.0 + + # Random seed for reproducible validation outputs + seed: 42 + + # Number of denoising steps for validation inference + # Higher values = better quality but slower generation + inference_steps: 30 + + # Generate validation videos every N training steps + # Set to null to disable validation during training + interval: 100 + + # Number of videos to generate per prompt + videos_per_prompt: 1 + + # Classifier-free guidance scale + # Higher values = stronger adherence to prompt but may introduce artifacts + guidance_scale: 3.0 + + # STG (Spatio-Temporal Guidance) parameters for improved video quality + # STG is combined with CFG for better temporal coherence + stg_scale: 1.0 # Recommended: 1.0 (0.0 disables STG) + stg_blocks: [29] # Recommended: single block 29 + stg_mode: "stg_av" # "stg_av" perturbs both audio and video, "stg_v" video only + + # Whether to generate audio in validation samples + # Independent of training_strategy.with_audio - you can generate audio + # in validation even when not training the audio branch + generate_audio: true + + # Skip validation at the beginning of training (step 0) + skip_initial_validation: false + +# ----------------------------------------------------------------------------- +# Checkpoint Configuration +# ----------------------------------------------------------------------------- +# Controls model checkpoint saving during training. +checkpoints: + # Save a checkpoint every N steps + # Set to null to disable intermediate checkpoints + interval: 250 + + # Number of most recent checkpoints to keep + # Set to -1 to keep all checkpoints + keep_last_n: -1 + +# ----------------------------------------------------------------------------- +# Flow Matching Configuration +# ----------------------------------------------------------------------------- +# Parameters for the flow matching training objective. +flow_matching: + # Timestep sampling mode + # "shifted_logit_normal" is recommended for LTX-2 models + timestep_sampling_mode: "shifted_logit_normal" + + # Additional parameters for timestep sampling + timestep_sampling_params: { } + +# ----------------------------------------------------------------------------- +# Hugging Face Hub Configuration +# ----------------------------------------------------------------------------- +# Settings for uploading trained models to the Hugging Face Hub. +hub: + # Whether to push the trained model to the Hub + push_to_hub: false + + # Repository ID on Hugging Face Hub (e.g., "username/my-lora-model") + # Required if push_to_hub is true + hub_model_id: null + +# ----------------------------------------------------------------------------- +# Weights & Biases Configuration +# ----------------------------------------------------------------------------- +# Settings for experiment tracking with W&B. +wandb: + # Enable W&B logging + enabled: false + + # W&B project name + project: "ltx-2-trainer" + + # W&B username or team (null uses default account) + entity: null + + # Tags to help organize runs + tags: [ "ltx2", "lora" ] + + # Log validation videos to W&B + log_validation_videos: true + +# ----------------------------------------------------------------------------- +# General Configuration +# ----------------------------------------------------------------------------- +# Global settings for the training run. + +# Random seed for reproducibility +seed: 42 + +# Directory to save outputs (checkpoints, validation videos, logs) +output_dir: "outputs/ltx2_av_lora" diff --git a/packages/ltx-trainer/configs/ltx2_v2v_ic_lora.yaml b/packages/ltx-trainer/configs/ltx2_v2v_ic_lora.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f5607009084d50ae48eac494b5361a6e00901eba --- /dev/null +++ b/packages/ltx-trainer/configs/ltx2_v2v_ic_lora.yaml @@ -0,0 +1,317 @@ +# ============================================================================= +# LTX-2 Video-to-Video (IC-LoRA) Training Configuration +# ============================================================================= +# +# This configuration is for training In-Context LoRA (IC-LoRA) adapters that +# enable video-to-video transformations. IC-LoRA learns to apply visual +# transformations (e.g., depth-to-video, pose control, style transfer, etc.) +# by conditioning on reference videos. +# +# Key differences from text-to-video LoRA: +# - Uses reference videos as conditioning input alongside text prompts +# - Requires preprocessed reference latents in addition to target latents +# - Validation requires reference videos to demonstrate the transformation +# +# Dataset structure for IC-LoRA training: +# preprocessed_data_root/ +# ├── latents/ # Target video latents (what the model learns to generate) +# ├── conditions/ # Text embeddings for each video +# └── reference_latents/ # Reference video latents (conditioning input) +# +# ============================================================================= + +# ----------------------------------------------------------------------------- +# Model Configuration +# ----------------------------------------------------------------------------- +# Specifies the base model to fine-tune and the training mode. +model: + # Path to the LTX-2 model checkpoint (.safetensors file) + # This should be a local path to your downloaded model + model_path: "path/to/ltx-2-model.safetensors" + + # Path to the text encoder model directory + # For LTX-2, this is typically the Gemma-based text encoder + text_encoder_path: "path/to/gemma-text-encoder" + + # Training mode: "lora" for efficient adapter training, "full" for full fine-tuning + # Note: video_to_video strategy requires "lora" mode + training_mode: "lora" + + # Optional: Path to resume training from a checkpoint + # Can be a checkpoint file (.safetensors) or directory (uses latest checkpoint) + load_checkpoint: null + +# ----------------------------------------------------------------------------- +# LoRA Configuration +# ----------------------------------------------------------------------------- +# Controls the Low-Rank Adaptation parameters for efficient fine-tuning. +lora: + # Rank of the LoRA matrices (higher = more capacity but more parameters) + # Typical values: 8, 16, 32, 64. Start with 16-32 for IC-LoRA. + rank: 32 + + # Alpha scaling factor (usually set equal to rank) + # The effective scaling is alpha/rank, so alpha=rank means scaling of 1.0 + alpha: 32 + + # Dropout probability for LoRA layers (0.0 = no dropout) + # Can help with regularization if overfitting occurs + dropout: 0.0 + + # Which transformer modules to apply LoRA to + # The LTX-2 transformer has separate attention and FFN blocks for video and audio: + # + # VIDEO MODULES: + # - attn1.to_k, attn1.to_q, attn1.to_v, attn1.to_out.0 (video self-attention) + # - attn2.to_k, attn2.to_q, attn2.to_v, attn2.to_out.0 (video cross-attention to text) + # - ff.net.0.proj, ff.net.2 (video feed-forward) + # + # AUDIO MODULES (not used for video-only IC-LoRA): + # - audio_attn1.to_k, audio_attn1.to_q, audio_attn1.to_v, audio_attn1.to_out.0 (audio self-attention) + # - audio_attn2.to_k, audio_attn2.to_q, audio_attn2.to_v, audio_attn2.to_out.0 (audio cross-attention to text) + # - audio_ff.net.0.proj, audio_ff.net.2 (audio feed-forward) + # + # AUDIO-VIDEO CROSS-ATTENTION MODULES (for cross-modal interaction, not used for video-only IC-LoRA): + # - audio_to_video_attn.to_k, audio_to_video_attn.to_q, audio_to_video_attn.to_v, audio_to_video_attn.to_out.0 + # (Q from video, K/V from audio - allows video to attend to audio features) + # - video_to_audio_attn.to_k, video_to_audio_attn.to_q, video_to_audio_attn.to_v, video_to_audio_attn.to_out.0 + # (Q from audio, K/V from video - allows audio to attend to video features) + # + # For IC-LoRA (video-only), we explicitly target video modules. + # Including FFN layers often improves transformation quality. + target_modules: + # Video self-attention + - "attn1.to_k" + - "attn1.to_q" + - "attn1.to_v" + - "attn1.to_out.0" + # Video cross-attention + - "attn2.to_k" + - "attn2.to_q" + - "attn2.to_v" + - "attn2.to_out.0" + # Video feed-forward (often improves transformation quality) + - "ff.net.0.proj" + - "ff.net.2" + +# ----------------------------------------------------------------------------- +# Training Strategy Configuration +# ----------------------------------------------------------------------------- +# Defines the video-to-video (IC-LoRA) training approach. +training_strategy: + # Strategy name: "video_to_video" for IC-LoRA training + name: "video_to_video" + + # Probability of conditioning on the first frame during training + # Higher values train the model to perform better in image-to-video (I2V) mode, + # where a clean first frame is provided and the model generates the rest of the video + # Increase this value to train the model to perform better in image-to-video (I2V) mode + first_frame_conditioning_p: 0.2 + + # Directory name (within preprocessed_data_root) containing reference video latents + # These are the conditioning inputs that guide the transformation + reference_latents_dir: "reference_latents" + +# ----------------------------------------------------------------------------- +# Optimization Configuration +# ----------------------------------------------------------------------------- +# Controls the training optimization parameters. +optimization: + # Learning rate for the optimizer + # Typical range for LoRA: 1e-5 to 1e-4 + learning_rate: 2e-4 + + # Total number of training steps + steps: 3000 + + # Batch size per GPU + # Reduce if running out of memory + batch_size: 1 + + # Number of gradient accumulation steps + # Effective batch size = batch_size * gradient_accumulation_steps * num_gpus + gradient_accumulation_steps: 1 + + # Maximum gradient norm for clipping (helps training stability) + max_grad_norm: 1.0 + + # Optimizer type: "adamw" (standard) or "adamw8bit" (memory-efficient) + optimizer_type: "adamw" + + # Learning rate scheduler type + # Options: "constant", "linear", "cosine", "cosine_with_restarts", "polynomial" + scheduler_type: "linear" + + # Additional scheduler parameters (depends on scheduler_type) + scheduler_params: { } + + # Enable gradient checkpointing to reduce memory usage + # Recommended for training with limited GPU memory + enable_gradient_checkpointing: true + +# ----------------------------------------------------------------------------- +# Acceleration Configuration +# ----------------------------------------------------------------------------- +# Hardware acceleration and memory optimization settings. +acceleration: + # Mixed precision training mode + # Options: "no" (fp32), "fp16" (half precision), "bf16" (bfloat16, recommended) + mixed_precision_mode: "bf16" + + # Model quantization for reduced memory usage + # Options: null (none), "int8-quanto", "int4-quanto", "int2-quanto", "fp8-quanto", "fp8uz-quanto" + quantization: null + + # Load text encoder in 8-bit precision to save memory + # Useful when GPU memory is limited + load_text_encoder_in_8bit: false + +# ----------------------------------------------------------------------------- +# Data Configuration +# ----------------------------------------------------------------------------- +# Specifies the training data location and loading parameters. +data: + # Root directory containing preprocessed training data + # Should contain: latents/, conditions/, and reference_latents/ subdirectories + preprocessed_data_root: "/path/to/preprocessed/data" + + # Number of worker processes for data loading + # Used for parallel data loading to speed up data loading + num_dataloader_workers: 2 + +# ----------------------------------------------------------------------------- +# Validation Configuration +# ----------------------------------------------------------------------------- +# Controls validation video generation during training. +validation: + # Text prompts for validation video generation + # Provide prompts representative of your training data + # LTX-2 prefers longer, detailed prompts that describe both visual content and audio + prompts: + - "A man in a casual blue jacket walks along a winding path through a lush green park on a bright sunny afternoon. Tall oak trees line the pathway, their leaves rustling gently in the breeze. Dappled sunlight creates shifting patterns on the ground as he strolls at a relaxed pace, occasionally looking up at the scenery around him. The audio captures footsteps on gravel, birds singing in the trees, distant children playing, and the soft whisper of wind through the foliage." + - "A fluffy orange tabby cat sits perfectly still on a wooden windowsill, its green eyes intently tracking small birds hopping on a branch just outside the glass. The cat's ears twitch and rotate, following every movement. Warm afternoon light illuminates its fur, creating a soft golden glow. Behind the cat, a cozy living room with a bookshelf and houseplants is visible. The audio features gentle purring, occasional soft meows, muffled bird chirps through the window, and quiet ambient room sounds." + + # Reference videos for validation (REQUIRED for video_to_video strategy) + # Must provide one reference video per prompt + # These are the conditioning inputs for generating validation outputs + reference_videos: + - "/path/to/reference_video_1.mp4" + - "/path/to/reference_video_2.mp4" + + # Negative prompt to avoid unwanted artifacts + negative_prompt: "worst quality, inconsistent motion, blurry, jittery, distorted" + + # Optional: First frame images for additional conditioning + # If provided, must have one image per prompt + images: null + + # Output video dimensions [width, height, frames] + # Width and height must be divisible by 32 + # Frames must satisfy: frames % 8 == 1 (e.g., 1, 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, ...) + video_dims: [ 512, 512, 81 ] + + # Frame rate for generated videos + frame_rate: 25.0 + + # Random seed for reproducible validation outputs + seed: 42 + + # Number of denoising steps for validation inference + # Higher values = better quality but slower generation + inference_steps: 30 + + # Generate validation videos every N training steps + # Set to null to disable validation during training + interval: 100 + + # Number of videos to generate per prompt + videos_per_prompt: 1 + + # Classifier-free guidance scale + # Higher values = stronger adherence to prompt but may introduce artifacts + guidance_scale: 3.0 + + # STG (Spatio-Temporal Guidance) parameters for improved video quality + # STG is combined with CFG for better temporal coherence + stg_scale: 1.0 # Recommended: 1.0 (0.0 disables STG) + stg_blocks: [29] # Recommended: single block 29 + stg_mode: "stg_v" # "stg_v" for video-only (no audio training) + + # Whether to generate audio in validation samples + # Can be enabled even when not training the audio branch + generate_audio: false + + # Skip validation at the beginning of training (step 0) + skip_initial_validation: false + + # Concatenate reference video side-by-side with generated output + # Useful for visually comparing the transformation quality + include_reference_in_output: true + +# ----------------------------------------------------------------------------- +# Checkpoint Configuration +# ----------------------------------------------------------------------------- +# Controls model checkpoint saving during training. +checkpoints: + # Save a checkpoint every N steps + # Set to null to disable intermediate checkpoints + interval: 250 + + # Number of most recent checkpoints to keep + # Set to -1 to keep all checkpoints + keep_last_n: 3 + +# ----------------------------------------------------------------------------- +# Flow Matching Configuration +# ----------------------------------------------------------------------------- +# Parameters for the flow matching training objective. +flow_matching: + # Timestep sampling mode + # "shifted_logit_normal" is recommended for LTX-2 models + timestep_sampling_mode: "shifted_logit_normal" + + # Additional parameters for timestep sampling + timestep_sampling_params: { } + +# ----------------------------------------------------------------------------- +# Hugging Face Hub Configuration +# ----------------------------------------------------------------------------- +# Settings for uploading trained models to the Hugging Face Hub. +hub: + # Whether to push the trained model to the Hub + push_to_hub: false + + # Repository ID on Hugging Face Hub (e.g., "username/my-ic-lora-model") + # Required if push_to_hub is true + hub_model_id: null + +# ----------------------------------------------------------------------------- +# Weights & Biases Configuration +# ----------------------------------------------------------------------------- +# Settings for experiment tracking with W&B. +wandb: + # Enable W&B logging + enabled: false + + # W&B project name + project: "ltx-2-trainer" + + # W&B username or team (null uses default account) + entity: null + + # Tags to help organize runs + tags: [ "ltx2", "ic-lora", "video-to-video" ] + + # Log validation videos to W&B + log_validation_videos: true + +# ----------------------------------------------------------------------------- +# General Configuration +# ----------------------------------------------------------------------------- +# Global settings for the training run. + +# Random seed for reproducibility +seed: 42 + +# Directory to save outputs (checkpoints, validation videos, logs) +output_dir: "outputs/ltx2_v2v_ic_lora" diff --git a/packages/ltx-trainer/docs/configuration-reference.md b/packages/ltx-trainer/docs/configuration-reference.md new file mode 100644 index 0000000000000000000000000000000000000000..e924151b9e4fd6d9f4f33f7b82bd5cc8101ddaee --- /dev/null +++ b/packages/ltx-trainer/docs/configuration-reference.md @@ -0,0 +1,366 @@ +# Configuration Reference + +The trainer uses structured Pydantic models for configuration, making it easy to customize training parameters. +This guide covers all available configuration options and their usage. + +## 📋 Overview + +The main configuration class is [`LtxTrainerConfig`](../src/ltx_trainer/config.py), which includes the following sub-configurations: + +- **ModelConfig**: Base model and training mode settings +- **LoraConfig**: LoRA training parameters +- **TrainingStrategyConfig**: Training strategy settings (text-to-video or video-to-video) +- **OptimizationConfig**: Learning rate, batch sizes, and scheduler settings +- **AccelerationConfig**: Mixed precision and quantization settings +- **DataConfig**: Data loading parameters +- **ValidationConfig**: Validation and inference settings +- **CheckpointsConfig**: Checkpoint saving frequency and retention settings +- **HubConfig**: Hugging Face Hub integration settings +- **WandbConfig**: Weights & Biases logging settings +- **FlowMatchingConfig**: Timestep sampling parameters + +## 📄 Example Configuration Files + +Check out our example configurations in the `configs` directory: + +- 📄 [Audio-Video LoRA Training](../configs/ltx2_av_lora.yaml) - Joint audio-video to generation training +- 📄 [IC-LoRA Training](../configs/ltx2_v2v_ic_lora.yaml) - Video-to-video transformation training + +## ⚙️ Configuration Sections + +### ModelConfig + +Controls the base model and training mode settings. + +```yaml +model: + model_path: "/path/to/ltx-2-model.safetensors" # Local path to model checkpoint + text_encoder_path: "/path/to/gemma-model" # Path to Gemma text encoder directory + training_mode: "lora" # "lora" or "full" + load_checkpoint: null # Path to checkpoint to resume from +``` + +**Key parameters:** + +| Parameter | Description | +|-----------|-------------| +| `model_path` | **Required.** Local path to the LTX-2 model checkpoint (`.safetensors` file). URLs are not supported. | +| `text_encoder_path` | **Required.** Path to the Gemma text encoder model directory. Download from [HuggingFace](https://huggingface.co/google/gemma-3-12b-it-qat-q4_0-unquantized/). | +| `training_mode` | Training approach - `"lora"` for LoRA training or `"full"` for full-rank fine-tuning. | +| `load_checkpoint` | Optional path to resume training from a checkpoint file or directory. | + +> [!NOTE] +> LTX-2 requires both a model checkpoint and a Gemma text encoder. Both must be local paths. + +### LoraConfig + +LoRA-specific fine-tuning parameters (only used when `training_mode: "lora"`). + +```yaml +lora: + rank: 32 # LoRA rank (higher = more parameters) + alpha: 32 # LoRA alpha scaling factor + dropout: 0.0 # Dropout probability (0.0-1.0) + target_modules: # Modules to apply LoRA to + - "to_k" + - "to_q" + - "to_v" + - "to_out.0" +``` + +**Key parameters:** + +| Parameter | Description | +|-----------|-------------| +| `rank` | LoRA rank - higher values mean more trainable parameters (typical range: 8-128) | +| `alpha` | Alpha scaling factor - typically set equal to rank | +| `dropout` | Dropout probability for regularization | +| `target_modules` | List of transformer modules to apply LoRA adapters to (see below) | + +#### Understanding Target Modules + +The LTX-2 transformer has separate attention and feed-forward blocks for video and audio, as well as cross-attention +modules that enable the two modalities to exchange information. Choosing the right `target_modules` is critical for +achieving good results, especially when training with audio. + +**Video-only modules:** + +| Module Pattern | Description | +|----------------|-------------| +| `attn1.to_k`, `attn1.to_q`, `attn1.to_v`, `attn1.to_out.0` | Video self-attention | +| `attn2.to_k`, `attn2.to_q`, `attn2.to_v`, `attn2.to_out.0` | Video cross-attention (to text) | +| `ff.net.0.proj`, `ff.net.2` | Video feed-forward network | + +**Audio-only modules:** + +| Module Pattern | Description | +|----------------|-------------| +| `audio_attn1.to_k`, `audio_attn1.to_q`, `audio_attn1.to_v`, `audio_attn1.to_out.0` | Audio self-attention | +| `audio_attn2.to_k`, `audio_attn2.to_q`, `audio_attn2.to_v`, `audio_attn2.to_out.0` | Audio cross-attention (to text) | +| `audio_ff.net.0.proj`, `audio_ff.net.2` | Audio feed-forward network | + +**Audio-video cross-attention modules:** + +These modules enable bidirectional information flow between the audio and video modalities: + +| Module Pattern | Description | +|----------------|-------------| +| `audio_to_video_attn.to_k`, `audio_to_video_attn.to_q`, `audio_to_video_attn.to_v`, `audio_to_video_attn.to_out.0` | Video attends to audio (Q from video, K/V from audio) | +| `video_to_audio_attn.to_k`, `video_to_audio_attn.to_q`, `video_to_audio_attn.to_v`, `video_to_audio_attn.to_out.0` | Audio attends to video (Q from audio, K/V from video) | + +**Recommended configurations:** + +For **video-only training**, target the video attention layers: + +```yaml +target_modules: + - "attn1.to_k" + - "attn1.to_q" + - "attn1.to_v" + - "attn1.to_out.0" + - "attn2.to_k" + - "attn2.to_q" + - "attn2.to_v" + - "attn2.to_out.0" +``` + +For **audio-video training**, use patterns that match both branches: + +```yaml +target_modules: + - "to_k" + - "to_q" + - "to_v" + - "to_out.0" +``` + +> [!NOTE] +> Using shorter patterns like `"to_k"` will match all attention modules including `attn1.to_k`, `audio_attn1.to_k`, +> `audio_to_video_attn.to_k`, and `video_to_audio_attn.to_k`, effectively training video, audio, and cross-modal +> attention branches together. + +> [!TIP] +> You can also target the feed-forward (FFN) modules (`ff.net.0.proj`, `ff.net.2` for video, +> `audio_ff.net.0.proj`, `audio_ff.net.2` for audio) to increase the LoRA's capacity and potentially +> help it capture the target distribution better. + +### TrainingStrategyConfig + +Configures the training strategy. This replaces the legacy `ConditioningConfig`. + +#### Text-to-Video Strategy + +```yaml +training_strategy: + name: "text_to_video" + first_frame_conditioning_p: 0.1 # Probability of first-frame conditioning + with_audio: false # Enable joint audio-video training + audio_latents_dir: "audio_latents" # Directory for audio latents (when with_audio: true) +``` + +#### Video-to-Video Strategy (IC-LoRA) + +```yaml +training_strategy: + name: "video_to_video" + first_frame_conditioning_p: 0.1 + reference_latents_dir: "reference_latents" # Directory for reference video latents +``` + +**Key parameters:** + +| Parameter | Description | +|-----------|-------------| +| `name` | Strategy type: `"text_to_video"` or `"video_to_video"` | +| `first_frame_conditioning_p` | Probability of using first frame as conditioning (0.0-1.0) | +| `with_audio` | (text_to_video only) Enable joint audio-video training | +| `audio_latents_dir` | (text_to_video only) Directory name for audio latents | +| `reference_latents_dir` | (video_to_video only) Directory name for reference video latents | + +### OptimizationConfig + +Training optimization parameters including learning rates, batch sizes, and schedulers. + +```yaml +optimization: + learning_rate: 1e-4 # Learning rate + steps: 2000 # Total training steps + batch_size: 1 # Batch size per GPU + gradient_accumulation_steps: 1 # Steps to accumulate gradients + max_grad_norm: 1.0 # Gradient clipping threshold + optimizer_type: "adamw" # "adamw" or "adamw8bit" + scheduler_type: "linear" # Scheduler type + scheduler_params: {} # Additional scheduler parameters + enable_gradient_checkpointing: true # Memory optimization +``` + +**Key parameters:** + +| Parameter | Description | +|-----------|-------------| +| `learning_rate` | Learning rate for optimization (typical range: 1e-5 to 1e-3) | +| `steps` | Total number of training steps | +| `batch_size` | Batch size per GPU (reduce if running out of memory) | +| `gradient_accumulation_steps` | Accumulate gradients over multiple steps | +| `scheduler_type` | LR scheduler: `"constant"`, `"linear"`, `"cosine"`, `"cosine_with_restarts"`, `"polynomial"` | +| `enable_gradient_checkpointing` | Trade training speed for GPU memory savings (recommended for large models) | + +### AccelerationConfig + +Hardware acceleration and compute optimization settings. + +```yaml +acceleration: + mixed_precision_mode: "bf16" # "no", "fp16", or "bf16" + quantization: null # Quantization options + load_text_encoder_in_8bit: false # Load text encoder in 8-bit +``` + +**Key parameters:** + +| Parameter | Description | +|-----------|-------------| +| `mixed_precision_mode` | Precision mode - `"bf16"` recommended for modern GPUs | +| `quantization` | Model quantization: `null`, `"int8-quanto"`, `"int4-quanto"`, `"fp8-quanto"`, etc. | +| `load_text_encoder_in_8bit` | Load the Gemma text encoder in 8-bit to save GPU memory | + +### DataConfig + +Data loading and processing configuration. + +```yaml +data: + preprocessed_data_root: "/path/to/preprocessed/data" # Path to precomputed dataset + num_dataloader_workers: 2 # Background data loading workers +``` + +**Key parameters:** + +| Parameter | Description | +|-----------|-------------| +| `preprocessed_data_root` | Path to your preprocessed dataset (contains `latents/`, `conditions/`, etc.) | +| `num_dataloader_workers` | Number of parallel data loading processes (0 = synchronous loading, useful when debugging) | + +### ValidationConfig + +Validation and inference settings for monitoring training progress. + +```yaml +validation: + prompts: # Validation prompts + - "A cat playing with a ball" + - "A dog running in a field" + negative_prompt: "worst quality, inconsistent motion, blurry, jittery, distorted" + images: null # Optional image paths for image-to-video + reference_videos: null # Reference video paths (IC-LoRA only) + video_dims: [576, 576, 89] # Video dimensions [width, height, frames] + frame_rate: 25.0 # Frame rate for generated videos + seed: 42 # Random seed for reproducibility + inference_steps: 30 # Number of inference steps + interval: 100 # Steps between validation runs + videos_per_prompt: 1 # Videos generated per prompt + guidance_scale: 3.0 # CFG guidance strength + stg_scale: 1.0 # STG guidance strength (0.0 to disable) + stg_blocks: [29] # Transformer blocks to perturb for STG + stg_mode: "stg_av" # "stg_av" or "stg_v" (video only) + generate_audio: true # Whether to generate audio + skip_initial_validation: false # Skip validation at step 0 + include_reference_in_output: false # Include reference video side-by-side (IC-LoRA) +``` + +**Key parameters:** + +| Parameter | Description | +|-----------|-------------| +| `prompts` | List of text prompts for validation video generation | +| `images` | List of image paths for image-to-video validation (must match number of prompts) | +| `reference_videos` | List of reference video paths for IC-LoRA validation (must match number of prompts) | +| `video_dims` | Output dimensions `[width, height, frames]`. Width/height must be divisible by 32, frames must satisfy `frames % 8 == 1` | +| `interval` | Steps between validation runs (set to `null` to disable) | +| `guidance_scale` | CFG (Classifier-Free Guidance) scale. Recommended: 3.0 | +| `stg_scale` | STG (Spatio-Temporal Guidance) scale. 0.0 disables STG. Recommended: 1.0 | +| `stg_blocks` | Transformer blocks to perturb for STG. Recommended: `[29]` (single block) | +| `stg_mode` | STG mode: `"stg_av"` perturbs both audio and video, `"stg_v"` perturbs video only | +| `generate_audio` | Whether to generate audio in validation samples | +| `include_reference_in_output` | For IC-LoRA: concatenate reference video side-by-side with output | + +### CheckpointsConfig + +Model checkpointing configuration. + +```yaml +checkpoints: + interval: 250 # Steps between checkpoint saves (null = disabled) + keep_last_n: 3 # Number of recent checkpoints to retain +``` + +**Key parameters:** + +| Parameter | Description | +|-----------|-------------| +| `interval` | Steps between intermediate checkpoint saves (set to `null` to disable) | +| `keep_last_n` | Number of most recent checkpoints to keep (-1 = keep all) | + +### HubConfig + +Hugging Face Hub integration for automatic model uploads. + +```yaml +hub: + push_to_hub: false # Enable Hub uploading + hub_model_id: "username/model-name" # Hub repository ID +``` + +**Key parameters:** + +| Parameter | Description | +|-----------|-------------| +| `push_to_hub` | Whether to automatically push trained models to Hugging Face Hub | +| `hub_model_id` | Repository ID in format `"username/repository-name"` | + +### WandbConfig + +Weights & Biases logging configuration. + +```yaml +wandb: + enabled: false # Enable W&B logging + project: "ltx-2-trainer" # W&B project name + entity: null # W&B username or team + tags: [] # Tags for the run + log_validation_videos: true # Log validation videos to W&B +``` + +**Key parameters:** + +| Parameter | Description | +|-----------|-------------| +| `enabled` | Whether to enable W&B logging | +| `project` | W&B project name | +| `entity` | W&B username or team (null uses default account) | +| `log_validation_videos` | Whether to log validation videos to W&B | + +### FlowMatchingConfig + +Flow matching training configuration for timestep sampling. + +```yaml +flow_matching: + timestep_sampling_mode: "shifted_logit_normal" # Timestep sampling strategy + timestep_sampling_params: {} # Additional sampling parameters +``` + +**Key parameters:** + +| Parameter | Description | +|-----------|-------------| +| `timestep_sampling_mode` | Sampling strategy: `"uniform"` or `"shifted_logit_normal"` | +| `timestep_sampling_params` | Additional parameters for the sampling strategy | + +## 🚀 Next Steps + +Once you've configured your training parameters: + +- Set up your dataset using [Dataset Preparation](dataset-preparation.md) +- Choose your training approach in [Training Modes](training-modes.md) +- Start training with the [Training Guide](training-guide.md) diff --git a/packages/ltx-trainer/docs/dataset-preparation.md b/packages/ltx-trainer/docs/dataset-preparation.md new file mode 100644 index 0000000000000000000000000000000000000000..89413038c8d843aa8c06a6309e95e9547e6bdd90 --- /dev/null +++ b/packages/ltx-trainer/docs/dataset-preparation.md @@ -0,0 +1,331 @@ +# Dataset Preparation Guide + +This guide covers the complete workflow for preparing and preprocessing your dataset for training. + +## 📋 Overview + +The general dataset preparation workflow is: + +1. **(Optional)** Split long videos into scenes using `split_scenes.py` +2. **(Optional)** Generate captions for your videos using `caption_videos.py` +3. **Preprocess your dataset** using `process_dataset.py` to compute and cache video/audio latents and text embeddings +4. **Run the trainer** with your preprocessed dataset + +## 🎬 Step 1: Split Scenes + +If you're starting with raw, long-form videos (e.g., downloaded from YouTube), you should first split them into shorter, coherent scenes. + +```bash +uv run python scripts/split_scenes.py input.mp4 scenes_output_dir/ \ + --filter-shorter-than 5s +``` + +This will create multiple video clips in `scenes_output_dir`. +These clips will be the input for the captioning step, if you choose to use it. + +The script supports many configuration options for scene detection (detector algorithms, thresholds, minimum scene lengths, etc.): + +```bash +uv run python scripts/split_scenes.py --help +``` + +## 📝 Step 2: Caption Videos + +If your dataset doesn't include captions, you can automatically generate them using multimodal models that understand both video and audio. + +```bash +uv run python scripts/caption_videos.py scenes_output_dir/ \ + --output scenes_output_dir/dataset.json +``` + +If you're running into VRAM issues, try enabling 8-bit quantization to reduce memory usage: + +```bash +uv run python scripts/caption_videos.py scenes_output_dir/ \ + --output scenes_output_dir/dataset.json \ + --use-8bit +``` + +This will create a `dataset.json` file containing video paths and their captions. + +**Captioning options:** + +| Option | Description | +|--------|-------------| +| `--captioner-type` | `qwen_omni` (default, local) or `gemini_flash` (API) | +| `--use-8bit` | Enable 8-bit quantization for lower VRAM usage | +| `--no-audio` | Disable audio processing (video-only captions) | +| `--override` | Re-caption files that already have captions | +| `--api-key` | API key for Gemini Flash (or set `GOOGLE_API_KEY` env var) | + +**Caption format:** + +The captioner produces structured captions with sections for: +- **Visual content**: People, objects, actions, settings, colors, movements +- **Speech transcription**: Word-for-word transcription of spoken content +- **Sounds**: Music, ambient sounds, sound effects +- **On-screen text**: Any visible text overlays + +> [!NOTE] +> The automatically generated captions may contain inaccuracies or hallucinated content. +> We recommend reviewing and correcting the generated captions in your `dataset.json` file before proceeding to preprocessing. + +## ⚡ Step 3: Dataset Preprocessing + +This step preprocesses your video dataset by: + +1. Resizing and cropping videos to fit specified resolution buckets +2. Computing and caching video latent representations +3. Computing and caching text embeddings for captions +4. (Optional) Computing and caching audio latents + +### Basic Usage + +```bash +uv run python scripts/process_dataset.py dataset.json \ + --resolution-buckets "960x544x49" \ + --model-path /path/to/ltx-2-model.safetensors \ + --text-encoder-path /path/to/gemma-model +``` + +### With Audio Processing + +For audio-video training, add the `--with-audio` flag: + +```bash +uv run python scripts/process_dataset.py dataset.json \ + --resolution-buckets "960x544x49" \ + --model-path /path/to/ltx-2-model.safetensors \ + --text-encoder-path /path/to/gemma-model \ + --with-audio +``` + +### 📊 Dataset Format + +The trainer supports either videos or single images. +Note that your dataset must be homogeneous - either all videos or all images, mixing is not supported. + +> [!TIP] +> **Image Datasets:** When using images, follow the same preprocessing steps and format requirements as with videos, +> but use `1` for the frame count in the resolution bucket (e.g., `960x544x1`). + +The dataset must be a CSV, JSON, or JSONL metadata file with columns for captions and video paths: + +**JSON format example:** + +```json +[ + { + "caption": "A cat playing with a ball of yarn", + "media_path": "videos/cat_playing.mp4" + }, + { + "caption": "A dog running in the park", + "media_path": "videos/dog_running.mp4" + } +] +``` + +**JSONL format example:** + +```jsonl +{"caption": "A cat playing with a ball of yarn", "media_path": "videos/cat_playing.mp4"} +{"caption": "A dog running in the park", "media_path": "videos/dog_running.mp4"} +``` + +**CSV format example:** + +```csv +caption,media_path +"A cat playing with a ball of yarn","videos/cat_playing.mp4" +"A dog running in the park","videos/dog_running.mp4" +``` + +### 📐 Resolution Buckets + +Videos are organized into "buckets" of specific dimensions (width × height × frames). +Each video is assigned to the nearest matching bucket. +You can preprocess with one or multiple resolution buckets. +When training with multiple resolution buckets, you must use a batch size of 1. + +The dimensions of each bucket must follow these constraints due to LTX-2's VAE architecture: + +- **Spatial dimensions** (width and height) must be multiples of 32 +- **Number of frames** must satisfy `frames % 8 == 1` (e.g., 1, 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, 97, 121, etc.) + +**Guidelines for choosing training resolution:** + +- For high-quality, detailed videos: use larger spatial dimensions (e.g. 768x448) with fewer frames (e.g. 89) +- For longer, motion-focused videos: use smaller spatial dimensions (512×512) with more frames (121) +- Memory usage increases with both spatial and temporal dimensions + +**Example usage:** + +```bash +uv run python scripts/process_dataset.py dataset.json \ + --resolution-buckets "960x544x49" \ + --model-path /path/to/ltx-2-model.safetensors \ + --text-encoder-path /path/to/gemma-model +``` + +Multiple buckets are supported by separating entries with `;`: + +```bash +uv run python scripts/process_dataset.py dataset.json \ + --resolution-buckets "960x544x49;512x512x49" \ + --model-path /path/to/ltx-2-model.safetensors \ + --text-encoder-path /path/to/gemma-model +``` + +**Video processing workflow:** + +1. Videos are **resized** maintaining aspect ratio until either width or height matches the target +2. The larger dimension is **center cropped** to match the bucket's dimensions +3. Only the **first X frames are taken** to match the bucket's frame count, remaining frames are ignored + +> [!NOTE] +> The sequence length processed by the transformer model can be calculated as: +> +> ``` +> sequence_length = (H/32) * (W/32) * ((F-1)/8 + 1) +> ``` +> +> Where: +> - H = Height of video +> - W = Width of video +> - F = Number of frames +> - 32 = VAE's spatial downsampling factor +> - 8 = VAE's temporal downsampling factor +> +> For example, a 768×448×89 video would have sequence length: +> ``` +> (768/32) * (448/32) * ((89-1)/8 + 1) = 24 * 14 * 12 = 4,032 +> ``` +> +> Keep this in mind when choosing video dimensions, as longer sequences require more GPU memory. + +> [!WARNING] +> When training with multiple resolution buckets, you must use a batch size of 1 +> (i.e., set `optimization.batch_size: 1` in your training config). + +### 📁 Output Structure + +The preprocessed data is saved in a `.precomputed` directory: + +``` +dataset/ +└── .precomputed/ + ├── latents/ # Cached video latents + ├── conditions/ # Cached text embeddings + ├── audio_latents/ # (only if --with-audio) Cached audio latents + └── reference_latents/ # (only for IC-LoRA) Cached reference video latents +``` + +## 🪄 IC-LoRA Reference Video Preprocessing + +For IC-LoRA training, you need to preprocess datasets that include reference videos. +Reference videos provide the conditioning input while target videos represent the desired transformed output. + +### Dataset Format with Reference Videos + +**JSON format:** + +```json +[ + { + "caption": "A cat playing with a ball of yarn", + "media_path": "videos/cat_playing.mp4", + "reference_path": "references/cat_playing_depth.mp4" + } +] +``` + +**JSONL format:** + +```jsonl +{"caption": "A cat playing with a ball of yarn", "media_path": "videos/cat_playing.mp4", "reference_path": "references/cat_playing_depth.mp4"} +{"caption": "A dog running in the park", "media_path": "videos/dog_running.mp4", "reference_path": "references/dog_running_depth.mp4"} +``` + +### Preprocessing with Reference Videos + +To preprocess a dataset with reference videos, add the `--reference-column` argument specifying the name of the field +in your dataset JSON/JSONL/CSV that contains the reference video paths: + +```bash +uv run python scripts/process_dataset.py dataset.json \ + --resolution-buckets "960x544x49" \ + --model-path /path/to/ltx-2-model.safetensors \ + --text-encoder-path /path/to/gemma-model \ + --reference-column "reference_path" +``` + +This will create an additional `reference_latents/` directory containing the preprocessed reference video latents. + + +### Generating Reference Videos + +**Dataset Requirements for IC-LoRA:** + +- Your dataset must contain paired videos where each target video has a corresponding reference video +- Reference and target videos must have *identical* resolution and length +- Both reference and target videos should be preprocessed together using the same resolution buckets + +We provide an example script, [`scripts/compute_reference.py`](../scripts/compute_reference.py), to generate reference +videos for a given dataset. The default implementation generates Canny edge reference videos. + +```bash +uv run python scripts/compute_reference.py scenes_output_dir/ \ + --output scenes_output_dir/dataset.json +``` + +The script accepts a JSON file as the dataset configuration and updates it in-place by adding the filenames of the generated reference videos. + +If you want to generate a different type of condition (depth maps, pose skeletons, etc.), modify or replace the `compute_reference()` function within this script. + +### Example Dataset + +For reference, see our **[Canny Control Dataset](https://huggingface.co/datasets/Lightricks/Canny-Control-Dataset)** which demonstrates proper IC-LoRA dataset structure with paired videos and Canny edge maps. + + +## 🎯 LoRA Trigger Words + +When training a LoRA, you can specify a trigger token that will be prepended to all captions: + +```bash +uv run python scripts/process_dataset.py dataset.json \ + --resolution-buckets "960x544x49" \ + --model-path /path/to/ltx-2-model.safetensors \ + --text-encoder-path /path/to/gemma-model \ + --lora-trigger "MYTRIGGER" +``` + +This acts as a trigger word that activates the LoRA during inference when you include the same token in your prompts. + +> [!NOTE] +> There is no need to manually insert the trigger word into your dataset JSON/JSONL/CSV file. +> The trigger word specified with `--lora-trigger` is automatically prepended to each caption during preprocessing. + +## 🔍 Decoding Videos for Verification + +If you add the `--decode` flag, the script will VAE-decode the precomputed latents and save the resulting videos +in `.precomputed/decoded_videos`. When audio preprocessing is enabled (`--with-audio`), audio latents will also be +decoded and saved to `.precomputed/decoded_audio`. This allows you to visually and audibly inspect the processed data. + +```bash +uv run python scripts/process_dataset.py dataset.json \ + --resolution-buckets "960x544x49" \ + --model-path /path/to/ltx-2-model.safetensors \ + --text-encoder-path /path/to/gemma-model \ + --decode +``` + +For single-frame images, the decoded latents will be saved as PNG files rather than MP4 videos. + +## 🚀 Next Steps + +Once your dataset is preprocessed, you can proceed to: + +- Configure your training parameters in [Configuration Reference](configuration-reference.md) +- Choose your training approach in [Training Modes](training-modes.md) +- Start training with the [Training Guide](training-guide.md) diff --git a/packages/ltx-trainer/docs/ltx-core-api-guide.md b/packages/ltx-trainer/docs/ltx-core-api-guide.md new file mode 100644 index 0000000000000000000000000000000000000000..ae4e812e62e0a2c9119a4560d03b9dfaf18df63b --- /dev/null +++ b/packages/ltx-trainer/docs/ltx-core-api-guide.md @@ -0,0 +1,598 @@ +# LTX-Core Model API Guide + +This guide explains the core concepts and APIs used in the LTX-2 Audio-Video diffusion model. Understanding these concepts is essential for training, fine-tuning, and running inference with LTX models. + +## Table of Contents + +1. [Overview](#overview) +2. [Core Concepts](#core-concepts) + - [Modality](#modality---the-input-container) + - [Patchifiers](#patchifiers---format-conversion) + - [Latent Tools](#latent-tools---preparing-inputs) + - [Conditioning Items](#conditioning-items---adding-constraints) + - [Perturbations](#perturbations---fine-grained-control) +3. [Model Architecture](#model-architecture) +4. [Usage Patterns](#usage-patterns) + - [Text-to-Video Generation](#text-to-video-generation) + - [Image-to-Video Generation](#image-to-video-generation) + - [Video-to-Video (IC-LoRA)](#video-to-video-ic-lora) + - [Audio-Video Generation](#audio-video-generation) +5. [Common Pitfalls](#common-pitfalls) + +--- + +## Overview + +The LTX-2 model is a **joint Audio-Video diffusion transformer**. Unlike traditional models that handle one modality at a time, LTX-2 processes **video and audio simultaneously** in a unified architecture, enabling cross-modal attention between them. + +Key characteristics: +- **Dual-stream architecture**: Separate processing paths for video and audio that interact via cross-attention +- **Per-token timesteps**: Different tokens can have different noise levels (enables advanced conditioning) +- **Flexible conditioning**: Supports text, image, and video conditioning + +--- + +## Core Concepts + +### Modality - The Input Container + +The `Modality` dataclass wraps all information needed to process either video or audio: + +```python +from ltx_core.model.transformer.modality import Modality + +@dataclass +class Modality: + enabled: bool # Whether this modality should be processed + latent: torch.Tensor # Shape: (B, seq_len, D) - patchified tokens + timesteps: torch.Tensor # Shape: (B, seq_len) - noise level per token + positions: torch.Tensor # Shape: (B, dims, seq_len, 2) - spatial/temporal coordinates + context: torch.Tensor # Text embeddings + context_mask: torch.Tensor | None +``` + +**Field descriptions:** + +| Field | Description | +|-------|-------------| +| `enabled` | Set to `False` to skip processing this modality | +| `latent` | Sequence of tokens in patchified format (not spatial `[B,C,F,H,W]`) | +| `timesteps` | Per-token noise levels (sigma values). Enables token-level conditioning | +| `positions` | Coordinates for RoPE (Rotary Position Embeddings). Video: `[B, 3, seq, 2]`, Audio: `[B, 1, seq, 2]` | +| `context` | Text prompt embeddings from the Gemma encoder | +| `context_mask` | Optional attention mask for the context | + +### Patchifiers - Format Conversion + +Patchifiers convert between spatial format and sequence format: + +```python +from ltx_core.pipeline.components.patchifiers import ( + VideoLatentPatchifier, + AudioPatchifier, + VideoLatentShape, + AudioLatentShape, +) + +# Video patchification +video_patchifier = VideoLatentPatchifier(patch_size=1) + +# Spatial to sequence: [B, C, F, H, W] → [B, F*H*W, C] +patchified = video_patchifier.patchify(video_latent) + +# Sequence to spatial: [B, seq_len, C] → [B, C, F, H, W] +spatial = video_patchifier.unpatchify( + patchified, + output_shape=VideoLatentShape( + batch=1, channels=128, frames=7, height=16, width=24 + ) +) + +# Audio patchification +audio_patchifier = AudioPatchifier(patch_size=1) + +# [B, C, T, mel_bins] → [B, T, C*mel_bins] +patchified_audio = audio_patchifier.patchify(audio_latent) +``` + +### Latent Tools - Preparing Inputs + +Latent tools handle the setup of initial latents, masks, and positions. Combined with conditioning items, they provide flexible input preparation: + +```python +from ltx_core.pipeline.conditioning.tools import ( + VideoLatentTools, + AudioLatentTools, + LatentState, +) +from ltx_core.pipeline.components.patchifiers import VideoLatentShape, AudioLatentShape +from ltx_core.pipeline.components.protocols import VideoPixelShape + +# Create video latent tools +pixel_shape = VideoPixelShape( + batch=1, + frames=49, # Must be k*8 + 1 (e.g., 49, 97, 121) + height=512, + width=768, + fps=25.0, +) +video_tools = VideoLatentTools( + patchifier=video_patchifier, + target_shape=VideoLatentShape.from_pixel_shape(shape=pixel_shape), + fps=25.0, +) + +# Create an empty latent state (zeros with positions computed) +video_state = video_tools.create_initial_state(device=device, dtype=torch.bfloat16) +# video_state.latent: [B, seq_len, 128] - zeros (will be replaced with noise) +# video_state.denoise_mask: [B, seq_len, 1] - ones (all tokens to denoise) +# video_state.positions: [B, 3, seq_len, 2] - pixel coordinates for RoPE + +# Audio latent tools (similar pattern) +audio_tools = AudioLatentTools( + patchifier=audio_patchifier, + target_shape=AudioLatentShape.from_duration( + batch=1, + duration=2.0, # seconds + channels=8, + mel_bins=16, + ), +) +audio_state = audio_tools.create_initial_state(device, dtype) +``` + +### Conditioning Items - Adding Constraints + +Conditioning items modify latent states to add constraints like first-frame conditioning: + +```python +from ltx_core.pipeline.conditioning.types.latent_cond import VideoConditionByLatentIndex +from ltx_core.pipeline.conditioning.types.keyframe_cond import VideoConditionByKeyframeIndex + +# Option 1: Condition by latent index (replaces tokens in-place) +first_frame_cond = VideoConditionByLatentIndex( + latent=encoded_image, # VAE-encoded image [B, C, 1, H, W] + strength=1.0, # 1.0 = fully conditioned, 0.0 = fully denoised + latent_idx=0, # Which latent frame to condition +) +video_state = first_frame_cond.apply_to(video_state, video_tools) + +# Option 2: Condition by keyframe (appends conditioning tokens) +keyframe_cond = VideoConditionByKeyframeIndex( + keyframes=encoded_image, # VAE-encoded keyframe(s) + frame_idx=0, # Target frame index + strength=1.0, +) +video_state = keyframe_cond.apply_to(video_state, video_tools) +``` + +**Key concepts:** +- `LatentState` is a frozen dataclass containing `latent`, `denoise_mask`, and `positions` +- `denoise_mask` values: `1.0` = denoise this token, `0.0` = keep this token fixed +- Conditioning items return a new `LatentState` (immutable pattern) + +### Perturbations - Fine-Grained Control + +Perturbations allow you to selectively skip operations at the per-sample, per-block level: + +```python +from ltx_core.guidance.perturbations import ( + Perturbation, + PerturbationType, + PerturbationConfig, + BatchedPerturbationConfig, +) + +# Available perturbation types +PerturbationType.SKIP_A2V_CROSS_ATTN # Skip audio→video cross attention +PerturbationType.SKIP_V2A_CROSS_ATTN # Skip video→audio cross attention +PerturbationType.SKIP_VIDEO_SELF_ATTN # Skip video self attention +PerturbationType.SKIP_AUDIO_SELF_ATTN # Skip audio self attention + +# Example: Skip audio→video attention in specific blocks +perturbation = Perturbation( + type=PerturbationType.SKIP_A2V_CROSS_ATTN, + blocks=[0, 1, 2, 3], # Skip in blocks 0-3, or None for all blocks +) +config = PerturbationConfig(perturbations=[perturbation]) + +# For batched inputs +batched_config = BatchedPerturbationConfig([config, config]) # batch_size=2 + +# Or use empty config for normal operation +batched_config = BatchedPerturbationConfig.empty(batch_size=2) +``` + +**Use cases for perturbations:** +- **STG (Spatio-Temporal Guidance)**: Skip self-attention in block 29 to improve video quality +- Ablation studies (disable specific attention paths) +- Custom guidance strategies +- Debugging model behavior + +**STG (Spatio-Temporal Guidance) Example:** + +STG uses perturbations to improve video generation quality by running an additional forward pass with self-attention skipped: + +```python +from ltx_core.guidance.perturbations import ( + Perturbation, PerturbationType, PerturbationConfig, BatchedPerturbationConfig +) +from ltx_core.pipeline.components.guiders import STGGuider + +# Create STG perturbation config (recommended: block 29) +stg_perturbation = Perturbation( + type=PerturbationType.SKIP_VIDEO_SELF_ATTN, + blocks=[29], # Recommended: single block 29 +) +stg_config = BatchedPerturbationConfig([PerturbationConfig([stg_perturbation])]) + +# In your denoising loop: +stg_guider = STGGuider(scale=1.0) # Recommended scale + +# Normal forward pass +pos_video, pos_audio = model(video=video, audio=audio, perturbations=None) + +# Perturbed forward pass (for STG) +perturbed_video, perturbed_audio = model(video=video, audio=audio, perturbations=stg_config) + +# Apply STG guidance +denoised_video = pos_video + stg_guider.delta(pos_video, perturbed_video) +``` + +--- + +## Model Architecture + +The LTX-2 transformer consists of 48 blocks, each with the following structure: + +``` +┌─────────────────────────────────────────────────────────────┐ +│ VIDEO STREAM AUDIO STREAM │ +│ ─────────── ──────────── │ +│ │ +│ 1. Video Self-Attention 1. Audio Self-Attention │ +│ (attends to all video) (attends to all audio) │ +│ │ +│ 2. Video Cross-Attention 2. Audio Cross-Attention │ +│ (attends to text prompt) (attends to text prompt)│ +│ │ +│ ╔═══════════════════════════════════╗ │ +│ ║ 3. AUDIO-VIDEO CROSS ATTENTION ║ │ +│ ║ ║ │ +│ ║ • Audio-to-Video (A→V): ║ │ +│ ║ Video queries, Audio keys/vals ║ │ +│ ║ ║ │ +│ ║ • Video-to-Audio (V→A): ║ │ +│ ║ Audio queries, Video keys/vals ║ │ +│ ╚═══════════════════════════════════╝ │ +│ │ +│ 4. Video Feed-Forward 4. Audio Feed-Forward │ +│ │ +└─────────────────────────────────────────────────────────────┘ +``` + +**Key insight**: Video and audio "talk" to each other through bidirectional cross-attention in every block, enabling synchronized audio-video generation. + +### Forward Pass + +```python +from ltx_core.model.transformer.model import LTXModel + +# The transformer takes both modalities and returns predictions for both +video_velocity, audio_velocity = model( + video=video_modality, + audio=audio_modality, + perturbations=None, # or BatchedPerturbationConfig +) +# Returns velocity predictions used in the Euler diffusion step +``` + +--- + +## Usage Patterns + +### Text-to-Video Generation + +Basic text-to-video generation flow: + +```python +from dataclasses import replace +from ltx_core.pipeline.components.schedulers import LTX2Scheduler +from ltx_core.pipeline.components.diffusion_steps import EulerDiffusionStep +from ltx_core.pipeline.components.guiders import CFGGuider +from ltx_core.pipeline.conditioning.tools import VideoLatentTools +from ltx_core.pipeline.components.patchifiers import VideoLatentShape + +# 1. Encode text prompt +video_context, audio_context, mask = text_encoder(prompt) + +# 2. Create video latent tools and initial state +pixel_shape = VideoPixelShape(batch=1, frames=49, height=512, width=768, fps=25.0) +video_tools = VideoLatentTools( + patchifier=video_patchifier, + target_shape=VideoLatentShape.from_pixel_shape(shape=pixel_shape), + fps=25.0, +) +video_state = video_tools.create_initial_state(device, dtype) + +# 3. Add noise to the latent +noise = torch.randn_like(video_state.latent) +noised_latent = noise # Start from pure noise + +# 4. Create video modality +video = Modality( + enabled=True, + latent=noised_latent, + timesteps=video_state.denoise_mask, # Will be updated each step + positions=video_state.positions, + context=video_context, + context_mask=None, +) + +# 5. Setup scheduler and diffusion components +scheduler = LTX2Scheduler() +sigmas = scheduler.execute(steps=30).to(device) +stepper = EulerDiffusionStep() + +# 6. Denoising loop +for step_idx, sigma in enumerate(sigmas[:-1]): + # Update timesteps with current sigma (use replace for immutable Modality) + video = replace(video, timesteps=sigma * video_state.denoise_mask) + + # Forward pass + video_vel, _ = model(video=video, audio=disabled_audio, perturbations=None) + + # Euler step + new_latent = stepper.step(video.latent, video_vel, sigmas, step_idx) + video = replace(video, latent=new_latent) + +# 7. Decode to pixels +video_spatial = video_tools.unpatchify( + replace(video_state, latent=video.latent) +).latent # [B, C, F, H, W] +video_pixels = vae_decoder(video_spatial) # [B, 3, F, H, W] +``` + +### Image-to-Video Generation + +Condition the first frame with an image: + +```python +from ltx_core.pipeline.conditioning.types.latent_cond import VideoConditionByLatentIndex + +# Encode the conditioning image +image_latent = vae_encoder(image) # [B, C, 1, H, W] + +# Create video tools and initial state +pixel_shape = VideoPixelShape(batch=1, frames=49, height=512, width=768, fps=25.0) +video_tools = VideoLatentTools( + patchifier=video_patchifier, + target_shape=VideoLatentShape.from_pixel_shape(shape=pixel_shape), + fps=25.0, +) +video_state = video_tools.create_initial_state(device, dtype) + +# Apply first-frame conditioning +first_frame_cond = VideoConditionByLatentIndex( + latent=image_latent, + strength=1.0, # 1.0 = fully conditioned (no denoising on first frame) + latent_idx=0, # Condition frame 0 +) +video_state = first_frame_cond.apply_to(video_state, video_tools) +# The denoise_mask will be 0.0 for first-frame tokens, 1.0 for the rest + +# Proceed with denoising as usual... +``` + +### Video-to-Video (IC-LoRA) + +IC-LoRA enables video-to-video transformation by conditioning on a reference video. The key insight is that reference tokens are included in the sequence but kept at timestep=0 (clean, no denoising). + +```python +from dataclasses import replace +from ltx_core.pipeline.conditioning.tools import VideoLatentTools +from ltx_core.pipeline.components.patchifiers import VideoLatentShape +from ltx_core.pipeline.components.protocols import VideoPixelShape + +# 1. Create video tools for target +pixel_shape = VideoPixelShape(batch=1, frames=49, height=512, width=768, fps=25.0) +video_tools = VideoLatentTools( + patchifier=video_patchifier, + target_shape=VideoLatentShape.from_pixel_shape(shape=pixel_shape), + fps=25.0, +) + +# 2. Encode reference video to latents and patchify +ref_latents = vae_encoder(reference_video) # [B, C, F, H, W] +patchified_ref = video_patchifier.patchify(ref_latents) # [B, ref_seq_len, C] +ref_seq_len = patchified_ref.shape[1] + +# 3. Create target video state (positions computed automatically) +target_state = video_tools.create_initial_state(device, dtype) + +# 4. Compute positions for reference (SAME grid as target!) +# Reference positions are identical to target - this tells the model they correspond +ref_positions = target_state.positions.clone() + +# 5. CONCATENATE reference + target +combined_latent = torch.cat([patchified_ref, torch.randn_like(target_state.latent)], dim=1) +combined_positions = torch.cat([ref_positions, target_state.positions], dim=2) + +# 6. Create denoise mask: 0 for reference (keep clean), 1 for target (denoise) +ref_denoise_mask = torch.zeros(1, ref_seq_len, 1, device=device) +combined_denoise_mask = torch.cat([ref_denoise_mask, target_state.denoise_mask], dim=1) + +# 7. Create modality with combined inputs +video = Modality( + enabled=True, + latent=combined_latent, + timesteps=combined_denoise_mask, # Will be updated with sigma + positions=combined_positions, + context=video_context, + context_mask=None, +) + +# 8. Denoising loop - only update target portion +for step_idx, sigma in enumerate(sigmas[:-1]): + # Timesteps: 0 for reference, sigma for target + ref_timesteps = torch.zeros(1, ref_seq_len, 1, device=device) + target_timesteps = sigma * target_state.denoise_mask + new_timesteps = torch.cat([ref_timesteps, target_timesteps], dim=1) + video = replace(video, timesteps=new_timesteps) + + # Forward pass + video_vel, _ = model(video=video, audio=audio, perturbations=None) + + # Euler step - ONLY update target portion + target_latent = video.latent[:, ref_seq_len:] + target_vel = video_vel[:, ref_seq_len:] + updated_target = stepper.step(target_latent, target_vel, sigmas, step_idx) + + # Reconstruct (reference stays fixed) + new_latent = torch.cat([patchified_ref, updated_target], dim=1) + video = replace(video, latent=new_latent) + +# 9. Extract and decode only the target portion +final_target = video.latent[:, ref_seq_len:] +target_state_with_output = replace(target_state, latent=final_target) +target_spatial = video_tools.unpatchify(target_state_with_output).latent +video_pixels = vae_decoder(target_spatial) +``` + +**Why this works:** +- Self-attention sees both reference and target tokens +- Reference tokens have `timestep=0` (clean signal) - model learns to "copy" from them +- Shared positions tell the model "frame N of reference = frame N of target" +- Only target portion is updated during denoising + +### Audio-Video Generation + +Generate synchronized audio and video: + +```python +from dataclasses import replace +from ltx_core.pipeline.conditioning.tools import VideoLatentTools, AudioLatentTools +from ltx_core.pipeline.components.patchifiers import VideoLatentShape, AudioLatentShape +from ltx_core.pipeline.components.protocols import VideoPixelShape + +# Create latent tools for both modalities +pixel_shape = VideoPixelShape(batch=1, frames=49, height=512, width=768, fps=25.0) +video_tools = VideoLatentTools( + patchifier=video_patchifier, + target_shape=VideoLatentShape.from_pixel_shape(shape=pixel_shape), + fps=25.0, +) +audio_tools = AudioLatentTools( + patchifier=audio_patchifier, + target_shape=AudioLatentShape.from_duration(batch=1, duration=2.0, channels=8, mel_bins=16), +) + +# Create initial states +video_state = video_tools.create_initial_state(device, dtype) +audio_state = audio_tools.create_initial_state(device, dtype) + +# Encode text (returns separate embeddings for each modality) +video_context, audio_context, mask = text_encoder(prompt) + +# Create both modalities with noise +video = Modality( + enabled=True, + latent=torch.randn_like(video_state.latent), + timesteps=video_state.denoise_mask, + positions=video_state.positions, + context=video_context, + context_mask=None, +) +audio = Modality( + enabled=True, + latent=torch.randn_like(audio_state.latent), + timesteps=audio_state.denoise_mask, + positions=audio_state.positions, + context=audio_context, + context_mask=None, +) + +# Denoising loop - update both (use replace for immutable Modality) +for step_idx, sigma in enumerate(sigmas[:-1]): + video = replace(video, timesteps=sigma * video_state.denoise_mask) + audio = replace(audio, timesteps=sigma * audio_state.denoise_mask) + + # Forward pass returns both predictions + video_vel, audio_vel = model(video=video, audio=audio, perturbations=None) + + # Update both latents + video = replace(video, latent=stepper.step(video.latent, video_vel, sigmas, step_idx)) + audio = replace(audio, latent=stepper.step(audio.latent, audio_vel, sigmas, step_idx)) + +# Decode both +video_spatial = video_tools.unpatchify(replace(video_state, latent=video.latent)).latent +video_pixels = vae_decoder(video_spatial) +audio_spatial = audio_tools.unpatchify(replace(audio_state, latent=audio.latent)).latent +audio_mel = audio_decoder(audio_spatial) +audio_waveform = vocoder(audio_mel) +``` + +--- + +## Common Pitfalls + +### 1. Frame Count Constraints + +Video frame count must satisfy `num_frames % 8 == 1`: +- ✅ Valid: 49, 97, 121, 145 +- ❌ Invalid: 48, 50, 100 + +```python +# The "+1" accounts for causal padding in the VAE +latent_frames = (num_frames - 1) // 8 + 1 +``` + +### 2. Resolution Constraints + +Height and width must be divisible by 32: +- ✅ Valid: 512×768, 768×1024 +- ❌ Invalid: 500×750 + +### 3. Position Tensor Shapes + +Different modalities have different position tensor shapes: +- Video: `[B, 3, seq_len, 2]` - 3 dimensions for (time, height, width) +- Audio: `[B, 1, seq_len, 2]` - 1 dimension for time only + +### 4. Separate Context Embeddings + +Video and audio modalities receive **different** context embeddings from the text encoder: + +```python +# The text encoder returns separate embeddings +video_context, audio_context, mask = text_encoder(prompt) + +# Use the appropriate one for each modality +video = Modality(context=video_context, ...) # NOT audio_context! +audio = Modality(context=audio_context, ...) # NOT video_context! +``` + +### 5. Immutable Modality + +The `Modality` dataclass is **frozen** (immutable). Use `dataclasses.replace()` to create modified copies: + +```python +from dataclasses import replace + +# ❌ Wrong - will raise an error +video.latent = new_latent + +# ✅ Correct - create a new Modality with updated field +video = replace(video, latent=new_latent) + +# ✅ Update multiple fields at once +video = replace(video, latent=new_latent, timesteps=new_timesteps) +``` + +--- + +## Additional Resources + +- [Training Guide](./training-guide.md) - How to fine-tune LTX-2 models +- [Configuration Reference](./configuration-reference.md) - All configuration options +- [Training Modes](./training-modes.md) - LoRA, audio-video, and IC-LoRA training diff --git a/packages/ltx-trainer/docs/quick-start.md b/packages/ltx-trainer/docs/quick-start.md new file mode 100644 index 0000000000000000000000000000000000000000..76fa9bd83d4edb46481dba57aa14bfcea623e505 --- /dev/null +++ b/packages/ltx-trainer/docs/quick-start.md @@ -0,0 +1,124 @@ +# Quick Start Guide + +Get up and running with LTX-2 training in just a few steps! + +## 📋 Prerequisites + +Before you begin, ensure you have: + +1. **LTX-2 Model Checkpoint** - A local `.safetensors` file containing the LTX-2 model weights +2. **Gemma Text Encoder** - A local directory containing the Gemma model (required for LTX-2). + Download from: [HuggingFace Hub](https://huggingface.co/google/gemma-3-12b-it-qat-q4_0-unquantized/) +3. **Linux with CUDA** - The trainer requires `triton` which is Linux-only +4. **GPU with sufficient VRAM** - 80GB recommended. Lower VRAM may work with gradient checkpointing and lower + resolutions + +## ⚡ Installation + +First, install [uv](https://docs.astral.sh/uv/getting-started/installation/) if you haven't already. +Then clone the repository and install the dependencies: + +```bash +git clone https://github.com/Lightricks/LTX-Video +``` + +The `ltx-trainer` package is part of the `LTX-2` monorepo. Install the dependencies from the repository root, +then navigate to the trainer package: + +```bash +# From the repository root +uv sync +cd packages/ltx-trainer +``` + +> [!NOTE] +> The trainer depends on [`ltx-core`](../../ltx-core/) and [`ltx-pipelines`](../../ltx-pipelines/) packages which are automatically installed from the monorepo. + +## 🏋 Training Workflow + +### 1. Prepare Your Dataset + +Organize your videos and captions, then preprocess them: + +```bash +# Split long videos into scenes (optional) +uv run python scripts/split_scenes.py input.mp4 scenes_output_dir/ --filter-shorter-than 5s + +# Generate captions for videos (optional) +uv run python scripts/caption_videos.py scenes_output_dir/ --output dataset.json + +# Preprocess the dataset (compute latents and embeddings) +uv run python scripts/process_dataset.py dataset.json \ + --resolution-buckets "960x544x49" \ + --model-path /path/to/ltx-2-model.safetensors \ + --text-encoder-path /path/to/gemma-model +``` + +See [Dataset Preparation](dataset-preparation.md) for detailed instructions. + +### 2. Configure Training + +Create or modify a configuration YAML file. Start with one of the example configs: + +- [`configs/ltx2_av_lora.yaml`](../configs/ltx2_av_lora.yaml) - Audio-video LoRA training +- [`configs/ltx2_v2v_ic_lora.yaml`](../configs/ltx2_v2v_ic_lora.yaml) - IC-LoRA video-to-video + +Key settings to update: + +```yaml +model: + model_path: "/path/to/ltx-2-model.safetensors" + text_encoder_path: "/path/to/gemma-model" + +data: + preprocessed_data_root: "/path/to/preprocessed/data" + +output_dir: "outputs/my_training_run" +``` + +See [Configuration Reference](configuration-reference.md) for all available options. + +### 3. Start Training + +```bash +uv run python scripts/train.py configs/ltx2_av_lora.yaml +``` + +For multi-GPU training: + +```bash +uv run accelerate launch scripts/train.py configs/ltx2_av_lora.yaml +``` + +See [Training Guide](training-guide.md) for distributed training and advanced options. + +## 🎯 Training Modes + +The trainer supports several training modes: + +| Mode | Description | Config Example | +|----------------------|--------------------------------|--------------------------------------------| +| **LoRA** | Efficient adapter training | `training_strategy.name: "text_to_video"` | +| **Audio-Video LoRA** | Joint audio-video training | `training_strategy.with_audio: true` | +| **IC-LoRA** | Video-to-video transformations | `training_strategy.name: "video_to_video"` | +| **Full Fine-tuning** | Full model training | `model.training_mode: "full"` | + +See [Training Modes](training-modes.md) for detailed explanations. + +## Next Steps + +Once you've completed your first training run, you can: + +- **Use your trained LoRA for inference** - The [`ltx-pipelines`](../../ltx-pipelines/) package provides production-ready inference + pipelines for various use cases (T2V, I2V, IC-LoRA, etc.). See the package documentation for details. +- Learn more about [Dataset Preparation](dataset-preparation.md) for advanced preprocessing +- Explore different [Training Modes](training-modes.md) (LoRA, Audio-Video, IC-LoRA) +- Dive deeper into [Training Configuration](configuration-reference.md) +- Understand the model architecture in [LTX-Core API Guide](ltx-core-api-guide.md) + +## Need Help? + +If you run into issues at any step, see the [Troubleshooting Guide](troubleshooting.md) for solutions to common +problems. + +Join our [Discord community](https://discord.gg/2mafsHjJ) for real-time help and discussion! diff --git a/packages/ltx-trainer/docs/training-guide.md b/packages/ltx-trainer/docs/training-guide.md new file mode 100644 index 0000000000000000000000000000000000000000..809e03757aba77e20586e29b8ee6d9a833250875 --- /dev/null +++ b/packages/ltx-trainer/docs/training-guide.md @@ -0,0 +1,202 @@ +# Training Guide + +This guide covers how to run training jobs, from basic single-GPU training to advanced distributed setups and automatic +model uploads. + +## ⚡ Basic Training (Single GPU) + +After preprocessing your dataset and preparing a configuration file, you can start training using the trainer script: + +```bash +uv run python scripts/train.py configs/ltx2_av_lora.yaml +``` + +The trainer will: + +1. **Load your configuration** and validate all parameters +2. **Initialize models** and apply optimizations +3. **Run the training loop** with progress tracking +4. **Generate validation videos** (if configured) +5. **Save the trained weights** in your output directory + +### Output Files + +**For LoRA training:** + +- `lora_weights.safetensors` - Main LoRA weights file +- `training_config.yaml` - Copy of training configuration +- `validation_samples/` - Generated validation videos (if enabled) + +**For full model fine-tuning:** + +- `model_weights.safetensors` - Full model weights +- `training_config.yaml` - Copy of training configuration +- `validation_samples/` - Generated validation videos (if enabled) + +## 🖥️ Distributed / Multi-GPU Training + +We use Hugging Face 🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) for multi-GPU DDP and FSDP. + +### Configure Accelerate + +Run the interactive wizard once to set up your environment (DDP / FSDP, GPU count, etc.): + +```bash +uv run accelerate config +``` + +This stores your preferences in `~/.cache/huggingface/accelerate/default_config.yaml`. + +### Use the Provided Accelerate Configs (Recommended) + +We include ready-to-use Accelerate config files in `configs/accelerate/`: + +- [ddp.yaml](../configs/accelerate/ddp.yaml) — Standard DDP +- [ddp_compile.yaml](../configs/accelerate/ddp_compile.yaml) — DDP with `torch.compile` (Inductor) +- [fsdp.yaml](../configs/accelerate/fsdp.yaml) — Standard FSDP (auto-wraps `BasicAVTransformerBlock`) +- [fsdp_compile.yaml](../configs/accelerate/fsdp_compile.yaml) — FSDP with `torch.compile` (Inductor) + +Launch with a specific config using `--config_file`: + +```bash +# DDP (2 GPUs shown as example) +CUDA_VISIBLE_DEVICES=0,1 \ +uv run accelerate launch --config_file configs/accelerate/ddp.yaml \ + scripts/train.py configs/ltx2_av_lora.yaml + +# DDP + torch.compile +CUDA_VISIBLE_DEVICES=0,1 \ +uv run accelerate launch --config_file configs/accelerate/ddp_compile.yaml \ + scripts/train.py configs/ltx2_av_lora.yaml + +# FSDP (4 GPUs shown as example) +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +uv run accelerate launch --config_file configs/accelerate/fsdp.yaml \ + scripts/train.py configs/ltx2_av_lora.yaml + +# FSDP + torch.compile +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +uv run accelerate launch --config_file configs/accelerate/fsdp_compile.yaml \ + scripts/train.py configs/ltx2_av_lora.yaml +``` + +**Notes:** + +- The number of processes is taken from the Accelerate config (`num_processes`). Override with `--num_processes X` or + restrict GPUs with `CUDA_VISIBLE_DEVICES`. +- The compile variants enable `torch.compile` with the Inductor backend via Accelerate's `dynamo_config`. +- FSDP configs auto-wrap the transformer blocks (`fsdp_transformer_layer_cls_to_wrap: BasicAVTransformerBlock`). + +### Launch with Your Default Accelerate Config + +If you prefer to use your default Accelerate profile: + +```bash +# Use settings from your default accelerate config +uv run accelerate launch scripts/train.py configs/ltx2_av_lora.yaml + +# Override number of processes on the fly (e.g., 2 GPUs) +uv run accelerate launch --num_processes 2 scripts/train.py configs/ltx2_av_lora.yaml + +# Select specific GPUs +CUDA_VISIBLE_DEVICES=0,1 uv run accelerate launch scripts/train.py configs/ltx2_av_lora.yaml +``` + +> [!TIP] +> You can disable the in-terminal progress bars with `--disable_progress_bars` flag in the trainer CLI if desired. + +### Benefits of Distributed Training + +- **Faster training**: Distribute workload across multiple GPUs +- **Larger effective batch sizes**: Combine gradients from multiple GPUs +- **Memory efficiency**: Each GPU handles a portion of the batch + +> [!NOTE] +> Distributed training requires that all GPUs have sufficient memory for the model and batch size. The effective batch +> size becomes `batch_size × num_processes`. + +## 🤗 Pushing Models to Hugging Face Hub + +You can automatically push your trained models to the Hugging Face Hub by adding the following to your configuration: + +```yaml +hub: + push_to_hub: true + hub_model_id: "your-username/your-model-name" +``` + +### Prerequisites + +Before pushing, make sure you: + +1. **Have a Hugging Face account** - Sign up at [huggingface.co](https://huggingface.co) +2. **Are logged in** via `huggingface-cli login` or have set the `HUGGING_FACE_HUB_TOKEN` environment variable +3. **Have write access** to the specified repository (it will be created if it doesn't exist) + +### Login Options + +**Option 1: Interactive login** + +```bash +uv run huggingface-cli login +``` + +**Option 2: Environment variable** + +```bash +export HUGGING_FACE_HUB_TOKEN="your_token_here" +``` + +### What Gets Uploaded + +The trainer will automatically: + +- **Create a model card** with training details and sample outputs +- **Upload model weights** +- **Push sample videos as GIFs** in the model card +- **Include training configuration and prompts** + +## 📊 Weights & Biases Logging + +Enable experiment tracking with W&B by adding to your configuration: + +```yaml +wandb: + enabled: true + project: "ltx-2-trainer" + entity: null # Your W&B username or team + tags: [ "ltx2", "lora" ] + log_validation_videos: true +``` + +This will log: + +- Training loss and learning rate +- Validation videos +- Model configuration +- Training progress + +## 🚀 Next Steps + +After training completes: + +- **Run inference with your trained LoRA** - The [`ltx-pipelines`](../../ltx-pipelines/) package provides production-ready inference + pipelines that support loading custom LoRAs. Available pipelines include text-to-video, image-to-video, + IC-LoRA video-to-video, and more. See the [`ltx-pipelines`](../../ltx-pipelines/) package for usage details. +- **Test your model** with validation prompts +- **Iterate and improve** based on validation results +- **Share your results** by pushing to Hugging Face Hub + +## 💡 Tips for Successful Training + +- **Start small**: Begin with a small dataset and a few hundred steps to verify everything works +- **Monitor validation**: Keep an eye on validation samples to catch overfitting +- **Adjust learning rate**: Lower learning rates often produce better results +- **Use gradient checkpointing**: Essential for training with limited GPU memory +- **Save checkpoints**: Regular checkpoints help recover from interruptions + +## Need Help? + +If you encounter issues during training, see the [Troubleshooting Guide](troubleshooting.md). + +Join our [Discord community](https://discord.gg/2mafsHjJ) for real-time help! diff --git a/packages/ltx-trainer/docs/training-modes.md b/packages/ltx-trainer/docs/training-modes.md new file mode 100644 index 0000000000000000000000000000000000000000..8e8b4bdde30fb990e846ad0153bbacb4405ec1e8 --- /dev/null +++ b/packages/ltx-trainer/docs/training-modes.md @@ -0,0 +1,216 @@ +# Training Modes Guide + +The trainer supports several training modes, each suited for different use cases and requirements. + +## 🎯 Standard LoRA Training (Video-Only) + +Standard LoRA (Low-Rank Adaptation) training fine-tunes the model by adding small, trainable adapter layers while +keeping the base model frozen. This approach: + +- **Requires significantly less memory and compute** than full fine-tuning +- **Produces small, portable weight files** (typically a few hundred MB) +- **Is ideal for learning specific styles, effects, or concepts** +- **Can be easily combined with other LoRAs** during inference + +Configure standard LoRA training with: + +```yaml +model: + training_mode: "lora" + +training_strategy: + name: "text_to_video" + first_frame_conditioning_p: 0.1 + with_audio: false # Video-only training +``` + +## 🔊 Audio-Video LoRA Training + +LTX-2 supports joint audio-video generation. You can train LoRA adapters that affect both video and audio output: + +- **Synchronized audio-video generation** - Audio matches the visual content +- **Same efficient LoRA approach** - Just enable audio training +- **Requires audio latents** - Dataset must include preprocessed audio + +Configure audio-video training with: + +```yaml +model: + training_mode: "lora" + +training_strategy: + name: "text_to_video" + first_frame_conditioning_p: 0.1 + with_audio: true # Enable audio training + audio_latents_dir: "audio_latents" # Directory containing audio latents +``` + +**Example configuration file:** + +- 📄 [Audio-Video LoRA Training](../configs/ltx2_av_lora.yaml) + +**Dataset structure for audio-video training:** + +``` +preprocessed_data_root/ +├── latents/ # Video latents +├── conditions/ # Text embeddings +└── audio_latents/ # Audio latents (required when with_audio: true) +``` + +> [!IMPORTANT] +> When training audio-video LoRAs, ensure your `target_modules` configuration captures video, audio, and +> cross-modal attention branches. Use patterns like `"to_k"` instead of `"attn1.to_k"` to match: +> - Video modules: `attn1.to_k`, `attn2.to_k` +> - Audio modules: `audio_attn1.to_k`, `audio_attn2.to_k` +> - Cross-modal modules: `audio_to_video_attn.to_k`, `video_to_audio_attn.to_k` +> +> The cross-modal attention modules (`audio_to_video_attn` and `video_to_audio_attn`) enable bidirectional +> information flow between audio and video, which is critical for synchronized audiovisual generation. +> See [Understanding Target Modules](configuration-reference.md#understanding-target-modules) for detailed guidance. + +> [!NOTE] +> You can generate audio during validation even if you're not training the audio branch. +> Set `validation.generate_audio: true` independently of `training_strategy.with_audio`. + +## 🔥 Full Model Fine-tuning + +Full model fine-tuning updates all parameters of the base model, providing maximum flexibility but +requiring substantial computational resources and larger training datasets: + +- **Offers the highest potential quality and capability improvements** +- **Requires multiple GPUs** and distributed training techniques (e.g., FSDP) +- **Produces large checkpoint files** (several GB) +- **Best for major model adaptations** or when LoRA limitations are reached + +Configure full fine-tuning with: + +```yaml +model: + training_mode: "full" + +training_strategy: + name: "text_to_video" + first_frame_conditioning_p: 0.1 +``` + +> [!IMPORTANT] +> Full fine-tuning of LTX-2 requires multiple high-end GPUs (e.g., 4-8× H100 80GB) and distributed +> training with FSDP. See [Training Guide](training-guide.md) for multi-GPU setup instructions. + +## 🔄 In-Context LoRA (IC-LoRA) Training + +IC-LoRA is a specialized training mode for video-to-video transformations. +Unlike standard training modes that learn from individual videos, IC-LoRA learns transformations from pairs of videos. +IC-LoRA enables a wide range of advanced video-to-video applications, such as: + +- **Control adapters** (e.g., Depth, Pose): Learn to map from a control signal (like a depth map or pose skeleton) to a + target video +- **Video deblurring**: Transform blurry input videos into sharp, high-quality outputs +- **Style transfer**: Apply the style of a reference video to a target video sequence +- **Colorization**: Convert grayscale reference videos into colorized outputs +- **Restoration and enhancement**: Denoise, upscale, or restore old or degraded videos + +By providing paired reference and target videos, IC-LoRA can learn complex transformations that go beyond caption-based conditioning. + +IC-LoRA training fundamentally differs from standard LoRA and full fine-tuning: + +- **Reference videos** provide clean, unnoised conditioning input showing the "before" state +- **Target videos** are noised during training and represent the desired "after" state +- **The model learns transformations** from reference videos to target videos +- **Loss is applied only to the target portion**, not the reference +- **Training and inference time increase significantly** due to the doubled sequence length + +To enable IC-LoRA training, configure your YAML file with: + +```yaml +model: + training_mode: "lora" # Required: IC-LoRA uses LoRA mode + +training_strategy: + name: "video_to_video" + first_frame_conditioning_p: 0.1 + reference_latents_dir: "reference_latents" # Directory for reference video latents +``` + +**Example configuration file:** + +- 📄 [IC-LoRA Training](../configs/ltx2_v2v_ic_lora.yaml) - Video-to-video transformation training + +### Dataset Requirements for IC-LoRA + +- Your dataset must contain **paired videos** where each target video has a corresponding reference video +- Reference and target videos must have **identical resolution and length** +- Both reference and target videos should be **preprocessed together** using the same resolution buckets + +**Dataset structure for IC-LoRA training:** + +``` +preprocessed_data_root/ +├── latents/ # Target video latents (what the model learns to generate) +├── conditions/ # Text embeddings for each video +└── reference_latents/ # Reference video latents (conditioning input) +``` + +### Generating Reference Videos + +We provide an example script to generate reference videos (e.g., Canny edge maps) for a given dataset. +The script takes a JSON file as input (e.g., output of `caption_videos.py`) and updates it with the generated reference +video paths. + +```bash +uv run python scripts/compute_reference.py scenes_output_dir/ \ + --output scenes_output_dir/dataset.json +``` + +To compute a different condition (depth maps, pose skeletons, etc.), modify the `compute_reference()` function in the +script. + +### Configuration Requirements for IC-LoRA + +- You **must** provide `reference_videos` in your validation configuration when using IC-LoRA training +- The number of reference videos must match the number of validation prompts + +Example validation configuration for IC-LoRA: + +```yaml +validation: + prompts: + - "First prompt describing the desired output" + - "Second prompt describing the desired output" + reference_videos: + - "/path/to/reference1.mp4" + - "/path/to/reference2.mp4" + include_reference_in_output: true # Show reference side-by-side with output +``` + +## 📊 Training Mode Comparison + +| Aspect | LoRA | Audio-Video LoRA | Full Fine-tuning | IC-LoRA | +|----------------------|------------|------------------|------------------|----------------| +| **Memory Usage** | Low | Low-Medium | High | Medium | +| **Training Speed** | Fast | Fast | Slow | Medium | +| **Output Size** | 100MB-few GB (depends on rank) | 100MB-few GB (depends on rank) | Tens of GB | 100MB-few GB (depends on rank) | +| **Flexibility** | Medium | Medium | High | Specialized | +| **Audio Support** | Optional | Yes | Optional | No | +| **Reference Videos** | No | No | No | Yes (required) | + +## 🎬 Using Trained Models for Inference + +After training, use the [`ltx-pipelines`](../../ltx-pipelines/) package for production inference with your trained LoRAs: + +| Training Mode | Recommended Pipeline | +|---------------|---------------------| +| LoRA / Audio-Video LoRA | `TI2VidOneStagePipeline` or `TI2VidTwoStagesPipeline` | +| IC-LoRA | `ICLoraPipeline` | + +All pipelines support loading custom LoRAs via the `loras` parameter. See the [`ltx-pipelines`](../../ltx-pipelines/) package +documentation for detailed usage instructions. + +## 🚀 Next Steps + +Once you've chosen your training mode: + +- Set up your dataset using [Dataset Preparation](dataset-preparation.md) +- Configure your training parameters in [Configuration Reference](configuration-reference.md) +- Start training with the [Training Guide](training-guide.md) diff --git a/packages/ltx-trainer/docs/troubleshooting.md b/packages/ltx-trainer/docs/troubleshooting.md new file mode 100644 index 0000000000000000000000000000000000000000..fd265d4908cb1626c74fdcdc592fb5e9043b6975 --- /dev/null +++ b/packages/ltx-trainer/docs/troubleshooting.md @@ -0,0 +1,293 @@ +# Troubleshooting Guide + +This guide covers common issues and solutions when training with the LTX-2 trainer. + +## 🔧 VRAM and Memory Issues + +Memory management is crucial for successful training with LTX-2. + +### Memory Optimization Techniques + +#### 1. Enable Gradient Checkpointing + +Gradient checkpointing trades training speed for memory savings. **Highly recommended** for most training runs: + +```yaml +optimization: + enable_gradient_checkpointing: true +``` + +#### 2. Enable 8-bit Text Encoder + +Load the Gemma text encoder in 8-bit precision to save GPU memory: + +```yaml +acceleration: + load_text_encoder_in_8bit: true +``` + +#### 3. Reduce Batch Size + +Lower the batch size if you encounter out-of-memory errors: + +```yaml +optimization: + batch_size: 1 # Start with 1 and increase gradually +``` + +Use gradient accumulation to maintain a larger effective batch size: + +```yaml +optimization: + batch_size: 1 + gradient_accumulation_steps: 4 # Effective batch size = 4 +``` + +#### 4. Use Lower Resolution + +Reduce spatial or temporal dimensions to save memory: + +```bash +# Smaller spatial resolution +uv run python scripts/process_dataset.py dataset.json \ + --resolution-buckets "512x512x49" \ + --model-path /path/to/model.safetensors \ + --text-encoder-path /path/to/gemma + +# Fewer frames +uv run python scripts/process_dataset.py dataset.json \ + --resolution-buckets "960x544x25" \ + --model-path /path/to/model.safetensors \ + --text-encoder-path /path/to/gemma +``` + +#### 5. Enable Model Quantization + +Use quantization to reduce memory usage: + +```yaml +acceleration: + quantization: "int8-quanto" # Options: int8-quanto, int4-quanto, fp8-quanto +``` + +#### 6. Use 8-bit Optimizer + +The 8-bit AdamW optimizer uses less memory: + +```yaml +optimization: + optimizer_type: "adamw8bit" +``` + +--- + +## ⚠️ Common Usage Issues + +### Issue: "No module named 'ltx_trainer'" Error + +**Solution:** +Ensure you've installed the dependencies and are using `uv run` to execute scripts: + +```bash +# From the repository root +uv sync +cd packages/ltx-trainer +uv run python scripts/train.py configs/ltx2_av_lora.yaml +``` + +> [!TIP] +> Always use `uv run` to execute Python scripts. This automatically uses the correct virtual environment +> without requiring manual activation. + +### Issue: "Gemma model path is not a directory" Error + +**Solution:** +The `text_encoder_path` must point to a directory containing the Gemma model, not a file: + +```yaml +model: + model_path: "/path/to/ltx-2-model.safetensors" # File path + text_encoder_path: "/path/to/gemma-model/" # Directory path +``` + +### Issue: "Model path does not exist" Error + +**Solution:** +LTX-2 requires local model paths. URLs are not supported: + +```yaml +# ✅ Correct - local path +model: + model_path: "/path/to/ltx-2-model.safetensors" + +# ❌ Wrong - URL not supported +model: + model_path: "https://huggingface.co/..." +``` + +### Issue: "Frames must satisfy frames % 8 == 1" Error + +**Solution:** +LTX-2 requires the number of frames to satisfy `frames % 8 == 1`: + +- ✅ Valid: 1, 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, 97, 121 +- ❌ Invalid: 24, 32, 48, 64, 100 + +### Issue: Slow Training Speed + +**Optimizations:** + +1. **Disable gradient checkpointing** (if you have enough VRAM): + + ```yaml + optimization: + enable_gradient_checkpointing: false + ``` + + +2. **Use torch.compile** via Accelerate: + + ```bash + uv run accelerate launch --config_file configs/accelerate/ddp_compile.yaml \ + scripts/train.py configs/ltx2_av_lora.yaml + ``` + +### Issue: Poor Quality Validation Outputs + +**Solutions:** + +1. **Use Image-to-Video Validation:** + For more reliable validation, use image-to-video (first-frame conditioning) rather than pure text-to-video: + + ```yaml + validation: + prompts: + - "a professional portrait video of a person" + images: + - "/path/to/first_frame.png" # One image per prompt + ``` + +2. **Increase inference steps:** + + ```yaml + validation: + inference_steps: 50 # Default is 30 + ``` + +3. **Adjust guidance settings:** + + ```yaml + validation: + guidance_scale: 3.0 # CFG scale (recommended: 3.0) + stg_scale: 1.0 # STG scale for temporal coherence (recommended: 1.0) + stg_blocks: [29] # Transformer block to perturb + ``` + +4. **Check caption quality:** + Review and manually edit captions for accuracy if using auto-generated captions. + LTX-2 prefers long, detailed captions that describe both visual content and audio (e.g., ambient sounds, speech, music). + +5. **Check target modules:** + Ensure your `target_modules` configuration matches your training goals. For audio-video training, + use patterns that match both branches (e.g., `"to_k"` instead of `"attn1.to_k"`). + See [Understanding Target Modules](configuration-reference.md#understanding-target-modules) for details. + +6. **Adjust LoRA rank:** + Try higher values for more capacity: + + ```yaml + lora: + rank: 64 # Or 128 for more capacity + ``` + +7. **Increase training steps:** + + ```yaml + optimization: + steps: 3000 + ``` + +--- + +## 🔍 Debugging Tools + +### Monitor GPU Memory Usage + +Track memory usage during training: + +```bash +# Watch GPU memory in real-time +watch -n 1 nvidia-smi + +# Log memory usage to file +nvidia-smi --query-gpu=memory.used,memory.total --format=csv --loop=5 > memory_log.csv +``` + +### Verify Preprocessed Data + +Decode latents to visualize the preprocessed videos: + +```bash +uv run python scripts/decode_latents.py dataset/.precomputed/latents debug_output \ + --model-path /path/to/model.safetensors +``` + +To also decode audio latents, add the `--with-audio` flag: + +```bash +uv run python scripts/decode_latents.py dataset/.precomputed/latents debug_output \ + --model-path /path/to/model.safetensors \ + --with-audio +``` + +Compare decoded videos and audio with originals to ensure quality. + +--- + +## 💡 Best Practices + +### Before Training + +- [ ] Test preprocessing with a small subset first +- [ ] Verify all video files are accessible +- [ ] Check available GPU memory +- [ ] Review configuration against hardware capabilities +- [ ] Ensure model and text encoder paths are correct + +### During Training + +- [ ] Monitor GPU memory usage +- [ ] Check loss convergence regularly +- [ ] Review validation samples periodically +- [ ] Save checkpoints frequently + +### After Training + +- [ ] Test trained model with diverse prompts +- [ ] Document training parameters and results +- [ ] Archive training data and configs + +## 🆘 Getting Help + +If you're still experiencing issues: + +1. **Check logs:** Review console output for error details +2. **Search issues:** Look through GitHub issues for similar problems +3. **Provide details:** When reporting issues, include: + - Hardware specifications (GPU model, VRAM) + - Configuration file used + - Complete error message + - Steps to reproduce the issue + +--- + +## 🤝 Join the Community + +Have questions, want to share your results, or need real-time help? +Join our [community Discord server](https://discord.gg/2mafsHjJ) to connect with other users and the development team! + +- Get troubleshooting help +- Share your training results and workflows +- Stay up to date with announcements and updates + +We look forward to seeing you there! diff --git a/packages/ltx-trainer/docs/utility-scripts.md b/packages/ltx-trainer/docs/utility-scripts.md new file mode 100644 index 0000000000000000000000000000000000000000..124163eac42c09ecb62ee9a3e586a97a84022aa6 --- /dev/null +++ b/packages/ltx-trainer/docs/utility-scripts.md @@ -0,0 +1,274 @@ +# Utility Scripts Reference + +This guide covers the various utility scripts available for preprocessing, conversion, and debugging tasks. + +## 🎬 Dataset Processing Scripts + +### Video Scene Splitting + +The `scripts/split_scenes.py` script automatically splits long videos into shorter, coherent scenes. + +```bash +# Basic scene splitting +uv run python scripts/split_scenes.py input.mp4 output_dir/ --filter-shorter-than 5s +``` + +**Key features:** + +- **Automatic scene detection**: Uses PySceneDetect for intelligent splitting +- **Multiple algorithms**: Content-based, adaptive, threshold, and histogram detection +- **Filtering options**: Remove scenes shorter than specified duration +- **Customizable parameters**: Thresholds, window sizes, and detection modes + +**Common options:** + +```bash +# See all available options +uv run python scripts/split_scenes.py --help + +# Use adaptive detection with custom threshold +uv run python scripts/split_scenes.py video.mp4 scenes/ --detector adaptive --threshold 30.0 + +# Limit to maximum number of scenes +uv run python scripts/split_scenes.py video.mp4 scenes/ --max-scenes 50 +``` + +### Automatic Video Captioning + +The `scripts/caption_videos.py` script generates captions for videos (with audio) using multimodal models. + +```bash +# Generate captions for all videos in a directory (uses Qwen2.5-Omni by default) +uv run python scripts/caption_videos.py videos_dir/ --output dataset.json + +# Use 8-bit quantization to reduce VRAM usage +uv run python scripts/caption_videos.py videos_dir/ --output dataset.json --use-8bit + +# Use Gemini Flash API instead (requires API key) +uv run python scripts/caption_videos.py videos_dir/ --output dataset.json \ + --captioner-type gemini_flash --api-key YOUR_API_KEY + +# Caption without audio processing (video-only) +uv run python scripts/caption_videos.py videos_dir/ --output dataset.json --no-audio + +# Force re-caption all files +uv run python scripts/caption_videos.py videos_dir/ --output dataset.json --override +``` + +**Key features:** + +- **Audio-visual captioning**: Processes both video and audio content, including speech transcription +- **Multiple backends**: + - `qwen_omni` (default): Local Qwen2.5-Omni model - processes video + audio locally + - `gemini_flash`: Google Gemini Flash API - cloud-based, requires API key +- **Structured output**: Captions include visual description, speech transcription, sounds, and on-screen text +- **Memory optimization**: 8-bit quantization option for limited VRAM +- **Incremental processing**: Skips already-captioned files by default +- **Multiple output formats**: JSON, JSONL, CSV, or TXT + +**Caption format:** + +The captioner produces structured captions with four sections: +- `[VISUAL]`: Detailed description of visual content +- `[SPEECH]`: Word-for-word transcription of spoken content +- `[SOUNDS]`: Description of music, ambient sounds, sound effects +- `[TEXT]`: Any on-screen text visible in the video + +**Environment variables (for Gemini Flash):** + +Set one of these to use Gemini Flash without passing `--api-key`: +- `GOOGLE_API_KEY` +- `GEMINI_API_KEY` + +### Dataset Preprocessing + +The `scripts/process_dataset.py` script processes videos and caches latents for training. + +```bash +# Basic preprocessing +uv run python scripts/process_dataset.py dataset.json \ + --resolution-buckets "960x544x49" \ + --model-path /path/to/ltx-2-model.safetensors \ + --text-encoder-path /path/to/gemma-model + +# With audio processing +uv run python scripts/process_dataset.py dataset.json \ + --resolution-buckets "960x544x49" \ + --model-path /path/to/ltx-2-model.safetensors \ + --text-encoder-path /path/to/gemma-model \ + --with-audio + +# With video decoding for verification +uv run python scripts/process_dataset.py dataset.json \ + --resolution-buckets "960x544x49" \ + --model-path /path/to/ltx-2-model.safetensors \ + --text-encoder-path /path/to/gemma-model \ + --decode +``` + +Multiple resolution buckets can be specified, separated by `;`: + +```bash +uv run python scripts/process_dataset.py dataset.json \ + --resolution-buckets "960x544x49;512x512x81" \ + --model-path /path/to/ltx-2-model.safetensors \ + --text-encoder-path /path/to/gemma-model +``` + +> [!NOTE] +> When training with multiple resolution buckets, set `optimization.batch_size: 1`. + +For detailed usage, see the [Dataset Preparation Guide](dataset-preparation.md). + +### Reference Video Generation + +The `scripts/compute_reference.py` script provides a template for creating reference videos needed for IC-LoRA training. +The default implementation generates Canny edge reference videos. + +```bash +# Generate Canny edge reference videos +uv run python scripts/compute_reference.py videos_dir/ --output dataset.json +``` + +**Key features:** + +- **Canny edge detection**: Creates edge-based reference videos +- **In-place editing**: Updates existing dataset JSON files +- **Customizable**: Modify the `compute_reference()` function for different conditions (depth, pose, etc.) + +> [!TIP] +> You can edit this script to generate other types of reference videos for IC-LoRA training, +> such as depth maps, segmentation masks, or any custom video transformation. + +## 🔍 Debugging and Verification Scripts + +### Latents Decoding + +The `scripts/decode_latents.py` script decodes precomputed video latents back into video files for visual inspection. + +```bash +# Basic usage +uv run python scripts/decode_latents.py /path/to/latents/dir \ + --output-dir /path/to/output \ + --model-path /path/to/ltx-2-model.safetensors + +# With VAE tiling for large videos +uv run python scripts/decode_latents.py /path/to/latents/dir \ + --output-dir /path/to/output \ + --model-path /path/to/ltx-2-model.safetensors \ + --vae-tiling + +# Decode both video and audio latents +uv run python scripts/decode_latents.py /path/to/latents/dir \ + --output-dir /path/to/output \ + --model-path /path/to/ltx-2-model.safetensors \ + --with-audio +``` + +**The script will:** + +1. **Load the VAE model** from the specified path +2. **Process all `.pt` latent files** in the input directory +3. **Decode each latent** back into a video using the VAE +4. **Save resulting videos** as MP4 files in the output directory + +**When to use:** + +- **Verify preprocessing quality**: Check that your videos were encoded correctly +- **Debug training data**: Visualize what the model actually sees during training +- **Quality assessment**: Ensure latent encoding preserves important visual details + + +### Inference Script + +The `scripts/inference.py` script runs inference with a trained model. + +> [!TIP] +> For production inference, consider using the [`ltx-pipelines`](../../ltx-pipelines/) package which provides optimized, +> feature-rich pipelines for various use cases: +> - **Text/Image-to-Video**: `TI2VidOneStagePipeline`, `TI2VidTwoStagesPipeline` +> - **Distilled (fast) inference**: `DistilledPipeline` +> - **IC-LoRA video-to-video**: `ICLoraPipeline` +> - **Keyframe interpolation**: `KeyframeInterpolationPipeline` +> +> All pipelines support loading custom LoRAs trained with this trainer. + +```bash +# Text-to-video inference (with audio by default) +# By default, uses CFG scale 3.0 and STG scale 1.0 with block 29 +uv run python scripts/inference.py \ + --checkpoint /path/to/model.safetensors \ + --text-encoder-path /path/to/gemma \ + --prompt "A cat playing with a ball" \ + --output output.mp4 + +# Video-only (skip audio generation) +uv run python scripts/inference.py \ + --checkpoint /path/to/model.safetensors \ + --text-encoder-path /path/to/gemma \ + --prompt "A cat playing with a ball" \ + --skip-audio \ + --output output.mp4 + +# Image-to-video with conditioning image +uv run python scripts/inference.py \ + --checkpoint /path/to/model.safetensors \ + --text-encoder-path /path/to/gemma \ + --prompt "A cat walking" \ + --condition-image first_frame.png \ + --output output.mp4 + +# Custom guidance settings +uv run python scripts/inference.py \ + --checkpoint /path/to/model.safetensors \ + --text-encoder-path /path/to/gemma \ + --prompt "A cat playing with a ball" \ + --guidance-scale 3.0 \ + --stg-scale 1.0 \ + --stg-blocks 29 \ + --output output.mp4 + +# Disable STG (CFG only) +uv run python scripts/inference.py \ + --checkpoint /path/to/model.safetensors \ + --text-encoder-path /path/to/gemma \ + --prompt "A cat playing with a ball" \ + --stg-scale 0.0 \ + --output output.mp4 +``` + +**Guidance parameters:** + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `--guidance-scale` | 3.0 | CFG (Classifier-Free Guidance) scale | +| `--stg-scale` | 1.0 | STG (Spatio-Temporal Guidance) scale. 0.0 disables STG | +| `--stg-blocks` | 29 | Transformer block(s) to perturb for STG | +| `--stg-mode` | stg_av | `stg_av` perturbs both audio and video, `stg_v` video only | + +## 🚀 Training Scripts + +### Basic and Distributed Training + +Use `scripts/train.py` for both single GPU and multi-GPU runs: + +```bash +# Single-GPU training +uv run python scripts/train.py configs/ltx2_av_lora.yaml + +# Multi-GPU (uses your accelerate config) +uv run accelerate launch scripts/train.py configs/ltx2_av_lora.yaml + +# Override number of processes +uv run accelerate launch --num_processes 4 scripts/train.py configs/ltx2_av_lora.yaml +``` + +For detailed usage, see the [Training Guide](training-guide.md). + +## 💡 Tips for Using Utility Scripts + +- **Start with `--help`**: Always check available options for each script +- **Test on small datasets**: Verify workflows with a few files before processing large datasets +- **Use decode verification**: Always decode a few samples to verify preprocessing quality +- **Monitor VRAM usage**: Use `--use-8bit` or quantization flags when running into memory issues +- **Keep backups**: Make copies of important dataset files before running conversion scripts diff --git a/packages/ltx-trainer/pyproject.toml b/packages/ltx-trainer/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..3866a8655a6ef5a583cce46d7bf0bb910d5c0139 --- /dev/null +++ b/packages/ltx-trainer/pyproject.toml @@ -0,0 +1,89 @@ +[project] +name = "ltx-trainer" +version = "0.1.0" +description = "LTX-2 training, democratized." +readme = "README.md" +authors = [ + { name = "Matan Ben-Yosef", email = "mbyosef@lightricks.com" } +] +requires-python = ">=3.12" +dependencies = [ + "ltx-core", + "accelerate>=1.2.1", + "av>=14.2.1", + "bitsandbytes >=0.45.2; sys_platform == 'linux'", + "diffusers>=0.32.1", + "huggingface-hub[hf-xet]>=0.31.4", + "imageio>=2.37.0", + "imageio-ffmpeg>=0.6.0", + "opencv-python>=4.11.0.86", + "optimum-quanto>=0.2.6", + "pandas>=2.2.3", + "peft>=0.14.0", + "pillow-heif>=0.21.0", + "pydantic>=2.10.4", + "rich>=13.9.4", + "safetensors>=0.5.0", + "scenedetect>=0.6.5.2", + "sentencepiece>=0.2.0", + "torch>=2.6.0", + "torchaudio>=2.9.0", + "torchcodec>=0.8.1", + "torchvision>=0.21.0", + "typer>=0.15.1", + "wandb>=0.19.11", +] + +[dependency-groups] +dev = [ + "pre-commit>=4.0.1", + "ruff>=0.8.6", +] + + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + + + +[tool.ruff] +target-version = "py311" +line-length = 120 + +[tool.ruff.lint] +select = [ + "E", # pycodestyle + "F", # pyflakes + "W", # pycodestyle (warnings) + "I", # isort + "N", # pep8-naming + "ANN", # flake8-annotations + "B", # flake8-bugbear + "A", # flake8-builtins + "COM", # flake8-commas + "C4", # flake8-comprehensions + "DTZ", # flake8-datetimez + "EXE", # flake8-executable + "PIE", # flake8-pie + "T20", # flake8-print + "PT", # flake8-pytest + "SIM", # flake8-simplify + "ARG", # flake8-unused-arguments + "PTH", # flake8--use-pathlib + "ERA", # flake8-eradicate + "RUF", # ruff specific rules + "PL", # pylint +] +ignore = [ + "ANN002", # Missing type annotation for *args + "ANN003", # Missing type annotation for **kwargs + "ANN204", # Missing type annotation for special method + "COM812", # Missing trailing comma + "PTH123", # `open()` should be replaced by `Path.open()` + "PLR2004", # Magic value used in comparison, consider replacing with a constant variable +] +[tool.ruff.lint.pylint] +max-args = 10 +[tool.ruff.lint.isort] +known-first-party = ["ltx_trainer", "ltx_core", "ltx_pipelines"] diff --git a/packages/ltx-trainer/scripts/caption_videos.py b/packages/ltx-trainer/scripts/caption_videos.py new file mode 100644 index 0000000000000000000000000000000000000000..293c2ce7a4b9e66e054d6d4535ea30f8e3c3a63a --- /dev/null +++ b/packages/ltx-trainer/scripts/caption_videos.py @@ -0,0 +1,515 @@ +#!/usr/bin/env python3 + +""" +Auto-caption videos with audio using multimodal models. + +This script provides a command-line interface for generating captions for videos +(including audio) using multimodal models. It supports: + +- Qwen2.5-Omni: Local model for audio-visual captioning (default) +- Gemini Flash: Cloud-based API for audio-visual captioning + +The paths to videos in the generated dataset/captions file will be RELATIVE to the +directory where the output file is stored. This makes the dataset more portable and +easier to use in different environments. + +Basic usage: + # Caption a single video (includes audio by default) + caption_videos.py video.mp4 --output captions.json + + # Caption all videos in a directory + caption_videos.py videos_dir/ --output captions.csv + + # Caption with custom instruction + caption_videos.py video.mp4 --instruction "Describe what happens in this video in detail." + +Advanced usage: + # Use Gemini Flash API (requires GEMINI_API_KEY or GOOGLE_API_KEY env var) + caption_videos.py videos_dir/ --captioner-type gemini_flash + + # Disable audio processing (video-only captions) + caption_videos.py videos_dir/ --no-audio + + # Process videos with specific extensions and save as JSON + caption_videos.py videos_dir/ --extensions mp4,mov,avi --output captions.json +""" + +import csv +import json +from enum import Enum +from pathlib import Path + +import torch +import typer +from rich.console import Console +from rich.progress import ( + BarColumn, + MofNCompleteColumn, + Progress, + SpinnerColumn, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, +) +from transformers.utils.logging import disable_progress_bar + +from ltx_trainer.captioning import ( + CaptionerType, + MediaCaptioningModel, + create_captioner, +) + +VIDEO_EXTENSIONS = ["mp4", "avi", "mov", "mkv", "webm"] +IMAGE_EXTENSIONS = ["jpg", "jpeg", "png"] +MEDIA_EXTENSIONS = VIDEO_EXTENSIONS + IMAGE_EXTENSIONS +SAVE_INTERVAL = 5 + +console = Console() +app = typer.Typer( + pretty_exceptions_enable=False, + no_args_is_help=True, + help="Auto-caption videos with audio using multimodal models.", +) + +disable_progress_bar() + + +class OutputFormat(str, Enum): + """Available output formats for captions.""" + + TXT = "txt" # Separate files for captions and video paths, one caption / video path per line + CSV = "csv" # CSV file with video path and caption columns + JSON = "json" # JSON file with video paths as keys and captions as values + JSONL = "jsonl" # JSON Lines file with one JSON object per line + + +def caption_media( + input_path: Path, + output_path: Path, + captioner: MediaCaptioningModel, + extensions: list[str], + recursive: bool, + fps: int, + include_audio: bool, + clean_caption: bool, + output_format: OutputFormat, + override: bool, +) -> None: + """Caption videos and images using the provided captioning model. + + Args: + input_path: Path to input video file or directory + output_path: Path to output caption file + captioner: Media captioning model + extensions: List of media file extensions to include + recursive: Whether to search subdirectories recursively + fps: Frames per second to sample from videos (ignored for images) + include_audio: Whether to include audio in captioning + clean_caption: Whether to clean up captions + output_format: Format to save the captions in + override: Whether to override existing captions + """ + + # Get list of media files to process + media_files = _get_media_files(input_path, extensions, recursive) + + if not media_files: + console.print("[bold yellow]No media files found to process.[/]") + return + + console.print(f"Found [bold]{len(media_files)}[/] media files to process.") + + # Load existing captions and determine which files need processing + base_dir = output_path.parent.resolve() + existing_captions = _load_existing_captions(output_path, output_format) + existing_abs_paths = {str((base_dir / p).resolve()) for p in existing_captions} + + if override: + media_to_process = media_files + else: + media_to_process = [f for f in media_files if str(f.resolve()) not in existing_abs_paths] + if skipped := len(media_files) - len(media_to_process): + console.print(f"[bold yellow]Skipping {skipped} media that already have captions.[/]") + + if not media_to_process: + console.print("[bold yellow]All media already have captions. Use --override to recaption.[/]") + return + + # Process media files + captions = existing_captions.copy() + successfully_captioned = 0 + progress = Progress( + SpinnerColumn(), + TextColumn("{task.description}"), + BarColumn(bar_width=40), + MofNCompleteColumn(), + TimeElapsedColumn(), + TextColumn("•"), + TimeRemainingColumn(), + console=console, + ) + + with progress: + task = progress.add_task("Captioning", total=len(media_to_process)) + + for i, media_file in enumerate(media_to_process): + progress.update(task, description=f"Captioning [bold blue]{media_file.name}[/]") + + try: + # Generate caption for the media + caption = captioner.caption( + path=media_file, + fps=fps, + include_audio=include_audio, + clean_caption=clean_caption, + ) + + # Convert absolute path to relative path (relative to the output file's directory) + rel_path = str(media_file.resolve().relative_to(base_dir)) + # Store the caption with the relative path as key + captions[rel_path] = caption + successfully_captioned += 1 + except Exception as e: + console.print(f"[bold red]Error captioning {media_file}: {e}[/]") + + if i % SAVE_INTERVAL == 0: + _save_captions(captions, output_path, output_format) + + # Advance progress bar + progress.advance(task) + + # Save captions to file + _save_captions(captions, output_path, output_format) + + # Print summary + console.print( + f"[bold green]✓[/] Captioned [bold]{successfully_captioned}/{len(media_to_process)}[/] media successfully.", + ) + + +def _get_media_files( + input_path: Path, + extensions: list[str] = MEDIA_EXTENSIONS, + recursive: bool = False, +) -> list[Path]: + """Get all media files from the input path.""" + input_path = Path(input_path) + # Normalize extensions to lowercase without dots + extensions = [ext.lower().lstrip(".") for ext in extensions] + + if input_path.is_file(): + # If input is a file, check if it has a valid extension + if input_path.suffix.lstrip(".").lower() in extensions: + return [input_path] + else: + typer.echo(f"Warning: {input_path} is not a recognized media file. Skipping.") + return [] + elif input_path.is_dir(): + # If input is a directory, find all media files + media_files = [] + + # Define the glob pattern based on whether we're searching recursively + glob_pattern = "**/*" if recursive else "*" + + # Find all files with the specified extensions + for ext in extensions: + media_files.extend(input_path.glob(f"{glob_pattern}.{ext}")) + + return sorted(media_files) + else: + typer.echo(f"Error: {input_path} does not exist.") + raise typer.Exit(code=1) + + +def _save_captions( + captions: dict[str, str], + output_path: Path, + format_type: OutputFormat, +) -> None: + """Save captions to a file in the specified format. + + Args: + captions: Dictionary mapping media paths to captions + output_path: Path to save the output file + format_type: Format to save the captions in + """ + # Create parent directories if they don't exist + output_path.parent.mkdir(parents=True, exist_ok=True) + + console.print("[bold blue]Saving captions...[/]") + + match format_type: + case OutputFormat.TXT: + # Create two separate files for captions and media paths + captions_file = output_path.with_stem(f"{output_path.stem}_captions") + paths_file = output_path.with_stem(f"{output_path.stem}_paths") + + with captions_file.open("w", encoding="utf-8") as f: + for caption in captions.values(): + f.write(f"{caption}\n") + + with paths_file.open("w", encoding="utf-8") as f: + for media_path in captions: + f.write(f"{media_path}\n") + + console.print(f"[bold green]✓[/] Captions saved to [cyan]{captions_file}[/]") + console.print(f"[bold green]✓[/] Media paths saved to [cyan]{paths_file}[/]") + + case OutputFormat.CSV: + with output_path.open("w", encoding="utf-8", newline="") as f: + writer = csv.writer(f) + writer.writerow(["caption", "media_path"]) + for media_path, caption in captions.items(): + writer.writerow([caption, media_path]) + + console.print(f"[bold green]✓[/] Captions saved to [cyan]{output_path}[/]") + + case OutputFormat.JSON: + # Format as list of dictionaries with caption and media_path keys + json_data = [{"caption": caption, "media_path": media_path} for media_path, caption in captions.items()] + + with output_path.open("w", encoding="utf-8") as f: + json.dump(json_data, f, indent=2, ensure_ascii=False) + + console.print(f"[bold green]✓[/] Captions saved to [cyan]{output_path}[/]") + + case OutputFormat.JSONL: + with output_path.open("w", encoding="utf-8") as f: + for media_path, caption in captions.items(): + f.write(json.dumps({"caption": caption, "media_path": media_path}, ensure_ascii=False) + "\n") + + console.print(f"[bold green]✓[/] Captions saved to [cyan]{output_path}[/]") + + case _: + raise ValueError(f"Unsupported output format: {format_type}") + + +def _load_existing_captions( # noqa: PLR0912 + output_path: Path, + format_type: OutputFormat, +) -> dict[str, str]: + """Load existing captions from a file. + + Args: + output_path: Path to the captions file + format_type: Format of the captions file + + Returns: + Dictionary mapping media paths to captions, or empty dict if file doesn't exist + """ + if not output_path.exists(): + return {} + + console.print(f"[bold blue]Loading existing captions from [cyan]{output_path}[/]...[/]") + + existing_captions = {} + + try: + match format_type: + case OutputFormat.TXT: + # For TXT format, we have two separate files + captions_file = output_path.with_stem(f"{output_path.stem}_captions") + paths_file = output_path.with_stem(f"{output_path.stem}_paths") + + if captions_file.exists() and paths_file.exists(): + captions = captions_file.read_text(encoding="utf-8").splitlines() + paths = paths_file.read_text(encoding="utf-8").splitlines() + + if len(captions) == len(paths): + existing_captions = dict(zip(paths, captions, strict=False)) + + case OutputFormat.CSV: + with output_path.open("r", encoding="utf-8", newline="") as f: + reader = csv.reader(f) + # Skip header + next(reader, None) + for row in reader: + if len(row) >= 2: + caption, media_path = row[0], row[1] + existing_captions[media_path] = caption + + case OutputFormat.JSON: + with output_path.open("r", encoding="utf-8") as f: + json_data = json.load(f) + for item in json_data: + if "caption" in item and "media_path" in item: + existing_captions[item["media_path"]] = item["caption"] + + case OutputFormat.JSONL: + with output_path.open("r", encoding="utf-8") as f: + for line in f: + item = json.loads(line) + if "caption" in item and "media_path" in item: + existing_captions[item["media_path"]] = item["caption"] + + case _: + raise ValueError(f"Unsupported output format: {format_type}") + + console.print(f"[bold green]✓[/] Loaded [bold]{len(existing_captions)}[/] existing captions") + return existing_captions + + except Exception as e: + console.print(f"[bold yellow]Warning: Could not load existing captions: {e}[/]") + return {} + + +@app.command() +def main( # noqa: PLR0913 + input_path: Path = typer.Argument( # noqa: B008 + ..., + help="Path to input video/image file or directory containing media files", + exists=True, + ), + output: Path | None = typer.Option( # noqa: B008 + None, + "--output", + "-o", + help="Path to output file for captions. Format determined by file extension.", + ), + captioner_type: CaptionerType = typer.Option( # noqa: B008 + CaptionerType.QWEN_OMNI, + "--captioner-type", + "-c", + help="Type of captioner to use. Valid values: 'qwen_omni' (local), 'gemini_flash' (API)", + case_sensitive=False, + ), + device: str | None = typer.Option( + None, + "--device", + "-d", + help="Device to use for inference (e.g., 'cuda', 'cuda:0', 'cpu'). Only for local models.", + ), + use_8bit: bool = typer.Option( + False, + "--use-8bit", + help="Whether to use 8-bit precision for the captioning model (reduces memory usage)", + ), + instruction: str | None = typer.Option( + None, + "--instruction", + "-i", + help="Custom instruction for the captioning model. If not provided, uses an appropriate default.", + ), + extensions: str = typer.Option( + ",".join(MEDIA_EXTENSIONS), + "--extensions", + "-e", + help="Comma-separated list of media file extensions to process", + ), + recursive: bool = typer.Option( + False, + "--recursive", + "-r", + help="Search for media files in subdirectories recursively", + ), + fps: int = typer.Option( + 3, + "--fps", + "-f", + help="Frames per second to sample from videos (ignored for images)", + ), + include_audio: bool = typer.Option( + True, + "--audio/--no-audio", + help="Whether to include audio in captioning (for videos with audio tracks)", + ), + clean_caption: bool = typer.Option( + True, + "--clean-caption/--raw-caption", + help="Whether to clean up captions by removing common VLM patterns", + ), + override: bool = typer.Option( + False, + "--override", + help="Whether to override existing captions for media", + ), + api_key: str | None = typer.Option( + None, + "--api-key", + envvar=["GOOGLE_API_KEY", "GEMINI_API_KEY"], + help="API key for Gemini Flash (can also use GOOGLE_API_KEY or GEMINI_API_KEY env var)", + ), +) -> None: + """Auto-caption videos with audio using multimodal models. + + This script supports audio-visual captioning using: + - Qwen2.5-Omni: Local model (default) - processes both video and audio + - Gemini Flash: Cloud API - requires GOOGLE_API_KEY environment variable + + The paths in the output file will be relative to the output file's directory. + + Examples: + # Caption videos with audio using Qwen2.5-Omni (default) + caption_videos.py videos_dir/ -o captions.json + + # Caption using Gemini Flash API + caption_videos.py videos_dir/ -o captions.json -c gemini_flash + + # Caption without audio (video-only) + caption_videos.py videos_dir/ -o captions.json --no-audio + + # Caption with custom instruction + caption_videos.py video.mp4 -o captions.json -i "Describe this video in detail" + + """ + + # Determine device for local models + device_str = device or ("cuda" if torch.cuda.is_available() else "cpu") + + # Parse extensions + ext_list = [ext.strip() for ext in extensions.split(",")] + + # Determine output path and format + if output is None: + output_format = OutputFormat.JSON + if input_path.is_file(): # noqa: SIM108 + # Default to a JSON file with the same name as the input media + output = input_path.with_suffix(".dataset.json") + else: + # Default to a JSON file in the input directory + output = input_path / "dataset.json" + else: + # Determine format from file extension + output_format = OutputFormat(Path(output).suffix.lstrip(".").lower()) + + # Ensure output path is absolute + output = Path(output).resolve() + console.print(f"Output will be saved to [bold blue]{output}[/]") + + # Initialize captioning model + with console.status("Loading captioning model...", spinner="dots"): + if captioner_type == CaptionerType.QWEN_OMNI: + captioner = create_captioner( + captioner_type=captioner_type, + device=device_str, + use_8bit=use_8bit, + instruction=instruction, + ) + elif captioner_type == CaptionerType.GEMINI_FLASH: + captioner = create_captioner( + captioner_type=captioner_type, + api_key=api_key, + instruction=instruction, + ) + else: + raise ValueError(f"Unsupported captioner type: {captioner_type}") + + console.print(f"[bold green]✓[/] {captioner_type.value} captioning model loaded successfully") + + # Caption media files + caption_media( + input_path=input_path, + output_path=output, + captioner=captioner, + extensions=ext_list, + recursive=recursive, + fps=fps, + include_audio=include_audio, + clean_caption=clean_caption, + output_format=output_format, + override=override, + ) + + +if __name__ == "__main__": + app() diff --git a/packages/ltx-trainer/scripts/compute_reference.py b/packages/ltx-trainer/scripts/compute_reference.py new file mode 100644 index 0000000000000000000000000000000000000000..7fdec1c8286f1fc8b052eb3d6686bfc24ae739bb --- /dev/null +++ b/packages/ltx-trainer/scripts/compute_reference.py @@ -0,0 +1,298 @@ +""" +Compute reference videos for IC-LoRA training. + +This script provides a command-line interface for generating reference videos to be used for IC-LoRA training. +Note that it reads and writes to the same file (the output of caption_videos.py), +where it adds the "reference_path" field to the JSON. + +Basic usage: + # Compute reference videos for all videos in a directory + compute_reference.py videos_dir/ --output videos_dir/captions.json +""" + +# Standard library imports +import json +from pathlib import Path +from typing import Dict + +# Third-party imports +import cv2 +import torch +import torchvision.transforms.functional as TF # noqa: N812 +import typer +from rich.console import Console +from rich.progress import ( + BarColumn, + MofNCompleteColumn, + Progress, + SpinnerColumn, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, +) +from transformers.utils.logging import disable_progress_bar + +# Local imports +from ltx_trainer.video_utils import read_video, save_video + +# Initialize console and disable progress bars +console = Console() +disable_progress_bar() + + +def compute_reference( + images: torch.Tensor, +) -> torch.Tensor: + """Compute Canny edge detection on a batch of images. + + Args: + images: Batch of images tensor of shape [B, C, H, W] + + Returns: + Binary edge masks tensor of shape [B, H, W] + """ + # Convert to grayscale if needed + if images.shape[1] == 3: + images = TF.rgb_to_grayscale(images) + + # Ensure images are in [0, 1] range + if images.max() > 1.0: + images = images / 255.0 + + # Compute Canny edges + edge_masks = [] + for image in images: + # Convert to numpy for OpenCV + image_np = (image.squeeze().cpu().numpy() * 255).astype("uint8") + + # Apply Canny edge detection + edges = cv2.Canny( + image_np, + threshold1=100, + threshold2=200, + ) + + # Convert back to tensor + edge_mask = torch.from_numpy(edges).float() + edge_masks.append(edge_mask) + + edges = torch.stack(edge_masks) + edges = torch.stack([edges] * 3, dim=1) # Convert to 3-channel + return edges + + +def _get_meta_data( + output_path: Path, +) -> Dict[str, str]: + """Get set of existing reference video paths without loading the actual files. + + Args: + output_path: Path to the reference video paths file + + Returns: + Dictionary mapping media paths to reference video paths + """ + if not output_path.exists(): + return {} + + console.print(f"[bold blue]Reading meta data from [cyan]{output_path}[/]...[/]") + + try: + with output_path.open("r", encoding="utf-8") as f: + json_data = json.load(f) + return json_data + + except Exception as e: + console.print(f"[bold yellow]Warning: Could not check meta data: {e}[/]") + return {} + + +def _save_dataset_json( + reference_paths: Dict[str, str], + output_path: Path, +) -> None: + """Save dataset json with reference video paths. + + Args: + reference_paths: Dictionary mapping media paths to reference video paths + output_path: Path to save the output file + """ + + with output_path.open("r", encoding="utf-8") as f: + json_data = json.load(f) + new_json_data = json_data.copy() + for i, item in enumerate(json_data): + media_path = item["media_path"] + reference_path = reference_paths[media_path] + new_json_data[i]["reference_path"] = reference_path + + with output_path.open("w", encoding="utf-8") as f: + json.dump(new_json_data, f, indent=2, ensure_ascii=False) + + console.print(f"[bold green]✓[/] Reference video paths saved to [cyan]{output_path}[/]") + console.print("[bold yellow]Note:[/] Use these files with ImageOrVideoDataset by setting:") + console.print(" reference_column='[cyan]reference_path[/]'") + console.print(" video_column='[cyan]media_path[/]'") + + +def process_media( + input_path: Path, + output_path: Path, + override: bool, + batch_size: int = 100, +) -> None: + """Process videos and images to compute condition on videos. + + Args: + input_path: Path to input video/image file or directory + output_path: Path to output reference video file + override: Whether to override existing reference video files + """ + if not output_path.exists(): + raise FileNotFoundError( + f"Output file does not exist: {output_path}. This is also the input file for the dataset." + ) + + # Check for existing reference video files + meta_data = _get_meta_data(output_path) + + base_dir = input_path.resolve() + console.print(f"Using [bold blue]{base_dir}[/] as base directory for relative paths") + + # Filter media files + media_to_process = [] + skipped_media = [] + + def media_path_to_reference_path(media_file: Path) -> Path: + return media_file.parent / (media_file.stem + "_reference" + media_file.suffix) + + media_files = [base_dir / Path(sample["media_path"]) for sample in meta_data] + for media_file in media_files: + reference_path = media_path_to_reference_path(media_file) + media_to_process.append(media_file) + + console.print(f"Processing [bold]{len(media_to_process)}[/] media.") + + # Initialize progress tracking + progress = Progress( + SpinnerColumn(), + TextColumn("{task.description}"), + BarColumn(bar_width=40), + MofNCompleteColumn(), + TimeElapsedColumn(), + TextColumn("•"), + TimeRemainingColumn(), + console=console, + ) + + # Process media files + media_paths = [item["media_path"] for item in meta_data] + reference_paths = {rel_path: str(media_path_to_reference_path(Path(rel_path))) for rel_path in media_paths} + + with progress: + task = progress.add_task("Computing condition on videos", total=len(media_to_process)) + + for media_file in media_to_process: + progress.update(task, description=f"Processing [bold blue]{media_file.name}[/]") + + rel_path = str(media_file.resolve().relative_to(base_dir)) + reference_path = media_path_to_reference_path(media_file) + reference_paths[rel_path] = str(reference_path.relative_to(base_dir)) + + if not reference_path.resolve().exists() or override: + try: + video, fps = read_video(media_file) + + # Process frames in batches + condition_frames = [] + + for i in range(0, len(video), batch_size): + batch = video[i : i + batch_size] + condition_batch = compute_reference(batch) + condition_frames.append(condition_batch) + + # Concatenate all edge frames + all_condition = torch.cat(condition_frames, dim=0) + + # Save the edge video + save_video(all_condition, reference_path.resolve(), fps=fps) + + except Exception as e: + console.print(f"[bold red]Error processing [bold blue]{media_file}[/]: {e}[/]") + reference_paths.pop(rel_path) + else: + skipped_media.append(media_file) + + progress.advance(task) + + # Save results + _save_dataset_json(reference_paths, output_path) + + # Print summary + total_to_process = len(media_files) - len(skipped_media) + console.print( + f"[bold green]✓[/] Processed [bold]{total_to_process}/{len(media_files)}[/] media successfully.", + ) + + +app = typer.Typer( + pretty_exceptions_enable=False, + no_args_is_help=True, + help="Compute reference videos for IC-LoRA training.", +) + + +@app.command() +def main( + input_path: Path = typer.Argument( # noqa: B008 + ..., + help="Path to input video/image file or directory containing media files", + exists=True, + ), + output: Path | None = typer.Option( # noqa: B008 + None, + "--output", + "-o", + help="Path to json output file for reference video paths. " + "This is also the input file for the dataset, the output of compute_captions.py.", + ), + override: bool = typer.Option( + False, + "--override", + help="Whether to override existing reference video files", + ), + batch_size: int = typer.Option( + 100, + "--batch-size", + help="Batch size for processing videos", + ), +) -> None: + """Compute reference videos for IC-LoRA training. + + This script generates reference videos (e.g., Canny edge maps) for given videos. + The paths in the output file will be relative to the output file's directory. + + Examples: + # Process all videos in a directory + compute_reference.py videos_dir/ -o videos_dir/captions.json + """ + + # Ensure output path is absolute + output = Path(output).resolve() + console.print(f"Output will be saved to [bold blue]{output}[/]") + + # Verify output path exists + if not output.exists(): + raise FileNotFoundError(f"Output file does not exist: {output}. This is also the input file for the dataset.") + + # Process media files + process_media( + input_path=input_path, + output_path=output, + override=override, + batch_size=batch_size, + ) + + +if __name__ == "__main__": + app() diff --git a/packages/ltx-trainer/scripts/decode_latents.py b/packages/ltx-trainer/scripts/decode_latents.py new file mode 100644 index 0000000000000000000000000000000000000000..81da232d88f9fc53aa38d0df09937990ff858416 --- /dev/null +++ b/packages/ltx-trainer/scripts/decode_latents.py @@ -0,0 +1,347 @@ +#!/usr/bin/env python3 + +""" +Decode precomputed video latents back into videos using the VAE. + +This script loads latent files saved during preprocessing and decodes them +back into video clips using the same VAE model. + +Basic usage: + python scripts/decode_latents.py /path/to/latents/dir /path/to/output \ + --model-source /path/to/ltx2.safetensors +""" + +from pathlib import Path + +import torch +import torchaudio +import torchvision.utils +import typer +from rich.console import Console +from rich.progress import ( + BarColumn, + MofNCompleteColumn, + Progress, + SpinnerColumn, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, +) +from transformers.utils.logging import disable_progress_bar + +from ltx_trainer import logger +from ltx_trainer.model_loader import load_audio_vae_decoder, load_video_vae_decoder, load_vocoder +from ltx_trainer.video_utils import save_video + +disable_progress_bar() +console = Console() +app = typer.Typer( + pretty_exceptions_enable=False, + no_args_is_help=True, + help="Decode precomputed video latents back into videos using the VAE.", +) + + +class LatentsDecoder: + def __init__( + self, + model_path: str, + device: str = "cuda", + vae_tiling: bool = False, + with_audio: bool = False, + ): + """Initialize the decoder with model configuration. + + Args: + model_path: Path to LTX-2 checkpoint (.safetensors) + device: Device to use for computation + vae_tiling: Whether to enable VAE tiling for larger video resolutions + with_audio: Whether to load audio VAE for audio decoding + """ + self.device = torch.device(device) + self.model_path = model_path + self.vae = None + self.audio_vae = None + self.vocoder = None + self._load_model(model_path, vae_tiling, with_audio) + + def _load_model(self, model_path: str, vae_tiling: bool, with_audio: bool = False) -> None: + """Initialize and load the VAE model(s).""" + with console.status(f"[bold]Loading video VAE decoder from {model_path}...", spinner="dots"): + self.vae = load_video_vae_decoder(model_path, device=self.device, dtype=torch.bfloat16) + + if vae_tiling: + self.vae.enable_tiling() + + if with_audio: + with console.status(f"[bold]Loading audio VAE decoder from {model_path}...", spinner="dots"): + self.audio_vae = load_audio_vae_decoder(model_path, device=self.device, dtype=torch.bfloat16) + + with console.status(f"[bold]Loading vocoder from {model_path}...", spinner="dots"): + self.vocoder = load_vocoder(model_path, device=self.device) + + @torch.inference_mode() + def decode(self, latents_dir: Path, output_dir: Path, seed: int | None = None) -> None: + """Decode all latent files in the directory recursively. + + Args: + latents_dir: Directory containing latent files (.pt) + output_dir: Directory to save decoded videos + seed: Optional random seed for noise generation + """ + # Find all .pt files recursively + latent_files = list(latents_dir.rglob("*.pt")) + + if not latent_files: + logger.warning(f"No .pt files found in {latents_dir}") + return + + logger.info(f"Found {len(latent_files):,} latent files to decode") + + # Process files with progress bar + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + MofNCompleteColumn(), + TimeElapsedColumn(), + TimeRemainingColumn(), + console=console, + ) as progress: + task = progress.add_task("Decoding latents", total=len(latent_files)) + + for latent_file in latent_files: + # Calculate relative path to maintain directory structure + rel_path = latent_file.relative_to(latents_dir) + output_subdir = output_dir / rel_path.parent + output_subdir.mkdir(parents=True, exist_ok=True) + + try: + self._process_file(latent_file, output_subdir, seed) + except Exception as e: + logger.error(f"Error processing {latent_file}: {e}") + continue + + progress.advance(task) + + logger.info(f"Decoding complete! Videos saved to {output_dir}") + + def _process_file(self, latent_file: Path, output_dir: Path, seed: int | None) -> None: + """Process a single latent file.""" + # Load the latent data + data = torch.load(latent_file, map_location=self.device, weights_only=False) + + # Get latents - handle both old patchified [seq_len, C] and new [C, F, H, W] formats + latents = data["latents"] + num_frames = data["num_frames"] + height = data["height"] + width = data["width"] + + # Check if latents need reshaping (old patchified format) + if latents.dim() == 2: + # Old format: [seq_len, C] -> reshape to [C, F, H, W] + _seq_len, channels = latents.shape + latents = latents.reshape(num_frames, height, width, channels) + latents = latents.permute(3, 0, 1, 2) # [F, H, W, C] -> [C, F, H, W] + + # Add batch dimension: [C, F, H, W] -> [1, C, F, H, W] + latents = latents.unsqueeze(0).to(device=self.device, dtype=torch.bfloat16) + + # Create generator only if seed is provided + generator = None + if seed is not None: + generator = torch.Generator(device=self.device) + generator.manual_seed(seed) + + # Decode the video (VAE decoder uses forward/call, not decode method) + video = self.vae(latents) # [B, C, F, H, W] + + # Convert to [F, C, H, W] format and normalize to [0, 1] + video = video[0] # Remove batch dimension -> [C, F, H, W] + video = video.permute(1, 0, 2, 3) # [C, F, H, W] -> [F, C, H, W] + video = (video + 1) / 2 # Denormalize from [-1, 1] to [0, 1] + video = video.clamp(0, 1) + + # Determine output format and save + is_image = video.shape[0] == 1 + if is_image: + # Save as PNG for single frame + output_path = output_dir / f"{latent_file.stem}.png" + torchvision.utils.save_image( + video[0], # [C, H, W] in [0, 1] + str(output_path), + ) + else: + # Save as MP4 for video using PyAV-based save_video + output_path = output_dir / f"{latent_file.stem}.mp4" + fps = data.get("fps", 24) # Use stored FPS or default to 24 + save_video( + video_tensor=video, # [F, C, H, W] in [0, 1] + output_path=output_path, + fps=fps, + ) + + @torch.inference_mode() + def decode_audio(self, latents_dir: Path, output_dir: Path) -> None: + """Decode all audio latent files in the directory recursively. + + Args: + latents_dir: Directory containing audio latent files (.pt) + output_dir: Directory to save decoded audio files + """ + # Check if audio VAE is loaded + if self.audio_vae is None or self.vocoder is None: + logger.warning("Audio VAE or vocoder not loaded. Skipping audio decoding.") + return + + # Find all .pt files recursively + latent_files = list(latents_dir.rglob("*.pt")) + + if not latent_files: + logger.warning(f"No .pt files found in {latents_dir}") + return + + logger.info(f"Found {len(latent_files):,} audio latent files to decode") + + # Process files with progress bar + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + MofNCompleteColumn(), + TimeElapsedColumn(), + TimeRemainingColumn(), + console=console, + ) as progress: + task = progress.add_task("Decoding audio latents", total=len(latent_files)) + + for latent_file in latent_files: + # Calculate relative path to maintain directory structure + rel_path = latent_file.relative_to(latents_dir) + output_subdir = output_dir / rel_path.parent + output_subdir.mkdir(parents=True, exist_ok=True) + + try: + self._process_audio_file(latent_file, output_subdir) + except Exception as e: + logger.error(f"Error processing audio {latent_file}: {e}") + continue + + progress.advance(task) + + logger.info(f"Audio decoding complete! Audio files saved to {output_dir}") + + def _process_audio_file(self, latent_file: Path, output_dir: Path) -> None: + """Process a single audio latent file.""" + # Load the latent data + data = torch.load(latent_file, map_location=self.device, weights_only=False) + + latents = data["latents"].to(device=self.device, dtype=torch.float32) + num_time_steps = data["num_time_steps"] + freq_bins = data["frequency_bins"] + + # Handle both old patchified [seq_len, C] and new [C, T, F] formats + if latents.dim() == 2: + # Old format: [seq_len, channels] where seq_len = time * freq + # Reshape to [C, T, F] + latents = latents.reshape(num_time_steps, freq_bins, -1) # [T, F, C] + latents = latents.permute(2, 0, 1) # [T, F, C] -> [C, T, F] + + # Add batch dimension: [C, T, F] -> [1, C, T, F] + latents = latents.unsqueeze(0) + + # Set correct dtype for audio VAE + latents = latents.to(dtype=torch.bfloat16) + + # Decode audio using audio VAE decoder (produces mel spectrogram) + mel_spectrogram = self.audio_vae(latents) + + # Convert mel spectrogram to waveform using vocoder + waveform = self.vocoder(mel_spectrogram) + + # Save as WAV + output_path = output_dir / f"{latent_file.stem}.wav" + sample_rate = self.vocoder.output_sample_rate + torchaudio.save(str(output_path), waveform[0].cpu(), sample_rate) + + +@app.command() +def main( + latents_dir: str = typer.Argument( + ..., + help="Directory containing the precomputed latent files (searched recursively)", + ), + output_dir: str = typer.Argument( + ..., + help="Directory to save the decoded videos (maintains same folder hierarchy as input)", + ), + model_path: str = typer.Option( + ..., + help="Path to LTX-2 checkpoint (.safetensors file)", + ), + device: str = typer.Option( + default="cuda", + help="Device to use for computation", + ), + vae_tiling: bool = typer.Option( + default=False, + help="Enable VAE tiling for larger video resolutions", + ), + seed: int | None = typer.Option( + default=None, + help="Random seed for noise generation during decoding", + ), + with_audio: bool = typer.Option( + default=False, + help="Also decode audio latents (requires audio_latents directory)", + ), + audio_latents_dir: str | None = typer.Option( + default=None, + help="Directory containing audio latent files (defaults to 'audio_latents' sibling of latents_dir)", + ), +) -> None: + """Decode precomputed video latents back into videos using the VAE. + + This script recursively searches for .pt latent files in the input directory + and decodes them to videos, maintaining the same folder hierarchy in the output. + + Examples: + # Basic usage + python scripts/decode_latents.py /path/to/latents /path/to/videos \\ + --model-path /path/to/ltx2.safetensors + + # With VAE tiling for large videos + python scripts/decode_latents.py /path/to/latents /path/to/videos \\ + --model-path /path/to/ltx2.safetensors --vae-tiling + + # With audio decoding + python scripts/decode_latents.py /path/to/latents /path/to/videos \\ + --model-path /path/to/ltx2.safetensors --with-audio + """ + latents_path = Path(latents_dir) + output_path = Path(output_dir) + + if not latents_path.exists() or not latents_path.is_dir(): + raise typer.BadParameter(f"Latents directory does not exist: {latents_path}") + + decoder = LatentsDecoder( + model_path=model_path, + device=device, + vae_tiling=vae_tiling, + with_audio=with_audio, + ) + decoder.decode(latents_path, output_path, seed=seed) + + # Decode audio if requested + if with_audio: + audio_path = Path(audio_latents_dir) if audio_latents_dir else latents_path.parent / "audio_latents" + + if audio_path.exists(): + audio_output_path = output_path.parent / "decoded_audio" + decoder.decode_audio(audio_path, audio_output_path) + else: + logger.warning(f"Audio latents directory not found: {audio_path}") + + +if __name__ == "__main__": + app() diff --git a/packages/ltx-trainer/scripts/inference.py b/packages/ltx-trainer/scripts/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..c90881cca397d3390115582a5d96a8c9867a3840 --- /dev/null +++ b/packages/ltx-trainer/scripts/inference.py @@ -0,0 +1,453 @@ +#!/usr/bin/env python3 +# ruff: noqa: T201 +""" +CLI script for running LTX video/audio generation inference. + +Usage: + # Text-to-Video + Audio (default behavior) + python scripts/inference.py --checkpoint path/to/model.safetensors \ + --text-encoder-path path/to/gemma \ + --prompt "A cat playing with a ball" --output output.mp4 + + # Video only (skip audio) + python scripts/inference.py --checkpoint path/to/model.safetensors \ + --text-encoder-path path/to/gemma \ + --prompt "A cat playing with a ball" --skip-audio --output output.mp4 + + # Image-to-Video + python scripts/inference.py --checkpoint path/to/model.safetensors \ + --text-encoder-path path/to/gemma \ + --prompt "A cat walking" --condition-image first_frame.png --output output.mp4 + + # Video-to-Video (IC-LoRA style) + python scripts/inference.py --checkpoint path/to/model.safetensors \ + --text-encoder-path path/to/gemma \ + --prompt "A cat turning into a dog" --reference-video input.mp4 --output output.mp4 + + # With LoRA weights + python scripts/inference.py --checkpoint path/to/model.safetensors \ + --text-encoder-path path/to/gemma \ + --lora-path path/to/lora.safetensors \ + --prompt "A cat in my custom style" --output output.mp4 +""" + +import argparse +import re +from pathlib import Path + +import torch +import torchaudio +from peft import LoraConfig, get_peft_model, set_peft_model_state_dict +from safetensors.torch import load_file +from torchvision import transforms + +from ltx_trainer.model_loader import load_model +from ltx_trainer.progress import StandaloneSamplingProgress +from ltx_trainer.utils import open_image_as_srgb +from ltx_trainer.validation_sampler import GenerationConfig, ValidationSampler +from ltx_trainer.video_utils import read_video, save_video + + +def load_image(image_path: str) -> torch.Tensor: + """Load an image and convert to tensor [C, H, W] in [0, 1].""" + image = open_image_as_srgb(image_path) + transform = transforms.ToTensor() + return transform(image) + + +def extract_lora_target_modules(state_dict: dict[str, torch.Tensor]) -> list[str]: + """Extract target module names from LoRA checkpoint keys. + + LoRA keys follow the pattern (after removing "diffusion_model." prefix): + - transformer_blocks.0.attn1.to_k.lora_A.weight + - transformer_blocks.0.ff.net.0.proj.lora_B.weight + + This extracts the full module path like "transformer_blocks.0.attn1.to_k". + Using full paths is more robust than partial patterns. + """ + target_modules = set() + # Pattern to extract everything before .lora_A or .lora_B + pattern = re.compile(r"(.+)\.lora_[AB]\.") + + for key in state_dict: + match = pattern.match(key) + if match: + module_path = match.group(1) + target_modules.add(module_path) + + return sorted(target_modules) + + +def load_lora_weights(transformer: torch.nn.Module, lora_path: str | Path) -> torch.nn.Module: + """Load LoRA weights into the transformer model. + + The LoRA rank and target modules are automatically detected from the checkpoint. + Alpha is set equal to rank (standard practice for inference). + + Args: + transformer: The base transformer model + lora_path: Path to the LoRA weights (.safetensors) + + Returns: + The transformer model with LoRA weights applied + """ + print(f"Loading LoRA weights from {lora_path}...") + + # Load the LoRA state dict + state_dict = load_file(str(lora_path)) + + # Remove "diffusion_model." prefix (ComfyUI-compatible format) + state_dict = {k.replace("diffusion_model.", "", 1): v for k, v in state_dict.items()} + + # Extract target modules from the checkpoint + target_modules = extract_lora_target_modules(state_dict) + if not target_modules: + raise ValueError(f"Could not extract target modules from LoRA checkpoint: {lora_path}") + print(f" Detected {len(target_modules)} target modules") + + # Auto-detect rank from the first lora_A weight shape + lora_rank = None + for key, value in state_dict.items(): + if "lora_A" in key and value.ndim == 2: + lora_rank = value.shape[0] + break + if lora_rank is None: + raise ValueError("Could not auto-detect LoRA rank from weights") + print(f" LoRA rank: {lora_rank}") + + # Create LoRA config and wrap the model + # Alpha = rank is standard for inference (maintains the trained scale) + lora_config = LoraConfig( + r=lora_rank, + lora_alpha=lora_rank, + target_modules=target_modules, + lora_dropout=0.0, + init_lora_weights=True, + ) + + # Wrap the transformer with PEFT to add LoRA layers + transformer = get_peft_model(transformer, lora_config) + + # Load the LoRA weights + base_model = transformer.get_base_model() + set_peft_model_state_dict(base_model, state_dict) + + print("✓ LoRA weights loaded successfully") + return transformer + + +def main() -> None: # noqa: PLR0912, PLR0915 + parser = argparse.ArgumentParser( + description="LTX Video/Audio Generation", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Model arguments + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to model checkpoint (.safetensors)", + ) + parser.add_argument( + "--text-encoder-path", + type=str, + required=True, + help="Path to Gemma text encoder directory", + ) + + # LoRA arguments + parser.add_argument( + "--lora-path", + type=str, + default=None, + help="Path to LoRA weights (.safetensors)", + ) + + # Generation arguments + parser.add_argument( + "--prompt", + type=str, + required=True, + help="Text prompt for generation", + ) + parser.add_argument( + "--negative-prompt", + type=str, + default="", + help="Negative prompt", + ) + parser.add_argument( + "--height", + type=int, + default=544, + help="Video height (must be divisible by 32)", + ) + parser.add_argument( + "--width", + type=int, + default=960, + help="Video width (must be divisible by 32)", + ) + parser.add_argument( + "--num-frames", + type=int, + default=97, + help="Number of video frames (must be k*8 + 1)", + ) + parser.add_argument( + "--frame-rate", + type=float, + default=25.0, + help="Video frame rate", + ) + parser.add_argument( + "--num-inference-steps", + type=int, + default=30, + help="Number of denoising steps", + ) + parser.add_argument( + "--guidance-scale", + type=float, + default=3.0, + help="Classifier-free guidance scale (CFG)", + ) + parser.add_argument( + "--stg-scale", + type=float, + default=1.0, + help="STG (Spatio-Temporal Guidance) scale. 0.0 disables STG. Default: 1.0", + ) + parser.add_argument( + "--stg-blocks", + type=int, + nargs="*", + default=[29], + help="Which transformer blocks to perturb for STG. Default: 29 (single block).", + ) + parser.add_argument( + "--stg-mode", + type=str, + default="stg_av", + choices=["stg_av", "stg_v"], + help="STG mode: 'stg_av' perturbs both audio and video, 'stg_v' perturbs video only", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for reproducibility", + ) + + # Conditioning arguments + parser.add_argument( + "--condition-image", + type=str, + default=None, + help="Path to conditioning image for image-to-video generation", + ) + parser.add_argument( + "--reference-video", + type=str, + default=None, + help="Path to reference video for video-to-video generation (IC-LoRA style)", + ) + parser.add_argument( + "--include-reference-in-output", + action="store_true", + help="Include reference video side-by-side with generated output (only for V2V)", + ) + + # Audio arguments + parser.add_argument( + "--skip-audio", + action="store_true", + help="Skip audio generation (by default, audio is generated alongside video)", + ) + + # Output arguments + parser.add_argument( + "--output", + type=str, + required=True, + help="Output video path (.mp4)", + ) + parser.add_argument( + "--audio-output", + type=str, + default=None, + help="Output audio path (.wav, optional - if not provided, audio will be embedded in video)", + ) + + # Device arguments + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to run on (cuda/cpu)", + ) + + args = parser.parse_args() + + # Validate conditioning arguments + if args.include_reference_in_output and args.reference_video is None: + parser.error("--include-reference-in-output requires --reference-video") + + # Validate arguments + generate_audio = not args.skip_audio + + print("=" * 80) + print("LTX Video/Audio Generation") + print("=" * 80) + + # Determine if we need VAE encoder (for image or video conditioning) + need_vae_encoder = args.condition_image is not None or args.reference_video is not None + + components = load_model( + checkpoint_path=args.checkpoint, + device="cpu", # Load to CPU first, sampler will move to device as needed + dtype=torch.bfloat16, + with_video_vae_encoder=need_vae_encoder, + with_video_vae_decoder=True, + with_audio_vae_decoder=generate_audio, + with_vocoder=generate_audio, + with_text_encoder=True, + text_encoder_path=args.text_encoder_path, + ) + + # Apply LoRA weights if provided + transformer = components.transformer + if args.lora_path is not None: + transformer = load_lora_weights(transformer, args.lora_path) + + # Load conditioning image if provided + condition_image = None + if args.condition_image: + print(f"Loading conditioning image from {args.condition_image}...") + condition_image = load_image(args.condition_image) + + # Load reference video if provided + reference_video = None + if args.reference_video: + print(f"Loading reference video from {args.reference_video}...") + reference_video, ref_fps = read_video(args.reference_video, max_frames=args.num_frames) + print(f" Loaded {reference_video.shape[0]} frames @ {ref_fps:.1f} fps") + + # Determine generation mode + if args.reference_video is not None and args.condition_image is not None: + mode = "Video-to-Video + Image Conditioning (V2V+I2V)" + elif args.reference_video is not None: + mode = "Video-to-Video (V2V)" + elif args.condition_image is not None: + mode = "Image-to-Video (I2V)" + else: + mode = "Text-to-Video (T2V)" + + print("\n" + "=" * 80) + print("Generation Parameters") + print("=" * 80) + print(f"Mode: {mode}") + print(f"Prompt: {args.prompt}") + if args.negative_prompt: + print(f"Negative prompt: {args.negative_prompt}") + print(f"Resolution: {args.width}x{args.height}") + print(f"Frames: {args.num_frames} @ {args.frame_rate} fps") + print(f"Inference steps: {args.num_inference_steps}") + print(f"CFG scale: {args.guidance_scale}") + if args.stg_scale > 0: + blocks_str = args.stg_blocks if args.stg_blocks else "all" + print(f"STG scale: {args.stg_scale} (mode: {args.stg_mode}, blocks: {blocks_str})") + else: + print("STG: disabled") + print(f"Seed: {args.seed}") + if args.lora_path: + print(f"LoRA: {args.lora_path}") + if condition_image is not None: + print(f"Conditioning: Image ({args.condition_image})") + if reference_video is not None: + print(f"Reference: Video ({args.reference_video})") + if args.include_reference_in_output: + print(" → Will include reference side-by-side in output") + if generate_audio: + video_duration = args.num_frames / args.frame_rate + print(f"Audio: Enabled (duration will match video: {video_duration:.2f}s)") + print("=" * 80) + + print(f"\nGenerating {'video + audio' if generate_audio else 'video'}...") + + # Create generation config + gen_config = GenerationConfig( + prompt=args.prompt, + negative_prompt=args.negative_prompt, + height=args.height, + width=args.width, + num_frames=args.num_frames, + frame_rate=args.frame_rate, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + seed=args.seed, + condition_image=condition_image, + reference_video=reference_video, + generate_audio=generate_audio, + include_reference_in_output=args.include_reference_in_output, + stg_scale=args.stg_scale, + stg_blocks=args.stg_blocks, + stg_mode=args.stg_mode, + ) + + # Generate with progress bar + with StandaloneSamplingProgress(num_steps=args.num_inference_steps) as progress: + # Create sampler with progress context + sampler = ValidationSampler( + transformer=transformer, + vae_decoder=components.video_vae_decoder, + vae_encoder=components.video_vae_encoder, + text_encoder=components.text_encoder, + audio_decoder=components.audio_vae_decoder if generate_audio else None, + vocoder=components.vocoder if generate_audio else None, + sampling_context=progress, + ) + video, audio = sampler.generate( + config=gen_config, + device=args.device, + ) + + # Save video + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Get audio sample rate from vocoder if audio was generated + audio_sample_rate = None + if audio is not None and components.vocoder is not None: + audio_sample_rate = components.vocoder.output_sample_rate + + save_video( + video_tensor=video, + output_path=output_path, + fps=args.frame_rate, + audio=audio, + audio_sample_rate=audio_sample_rate, + ) + print(f"✓ Video saved to {args.output}") + + # Save separate audio file if requested + if audio is not None and args.audio_output is not None: + audio_output_path = Path(args.audio_output) + audio_output_path.parent.mkdir(parents=True, exist_ok=True) + + torchaudio.save( + str(audio_output_path), + audio.cpu(), + sample_rate=audio_sample_rate, + ) + duration = audio.shape[1] / audio_sample_rate + print(f"✓ Audio saved: {duration:.2f}s at {audio_sample_rate}Hz") + + print("\n" + "=" * 80) + print("Generation complete!") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/packages/ltx-trainer/scripts/process_captions.py b/packages/ltx-trainer/scripts/process_captions.py new file mode 100644 index 0000000000000000000000000000000000000000..f70906dc25c1e1034cd7d8bd2e4488e5aa7a4594 --- /dev/null +++ b/packages/ltx-trainer/scripts/process_captions.py @@ -0,0 +1,425 @@ +#!/usr/bin/env python + +""" +Compute text embeddings for video generation training. + +This module provides functionality for processing text captions, including: +- Loading captions from various file formats (CSV, JSON, JSONL) +- Cleaning and preprocessing text (removing LLM prefixes, adding ID tokens) +- CaptionsDataset for caption-only preprocessing workflows + +Can be used as a standalone script: + python scripts/process_captions.py dataset.json --output-dir /path/to/output \ + --model-source /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma +""" + +import json +import os +from pathlib import Path +from typing import Any + +import pandas as pd +import torch +import typer +from rich.console import Console +from rich.progress import ( + BarColumn, + MofNCompleteColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, +) +from torch.utils.data import DataLoader, Dataset +from transformers.utils.logging import disable_progress_bar + +from ltx_trainer import logger +from ltx_trainer.model_loader import load_text_encoder + +# Disable tokenizers parallelism to avoid warnings +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +disable_progress_bar() + +# Common phrases that LLMs often add to captions that we might want to remove +COMMON_BEGINNING_PHRASES: tuple[str, ...] = ( + "This video", + "The video", + "This clip", + "The clip", + "The animation", + "This image", + "The image", + "This picture", + "The picture", +) + +COMMON_CONTINUATION_WORDS: tuple[str, ...] = ( + "shows", + "depicts", + "features", + "captures", + "highlights", + "introduces", + "presents", +) + +COMMON_LLM_START_PHRASES: tuple[str, ...] = ( + "In the video,", + "In this video,", + "In this video clip,", + "In the clip,", + "Caption:", + *( + f"{beginning} {continuation}" + for beginning in COMMON_BEGINNING_PHRASES + for continuation in COMMON_CONTINUATION_WORDS + ), +) + +app = typer.Typer( + pretty_exceptions_enable=False, + no_args_is_help=True, + help="Process text captions and save embeddings for video generation training.", +) + + +class CaptionsDataset(Dataset): + """ + Dataset for processing text captions only. + + This dataset is designed for caption preprocessing workflows where you only need + to process text without loading videos. Useful for: + - Precomputing text embeddings + - Caption cleaning and preprocessing + - Text-only preprocessing pipelines + """ + + def __init__( + self, + dataset_file: str | Path, + caption_column: str, + media_column: str = "media_path", + lora_trigger: str | None = None, + remove_llm_prefixes: bool = False, + ) -> None: + """ + Initialize the captions dataset. + + Args: + dataset_file: Path to CSV/JSON/JSONL metadata file + caption_column: Column name for captions in the metadata file + media_column: Column name for media paths (used for output naming) + lora_trigger: Optional trigger word to prepend to each caption + remove_llm_prefixes: Whether to remove common LLM-generated prefixes + """ + super().__init__() + + self.dataset_file = Path(dataset_file) + self.caption_column = caption_column + self.media_column = media_column + self.lora_trigger = f"{lora_trigger.strip()} " if lora_trigger else "" + + # Load captions with their corresponding output embedding paths + self.caption_data = self._load_caption_data() + + # Convert to lists for indexing + self.output_paths = list(self.caption_data.keys()) + self.prompts = list(self.caption_data.values()) + + # Clean LLM start phrases if requested + if remove_llm_prefixes: + self._clean_llm_prefixes() + + def __len__(self) -> int: + return len(self.prompts) + + def __getitem__(self, index: int) -> dict[str, Any]: + """Get a single caption with optional trigger word prepended and output path.""" + prompt = self.lora_trigger + self.prompts[index] + return { + "prompt": prompt, + "output_path": self.output_paths[index], + "index": index, + } + + def _load_caption_data(self) -> dict[str, str]: + """Load captions and compute their output embedding paths.""" + if self.dataset_file.suffix == ".csv": + return self._load_caption_data_from_csv() + elif self.dataset_file.suffix == ".json": + return self._load_caption_data_from_json() + elif self.dataset_file.suffix == ".jsonl": + return self._load_caption_data_from_jsonl() + else: + raise ValueError("Expected `dataset_file` to be a path to a CSV, JSON, or JSONL file.") + + def _load_caption_data_from_csv(self) -> dict[str, str]: + """Load captions from a CSV file and compute output embedding paths.""" + df = pd.read_csv(self.dataset_file) + + if self.caption_column not in df.columns: + raise ValueError(f"Column '{self.caption_column}' not found in CSV file") + if self.media_column not in df.columns: + raise ValueError(f"Column '{self.media_column}' not found in CSV file") + + caption_data = {} + for _, row in df.iterrows(): + media_path = Path(row[self.media_column].strip()) + # Convert media path to embedding output path (same structure, .pt extension) + output_path = str(media_path.with_suffix(".pt")) + caption_data[output_path] = row[self.caption_column] + + return caption_data + + def _load_caption_data_from_json(self) -> dict[str, str]: + """Load captions from a JSON file and compute output embedding paths.""" + with open(self.dataset_file, "r", encoding="utf-8") as file: + data = json.load(file) + + if not isinstance(data, list): + raise ValueError("JSON file must contain a list of objects") + + caption_data = {} + for entry in data: + if self.caption_column not in entry: + raise ValueError(f"Key '{self.caption_column}' not found in JSON entry: {entry}") + if self.media_column not in entry: + raise ValueError(f"Key '{self.media_column}' not found in JSON entry: {entry}") + + media_path = Path(entry[self.media_column].strip()) + # Convert media path to embedding output path (same structure, .pt extension) + output_path = str(media_path.with_suffix(".pt")) + caption_data[output_path] = entry[self.caption_column] + + return caption_data + + def _load_caption_data_from_jsonl(self) -> dict[str, str]: + """Load captions from a JSONL file and compute output embedding paths.""" + caption_data = {} + with open(self.dataset_file, "r", encoding="utf-8") as file: + for line in file: + entry = json.loads(line) + if self.caption_column not in entry: + raise ValueError(f"Key '{self.caption_column}' not found in JSONL entry: {entry}") + if self.media_column not in entry: + raise ValueError(f"Key '{self.media_column}' not found in JSONL entry: {entry}") + + media_path = Path(entry[self.media_column].strip()) + # Convert media path to embedding output path (same structure, .pt extension) + output_path = str(media_path.with_suffix(".pt")) + caption_data[output_path] = entry[self.caption_column] + + return caption_data + + def _clean_llm_prefixes(self) -> None: + """Remove common LLM-generated prefixes from captions.""" + for i in range(len(self.prompts)): + self.prompts[i] = self.prompts[i].strip() + for phrase in COMMON_LLM_START_PHRASES: + if self.prompts[i].startswith(phrase): + self.prompts[i] = self.prompts[i].removeprefix(phrase).strip() + break + + +def compute_captions_embeddings( + dataset_file: str | Path, + output_dir: str, + model_path: str, + text_encoder_path: str, + caption_column: str = "caption", + media_column: str = "media_path", + lora_trigger: str | None = None, + remove_llm_prefixes: bool = False, + batch_size: int = 8, + device: str = "cuda", +) -> None: + """ + Process captions and save text embeddings. + + Args: + dataset_file: Path to metadata file (CSV/JSON/JSONL) containing captions and media paths + output_dir: Directory to save embeddings + model_path: Path to LTX-2 checkpoint (.safetensors) + text_encoder_path: Path to Gemma text encoder directory + caption_column: Column name containing captions in the metadata file + media_column: Column name containing media paths (used for output naming) + lora_trigger: Optional trigger word to prepend to each caption + remove_llm_prefixes: Whether to remove common LLM-generated prefixes + batch_size: Batch size for processing + device: Device to use for computation + """ + + console = Console() + + # Create dataset + dataset = CaptionsDataset( + dataset_file=dataset_file, + caption_column=caption_column, + media_column=media_column, + lora_trigger=lora_trigger, + remove_llm_prefixes=remove_llm_prefixes, + ) + logger.info(f"Loaded {len(dataset):,} captions") + + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + # Load text encoder + with console.status("[bold]Loading Gemma text encoder...", spinner="dots"): + text_encoder = load_text_encoder(model_path, text_encoder_path, device=device, dtype=torch.bfloat16) + + logger.info("Text encoder loaded successfully") + + # TODO(batch-tokenization): The current Gemma tokenizer doesn't support batched tokenization. + if batch_size > 1: + logger.warning( + "Batch size greater than 1 is not currently supported with the Gemma tokenizer. " + "Overriding batch_size to 1. This will be fixed in a future update." + ) + batch_size = 1 + + # Create dataloader + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2) + + # Process batches + total_batches = len(dataloader) + logger.info(f"Processing captions in {total_batches:,} batches...") + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + MofNCompleteColumn(), + TimeElapsedColumn(), + TimeRemainingColumn(), + console=console, + ) as progress: + task = progress.add_task("Processing captions", total=len(dataloader)) + for batch in dataloader: + # Encode prompts using _preprocess_text (returns embeddings before connector) + # This is what we want to save - the connector is applied during training + with torch.inference_mode(): + # TODO(batch-tokenization): When tokenizer supports batching, encode all prompts at once: + # prompt_embeds, prompt_attention_mask = text_encoder._preprocess_text(batch["prompt"]) # noqa: ERA001 + # For now, process one at a time: + for i in range(len(batch["prompt"])): + prompt_embeds, prompt_attention_mask = text_encoder._preprocess_text( + batch["prompt"][i], padding_side="left" + ) + + output_rel_path = Path(batch["output_path"][i]) + + # Create output directory maintaining structure + output_dir_path = output_path / output_rel_path.parent + output_dir_path.mkdir(parents=True, exist_ok=True) + + embedding_data = { + "prompt_embeds": prompt_embeds[0].cpu().contiguous(), + "prompt_attention_mask": prompt_attention_mask[0].cpu().contiguous(), + } + + output_file = output_path / output_rel_path + torch.save(embedding_data, output_file) + + progress.advance(task) + + logger.info(f"Processed {len(dataset):,} captions. Embeddings saved to {output_path}") + + +@app.command() +def main( + dataset_file: str = typer.Argument( + ..., + help="Path to metadata file (CSV/JSON/JSONL) containing captions and media paths", + ), + output_dir: str = typer.Option( + ..., + help="Output directory to save text embeddings", + ), + model_path: str = typer.Option( + ..., + help="Path to LTX-2 checkpoint (.safetensors file)", + ), + text_encoder_path: str = typer.Option( + ..., + help="Path to Gemma text encoder directory", + ), + caption_column: str = typer.Option( + default="caption", + help="Column name containing captions in the dataset JSON/JSONL/CSV file", + ), + media_column: str = typer.Option( + default="media_path", + help="Column name in the dataset JSON/JSONL/CSV file containing media paths " + "(used for output file naming and folder structure)", + ), + batch_size: int = typer.Option( + default=8, + help="Batch size for processing", + ), + device: str = typer.Option( + default="cuda", + help="Device to use for computation", + ), + lora_trigger: str | None = typer.Option( + default=None, + help="Optional trigger word to prepend to each caption (activates the LoRA during inference)", + ), + remove_llm_prefixes: bool = typer.Option( + default=False, + help="Remove common LLM-generated prefixes from captions", + ), +) -> None: + """Process text captions and save embeddings for video generation training. + + This script processes captions from metadata files and saves text embeddings + that can be used for training video generation models. The output embeddings + will maintain the same folder structure and naming as the corresponding media files. + + Note: This script is designed for LTX-2 models which use the Gemma text encoder. + + Examples: + # Process captions with LTX-2 model + python scripts/process_captions.py dataset.json --output-dir ./embeddings \\ + --model-path /path/to/ltx2_checkpoint.safetensors \\ + --text-encoder-path /path/to/gemma + + # Add a trigger word for LoRA training + python scripts/process_captions.py dataset.json --output-dir ./embeddings \\ + --model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma \\ + --lora-trigger "mytoken" + + # Remove LLM-generated prefixes from captions + python scripts/process_captions.py dataset.json --output-dir ./embeddings \\ + --model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma \\ + --remove-llm-prefixes + """ + + # Validate dataset file + if not Path(dataset_file).is_file(): + raise typer.BadParameter(f"Dataset file not found: {dataset_file}") + + if lora_trigger: + logger.info(f'LoRA trigger word "{lora_trigger}" will be prepended to all captions') + + # Process embeddings + compute_captions_embeddings( + dataset_file=dataset_file, + output_dir=output_dir, + model_path=model_path, + text_encoder_path=text_encoder_path, + caption_column=caption_column, + media_column=media_column, + lora_trigger=lora_trigger, + remove_llm_prefixes=remove_llm_prefixes, + batch_size=batch_size, + device=device, + ) + + +if __name__ == "__main__": + app() diff --git a/packages/ltx-trainer/scripts/process_dataset.py b/packages/ltx-trainer/scripts/process_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2da0edaaf34edf8cdcbe7be5b60e1c94cfde253e --- /dev/null +++ b/packages/ltx-trainer/scripts/process_dataset.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 + +""" +Preprocess a video dataset by computing video clips latents and text captions embeddings. + +This script provides a command-line interface for preprocessing video datasets by computing +latent representations of video clips and text embeddings of their captions. The preprocessed +data can be used to accelerate training of video generation models and to save GPU memory. + +Basic usage: + python scripts/process_dataset.py /path/to/dataset.json --resolution-buckets 768x768x49 \ + --model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma + +The dataset must be a CSV, JSON, or JSONL file with columns for captions and video paths. +""" + +from pathlib import Path + +import typer +from decode_latents import LatentsDecoder +from process_captions import compute_captions_embeddings +from process_videos import compute_latents, parse_resolution_buckets +from rich.console import Console + +from ltx_trainer import logger + +console = Console() +app = typer.Typer( + pretty_exceptions_enable=False, + no_args_is_help=True, + help="Preprocess a video dataset by computing video clips latents and text captions embeddings. " + "The dataset must be a CSV, JSON, or JSONL file with columns for captions and video paths.", +) + + +def preprocess_dataset( # noqa: PLR0913 + dataset_file: str, + caption_column: str, + video_column: str, + resolution_buckets: list[tuple[int, int, int]], + batch_size: int, + output_dir: str | None, + lora_trigger: str | None, + vae_tiling: bool, + decode: bool, + model_path: str, + text_encoder_path: str, + device: str, + remove_llm_prefixes: bool = False, + reference_column: str | None = None, + with_audio: bool = False, +) -> None: + """Run the preprocessing pipeline with the given arguments.""" + # Validate dataset file + _validate_dataset_file(dataset_file) + + # Set up output directories + output_base = Path(output_dir) if output_dir else Path(dataset_file).parent / ".precomputed" + conditions_dir = output_base / "conditions" + latents_dir = output_base / "latents" + + if lora_trigger: + logger.info(f'LoRA trigger word "{lora_trigger}" will be prepended to all captions') + + # Process captions using the dedicated function + compute_captions_embeddings( + dataset_file=dataset_file, + output_dir=str(conditions_dir), + model_path=model_path, + text_encoder_path=text_encoder_path, + caption_column=caption_column, + media_column=video_column, + lora_trigger=lora_trigger, + remove_llm_prefixes=remove_llm_prefixes, + batch_size=batch_size, + device=device, + ) + + # Process videos using the dedicated function + audio_latents_dir = None + if with_audio: + logger.info("Audio preprocessing enabled - will extract and encode audio from videos") + audio_latents_dir = output_base / "audio_latents" + + compute_latents( + dataset_file=dataset_file, + video_column=video_column, + resolution_buckets=resolution_buckets, + output_dir=str(latents_dir), + model_path=model_path, + batch_size=batch_size, + device=device, + vae_tiling=vae_tiling, + with_audio=with_audio, + audio_output_dir=str(audio_latents_dir) if audio_latents_dir else None, + ) + + # Process reference videos if reference_column is provided + if reference_column: + logger.info("Processing reference videos for IC-LoRA training...") + reference_latents_dir = output_base / "reference_latents" + + compute_latents( + dataset_file=dataset_file, + main_media_column=video_column, + video_column=reference_column, + resolution_buckets=resolution_buckets, + output_dir=str(reference_latents_dir), + model_path=model_path, + batch_size=batch_size, + device=device, + vae_tiling=vae_tiling, + ) + + # Handle decoding if requested (for verification) + if decode: + logger.info("Decoding latents for verification...") + + decoder = LatentsDecoder( + model_path=model_path, + device=device, + vae_tiling=vae_tiling, + with_audio=with_audio, + ) + decoder.decode(latents_dir, output_base / "decoded_videos") + + # Also decode reference videos if they exist + if reference_column: + reference_latents_dir = output_base / "reference_latents" + if reference_latents_dir.exists(): + logger.info("Decoding reference videos...") + decoder.decode(reference_latents_dir, output_base / "decoded_reference_videos") + + # Decode audio latents if they exist + if with_audio and audio_latents_dir and audio_latents_dir.exists(): + logger.info("Decoding audio latents...") + decoder.decode_audio(audio_latents_dir, output_base / "decoded_audio") + + # Print summary + logger.info(f"Dataset preprocessing complete! Results saved to {output_base}") + if reference_column: + logger.info("Reference videos processed and saved to reference_latents/ directory for IC-LoRA training") + if with_audio: + logger.info("Audio latents saved to audio_latents/ directory for audio-video training") + + +def _validate_dataset_file(dataset_path: str) -> None: + """Validate that the dataset file exists and has the correct format.""" + dataset_file = Path(dataset_path) + + if not dataset_file.exists(): + raise FileNotFoundError(f"Dataset file does not exist: {dataset_file}") + + if not dataset_file.is_file(): + raise ValueError(f"Dataset path must be a file, not a directory: {dataset_file}") + + if dataset_file.suffix.lower() not in [".csv", ".json", ".jsonl"]: + raise ValueError(f"Dataset file must be CSV, JSON, or JSONL format: {dataset_file}") + + +@app.command() +def main( # noqa: PLR0913 + dataset_path: str = typer.Argument( + ..., + help="Path to metadata file (CSV/JSON/JSONL) containing captions and video paths", + ), + resolution_buckets: str = typer.Option( + ..., + help='Resolution buckets in format "WxHxF;WxHxF;..." (e.g. "768x768x25;512x512x49")', + ), + model_path: str = typer.Option( + ..., + help="Path to LTX-2 checkpoint (.safetensors file)", + ), + text_encoder_path: str = typer.Option( + ..., + help="Path to Gemma text encoder directory", + ), + caption_column: str = typer.Option( + default="caption", + help="Column name containing captions in the dataset JSON/JSONL/CSV file", + ), + video_column: str = typer.Option( + default="media_path", + help="Column name containing video paths in the dataset JSON/JSONL/CSV file", + ), + batch_size: int = typer.Option( + default=1, + help="Batch size for preprocessing", + ), + device: str = typer.Option( + default="cuda", + help="Device to use for computation", + ), + vae_tiling: bool = typer.Option( + default=False, + help="Enable VAE tiling for larger video resolutions", + ), + output_dir: str | None = typer.Option( + default=None, + help="Output directory (defaults to .precomputed in dataset directory)", + ), + lora_trigger: str | None = typer.Option( + default=None, + help="Optional trigger word to prepend to each caption (activates the LoRA during inference)", + ), + decode: bool = typer.Option( + default=False, + help="Decode and save latents after encoding (videos and audio) for verification", + ), + remove_llm_prefixes: bool = typer.Option( + default=False, + help="Remove LLM prefixes from captions", + ), + reference_column: str | None = typer.Option( + default=None, + help="Column name containing reference video paths (for video-to-video training)", + ), + with_audio: bool = typer.Option( + default=False, + help="Extract and encode audio from video files", + ), +) -> None: + """Preprocess a video dataset by computing and saving latents and text embeddings. + + The dataset must be a CSV, JSON, or JSONL file with columns for captions and video paths. + This script is designed for LTX-2 models which use the Gemma text encoder. + + Examples: + # Process a dataset with LTX-2 model + python scripts/process_dataset.py dataset.json --resolution-buckets 768x768x25 \\ + --model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma + + # Process dataset with custom column names + python scripts/process_dataset.py dataset.json --resolution-buckets 768x768x25 \\ + --model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma \\ + --caption-column "text" --video-column "video_path" + + # Process dataset with reference videos for IC-LoRA training + python scripts/process_dataset.py dataset.json --resolution-buckets 768x768x25 \\ + --model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma \\ + --reference-column "reference_path" + + # Process dataset with audio for audio-video training + python scripts/process_dataset.py dataset.json --resolution-buckets 768x512x97 \\ + --model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma \\ + --with-audio + """ + parsed_resolution_buckets = parse_resolution_buckets(resolution_buckets) + + if len(parsed_resolution_buckets) > 1: + logger.warning( + "Using multiple resolution buckets. " + "When training with multiple resolution buckets, you must use a batch size of 1." + ) + + preprocess_dataset( + dataset_file=dataset_path, + caption_column=caption_column, + video_column=video_column, + resolution_buckets=parsed_resolution_buckets, + batch_size=batch_size, + output_dir=output_dir, + lora_trigger=lora_trigger, + vae_tiling=vae_tiling, + decode=decode, + model_path=model_path, + text_encoder_path=text_encoder_path, + device=device, + remove_llm_prefixes=remove_llm_prefixes, + reference_column=reference_column, + with_audio=with_audio, + ) + + +if __name__ == "__main__": + app() diff --git a/packages/ltx-trainer/scripts/process_videos.py b/packages/ltx-trainer/scripts/process_videos.py new file mode 100644 index 0000000000000000000000000000000000000000..cf6a6da394a702e4ab1c18bf9e897f9b073839b2 --- /dev/null +++ b/packages/ltx-trainer/scripts/process_videos.py @@ -0,0 +1,840 @@ +#!/usr/bin/env python3 + +""" +Compute latent representations for video generation training. + +This module provides functionality for processing video and image files, including: +- Loading videos/images from various file formats (CSV, JSON, JSONL) +- Resizing, cropping, and transforming media +- MediaDataset for video-only preprocessing workflows +- BucketSampler for grouping videos by resolution + +Can be used as a standalone script: + python scripts/process_videos.py dataset.csv --resolution-buckets 768x768x25 \ + --output-dir /path/to/output --model-source /path/to/ltx2.safetensors +""" + +import json +import math +from pathlib import Path +from typing import Any + +import numpy as np +import pandas as pd +import torch +import torchaudio +import typer +from pillow_heif import register_heif_opener +from rich.console import Console +from rich.progress import ( + BarColumn, + MofNCompleteColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, +) +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from torchvision.transforms.functional import crop, resize, to_tensor +from transformers.utils.logging import disable_progress_bar + +from ltx_core.model.audio_vae.ops import AudioProcessor +from ltx_trainer import logger +from ltx_trainer.model_loader import load_audio_vae_encoder, load_video_vae_encoder +from ltx_trainer.utils import open_image_as_srgb +from ltx_trainer.video_utils import get_video_frame_count, read_video + +disable_progress_bar() + +# Register HEIF/HEIC support +register_heif_opener() + +# Constants for validation +VAE_SPATIAL_FACTOR = 32 +VAE_TEMPORAL_FACTOR = 8 + +# Audio constants +AUDIO_LATENT_CHANNELS = 8 +AUDIO_FREQUENCY_BINS = 16 + +app = typer.Typer( + pretty_exceptions_enable=False, + no_args_is_help=True, + help="Process videos/images and save latent representations for video generation training.", +) + + +class MediaDataset(Dataset): + """ + Dataset for processing video and image files. + + This dataset is designed for media preprocessing workflows where you need to: + - Load and preprocess videos/images + - Apply resizing and cropping transformations + - Handle different resolution buckets + - Filter out invalid media files + - Optionally extract audio from video files + """ + + def __init__( + self, + dataset_file: str | Path, + main_media_column: str, + video_column: str, + resolution_buckets: list[tuple[int, int, int]], + reshape_mode: str = "center", + with_audio: bool = False, + ) -> None: + """ + Initialize the media dataset. + + Args: + dataset_file: Path to CSV/JSON/JSONL metadata file + video_column: Column name for video paths in the metadata file + resolution_buckets: List of (frames, height, width) tuples + reshape_mode: How to crop videos ("center", "random") + with_audio: Whether to extract audio from video files + """ + super().__init__() + + self.dataset_file = Path(dataset_file) + self.main_media_column = main_media_column + self.resolution_buckets = resolution_buckets + self.reshape_mode = reshape_mode + self.with_audio = with_audio + + # First load main media paths + self.main_media_paths = self._load_video_paths(main_media_column) + + # Then load reference video paths + self.video_paths = self._load_video_paths(video_column) + + # Filter out videos with insufficient frames + self._filter_valid_videos() + + self.max_target_frames = max(self.resolution_buckets, key=lambda x: x[0])[0] + + # Set up video transforms + self.transforms = transforms.Compose( + [ + transforms.Lambda(lambda x: x.clamp_(0, 1)), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + + def __len__(self) -> int: + return len(self.video_paths) + + def __getitem__(self, index: int) -> dict[str, Any]: + """Get a single video/image with metadata, and optionally audio.""" + if isinstance(index, list): + # Special case for BucketSampler - return cached data + return index + + video_path: Path = self.video_paths[index] + + # Compute relative path of the video + data_root = self.dataset_file.parent + relative_path = str(video_path.relative_to(data_root)) + media_relative_path = str(self.main_media_paths[index].relative_to(data_root)) + + if video_path.suffix.lower() in [".png", ".jpg", ".jpeg"]: + media_tensor = self._preprocess_image(video_path) + fps = 1.0 + audio_data = None # Images don't have audio + else: + media_tensor, fps = self._preprocess_video(video_path) + + # Extract audio if enabled + if self.with_audio: + # Calculate target duration from the processed video frames + # This ensures audio is trimmed to match the exact video duration + # media_tensor is [C, F, H, W] so shape[1] is num_frames + target_duration = media_tensor.shape[1] / fps + audio_data = self._extract_audio(video_path, target_duration) + else: + audio_data = None + + # media_tensor is [C, F, H, W] format for VAE compatibility + _, num_frames, height, width = media_tensor.shape + + result = { + "video": media_tensor, + "relative_path": relative_path, + "main_media_relative_path": media_relative_path, + "video_metadata": { + "num_frames": num_frames, + "height": height, + "width": width, + "fps": fps, + }, + } + + # Add audio data if available + if audio_data is not None: + result["audio"] = audio_data + + return result + + @staticmethod + def _extract_audio(video_path: Path, target_duration: float) -> dict[str, torch.Tensor | int] | None: + """Extract audio track from a video file, trimmed to match video duration.""" + try: + # torchaudio can extract audio from video files directly + # waveform shape: [channels, samples] + waveform, sample_rate = torchaudio.load(str(video_path)) + + # Trim or pad to target duration + target_samples = int(target_duration * sample_rate) + current_samples = waveform.shape[-1] + + if current_samples > target_samples: + # Trim to target duration + waveform = waveform[..., :target_samples] + elif current_samples < target_samples: + # Pad with zeros to target duration + padding = target_samples - current_samples + waveform = torch.nn.functional.pad(waveform, (0, padding)) + logger.warning(f"Padded audio to {target_duration:.2f} seconds for {video_path}") + + return {"waveform": waveform, "sample_rate": sample_rate} + + except Exception as e: + logger.debug(f"Could not extract audio from {video_path}: {e}") + return None + + def _load_video_paths(self, column: str) -> list[Path]: + """Load video paths from the specified data source.""" + if self.dataset_file.suffix == ".csv": + return self._load_video_paths_from_csv(column) + elif self.dataset_file.suffix == ".json": + return self._load_video_paths_from_json(column) + elif self.dataset_file.suffix == ".jsonl": + return self._load_video_paths_from_jsonl(column) + else: + raise ValueError("Expected `dataset_file` to be a path to a CSV, JSON, or JSONL file.") + + def _load_video_paths_from_csv(self, column: str) -> list[Path]: + """Load video paths from a CSV file.""" + df = pd.read_csv(self.dataset_file) + if column not in df.columns: + raise ValueError(f"Column '{column}' not found in CSV file") + + data_root = self.dataset_file.parent + video_paths = [data_root / Path(line.strip()) for line in df[column].tolist()] + + # Validate that all paths exist + invalid_paths = [path for path in video_paths if not path.is_file()] + if invalid_paths: + raise ValueError(f"Found {len(invalid_paths)} invalid video paths. First few: {invalid_paths[:5]}") + + return video_paths + + def _load_video_paths_from_json(self, column: str) -> list[Path]: + """Load video paths from a JSON file.""" + with open(self.dataset_file, "r", encoding="utf-8") as file: + data = json.load(file) + + if not isinstance(data, list): + raise ValueError("JSON file must contain a list of objects") + + data_root = self.dataset_file.parent + video_paths = [] + for entry in data: + if column not in entry: + raise ValueError(f"Key '{column}' not found in JSON entry") + video_paths.append(data_root / Path(entry[column].strip())) + + # Validate that all paths exist + invalid_paths = [path for path in video_paths if not path.is_file()] + if invalid_paths: + raise ValueError(f"Found {len(invalid_paths)} invalid video paths. First few: {invalid_paths[:5]}") + + return video_paths + + def _load_video_paths_from_jsonl(self, column: str) -> list[Path]: + """Load video paths from a JSONL file.""" + data_root = self.dataset_file.parent + video_paths = [] + with open(self.dataset_file, "r", encoding="utf-8") as file: + for line in file: + entry = json.loads(line) + if column not in entry: + raise ValueError(f"Key '{column}' not found in JSONL entry") + video_paths.append(data_root / Path(entry[column].strip())) + + # Validate that all paths exist + invalid_paths = [path for path in video_paths if not path.is_file()] + if invalid_paths: + raise ValueError(f"Found {len(invalid_paths)} invalid video paths. First few: {invalid_paths[:5]}") + + return video_paths + + def _filter_valid_videos(self) -> None: + """Filter out videos with insufficient frames.""" + original_length = len(self.video_paths) + valid_video_paths = [] + valid_main_media_paths = [] + min_frames_required = min(self.resolution_buckets, key=lambda x: x[0])[0] + + for i, video_path in enumerate(self.video_paths): + if video_path.suffix.lower() in [".png", ".jpg", ".jpeg"]: + valid_video_paths.append(video_path) + valid_main_media_paths.append(self.main_media_paths[i]) + continue + + try: + frame_count = get_video_frame_count(video_path) + + if frame_count >= min_frames_required: + valid_video_paths.append(video_path) + valid_main_media_paths.append(self.main_media_paths[i]) + else: + logger.warning( + f"Skipping video at {video_path} - has {frame_count} frames, " + f"which is less than the minimum required frames ({min_frames_required})" + ) + except Exception as e: + logger.warning(f"Failed to read video at {video_path}: {e!s}") + + # Update both path lists to maintain synchronization + self.video_paths = valid_video_paths + self.main_media_paths = valid_main_media_paths + + if len(self.video_paths) < original_length: + logger.warning( + f"Filtered out {original_length - len(self.video_paths)} videos with insufficient frames. " + f"Proceeding with {len(self.video_paths)} valid videos." + ) + + def _preprocess_image(self, path: Path) -> torch.Tensor: + """Preprocess a single image by resizing and applying transforms.""" + image = open_image_as_srgb(path) + image = to_tensor(image) + image = image.unsqueeze(0) # Add frame dimension [1, C, H, W] for bucket selection + + # Find nearest resolution bucket and resize + nearest_bucket = self._get_resolution_bucket_for_item(image) + _, target_height, target_width = nearest_bucket + image_resized = self._resize_and_crop(image, target_height, target_width) + # _resize_and_crop returns [C, H, W] for single-frame input (squeeze removes dim 0) + + # Apply transforms + image = self.transforms(image_resized) # [C, H, W] -> [C, H, W] + + # Add frame dimension in VAE format: [C, H, W] -> [C, 1, H, W] + image = image.unsqueeze(1) + return image + + def _preprocess_video(self, path: Path) -> tuple[torch.Tensor, float]: + """Preprocess a video by loading, resizing, and applying transforms. + + Returns: + Tuple of (video tensor in [C, F, H, W] format, fps) + """ + # Load video frames up to max_target_frames + video, fps = read_video(path, max_frames=self.max_target_frames) + + nearest_bucket = self._get_resolution_bucket_for_item(video) + target_num_frames, target_height, target_width = nearest_bucket + frames_resized = self._resize_and_crop(video, target_height, target_width) + + # Trim video to target number of frames + frames_resized = frames_resized[:target_num_frames] + + # Apply transforms to each frame and stack + video = torch.stack([self.transforms(frame) for frame in frames_resized], dim=0) + + # Permute [F,C,H,W] -> [C,F,H,W] for VAE compatibility + # After DataLoader batching, this becomes [B,C,F,H,W] which VAE expects + video = video.permute(1, 0, 2, 3).contiguous() + + return video, fps + + def _get_resolution_bucket_for_item(self, media_tensor: torch.Tensor) -> tuple[int, int, int]: + """Get the nearest resolution bucket for the given media tensor.""" + num_frames, _, height, width = media_tensor.shape + + def distance(bucket: tuple[int, int, int]) -> tuple: + bucket_num_frames, bucket_height, bucket_width = bucket + # Lexicographic key: + # 1) minimize aspect-ratio diff (in log-scale, for invariance to shorter/longer ARs) + # 2) prefer buckets with more frames (by using negative) + # 3) prefer buckets with larger spatial area (by using negative) + return ( + abs(math.log(width / height) - math.log(bucket_width / bucket_height)), + -bucket_num_frames, + -(bucket_height * bucket_width), + ) + + # Keep only buckets with <= available frames + relevant_buckets = [b for b in self.resolution_buckets if b[0] <= num_frames] + if not relevant_buckets: + raise ValueError(f"No resolution buckets have <= {num_frames} frames. Available: {self.resolution_buckets}") + + # Find the bucket with the minimal distance (according to the function above) to the media item's shape. + nearest_bucket = min(relevant_buckets, key=distance) + + return nearest_bucket + + def _resize_and_crop(self, media_tensor: torch.Tensor, target_height: int, target_width: int) -> torch.Tensor: + """Resize and crop tensor to target size.""" + # Get current dimensions + current_height, current_width = media_tensor.shape[2], media_tensor.shape[3] + + # Calculate aspect ratios to determine which dimension to resize first + current_aspect = current_width / current_height + target_aspect = target_width / target_height + + # Resize while maintaining aspect ratio - scale to make the smaller dimension fit + if current_aspect > target_aspect: + # Current is wider than target, so scale by height + new_width = int(current_width * target_height / current_height) + media_tensor = resize( + media_tensor, + size=[target_height, new_width], # type: ignore + interpolation=InterpolationMode.BICUBIC, + ) + else: + # Current is taller than target, so scale by width + new_height = int(current_height * target_width / current_width) + media_tensor = resize( + media_tensor, + size=[new_height, target_width], + interpolation=InterpolationMode.BICUBIC, + ) + + # Update dimensions after resize + current_height, current_width = media_tensor.shape[2], media_tensor.shape[3] + media_tensor = media_tensor.squeeze(0) + + # Calculate how much we need to crop from each dimension + delta_h = current_height - target_height + delta_w = current_width - target_width + + # Determine crop position based on reshape mode + if self.reshape_mode == "random": + # Random crop position + top = np.random.randint(0, delta_h + 1) + left = np.random.randint(0, delta_w + 1) + elif self.reshape_mode == "center": + # Center crop + top, left = delta_h // 2, delta_w // 2 + else: + raise ValueError(f"Unsupported reshape mode: {self.reshape_mode}") + + # Perform the final crop to exact target dimensions + media_tensor = crop(media_tensor, top=top, left=left, height=target_height, width=target_width) + return media_tensor + + +def compute_latents( # noqa: PLR0913, PLR0915 + dataset_file: str | Path, + video_column: str, + resolution_buckets: list[tuple[int, int, int]], + output_dir: str, + model_path: str, + main_media_column: str | None = None, + reshape_mode: str = "center", + batch_size: int = 1, + device: str = "cuda", + vae_tiling: bool = False, + with_audio: bool = False, + audio_output_dir: str | None = None, +) -> None: + """ + Process videos and save latent representations. + + Args: + dataset_file: Path to metadata file (CSV/JSON/JSONL) containing video paths + video_column: Column name for video paths in the metadata file + resolution_buckets: List of (frames, height, width) tuples + output_dir: Directory to save video latents + model_path: Path to LTX-2 checkpoint (.safetensors) + reshape_mode: How to crop videos ("center", "random") + main_media_column: Column name for main media paths (if different from video_column) + batch_size: Batch size for processing + device: Device to use for computation + vae_tiling: Whether to enable VAE tiling + with_audio: Whether to extract and encode audio from videos + audio_output_dir: Directory to save audio latents (required if with_audio=True) + """ + # Validate audio parameters + if with_audio and audio_output_dir is None: + raise ValueError("audio_output_dir must be provided when with_audio=True") + + console = Console() + torch_device = torch.device(device) + + # Create dataset + dataset = MediaDataset( + dataset_file=dataset_file, + main_media_column=main_media_column or video_column, + video_column=video_column, + resolution_buckets=resolution_buckets, + reshape_mode=reshape_mode, + with_audio=with_audio, + ) + logger.info(f"Loaded {len(dataset)} valid media files") + + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + # Set up audio output directory if needed + audio_output_path = None + if with_audio: + audio_output_path = Path(audio_output_dir) + audio_output_path.mkdir(parents=True, exist_ok=True) + + # Load video VAE encoder + with console.status(f"[bold]Loading video VAE encoder from [cyan]{model_path}[/]...", spinner="dots"): + vae = load_video_vae_encoder(model_path, device=torch_device, dtype=torch.bfloat16) + + if vae_tiling: + vae.enable_tiling() + + # Load audio VAE encoder and audio processor if needed + audio_vae_encoder = None + audio_processor = None + if with_audio: + with console.status(f"[bold]Loading audio VAE encoder from [cyan]{model_path}[/]...", spinner="dots"): + audio_vae_encoder = load_audio_vae_encoder( + checkpoint_path=model_path, + device=torch_device, + dtype=torch.float32, # Audio VAE needs float32 for quality. TODO: re-test with bfloat16. + ) + # Create audio processor for waveform-to-spectrogram conversion + audio_processor = AudioProcessor( + sample_rate=audio_vae_encoder.sample_rate, + mel_bins=audio_vae_encoder.mel_bins, + mel_hop_length=audio_vae_encoder.mel_hop_length, + n_fft=audio_vae_encoder.n_fft, + ).to(torch_device) + + # Create dataloader + # Note: batch_size=1 required when with_audio because audio extraction can fail for some videos, + # and the default collate function can't handle mixed None/dict values across a batch. + if with_audio and batch_size > 1: + logger.warning("Audio processing requires batch_size=1. Overriding batch_size to 1.") + batch_size = 1 + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4) + + # Track audio statistics + audio_success_count = 0 + audio_skip_count = 0 + + # Process batches + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + MofNCompleteColumn(), + TimeElapsedColumn(), + TimeRemainingColumn(), + console=console, + ) as progress: + task = progress.add_task("Processing videos", total=len(dataloader)) + + for batch in dataloader: + # Get video tensor - shape is [B, F, C, H, W] from DataLoader + video = batch["video"] + + # Encode video + with torch.inference_mode(): + video_latent_data = encode_video(vae=vae, video=video) + + # Save latents for each item in batch + for i in range(len(batch["relative_path"])): + output_rel_path = Path(batch["main_media_relative_path"][i]).with_suffix(".pt") + output_file = output_path / output_rel_path + + # Create output directory maintaining structure + output_file.parent.mkdir(parents=True, exist_ok=True) + + # Index into batch to get this item's latents + latent_data = { + "latents": video_latent_data["latents"][i].cpu().contiguous(), # [C, F', H', W'] + "num_frames": video_latent_data["num_frames"], + "height": video_latent_data["height"], + "width": video_latent_data["width"], + "fps": batch["video_metadata"]["fps"][i].item(), + } + + torch.save(latent_data, output_file) + + # Process audio if enabled (audio is already extracted by the dataset) + if with_audio: + audio_batch = batch.get("audio") + if audio_batch is not None: + # Extract the i-th item from batched audio data + # DataLoader collates [channels, samples] -> [batch, channels, samples] + audio_data = { + "waveform": audio_batch["waveform"][i], + "sample_rate": audio_batch["sample_rate"][i].item(), + } + + # Encode audio + with torch.inference_mode(): + audio_latents = encode_audio(audio_vae_encoder, audio_processor, audio_data) + + # Save audio latents + audio_output_file = audio_output_path / output_rel_path + audio_output_file.parent.mkdir(parents=True, exist_ok=True) + + audio_save_data = { + "latents": audio_latents["latents"].cpu().contiguous(), + "num_time_steps": audio_latents["num_time_steps"], + "frequency_bins": audio_latents["frequency_bins"], + "duration": audio_latents["duration"], + } + + torch.save(audio_save_data, audio_output_file) + audio_success_count += 1 + else: + # Video has no audio track + audio_skip_count += 1 + + progress.advance(task) + + # Log summary + logger.info(f"Processed {len(dataset)} videos. Latents saved to {output_path}") + if with_audio: + logger.info( + f"Audio processing: {audio_success_count} videos with audio, " + f"{audio_skip_count} videos without audio (skipped)" + ) + + +def encode_video( + vae: torch.nn.Module, + video: torch.Tensor, + dtype: torch.dtype | None = None, +) -> dict[str, torch.Tensor | int]: + """Encode video into non-patchified latent representation. + + Args: + vae: Video VAE encoder model + video: Input tensor of shape [B, C, F, H, W] (batch, channels, frames, height, width) + This is the format expected by the VAE encoder. + dtype: Target dtype for output latents + + Returns: + Dict containing non-patchified latents and shape information: + { + "latents": Tensor[B, C, F', H', W'], # Non-patchified format with batch dim + "num_frames": int, # Latent frame count + "height": int, # Latent height + "width": int, # Latent width + } + """ + device = next(vae.parameters()).device + vae_dtype = next(vae.parameters()).dtype + + # Add batch dimension if needed + if video.ndim == 4: + video = video.unsqueeze(0) # [C, F, H, W] -> [B, C, F, H, W] + + video = video.to(device=device, dtype=vae_dtype) + + # Encode video - VAE expects [B, C, F, H, W], returns [B, C, F', H', W'] + latents = vae(video) + + if dtype is not None: + latents = latents.to(dtype=dtype) + + _, _, num_frames, height, width = latents.shape + + return { + "latents": latents, # [B, C, F', H', W'] + "num_frames": num_frames, + "height": height, + "width": width, + } + + +def encode_audio( + audio_vae_encoder: torch.nn.Module, + audio_processor: torch.nn.Module, + audio_data: dict[str, torch.Tensor | int], +) -> dict[str, torch.Tensor | int | float]: + """Encode audio waveform into latent representation. + + Args: + audio_vae_encoder: Audio VAE encoder model from ltx-core + audio_processor: AudioProcessor for waveform-to-spectrogram conversion + audio_data: Dict with {"waveform": Tensor[channels, samples], "sample_rate": int} + + Returns: + Dict containing audio latents and shape information: + { + "latents": Tensor[C, T, F], # Non-patchified format + "num_time_steps": int, + "frequency_bins": int, + "duration": float, + } + """ + device = next(audio_vae_encoder.parameters()).device + dtype = next(audio_vae_encoder.parameters()).dtype + + waveform = audio_data["waveform"].to(device=device, dtype=dtype) + sample_rate = audio_data["sample_rate"] + + # Add batch dimension if needed: [channels, samples] -> [batch, channels, samples] + if waveform.dim() == 2: + waveform = waveform.unsqueeze(0) + + # Calculate duration + duration = waveform.shape[-1] / sample_rate + + # Convert waveform to mel spectrogram using AudioProcessor + mel_spectrogram = audio_processor.waveform_to_mel(waveform, waveform_sample_rate=sample_rate) + mel_spectrogram = mel_spectrogram.to(dtype=dtype) + + # Encode mel spectrogram to latents + latents = audio_vae_encoder(mel_spectrogram) + + # latents shape: [batch, channels, time, freq] = [1, 8, T, 16] + _, _channels, time_steps, freq_bins = latents.shape + + return { + "latents": latents.squeeze(0), # [C, T, F] - remove batch dim + "num_time_steps": time_steps, + "frequency_bins": freq_bins, + "duration": duration, + } + + +def parse_resolution_buckets(resolution_buckets_str: str) -> list[tuple[int, int, int]]: + """Parse resolution buckets from string format to list of tuples (frames, height, width)""" + resolution_buckets = [] + for bucket_str in resolution_buckets_str.split(";"): + w, h, f = map(int, bucket_str.split("x")) + + if w % VAE_SPATIAL_FACTOR != 0 or h % VAE_SPATIAL_FACTOR != 0: + raise typer.BadParameter( + f"Width and height must be multiples of {VAE_SPATIAL_FACTOR}, got {w}x{h}", + param_hint="resolution-buckets", + ) + + if f % VAE_TEMPORAL_FACTOR != 1: + raise typer.BadParameter( + f"Number of frames must be a multiple of {VAE_TEMPORAL_FACTOR} plus 1, got {f}", + param_hint="resolution-buckets", + ) + + resolution_buckets.append((f, h, w)) + return resolution_buckets + + +@app.command() +def main( # noqa: PLR0913 + dataset_file: str = typer.Argument( + ..., + help="Path to metadata file (CSV/JSON/JSONL) containing video paths", + ), + resolution_buckets: str = typer.Option( + ..., + help='Resolution buckets in format "WxHxF;WxHxF;..." (e.g. "768x768x25;512x512x49")', + ), + output_dir: str = typer.Option( + ..., + help="Output directory to save video latents", + ), + model_path: str = typer.Option( + ..., + help="Path to LTX-2 checkpoint (.safetensors file)", + ), + video_column: str = typer.Option( + default="media_path", + help="Column name in the dataset JSON/JSONL/CSV file containing video paths", + ), + batch_size: int = typer.Option( + default=1, + help="Batch size for processing", + ), + device: str = typer.Option( + default="cuda", + help="Device to use for computation", + ), + vae_tiling: bool = typer.Option( + default=False, + help="Enable VAE tiling for larger video resolutions", + ), + reshape_mode: str = typer.Option( + default="center", + help="How to crop videos: 'center' or 'random'", + ), + with_audio: bool = typer.Option( + default=False, + help="Extract and encode audio from video files", + ), + audio_output_dir: str | None = typer.Option( + default=None, + help="Output directory for audio latents (required if --with-audio is set)", + ), +) -> None: + """Process videos/images and save latent representations for video generation training. + + This script processes videos and images from metadata files and saves latent representations + that can be used for training video generation models. The output latents will maintain + the same folder structure and naming as the corresponding media files. + + Examples: + # Process videos from a CSV file + python scripts/process_videos.py dataset.csv --resolution-buckets 768x768x25 \\ + --output-dir ./latents --model-path /path/to/ltx2.safetensors + + # Process videos from a JSON file with custom video column + python scripts/process_videos.py dataset.json --resolution-buckets 768x768x25 \\ + --output-dir ./latents --model-path /path/to/ltx2.safetensors --video-column "video_path" + + # Enable VAE tiling to save GPU VRAM + python scripts/process_videos.py dataset.csv --resolution-buckets 1024x1024x25 \\ + --output-dir ./latents --model-path /path/to/ltx2.safetensors --vae-tiling + + # Process videos with audio + python scripts/process_videos.py dataset.csv --resolution-buckets 768x768x25 \\ + --output-dir ./latents --model-path /path/to/ltx2.safetensors \\ + --with-audio --audio-output-dir ./audio_latents + """ + + # Validate dataset file exists + if not Path(dataset_file).is_file(): + raise typer.BadParameter(f"Dataset file not found: {dataset_file}") + + # Validate audio parameters + if with_audio and audio_output_dir is None: + raise typer.BadParameter("--audio-output-dir is required when --with-audio is set") + + # Parse resolution buckets + parsed_resolution_buckets = parse_resolution_buckets(resolution_buckets) + + if len(parsed_resolution_buckets) > 1: + logger.warning( + "Using multiple resolution buckets. " + "When training with multiple resolution buckets, you must use a batch size of 1." + ) + + # Process latents + compute_latents( + dataset_file=dataset_file, + video_column=video_column, + resolution_buckets=parsed_resolution_buckets, + output_dir=output_dir, + model_path=model_path, + reshape_mode=reshape_mode, + batch_size=batch_size, + device=device, + vae_tiling=vae_tiling, + with_audio=with_audio, + audio_output_dir=audio_output_dir, + ) + + +if __name__ == "__main__": + app() diff --git a/packages/ltx-trainer/scripts/split_scenes.py b/packages/ltx-trainer/scripts/split_scenes.py new file mode 100644 index 0000000000000000000000000000000000000000..ce59826488158bd2d9a449ace39d05cfd9092d0d --- /dev/null +++ b/packages/ltx-trainer/scripts/split_scenes.py @@ -0,0 +1,433 @@ +#!/usr/bin/env python3 + +""" +Split video into scenes using PySceneDetect. + +This script provides a command-line interface for splitting videos into scenes using various detection algorithms. +It supports multiple detection methods, preview image generation, and customizable parameters for fine-tuning +the scene detection process. + +Basic usage: + # Split video using default content-based detection + scenes_split.py input.mp4 output_dir/ + + # Save 3 preview images per scene + scenes_split.py input.mp4 output_dir/ --save-images 3 + + # Process specific duration and filter short scenes + scenes_split.py input.mp4 output_dir/ --duration 60s --filter-shorter-than 2s + +Advanced usage: + # Content detection with minimum scene length and frame skip + scenes_split.py input.mp4 output_dir/ --detector content --min-scene-length 30 --frame-skip 2 + + # Use adaptive detection with custom detector and detector parameters + scenes_split.py input.mp4 output_dir/ --detector adaptive --threshold 3.0 --adaptive-window 10 +""" + +from enum import Enum +from pathlib import Path +from typing import List, Optional, Tuple + +import typer +from scenedetect import ( + AdaptiveDetector, + ContentDetector, + HistogramDetector, + SceneManager, + ThresholdDetector, + open_video, +) +from scenedetect.frame_timecode import FrameTimecode +from scenedetect.scene_manager import SceneDetector, write_scene_list_html +from scenedetect.scene_manager import save_images as save_scene_images +from scenedetect.stats_manager import StatsManager +from scenedetect.video_splitter import split_video_ffmpeg + +app = typer.Typer(no_args_is_help=True, help="Split video into scenes using PySceneDetect.") + + +class DetectorType(str, Enum): + """Available scene detection algorithms.""" + + CONTENT = "content" # Detects fast cuts using HSV color space + ADAPTIVE = "adaptive" # Detects fast two-phase cuts + THRESHOLD = "threshold" # Detects fast cuts/slow fades in from and out to a given threshold level + HISTOGRAM = "histogram" # Detects based on YUV histogram differences in adjacent frames + + +def create_detector( + detector_type: DetectorType, + threshold: Optional[float] = None, + min_scene_len: Optional[int] = None, + luma_only: Optional[bool] = None, + adaptive_window: Optional[int] = None, + fade_bias: Optional[float] = None, +) -> SceneDetector: + """Create a scene detector based on the specified type and parameters. + + Args: + detector_type: Type of detector to create + threshold: Detection threshold (meaning varies by detector) + min_scene_len: Minimum scene length in frames + luma_only: If True, only use brightness for content detection + adaptive_window: Window size for adaptive detection + fade_bias: Bias for fade in/out detection (-1.0 to 1.0) + + Note: Parameters set to None will use the detector's built-in default values. + + Returns: + Configured scene detector instance + """ + # Set common arguments + kwargs = {} + if threshold is not None: + kwargs["threshold"] = threshold + + if min_scene_len is not None: + kwargs["min_scene_len"] = min_scene_len + + match detector_type: + case DetectorType.CONTENT: + if luma_only is not None: + kwargs["luma_only"] = luma_only + return ContentDetector(**kwargs) + case DetectorType.ADAPTIVE: + if adaptive_window is not None: + kwargs["window_width"] = adaptive_window + if luma_only is not None: + kwargs["luma_only"] = luma_only + if "threshold" in kwargs: + # Special case for adaptive detector which uses different param name + kwargs["adaptive_threshold"] = kwargs.pop("threshold") + return AdaptiveDetector(**kwargs) + case DetectorType.THRESHOLD: + if fade_bias is not None: + kwargs["fade_bias"] = fade_bias + return ThresholdDetector(**kwargs) + case DetectorType.HISTOGRAM: + return HistogramDetector(**kwargs) + case _: + raise ValueError(f"Unknown detector type: {detector_type}") + + +def validate_output_dir(output_dir: str) -> Path: + """Validate and create output directory if it doesn't exist. + + Args: + output_dir: Path to the output directory + + Returns: + Path object of the validated output directory + """ + path = Path(output_dir) + + if path.exists() and not path.is_dir(): + raise typer.BadParameter(f"{output_dir} exists but is not a directory") + + return path + + +def parse_timecode(video: any, time_str: Optional[str]) -> Optional[FrameTimecode]: + """Parse a timecode string into a FrameTimecode object. + + Supports formats: + - Frames: '123' + - Seconds: '123s' or '123.45s' + - Timecode: '00:02:03' or '00:02:03.456' + + Args: + video: Video object to get framerate from + time_str: String to parse, or None + + Returns: + FrameTimecode object or None if input is None + """ + if time_str is None: + return None + + try: + if time_str.endswith("s"): + # Seconds format + seconds = float(time_str[:-1]) + return FrameTimecode(timecode=seconds, fps=video.frame_rate) + elif ":" in time_str: + # Timecode format + return FrameTimecode(timecode=time_str, fps=video.frame_rate) + else: + # Frame number format + return FrameTimecode(timecode=int(time_str), fps=video.frame_rate) + except ValueError as e: + raise typer.BadParameter( + f"Invalid timecode format: {time_str}. Use frames (123), " + f"seconds (123s/123.45s), or timecode (HH:MM:SS[.nnn])", + ) from e + + +def detect_and_split_scenes( # noqa: PLR0913 + video_path: str, + output_dir: Path, + detector_type: DetectorType, + threshold: Optional[float] = None, + min_scene_len: Optional[int] = None, + max_scenes: Optional[int] = None, + filter_shorter_than: Optional[str] = None, + skip_start: Optional[int] = None, # noqa: ARG001 + skip_end: Optional[int] = None, # noqa: ARG001 + save_images_per_scene: int = 0, + stats_file: Optional[str] = None, + luma_only: bool = False, + adaptive_window: Optional[int] = None, + fade_bias: Optional[float] = None, + downscale_factor: Optional[int] = None, + frame_skip: int = 0, + duration: Optional[str] = None, +) -> List[Tuple[FrameTimecode, FrameTimecode]]: + """Detect and split scenes in a video using the specified parameters. + + Args: + video_path: Path to input video. + output_dir: Directory to save output split scenes. + detector_type: Type of scene detector to use. + threshold: Detection threshold. + min_scene_len: Minimum scene length in frames. + max_scenes: Maximum number of scenes to detect. + filter_shorter_than: Filter out scenes shorter than this duration (frames/seconds/timecode) + skip_start: Number of frames to skip at start. + skip_end: Number of frames to skip at end. + save_images_per_scene: Number of images to save per scene (0 to disable). + stats_file: Path to save detection statistics (optional). + luma_only: Only use brightness for content detection. + adaptive_window: Window size for adaptive detection. + fade_bias: Bias for fade detection (-1.0 to 1.0). + downscale_factor: Factor to downscale frames by during detection. + frame_skip: Number of frames to skip (i.e. process every 1 in N+1 frames, + where N is frame_skip, processing only 1/N+1 percent of the video, + speeding up the detection time at the expense of accuracy). + frame_skip must be 0 (the default) when using a StatsManager. + duration: How much of the video to process from start position. + Can be specified as frames (123), seconds (123s/123.45s), + or timecode (HH:MM:SS[.nnn]). + + Returns: + List of detected scenes as (start, end) FrameTimecode pairs. + """ + # Create video stream + video = open_video(video_path, backend="opencv") + + # Parse duration if specified + duration_tc = parse_timecode(video, duration) + + # Parse filter_shorter_than if specified + filter_shorter_than_tc = parse_timecode(video, filter_shorter_than) + + # Initialize scene manager with optional stats manager + stats_manager = StatsManager() if stats_file else None + scene_manager = SceneManager(stats_manager) + + # Configure scene manager + if downscale_factor: + scene_manager.auto_downscale = False + scene_manager.downscale = downscale_factor + + # Create and add detector + detector = create_detector( + detector_type=detector_type, + threshold=threshold, + min_scene_len=min_scene_len, + luma_only=luma_only, + adaptive_window=adaptive_window, + fade_bias=fade_bias, + ) + scene_manager.add_detector(detector) + + # Detect scenes + typer.echo("Detecting scenes...") + scene_manager.detect_scenes( + video=video, + show_progress=True, + frame_skip=frame_skip, + duration=duration_tc, + ) + + # Get scene list + scenes = scene_manager.get_scene_list() + + # Filter out scenes that are too short if filter_shorter_than is specified + if filter_shorter_than_tc: + original_count = len(scenes) + scenes = [ + (start, end) + for start, end in scenes + if (end.get_frames() - start.get_frames()) >= filter_shorter_than_tc.get_frames() + ] + if len(scenes) < original_count: + typer.echo( + f"Filtered out {original_count - len(scenes)} scenes shorter " + f"than {filter_shorter_than_tc.get_seconds():.1f} seconds " + f"({filter_shorter_than_tc.get_frames()} frames)", + ) + + # Apply max scenes limit if specified + if max_scenes and len(scenes) > max_scenes: + typer.echo(f"Dropping last {len(scenes) - max_scenes} scenes to meet max_scenes ({max_scenes}) limit") + scenes = scenes[:max_scenes] + + # Print scene information + typer.echo(f"Found {len(scenes)} scenes:") + for i, (start, end) in enumerate(scenes, 1): + typer.echo( + f"Scene {i}: {start.get_timecode()} to {end.get_timecode()} " + f"({end.get_frames() - start.get_frames()} frames)", + ) + + # Save stats if requested + if stats_file: + typer.echo(f"Saving detection stats to {stats_file}") + stats_manager.save_to_csv(stats_file) + + # Split video into scenes + typer.echo("Splitting video into scenes...") + try: + split_video_ffmpeg( + input_video_path=video_path, + scene_list=scenes, + output_dir=output_dir, + show_progress=True, + ) + typer.echo(f"Scenes have been saved to: {output_dir}") + except Exception as e: + raise typer.BadParameter(f"Error splitting video: {e}") from e + + # Save preview images if requested + if save_images_per_scene > 0: + typer.echo(f"Saving {save_images_per_scene} preview images per scene...") + image_filenames = save_scene_images( + scene_list=scenes, + video=video, + num_images=save_images_per_scene, + output_dir=str(output_dir), + show_progress=True, + ) + + # Generate HTML report with scene information and previews + html_path = output_dir / "scene_report.html" + write_scene_list_html( + output_html_filename=str(html_path), + scene_list=scenes, + image_filenames=image_filenames, + ) + typer.echo(f"Scene report saved to: {html_path}") + + return scenes + + +@app.command() +def main( # noqa: PLR0913 + video_path: Path = typer.Argument( # noqa: B008 + ..., + help="Path to the input video file", + exists=True, + dir_okay=False, + ), + output_dir: str = typer.Argument( + ..., + help="Directory where split scenes will be saved", + ), + detector: DetectorType = typer.Option( # noqa: B008 + DetectorType.CONTENT, + help="Scene detection algorithm to use", + ), + threshold: Optional[float] = typer.Option( + None, + help="Detection threshold (meaning varies by detector)", + ), + max_scenes: Optional[int] = typer.Option( + None, + help="Maximum number of scenes to produce", + ), + min_scene_length: Optional[int] = typer.Option( + None, + help="Minimum scene length during detection. Forces the detector to make scenes at least this many frames. " + "This affects scene detection behavior but does not filter out short scenes.", + ), + filter_shorter_than: Optional[str] = typer.Option( + None, + help="Filter out scenes shorter than this duration. Can be specified as frames (123), " + "seconds (123s/123.45s), or timecode (HH:MM:SS[.nnn]). These scenes will be detected but not saved.", + ), + skip_start: Optional[int] = typer.Option( + None, + help="Number of frames to skip at the start of the video", + ), + skip_end: Optional[int] = typer.Option( + None, + help="Number of frames to skip at the end of the video", + ), + duration: Optional[str] = typer.Option( + None, + "-d", + help="How much of the video to process. Can be specified as frames (123), " + "seconds (123s/123.45s), or timecode (HH:MM:SS[.nnn])", + ), + save_images: int = typer.Option( + 0, + help="Number of preview images to save per scene (0 to disable)", + ), + stats_file: Optional[str] = typer.Option( + None, + help="Path to save detection statistics CSV", + ), + luma_only: bool = typer.Option( + False, + help="Only use brightness for content detection", + ), + adaptive_window: Optional[int] = typer.Option( + None, + help="Window size for adaptive detection", + ), + fade_bias: Optional[float] = typer.Option( + None, + help="Bias for fade detection (-1.0 to 1.0)", + ), + downscale: Optional[int] = typer.Option( + None, + help="Factor to downscale frames by during detection", + ), + frame_skip: int = typer.Option( + 0, + help="Number of frames to skip during processing", + ), +) -> None: + """Split video into scenes using PySceneDetect.""" + if skip_start or skip_end: + typer.echo("Skipping start and end frames is not supported yet.") + return + + # Validate output directory + output_path = validate_output_dir(output_dir) + + # Detect and split scenes + detect_and_split_scenes( + video_path=str(video_path), + output_dir=output_path, + detector_type=detector, + threshold=threshold, + min_scene_len=min_scene_length, + max_scenes=max_scenes, + filter_shorter_than=filter_shorter_than, + skip_start=skip_start, + skip_end=skip_end, + duration=duration, + save_images_per_scene=save_images, + stats_file=stats_file, + luma_only=luma_only, + adaptive_window=adaptive_window, + fade_bias=fade_bias, + downscale_factor=downscale, + frame_skip=frame_skip, + ) + + +if __name__ == "__main__": + app() diff --git a/packages/ltx-trainer/scripts/train.py b/packages/ltx-trainer/scripts/train.py new file mode 100644 index 0000000000000000000000000000000000000000..6eecfec573dfdbb1985b1f7da65873ba35327c8a --- /dev/null +++ b/packages/ltx-trainer/scripts/train.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python + +""" +Train LTXV models using configuration from YAML files. + +This script provides a command-line interface for training LTXV models using +either LoRA fine-tuning or full model fine-tuning. It loads configuration from +a YAML file and passes it to the trainer. + +Basic usage: + python scripts/train.py CONFIG_PATH [--disable-progress-bars] + +For multi-GPU/FSDP training, configure and launch via Accelerate: + accelerate config + accelerate launch scripts/train.py CONFIG_PATH +""" + +from pathlib import Path + +import typer +import yaml +from rich.console import Console + +from ltx_trainer.config import LtxTrainerConfig +from ltx_trainer.trainer import LtxvTrainer + +console = Console() +app = typer.Typer( + pretty_exceptions_enable=False, + no_args_is_help=True, + help="Train LTXV models using configuration from YAML files.", +) + + +@app.command() +def main( + config_path: str = typer.Argument(..., help="Path to YAML configuration file"), + disable_progress_bars: bool = typer.Option( + False, + "--disable-progress-bars", + help="Disable progress bars (useful for multi-process runs)", + ), +) -> None: + """Train the model using the provided configuration file.""" + # Load the configuration from the YAML file + config_path = Path(config_path) + if not config_path.exists(): + typer.echo(f"Error: Configuration file {config_path} does not exist.") + raise typer.Exit(code=1) + + with open(config_path, "r") as file: + config_data = yaml.safe_load(file) + + # Convert the loaded data to the LtxTrainerConfig object + try: + trainer_config = LtxTrainerConfig(**config_data) + except Exception as e: + typer.echo(f"Error: Invalid configuration data: {e}") + raise typer.Exit(code=1) from e + + # Initialize the training process + trainer = LtxvTrainer(trainer_config) + trainer.train(disable_progress_bars=disable_progress_bars) + + +if __name__ == "__main__": + app() diff --git a/packages/ltx-trainer/src/ltx_trainer/__init__.py b/packages/ltx-trainer/src/ltx_trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..56efa403f9384f9dc4f631b4d63a52d2df6a7680 --- /dev/null +++ b/packages/ltx-trainer/src/ltx_trainer/__init__.py @@ -0,0 +1,44 @@ +import logging +import os +import sys +from logging import getLogger +from pathlib import Path + +from rich.logging import RichHandler + +# Get the process rank +IS_MULTI_GPU = os.environ.get("LOCAL_RANK") is not None +RANK = int(os.environ.get("LOCAL_RANK", "0")) + +# Configure with Rich +logging.basicConfig( + level="INFO", + format=f"\\[rank {RANK}] %(message)s" if IS_MULTI_GPU else "%(message)s", + handlers=[ + RichHandler( + rich_tracebacks=True, + show_time=False, + markup=True, + ) + ], +) + +# Get the logger and configure it +logger = getLogger("ltxv_trainer") +logger.setLevel(logging.DEBUG) +logger.propagate = True + +# Set level based on process +if RANK != 0: + logger.setLevel(logging.WARNING) + +# Expose common logging functions directly +debug = logger.debug +info = logger.info +warning = logger.warning +error = logger.error +critical = logger.critical + + +# Add the root directory to the Python path so we can import from scripts. +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) diff --git a/packages/ltx-trainer/src/ltx_trainer/captioning.py b/packages/ltx-trainer/src/ltx_trainer/captioning.py new file mode 100644 index 0000000000000000000000000000000000000000..41dc6227f1e3deb9de60bbedfa3c841784730ebc --- /dev/null +++ b/packages/ltx-trainer/src/ltx_trainer/captioning.py @@ -0,0 +1,420 @@ +""" +Audio-visual media captioning using multimodal models. + +This module provides captioning capabilities for videos with audio using: +- Qwen2.5-Omni: Local model supporting text, audio, image, and video inputs (default) +- Gemini Flash: Cloud-based API for audio-visual captioning + +Requirements: +- Qwen2.5-Omni: transformers>=4.50, torch +- Gemini Flash: google-generativeai (pip install google-generativeai) + Set GEMINI_API_KEY or GOOGLE_API_KEY environment variable +""" + +import itertools +import re +from abc import ABC, abstractmethod +from enum import Enum +from pathlib import Path + +import torch + +# Instruction for audio-visual captioning (default) - includes speech transcription and sounds +DEFAULT_CAPTION_INSTRUCTION = """\ +Analyze this media and provide a detailed caption in the following EXACT format. Fill in ALL sections: + +[VISUAL]: +[SPEECH]: +[SOUNDS]: +[TEXT]: + +You MUST fill in all four sections. For [SPEECH], transcribe the actual words spoken, not a summary.""" + +# Instruction for video-only captioning (no audio processing) +VIDEO_ONLY_CAPTION_INSTRUCTION = """\ +Analyze this media and provide a detailed caption in the following EXACT format. Fill in ALL sections: + +[VISUAL]: +[TEXT]: + +You MUST fill in both sections.""" + + +class CaptionerType(str, Enum): + """Enum for different types of media captioners.""" + + QWEN_OMNI = "qwen_omni" # Local Qwen2.5-Omni model (audio + video) + GEMINI_FLASH = "gemini_flash" # Gemini Flash API (audio + video) + + +def create_captioner(captioner_type: CaptionerType, **kwargs) -> "MediaCaptioningModel": + """Factory function to create a media captioner. + + Args: + captioner_type: The type of captioner to create + **kwargs: Additional arguments to pass to the captioner constructor + + Returns: + An instance of a MediaCaptioningModel + """ + match captioner_type: + case CaptionerType.QWEN_OMNI: + return QwenOmniCaptioner(**kwargs) + case CaptionerType.GEMINI_FLASH: + return GeminiFlashCaptioner(**kwargs) + case _: + raise ValueError(f"Unsupported captioner type: {captioner_type}") + + +class MediaCaptioningModel(ABC): + """Abstract base class for audio-visual media captioning models.""" + + @abstractmethod + def caption(self, path: str | Path, **kwargs) -> str: + """Generate a caption for the given video or image. + + Args: + path: Path to the video/image file to caption + + Returns: + A string containing the generated caption + """ + + @property + @abstractmethod + def supports_audio(self) -> bool: + """Whether this captioner supports audio input.""" + + @staticmethod + def _is_image_file(path: str | Path) -> bool: + """Check if the file is an image based on extension.""" + return str(path).lower().endswith((".png", ".jpg", ".jpeg", ".heic", ".heif", ".webp")) + + @staticmethod + def _is_video_file(path: str | Path) -> bool: + """Check if the file is a video based on extension.""" + return str(path).lower().endswith((".mp4", ".avi", ".mov", ".mkv", ".webm")) + + @staticmethod + def _clean_raw_caption(caption: str) -> str: + """Clean up the raw caption by removing common VLM patterns.""" + start = ["The", "This"] + kind = ["video", "image", "scene", "animated sequence", "clip", "footage"] + act = ["displays", "shows", "features", "depicts", "presents", "showcases", "captures", "contains"] + + for x, y, z in itertools.product(start, kind, act): + caption = caption.replace(f"{x} {y} {z} ", "", 1) + + return caption + + +class QwenOmniCaptioner(MediaCaptioningModel): + """Audio-visual captioning using Alibaba's Qwen2.5-Omni model. + + Qwen2.5-Omni is an end-to-end multimodal model that can perceive text, images, audio, and video. + It uses a Thinker-Talker architecture where the Thinker generates text and the Talker can + generate speech. For captioning, we use only the Thinker component for text generation. + + Key features: + - Block-wise processing for streaming multimodal inputs + - TMRoPE (Time-aligned Multimodal RoPE) for synchronizing video and audio timestamps + - Can extract and process audio directly from video files + + See: https://huggingface.co/docs/transformers/en/model_doc/qwen2_5_omni + + Model: Qwen/Qwen2.5-Omni-7B (7B parameters) + """ + + MODEL_ID = "Qwen/Qwen2.5-Omni-7B" + + # Default system prompt required by Qwen2.5-Omni for proper audio processing + DEFAULT_SYSTEM_PROMPT = ( + "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, " + "capable of perceiving auditory and visual inputs, as well as generating text and speech." + ) + + def __init__( + self, + device: str | torch.device | None = None, + use_8bit: bool = False, + instruction: str | None = None, + ): + """ + Initialize the Qwen2.5-Omni captioner. + + Args: + device: Device to use for inference (e.g., 'cuda', 'cuda:0', 'cpu') + use_8bit: Whether to use 8-bit quantization for reduced memory usage + instruction: Custom instruction prompt. If None, uses the default instruction + """ + self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) + self.instruction = instruction + self._load_model(use_8bit=use_8bit) + + @property + def supports_audio(self) -> bool: + return True + + def caption( + self, + path: str | Path, + fps: int = 1, + include_audio: bool = True, + clean_caption: bool = True, + ) -> str: + """Generate a caption for the given video or image. + + Args: + path: Path to the video/image file to caption + fps: Frames per second to sample from videos + include_audio: Whether to include audio in the captioning (for videos) + clean_caption: Whether to clean up the raw caption by removing common VLM patterns + + Returns: + A string containing the generated caption + """ + path = Path(path) + is_image = self._is_image_file(path) + is_video = self._is_video_file(path) + + # Determine if we should process audio + use_audio = include_audio and is_video + + # Use custom instruction if provided, otherwise pick appropriate default + if self.instruction is not None: + instruction = self.instruction + else: + instruction = DEFAULT_CAPTION_INSTRUCTION if use_audio else VIDEO_ONLY_CAPTION_INSTRUCTION + + # Build the user content based on media type + # Based on HuggingFace docs: https://huggingface.co/docs/transformers/en/model_doc/qwen2_5_omni + user_content = [] + + if is_image: + user_content.append({"type": "image", "image": str(path)}) + elif is_video: + user_content.append({"type": "video", "video": str(path)}) + + # Add the instruction text + user_content.append({"type": "text", "text": instruction}) + + # Build conversation - use the default system prompt required by Qwen2.5-Omni + # Using a custom system prompt causes warnings and may affect audio processing + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": self.DEFAULT_SYSTEM_PROMPT}], + }, + {"role": "user", "content": user_content}, + ] + + # Process inputs using the processor's apply_chat_template + # For videos with audio, use load_audio_from_video=True and use_audio_in_video=True + inputs = self.processor.apply_chat_template( + messages, + load_audio_from_video=use_audio, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + fps=fps, + padding=True, + use_audio_in_video=use_audio, + ).to(self.model.device) + + # Generate caption (text only, using Thinker-only model) + # Note: For Qwen2_5OmniThinkerForConditionalGeneration, use standard generate params + # (not thinker_ prefixed ones, those are for the full Qwen2_5OmniForConditionalGeneration) + input_len = inputs["input_ids"].shape[1] + + output_tokens = self.model.generate( + **inputs, + use_audio_in_video=use_audio, + do_sample=False, + max_new_tokens=1024, + ) + + # Extract only the generated tokens (exclude the input/prompt tokens) + generated_tokens = output_tokens[:, input_len:] + + # Decode only the generated response + caption_raw = self.processor.batch_decode( + generated_tokens, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + )[0] + + # Remove hallucinated conversation turns (e.g., "Human\nHuman\n..." or "Human: ...") + # This is a known issue with chat models continuing to generate fake turns + # We look for patterns that are clearly hallucinated chat turns, not legitimate uses of "human" + + # Match "\nHuman" followed by ":", "\n", or end of string (chat turn patterns) + # This won't match "A human walks..." or "...the human body..." + caption_raw = re.split(r"\nHuman(?::|(?:\s*\n)|$)", caption_raw, maxsplit=1)[0] + caption_raw = caption_raw.strip() + + # Clean up caption if requested + return self._clean_raw_caption(caption_raw) if clean_caption else caption_raw + + def _load_model(self, use_8bit: bool) -> None: + """Load the Qwen2.5-Omni model and processor. + + Uses the Thinker-only model (Qwen2_5OmniThinkerForConditionalGeneration) for text generation + to save compute by not loading the audio generation components. + """ + from transformers import ( # noqa: PLC0415 + BitsAndBytesConfig, + Qwen2_5OmniProcessor, + Qwen2_5OmniThinkerForConditionalGeneration, + ) + + quantization_config = BitsAndBytesConfig(load_in_8bit=True) if use_8bit else None + + # Use Thinker-only model for text generation (saves memory by not loading Talker) + self.model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained( + self.MODEL_ID, + dtype=torch.bfloat16, + low_cpu_mem_usage=True, + quantization_config=quantization_config, + device_map="auto", + ) + + self.processor = Qwen2_5OmniProcessor.from_pretrained(self.MODEL_ID) + + +class GeminiFlashCaptioner(MediaCaptioningModel): + """Audio-visual captioning using Google's Gemini Flash API. + + Gemini Flash is a cloud-based multimodal model that natively supports + audio and video understanding. Requires a Google API key. + + Note: This captioner requires the `google-generativeai` package and a valid API key. + Set the GEMINI_API_KEY or GOOGLE_API_KEY environment variable, or pass the key directly. + """ + + MODEL_ID = "gemini-flash-lite-latest" + + def __init__( + self, + api_key: str | None = None, + instruction: str | None = None, + ): + """Initialize the Gemini Flash captioner. + + Args: + api_key: Google API key. If not provided, will look for + GEMINI_API_KEY or GOOGLE_API_KEY environment variable. + instruction: Custom instruction prompt. If None, uses the default instruction + """ + self.instruction = instruction + self._init_client(api_key) + + @property + def supports_audio(self) -> bool: + return True + + def caption( + self, + path: str | Path, + fps: int = 3, # noqa: ARG002 - kept for API compatibility + include_audio: bool = True, + clean_caption: bool = True, + ) -> str: + """Generate a caption for the given video or image. + + Args: + path: Path to the video/image file to caption + fps: Frames per second (not used for Gemini, kept for API compatibility) + include_audio: Whether to include audio content in the caption + clean_caption: Whether to clean up the raw caption + + Returns: + A string containing the generated caption + """ + import time # noqa: PLC0415 + + path = Path(path) + is_video = self._is_video_file(path) + use_audio = include_audio and is_video + + # Use custom instruction if provided, otherwise pick appropriate default + if self.instruction is not None: + instruction = self.instruction + else: + instruction = DEFAULT_CAPTION_INSTRUCTION if use_audio else VIDEO_ONLY_CAPTION_INSTRUCTION + + # Upload the file to Gemini + uploaded_file = self._genai.upload_file(path) + + # Wait for processing to complete (videos need time to process) + while uploaded_file.state.name == "PROCESSING": + time.sleep(1) + uploaded_file = self._genai.get_file(uploaded_file.name) + + if uploaded_file.state.name == "FAILED": + raise RuntimeError(f"File processing failed: {uploaded_file.state.name}") + + # Generate caption + response = self._model.generate_content([uploaded_file, instruction]) + + caption_raw = response.text + + # Clean up the uploaded file + self._genai.delete_file(uploaded_file.name) + + # Clean up caption if requested + return self._clean_raw_caption(caption_raw) if clean_caption else caption_raw + + def _init_client(self, api_key: str | None) -> None: + """Initialize the Gemini API client.""" + import os # noqa: PLC0415 + + try: + import google.generativeai as genai # noqa: PLC0415 + except ImportError as e: + raise ImportError( + "The `google-generativeai` package is required for Gemini Flash captioning. " + "Install it with: `uv pip install google-generativeai`" + ) from e + + # Get API key from argument or environment + # GEMINI_API_KEY is the recommended variable, GOOGLE_API_KEY also works + resolved_api_key = api_key or os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY") + + if not resolved_api_key: + raise ValueError( + "Gemini API key is required. Provide it via the `api_key` argument " + "or set the GEMINI_API_KEY or GOOGLE_API_KEY environment variable." + ) + + # Configure the genai library with the API key + genai.configure(api_key=resolved_api_key) + + # Store reference to genai module for file operations + self._genai = genai + + # Initialize the model + self._model = genai.GenerativeModel(self.MODEL_ID) + + +def example() -> None: + """Example usage of the captioning module.""" + import sys # noqa: PLC0415 + + if len(sys.argv) < 2: + print(f"Usage: python {sys.argv[0]} [captioner_type]") # noqa: T201 + print(" captioner_type: qwen_omni (default) or gemini_flash") # noqa: T201 + sys.exit(1) + + video_path = sys.argv[1] + captioner_type = CaptionerType(sys.argv[2]) if len(sys.argv) > 2 else CaptionerType.QWEN_OMNI + + print(f"Using {captioner_type.value} captioner:") # noqa: T201 + captioner = create_captioner(captioner_type) + caption = captioner.caption(video_path) + print(f"CAPTION: {caption}") # noqa: T201 + + +if __name__ == "__main__": + example() diff --git a/packages/ltx-trainer/src/ltx_trainer/config.py b/packages/ltx-trainer/src/ltx_trainer/config.py new file mode 100644 index 0000000000000000000000000000000000000000..72a4d47eb6eb8903f5df83cdb9260dd3ce06a492 --- /dev/null +++ b/packages/ltx-trainer/src/ltx_trainer/config.py @@ -0,0 +1,472 @@ +from pathlib import Path +from typing import Annotated, Literal + +from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, ValidationInfo, field_validator, model_validator + +from ltx_trainer.quantization import QuantizationOptions +from ltx_trainer.training_strategies.base_strategy import TrainingStrategyConfigBase +from ltx_trainer.training_strategies.text_to_video import TextToVideoConfig +from ltx_trainer.training_strategies.video_to_video import VideoToVideoConfig + + +class ConfigBaseModel(BaseModel): + model_config = ConfigDict(extra="forbid") + + +class ModelConfig(ConfigBaseModel): + """Configuration for the base model and training mode""" + + model_path: str | Path = Field( + ..., + description="Model path - local path to safetensors checkpoint file", + ) + + text_encoder_path: str | Path | None = Field( + default=None, + description="Path to text encoder (required for LTX-2/Gemma models, optional for LTXV/T5 models)", + ) + + training_mode: Literal["lora", "full"] = Field( + default="lora", + description="Training mode - either LoRA fine-tuning or full model fine-tuning", + ) + + load_checkpoint: str | Path | None = Field( + default=None, + description="Path to a checkpoint file or directory to load from. " + "If a directory is provided, the latest checkpoint will be used.", + ) + + @field_validator("model_path") + @classmethod + def validate_model_path(cls, v: str | Path) -> str | Path: + """Validate that model_path is either a valid URL or an existing local path.""" + is_url = str(v).startswith(("http://", "https://")) + + if is_url: + raise ValueError(f"Model path cannot be a URL: {v}") + + if not Path(v).exists(): + raise ValueError(f"Model path does not exist: {v}") + + return v + + +class LoraConfig(ConfigBaseModel): + """Configuration for LoRA fine-tuning""" + + rank: int = Field( + default=64, + description="Rank of LoRA adaptation", + ge=2, + ) + + alpha: int = Field( + default=64, + description="Alpha scaling factor for LoRA", + ge=1, + ) + + dropout: float = Field( + default=0.0, + description="Dropout probability for LoRA layers", + ge=0.0, + le=1.0, + ) + + target_modules: list[str] = Field( + default=["to_k", "to_q", "to_v", "to_out.0"], + description="List of modules to target with LoRA", + ) + + +def _get_strategy_discriminator(v: dict | TrainingStrategyConfigBase) -> str: + """Discriminator function for strategy config union.""" + if isinstance(v, dict): + return v.get("name", "text_to_video") + return v.name + + +# Union type for all strategy configs with discriminator +TrainingStrategyConfig = Annotated[ + Annotated[TextToVideoConfig, Tag("text_to_video")] | Annotated[VideoToVideoConfig, Tag("video_to_video")], + Discriminator(_get_strategy_discriminator), +] + + +class OptimizationConfig(ConfigBaseModel): + """Configuration for optimization parameters""" + + learning_rate: float = Field( + default=5e-4, + description="Learning rate for optimization", + ) + + steps: int = Field( + default=3000, + description="Number of training steps", + ) + + batch_size: int = Field( + default=2, + description="Batch size for training", + ) + + gradient_accumulation_steps: int = Field( + default=1, + description="Number of steps to accumulate gradients", + ) + + max_grad_norm: float = Field( + default=1.0, + description="Maximum gradient norm for clipping", + ) + + optimizer_type: Literal["adamw", "adamw8bit"] = Field( + default="adamw", + description="Type of optimizer to use for training", + ) + + scheduler_type: Literal[ + "constant", + "linear", + "cosine", + "cosine_with_restarts", + "polynomial", + ] = Field( + default="linear", + description="Type of scheduler to use for training", + ) + + scheduler_params: dict = Field( + default_factory=dict, + description="Parameters for the scheduler", + ) + + enable_gradient_checkpointing: bool = Field( + default=False, + description="Enable gradient checkpointing to save memory at the cost of slower training", + ) + + +class AccelerationConfig(ConfigBaseModel): + """Configuration for hardware acceleration and compute optimization""" + + mixed_precision_mode: Literal["no", "fp16", "bf16"] | None = Field( + default="bf16", + description="Mixed precision training mode", + ) + + quantization: QuantizationOptions | None = Field( + default=None, + description="Quantization precision to use", + ) + + load_text_encoder_in_8bit: bool = Field( + default=False, + description="Whether to load the text encoder in 8-bit precision to save memory", + ) + + +class DataConfig(ConfigBaseModel): + """Configuration for data loading and processing""" + + preprocessed_data_root: str = Field( + description="Path to folder containing preprocessed training data", + ) + + num_dataloader_workers: int = Field( + default=2, + description="Number of background processes for data loading (0 means synchronous loading)", + ge=0, + ) + + +class ValidationConfig(ConfigBaseModel): + """Configuration for validation during training""" + + prompts: list[str] = Field( + default_factory=list, + description="List of prompts to use for validation", + ) + + negative_prompt: str = Field( + default="worst quality, inconsistent motion, blurry, jittery, distorted", + description="Negative prompt to use for validation examples", + ) + + images: list[str] | None = Field( + default=None, + description="List of image paths to use for validation. " + "One image path must be provided for each validation prompt", + ) + + reference_videos: list[str] | None = Field( + default=None, + description="List of reference video paths to use for validation. " + "One video path must be provided for each validation prompt", + ) + + video_dims: tuple[int, int, int] = Field( + default=(960, 544, 97), + description="Dimensions of validation videos (width, height, frames). " + "Width and height must be divisible by 32. Frames must satisfy frames % 8 == 1 for LTX-2.", + ) + + @field_validator("video_dims") + @classmethod + def validate_video_dims(cls, v: tuple[int, int, int]) -> tuple[int, int, int]: + """Validate video dimensions for LTX-2 compatibility.""" + width, height, frames = v + + if width % 32 != 0: + raise ValueError(f"Width ({width}) must be divisible by 32") + if height % 32 != 0: + raise ValueError(f"Height ({height}) must be divisible by 32") + if frames % 8 != 1: + raise ValueError(f"Frames ({frames}) must satisfy frames % 8 == 1 for LTX-2 (e.g., 1, 9, 17, 25, ...)") + + return v + + frame_rate: float = Field( + default=25.0, + description="Frame rate for validation videos", + gt=0, + ) + + seed: int = Field( + default=42, + description="Random seed used when sampling validation videos", + ) + + inference_steps: int = Field( + default=50, + description="Number of inference steps for validation", + gt=0, + ) + + interval: int | None = Field( + default=100, + description="Number of steps between validation runs. If None, validation is disabled.", + gt=0, + ) + + videos_per_prompt: int = Field( + default=1, + description="Number of videos to generate per validation prompt", + gt=0, + ) + + guidance_scale: float = Field( + default=3.0, + description="CFG guidance scale to use during validation", + ge=1.0, + ) + + stg_scale: float = Field( + default=1.0, + description="STG (Spatio-Temporal Guidance) scale. 0.0 disables STG. " + "Recommended value is 1.0. STG is combined with CFG for improved video quality.", + ge=0.0, + ) + + stg_blocks: list[int] | None = Field( + default=[29], + description="Which transformer blocks to perturb for STG. " + "None means all blocks are perturbed. Recommended for LTX-2: [29].", + ) + + stg_mode: Literal["stg_av", "stg_v"] = Field( + default="stg_av", + description="STG mode: 'stg_av' skips both audio and video self-attention, " + "'stg_v' skips only video self-attention.", + ) + + generate_audio: bool = Field( + default=True, + description="Whether to generate audio in validation samples. " + "Independent of training strategy setting - you can generate audio " + "in validation even when not training the audio branch.", + ) + + skip_initial_validation: bool = Field( + default=False, + description="Skip validation video sampling at step 0 (beginning of training)", + ) + + include_reference_in_output: bool = Field( + default=False, + description="For video-to-video training: concatenate the original reference video side-by-side " + "with the generated output. The reference comes from the input video, not from the model's output.", + ) + + @field_validator("images") + @classmethod + def validate_images(cls, v: list[str] | None, info: ValidationInfo) -> list[str] | None: + """Validate that number of images (if provided) matches number of prompts.""" + if v is None: + return None + + num_prompts = len(info.data.get("prompts", [])) + if v is not None and len(v) != num_prompts: + raise ValueError(f"Number of images ({len(v)}) must match number of prompts ({num_prompts})") + + for image_path in v: + if not Path(image_path).exists(): + raise ValueError(f"Image path '{image_path}' does not exist") + + return v + + @field_validator("reference_videos") + @classmethod + def validate_reference_videos(cls, v: list[str] | None, info: ValidationInfo) -> list[str] | None: + """Validate that number of reference videos (if provided) matches number of prompts.""" + if v is None: + return None + + num_prompts = len(info.data.get("prompts", [])) + if v is not None and len(v) != num_prompts: + raise ValueError(f"Number of reference videos ({len(v)}) must match number of prompts ({num_prompts})") + + for video_path in v: + if not Path(video_path).exists(): + raise ValueError(f"Reference video path '{video_path}' does not exist") + + return v + + +class CheckpointsConfig(ConfigBaseModel): + """Configuration for model checkpointing during training""" + + interval: int | None = Field( + default=None, + description="Number of steps between checkpoint saves. If None, intermediate checkpoints are disabled.", + gt=0, + ) + + keep_last_n: int = Field( + default=1, + description="Number of most recent checkpoints to keep. Set to -1 to keep all checkpoints.", + ge=-1, + ) + + +class HubConfig(ConfigBaseModel): + """Configuration for Hugging Face Hub integration""" + + push_to_hub: bool = Field(default=False, description="Whether to push the model weights to the Hugging Face Hub") + hub_model_id: str | None = Field( + default=None, description="Hugging Face Hub repository ID (e.g., 'username/repo-name')" + ) + + @model_validator(mode="after") + def validate_hub_config(self) -> "HubConfig": + """Validate that hub_model_id is not None when push_to_hub is True.""" + if self.push_to_hub and not self.hub_model_id: + raise ValueError("hub_model_id must be specified when push_to_hub is True") + return self + + +class WandbConfig(ConfigBaseModel): + """Configuration for Weights & Biases logging""" + + enabled: bool = Field( + default=False, + description="Whether to enable W&B logging", + ) + + project: str = Field( + default="ltxv-trainer", + description="W&B project name", + ) + + entity: str | None = Field( + default=None, + description="W&B username or team", + ) + + tags: list[str] = Field( + default_factory=list, + description="Tags to add to the W&B run", + ) + + log_validation_videos: bool = Field( + default=True, + description="Whether to log validation videos to W&B", + ) + + +class FlowMatchingConfig(ConfigBaseModel): + """Configuration for flow matching training""" + + timestep_sampling_mode: Literal["uniform", "shifted_logit_normal"] = Field( + default="shifted_logit_normal", + description="Mode to use for timestep sampling", + ) + + timestep_sampling_params: dict = Field( + default_factory=dict, + description="Parameters for timestep sampling", + ) + + +class LtxTrainerConfig(ConfigBaseModel): + """Unified configuration for LTXV training""" + + # Sub-configurations + model: ModelConfig = Field(default_factory=ModelConfig) + lora: LoraConfig | None = Field(default=None) + training_strategy: TrainingStrategyConfig = Field( + default_factory=TextToVideoConfig, + description="Training strategy configuration. Determines the training mode and its parameters.", + ) + optimization: OptimizationConfig = Field(default_factory=OptimizationConfig) + acceleration: AccelerationConfig = Field(default_factory=AccelerationConfig) + data: DataConfig + validation: ValidationConfig = Field(default_factory=ValidationConfig) + checkpoints: CheckpointsConfig = Field(default_factory=CheckpointsConfig) + hub: HubConfig = Field(default_factory=HubConfig) + flow_matching: FlowMatchingConfig = Field(default_factory=FlowMatchingConfig) + wandb: WandbConfig = Field(default_factory=WandbConfig) + + # General configuration + seed: int = Field( + default=42, + description="Random seed for reproducibility", + ) + + output_dir: str = Field( + default="outputs", + description="Directory to save model outputs", + ) + + # noinspection PyNestedDecorators + @field_validator("output_dir") + @classmethod + def expand_output_path(cls, v: str) -> str: + """Expand user home directory in output path.""" + return str(Path(v).expanduser().resolve()) + + @model_validator(mode="after") + def validate_strategy_compatibility(self) -> "LtxTrainerConfig": + """Validate that training strategy and other configurations are compatible.""" + + # Check that reference videos are provided when using video_to_video strategy + if ( + self.training_strategy.name == "video_to_video" + and self.validation.interval + and not self.validation.reference_videos + ): + raise ValueError( + "reference_videos must be provided in validation config when using video_to_video strategy" + ) + + # Check that LoRA config is provided when training mode is lora + if self.model.training_mode == "lora" and self.lora is None: + raise ValueError("LoRA configuration must be provided when training_mode is 'lora'") + + # Check that LoRA config is provided when using video_to_video strategy + if self.training_strategy.name == "video_to_video" and self.model.training_mode != "lora": + raise ValueError("Training mode must be 'lora' when using video_to_video strategy") + + return self diff --git a/packages/ltx-trainer/src/ltx_trainer/config_display.py b/packages/ltx-trainer/src/ltx_trainer/config_display.py new file mode 100644 index 0000000000000000000000000000000000000000..7eb28ed72ddf2bd4fa546f18152fc3db0c2725d4 --- /dev/null +++ b/packages/ltx-trainer/src/ltx_trainer/config_display.py @@ -0,0 +1,156 @@ +"""Display utilities for training configuration. + +This module provides formatted console output for LtxTrainerConfig. +""" + +from rich import box +from rich.console import Console +from rich.table import Table + +from ltx_trainer.config import LtxTrainerConfig + + +def print_config(config: LtxTrainerConfig) -> None: + """Print configuration as a nicely formatted table with sections.""" + + def fmt(v: object, max_len: int = 55) -> str: + """Format any value for display.""" + if v is None: + return "[dim]—[/]" + if isinstance(v, bool): + return "[green]✓[/]" if v else "[dim]✗[/]" + if isinstance(v, (list, tuple)): + if not v: + return "[dim]—[/]" + return ", ".join(str(x) for x in v) + s = str(v) + return s[: max_len - 3] + "..." if len(s) > max_len else s + + cfg = config + opt = cfg.optimization + val = cfg.validation + accel = cfg.acceleration + + # Build sections: list of (section_title, [(key, value), ...]) + sections: list[tuple[str, list[tuple[str, str]]]] = [ + ( + "🎬 Model", + [ + ("Base", fmt(cfg.model.model_path)), + ("Text Encoder", fmt(cfg.model.text_encoder_path) or "[dim]Built-in[/]"), + ("Training Mode", f"[bold green]{cfg.model.training_mode.upper()}[/]"), + ("Load Checkpoint", fmt(cfg.model.load_checkpoint) if cfg.model.load_checkpoint else "[dim]—[/]"), + ], + ), + ] + + if cfg.lora: + sections.append( + ( + "🔗 LoRA", + [ + ("Rank / Alpha", f"{cfg.lora.rank} / {cfg.lora.alpha}"), + ("Dropout", str(cfg.lora.dropout)), + ("Target Modules", fmt(cfg.lora.target_modules)), + ], + ) + ) + + # Strategy section - include strategy-specific fields + strategy_items: list[tuple[str, str]] = [("Name", cfg.training_strategy.name)] + if hasattr(cfg.training_strategy, "with_audio"): + strategy_items.append(("Audio", fmt(cfg.training_strategy.with_audio))) + if hasattr(cfg.training_strategy, "first_frame_conditioning_p"): + strategy_items.append(("First Frame Cond P", str(cfg.training_strategy.first_frame_conditioning_p))) + + sections.append(("🎯 Strategy", strategy_items)) + + sections.extend( + [ + ( + "⚡ Optimization", + [ + ("Steps", f"[bold]{opt.steps:,}[/]"), + ("Learning Rate", f"{opt.learning_rate:.2e}"), + ("Batch Size", str(opt.batch_size)), + ("Grad Accumulation", str(opt.gradient_accumulation_steps)), + ("Optimizer", opt.optimizer_type), + ("Scheduler", opt.scheduler_type), + ("Max Grad Norm", str(opt.max_grad_norm)), + ("Grad Checkpointing", fmt(opt.enable_gradient_checkpointing)), + ], + ), + ( + "🚀 Acceleration", + [ + ("Mixed Precision", accel.mixed_precision_mode or "[dim]—[/]"), + ("Quantization", str(accel.quantization) if accel.quantization else "[dim]—[/]"), + ("Text Encoder 8bit", fmt(accel.load_text_encoder_in_8bit)), + ], + ), + ( + "🎥 Validation", + [ + ("Prompts", f"{len(val.prompts)} prompt(s)" if val.prompts else "[dim]—[/]"), + ("Interval", f"Every {val.interval} steps" if val.interval else "[dim]Disabled[/]"), + ("Video Dims", f"{val.video_dims[0]}x{val.video_dims[1]}, {val.video_dims[2]} frames"), + ("Frame Rate", f"{val.frame_rate} fps"), + ("Inference Steps", str(val.inference_steps)), + ("CFG Scale", str(val.guidance_scale)), + ( + "STG", + f"scale={val.stg_scale}; blocks={fmt(val.stg_blocks)}; mode={val.stg_mode}" + if val.stg_scale > 0 + else "[dim]Disabled[/]", + ), + ("Seed", str(val.seed)), + ], + ), + ( + "📂 Data & Output", + [ + ("Dataset", fmt(cfg.data.preprocessed_data_root)), + ("Dataloader Workers", str(cfg.data.num_dataloader_workers)), + ("Output Dir", fmt(cfg.output_dir)), + ("Seed", str(cfg.seed)), + ], + ), + ( + "🔌 Integrations", + [ + ( + "Checkpoints", + f"Every {cfg.checkpoints.interval} steps (keep {cfg.checkpoints.keep_last_n})" + if cfg.checkpoints.interval + else "[dim]Disabled[/]", + ), + ("W&B", f"{cfg.wandb.project}" if cfg.wandb.enabled else "[dim]Disabled[/]"), + ("HF Hub", cfg.hub.hub_model_id if cfg.hub.push_to_hub else "[dim]Disabled[/]"), + ], + ), + ] + ) + + # Build table with section headers + table = Table( + title="[bold]⚙️ Training Configuration[/]", + show_header=False, + box=box.ROUNDED, + border_style="bright_blue", + padding=(0, 1), + title_style="bold bright_blue", + ) + table.add_column("Key", style="white", width=20) + table.add_column("Value", style="cyan") + + for i, (section_title, items) in enumerate(sections): + if i > 0: + table.add_row("", "") # Blank line between sections + table.add_row(f"[bold yellow]{section_title}[/]", "") + for key, value in items: + table.add_row(f" {key}", value) + + console = Console() + console.print() + console.print(table) + console.print() diff --git a/packages/ltx-trainer/src/ltx_trainer/datasets.py b/packages/ltx-trainer/src/ltx_trainer/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..e7a260560c695c5072b7d710d56ab2b2d2fcad73 --- /dev/null +++ b/packages/ltx-trainer/src/ltx_trainer/datasets.py @@ -0,0 +1,275 @@ +from pathlib import Path + +import torch +from einops import rearrange +from torch import Tensor +from torch.utils.data import Dataset + +from ltx_trainer import logger + +# Constants for precomputed data directories +PRECOMPUTED_DIR_NAME = ".precomputed" + + +class DummyDataset(Dataset): + """Produce random latents and prompt embeddings. For minimal demonstration and benchmarking purposes""" + + def __init__( + self, + width: int = 1024, + height: int = 1024, + num_frames: int = 25, + fps: int = 24, + dataset_length: int = 200, + latent_dim: int = 128, + latent_spatial_compression_ratio: int = 32, + latent_temporal_compression_ratio: int = 8, + prompt_embed_dim: int = 4096, + prompt_sequence_length: int = 256, + ) -> None: + if width % 32 != 0: + raise ValueError(f"Width must be divisible by 32, got {width=}") + + if height % 32 != 0: + raise ValueError(f"Height must be divisible by 32, got {height=}") + + if num_frames % 8 != 1: + raise ValueError(f"Number of frames must have a remainder of 1 when divided by 8, got {num_frames=}") + + self.width = width + self.height = height + self.num_frames = num_frames + self.fps = fps + self.dataset_length = dataset_length + self.latent_dim = latent_dim + self.num_latent_frames = (num_frames - 1) // latent_temporal_compression_ratio + 1 + self.latent_height = height // latent_spatial_compression_ratio + self.latent_width = width // latent_spatial_compression_ratio + self.latent_sequence_length = self.num_latent_frames * self.latent_height * self.latent_width + self.prompt_embed_dim = prompt_embed_dim + self.prompt_sequence_length = prompt_sequence_length + + def __len__(self) -> int: + return self.dataset_length + + def __getitem__(self, idx: int) -> dict[str, dict[str, Tensor]]: + return { + "latent_conditions": { + "latents": torch.randn( + self.latent_dim, + self.num_latent_frames, + self.latent_height, + self.latent_width, + ), + "num_frames": self.num_latent_frames, + "height": self.latent_height, + "width": self.latent_width, + "fps": self.fps, + }, + "text_conditions": { + "prompt_embeds": torch.randn( + self.prompt_sequence_length, + self.prompt_embed_dim, + ), # random text embeddings + "prompt_attention_mask": torch.ones( + self.prompt_sequence_length, + dtype=torch.bool, + ), # random attention mask + }, + } + + +class PrecomputedDataset(Dataset): + def __init__(self, data_root: str, data_sources: dict[str, str] | list[str] | None = None) -> None: + """ + Generic dataset for loading precomputed data from multiple sources. + + Args: + data_root: Root directory containing preprocessed data + data_sources: Either: + - Dict mapping directory names to output keys + - List of directory names (keys will equal values) + - None (defaults to ["latents", "conditions"]) + + Example: + # Standard mode (list) + dataset = PrecomputedDataset("data/", ["latents", "conditions"]) + + # Standard mode (dict) + dataset = PrecomputedDataset("data/", {"latents": "latent_conditions", "conditions": "text_conditions"}) + + # IC-LoRA mode + dataset = PrecomputedDataset("data/", ["latents", "conditions", "reference_latents"]) + + Note: + Latents are always returned in non-patchified format [C, F, H, W]. + Legacy patchified format [seq_len, C] is automatically converted. + """ + super().__init__() + + self.data_root = self._setup_data_root(data_root) + self.data_sources = self._normalize_data_sources(data_sources) + self.source_paths = self._setup_source_paths() + self.sample_files = self._discover_samples() + self._validate_setup() + + @staticmethod + def _setup_data_root(data_root: str) -> Path: + """Setup and validate the data root directory.""" + data_root = Path(data_root).expanduser().resolve() + + if not data_root.exists(): + raise FileNotFoundError(f"Data root directory does not exist: {data_root}") + + # If the given path is the dataset root, use the precomputed subdirectory + if (data_root / PRECOMPUTED_DIR_NAME).exists(): + data_root = data_root / PRECOMPUTED_DIR_NAME + + return data_root + + @staticmethod + def _normalize_data_sources(data_sources: dict[str, str] | list[str] | None) -> dict[str, str]: + """Normalize data_sources input to a consistent dict format.""" + if data_sources is None: + # Default sources + return {"latents": "latent_conditions", "conditions": "text_conditions"} + elif isinstance(data_sources, list): + # Convert list to dict where keys equal values + return {source: source for source in data_sources} + elif isinstance(data_sources, dict): + return data_sources.copy() + else: + raise TypeError(f"data_sources must be dict, list, or None, got {type(data_sources)}") + + def _setup_source_paths(self) -> dict[str, Path]: + """Map data source names to their actual directory paths.""" + source_paths = {} + + for dir_name in self.data_sources: + source_path = self.data_root / dir_name + source_paths[dir_name] = source_path + + # Check that all sources exist. + if not source_path.exists(): + raise FileNotFoundError(f"Required {dir_name} directory does not exist: {source_path}") + + return source_paths + + def _discover_samples(self) -> dict[str, list[Path]]: + """Discover all valid sample files across all data sources.""" + # Use first data source as the reference to discover samples + data_key = "latents" if "latents" in self.data_sources else next(iter(self.data_sources.keys())) + data_path = self.source_paths[data_key] + data_files = list(data_path.glob("**/*.pt")) + + if not data_files: + raise ValueError(f"No data files found in {data_path}") + + # Initialize sample files dict + sample_files = {output_key: [] for output_key in self.data_sources.values()} + + # For each data file, find corresponding files in other sources + for data_file in data_files: + rel_path = data_file.relative_to(data_path) + + # Check if corresponding files exist in ALL sources + if self._all_source_files_exist(data_file, rel_path): + self._fill_sample_data_files(data_file, rel_path, sample_files) + + return sample_files + + def _all_source_files_exist(self, data_file: Path, rel_path: Path) -> bool: + """Check if corresponding files exist in all data sources.""" + for dir_name in self.data_sources: + expected_path = self._get_expected_file_path(dir_name, data_file, rel_path) + if not expected_path.exists(): + logger.warning( + f"No matching {dir_name} file found for: {data_file.name} (expected in: {expected_path})" + ) + return False + + return True + + def _get_expected_file_path(self, dir_name: str, data_file: Path, rel_path: Path) -> Path: + """Get the expected file path for a given data source.""" + source_path = self.source_paths[dir_name] + + # For conditions, handle legacy naming where latent_X.pt maps to condition_X.pt + if dir_name == "conditions" and data_file.name.startswith("latent_"): + return source_path / f"condition_{data_file.stem[7:]}.pt" + + return source_path / rel_path + + def _fill_sample_data_files(self, data_file: Path, rel_path: Path, sample_files: dict[str, list[Path]]) -> None: + """Add a valid sample to the sample_files tracking.""" + for dir_name, output_key in self.data_sources.items(): + expected_path = self._get_expected_file_path(dir_name, data_file, rel_path) + sample_files[output_key].append(expected_path.relative_to(self.source_paths[dir_name])) + + def _validate_setup(self) -> None: + """Validate that the dataset setup is correct.""" + if not self.sample_files: + raise ValueError("No valid samples found - all data sources must have matching files") + + # Verify all output keys have the same number of samples + sample_counts = {key: len(files) for key, files in self.sample_files.items()} + if len(set(sample_counts.values())) > 1: + raise ValueError(f"Mismatched sample counts across sources: {sample_counts}") + + def __len__(self) -> int: + # Use the first output key as reference count + first_key = next(iter(self.sample_files.keys())) + return len(self.sample_files[first_key]) + + def __getitem__(self, index: int) -> dict[str, torch.Tensor]: + result = {} + + for dir_name, output_key in self.data_sources.items(): + source_path = self.source_paths[dir_name] + file_rel_path = self.sample_files[output_key][index] + file_path = source_path / file_rel_path + + try: + data = torch.load(file_path, map_location="cpu", weights_only=True) + + # Normalize video latent format if this is a latent source + if "latent" in dir_name.lower(): + data = self._normalize_video_latents(data) + + result[output_key] = data + except Exception as e: + raise RuntimeError(f"Failed to load {output_key} from {file_path}: {e}") from e + + # Add index for debugging + result["idx"] = index + return result + + @staticmethod + def _normalize_video_latents(data: dict) -> dict: + """ + Normalize video latents to non-patchified format [C, F, H, W]. + Used for keeping backward compatibility with legacy datasets. + """ + latents = data["latents"] + + # Check if latents are in legacy patchified format [seq_len, C] + if latents.dim() == 2: + # Legacy format: [seq_len, C] where seq_len = F * H * W + num_frames = data["num_frames"] + height = data["height"] + width = data["width"] + + # Unpatchify: [seq_len, C] -> [C, F, H, W] + latents = rearrange( + latents, + "(f h w) c -> c f h w", + f=num_frames, + h=height, + w=width, + ) + + # Update the data dict with unpatchified latents + data = data.copy() + data["latents"] = latents + + return data diff --git a/packages/ltx-trainer/src/ltx_trainer/hf_hub_utils.py b/packages/ltx-trainer/src/ltx_trainer/hf_hub_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..63ed56b36b3ad5715fa523e001d08ddac7db19bf --- /dev/null +++ b/packages/ltx-trainer/src/ltx_trainer/hf_hub_utils.py @@ -0,0 +1,208 @@ +import shutil +import tempfile +from pathlib import Path +from typing import List, Union + +import imageio +from huggingface_hub import HfApi, create_repo +from huggingface_hub.utils import are_progress_bars_disabled, disable_progress_bars, enable_progress_bars +from rich.progress import Progress, SpinnerColumn, TextColumn + +from ltx_trainer import logger +from ltx_trainer.config import LtxTrainerConfig + + +def push_to_hub(weights_path: Path, sampled_videos_paths: List[Path], config: LtxTrainerConfig) -> None: + """Push the trained LoRA weights to HuggingFace Hub.""" + if not config.hub.hub_model_id: + logger.warning("⚠️ HuggingFace hub_model_id not specified, skipping push to hub") + return + + api = HfApi() + + # Save original progress bar state + original_progress_state = are_progress_bars_disabled() + disable_progress_bars() # Disable during our custom progress tracking + + try: + # Try to create repo if it doesn't exist + try: + repo = create_repo( + repo_id=config.hub.hub_model_id, + repo_type="model", + exist_ok=True, # Don't raise error if repo exists + ) + repo_id = repo.repo_id + logger.info(f"🤗 Successfully created HuggingFace model repository at: {repo.url}") + except Exception as e: + logger.error(f"❌ Failed to create HuggingFace model repository: {e}") + return + + # Create a single temporary directory for all files + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + transient=True, + ) as progress: + try: + # Copy weights + task_copy = progress.add_task("Copying weights...", total=None) + weights_dest = temp_path / weights_path.name + shutil.copy2(weights_path, weights_dest) + progress.update(task_copy, description="✓ Weights copied") + + # Create model card and save samples + task_card = progress.add_task("Creating model card and samples...", total=None) + _create_model_card( + output_dir=temp_path, + videos=sampled_videos_paths, + config=config, + ) + progress.update(task_card, description="✓ Model card and samples created") + + # Upload everything at once + task_upload = progress.add_task("Pushing files to HuggingFace Hub...", total=None) + api.upload_folder( + folder_path=str(temp_path), + repo_id=repo_id, + repo_type="model", + ) + progress.update(task_upload, description="✓ Files pushed to HuggingFace Hub") + logger.info("✅ Successfully pushed files to HuggingFace Hub") + + except Exception as e: + logger.error(f"❌ Failed to process and push files to HuggingFace Hub: {e}") + raise # Re-raise to handle in outer try block + + finally: + # Restore original progress bar state + if not original_progress_state: + enable_progress_bars() + + +def convert_video_to_gif(video_path: Path, output_path: Path) -> None: + """Convert a video file to GIF format.""" + try: + # Read the video file + reader = imageio.get_reader(str(video_path)) + fps = reader.get_meta_data()["fps"] + + # Write GIF file with infinite loop + writer = imageio.get_writer( + str(output_path), + fps=min(fps, 15), # Cap FPS at 15 for reasonable file size + loop=0, # 0 means infinite loop + ) + + for frame in reader: + writer.append_data(frame) + + writer.close() + reader.close() + except Exception as e: + logger.error(f"Failed to convert video to GIF: {e}") + + +def _create_model_card( + output_dir: Union[str, Path], + videos: List[Path], + config: LtxTrainerConfig, +) -> Path: + """Generate and save a model card for the trained model.""" + + repo_id = config.hub.hub_model_id + pretrained_model_name_or_path = config.model.model_path + validation_prompts = config.validation.prompts + output_dir = Path(output_dir) + template_path = Path(__file__).parent.parent.parent / "templates" / "model_card.md" + + # Read the template + template = template_path.read_text() + + # Get model name from repo_id + model_name = repo_id.split("/")[-1] + + # Get base model information + base_model_link = str(pretrained_model_name_or_path) + model_path_str = str(pretrained_model_name_or_path) + is_url = model_path_str.startswith(("http://", "https://")) + + # For URLs, extract the filename from the URL. For local paths, use the filename stem + base_model_name = model_path_str.split("/")[-1] if is_url else Path(pretrained_model_name_or_path).name + + # Format validation prompts and create grid layout + prompts_text = "" + sample_grid = [] + + if validation_prompts and videos: + prompts_text = "Example prompts used during validation:\n\n" + + # Create samples directory + samples_dir = output_dir / "samples" + samples_dir.mkdir(exist_ok=True, parents=True) + + # Process videos and create cells + cells = [] + for i, (prompt, video) in enumerate(zip(validation_prompts, videos, strict=False)): + if video.exists(): + # Add prompt to text section + prompts_text += f"- `{prompt}`\n" + + # Convert video to GIF + gif_path = samples_dir / f"sample_{i}.gif" + try: + convert_video_to_gif(video, gif_path) + + # Create grid cell with collapsible description + cell = ( + f"![example{i + 1}](./samples/sample_{i}.gif)" + "
" + '
' + f"Prompt" + f"{prompt}" + "
" + ) + cells.append(cell) + except Exception as e: + logger.error(f"Failed to process video {video}: {e}") + + # Calculate optimal grid dimensions + num_cells = len(cells) + if num_cells > 0: + # Aim for a roughly square grid, with max 4 columns + num_cols = min(4, num_cells) + num_rows = (num_cells + num_cols - 1) // num_cols # Ceiling division + + # Create grid rows + for row in range(num_rows): + start_idx = row * num_cols + end_idx = min(start_idx + num_cols, num_cells) + row_cells = cells[start_idx:end_idx] + # Properly format the row with table markers and exact number of cells + formatted_row = "| " + " | ".join(row_cells) + " |" + sample_grid.append(formatted_row) + + # Join grid rows with just the content, no headers needed + grid_text = "\n".join(sample_grid) if sample_grid else "" + + # Fill in the template + model_card_content = template.format( + base_model=base_model_name, + base_model_link=base_model_link, + model_name=model_name, + training_type="LoRA fine-tuning" if config.model.training_mode == "lora" else "Full model fine-tuning", + training_steps=config.optimization.steps, + learning_rate=config.optimization.learning_rate, + batch_size=config.optimization.batch_size, + validation_prompts=prompts_text, + sample_grid=grid_text, + ) + + # Save the model card directly + model_card_path = output_dir / "README.md" + model_card_path.write_text(model_card_content) + + return model_card_path diff --git a/packages/ltx-trainer/src/ltx_trainer/model_loader.py b/packages/ltx-trainer/src/ltx_trainer/model_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..f07e119387341414cc8ba395c3a6f694e4068ee0 --- /dev/null +++ b/packages/ltx-trainer/src/ltx_trainer/model_loader.py @@ -0,0 +1,371 @@ +# ruff: noqa: PLC0415 + +""" +Model loader for LTX-2 trainer using the new ltx-core package. + +This module provides a unified interface for loading LTX-2 model components +for training, using SingleGPUModelBuilder from ltx-core. + +Example usage: + # Load individual components + vae_encoder = load_video_vae_encoder("/path/to/checkpoint.safetensors", device="cuda") + vae_decoder = load_video_vae_decoder("/path/to/checkpoint.safetensors", device="cuda") + text_encoder = load_text_encoder("/path/to/checkpoint.safetensors", "/path/to/gemma", device="cuda") + + # Load all components at once + components = load_model("/path/to/checkpoint.safetensors", text_encoder_path="/path/to/gemma") +""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING + +import torch + +from ltx_trainer import logger + +# Type alias for device specification +Device = str | torch.device + +# Type checking imports (not loaded at runtime) +if TYPE_CHECKING: + from ltx_core.model.audio_vae.audio_vae import Decoder as AudioVAEDecoder + from ltx_core.model.audio_vae.audio_vae import Encoder as AudioVAEEncoder + from ltx_core.model.audio_vae.vocoder import Vocoder + from ltx_core.model.clip.gemma.encoders.av_encoder import AVGemmaTextEncoderModel + from ltx_core.model.transformer.model import LTXModel + from ltx_core.model.video_vae.video_vae import Decoder as VideoVAEDecoder + from ltx_core.model.video_vae.video_vae import Encoder as VideoVAEEncoder + from ltx_core.pipeline.components.schedulers import LTX2Scheduler + + +def _to_torch_device(device: Device) -> torch.device: + """Convert device specification to torch.device.""" + return torch.device(device) if isinstance(device, str) else device + + +# ============================================================================= +# Individual Component Loaders +# ============================================================================= + + +def load_transformer( + checkpoint_path: str | Path, + device: Device = "cpu", + dtype: torch.dtype = torch.bfloat16, +) -> "LTXModel": + """Load the LTX transformer model. + + Args: + checkpoint_path: Path to the safetensors checkpoint file + device: Device to load model on + dtype: Data type for model weights + + Returns: + Loaded LTXModel transformer + """ + from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder + from ltx_core.model.transformer.model_configurator import ( + LTXV_MODEL_COMFY_RENAMING_MAP, + LTXModelConfigurator, + ) + + return SingleGPUModelBuilder( + model_path=str(checkpoint_path), + model_class_configurator=LTXModelConfigurator, + model_sd_ops=LTXV_MODEL_COMFY_RENAMING_MAP, + ).build(device=_to_torch_device(device), dtype=dtype) + + +def load_video_vae_encoder( + checkpoint_path: str | Path, + device: Device = "cpu", + dtype: torch.dtype = torch.bfloat16, +) -> "VideoVAEEncoder": + """Load the video VAE encoder (for preprocessing). + + Args: + checkpoint_path: Path to the safetensors checkpoint file + device: Device to load model on + dtype: Data type for model weights + + Returns: + Loaded VideoVAEEncoder + """ + from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder + from ltx_core.model.video_vae.model_configurator import VAE_ENCODER_COMFY_KEYS_FILTER + from ltx_core.model.video_vae.model_configurator import ( + VAEEncoderConfigurator as VideoVAEEncoderConfigurator, + ) + + return SingleGPUModelBuilder( + model_path=str(checkpoint_path), + model_class_configurator=VideoVAEEncoderConfigurator, + model_sd_ops=VAE_ENCODER_COMFY_KEYS_FILTER, + ).build(device=_to_torch_device(device), dtype=dtype) + + +def load_video_vae_decoder( + checkpoint_path: str | Path, + device: Device = "cpu", + dtype: torch.dtype = torch.bfloat16, +) -> "VideoVAEDecoder": + """Load the video VAE decoder (for inference/validation). + + Args: + checkpoint_path: Path to the safetensors checkpoint file + device: Device to load model on + dtype: Data type for model weights + + Returns: + Loaded VideoVAEDecoder + """ + from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder + from ltx_core.model.video_vae.model_configurator import VAE_DECODER_COMFY_KEYS_FILTER + from ltx_core.model.video_vae.model_configurator import ( + VAEDecoderConfigurator as VideoVAEDecoderConfigurator, + ) + + return SingleGPUModelBuilder( + model_path=str(checkpoint_path), + model_class_configurator=VideoVAEDecoderConfigurator, + model_sd_ops=VAE_DECODER_COMFY_KEYS_FILTER, + ).build(device=_to_torch_device(device), dtype=dtype) + + +def load_audio_vae_encoder( + checkpoint_path: str | Path, + device: Device = "cpu", + dtype: torch.dtype = torch.bfloat16, +) -> "AudioVAEEncoder": + """Load the audio VAE encoder (for preprocessing). + + Args: + checkpoint_path: Path to the safetensors checkpoint file + device: Device to load model on + dtype: Data type for model weights (default bfloat16, but float32 recommended for quality) + + Returns: + Loaded AudioVAEEncoder + """ + from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder + from ltx_core.model.audio_vae.model_configurator import AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER + from ltx_core.model.audio_vae.model_configurator import ( + VAEEncoderConfigurator as AudioVAEEncoderConfigurator, + ) + + return SingleGPUModelBuilder( + model_path=str(checkpoint_path), + model_class_configurator=AudioVAEEncoderConfigurator, + model_sd_ops=AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER, + ).build(device=_to_torch_device(device), dtype=dtype) + + +def load_audio_vae_decoder( + checkpoint_path: str | Path, + device: Device = "cpu", + dtype: torch.dtype = torch.bfloat16, +) -> "AudioVAEDecoder": + """Load the audio VAE decoder. + + Args: + checkpoint_path: Path to the safetensors checkpoint file + device: Device to load model on + dtype: Data type for model weights + + Returns: + Loaded AudioVAEDecoder + """ + from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder + from ltx_core.model.audio_vae.model_configurator import AUDIO_VAE_DECODER_COMFY_KEYS_FILTER + from ltx_core.model.audio_vae.model_configurator import ( + VAEDecoderConfigurator as AudioVAEDecoderConfigurator, + ) + + return SingleGPUModelBuilder( + model_path=str(checkpoint_path), + model_class_configurator=AudioVAEDecoderConfigurator, + model_sd_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER, + ).build(device=_to_torch_device(device), dtype=dtype) + + +def load_vocoder( + checkpoint_path: str | Path, + device: Device = "cpu", + dtype: torch.dtype = torch.bfloat16, +) -> "Vocoder": + """Load the vocoder (for audio waveform generation). + + Args: + checkpoint_path: Path to the safetensors checkpoint file + device: Device to load model on + dtype: Data type for model weights + + Returns: + Loaded Vocoder + """ + from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder + from ltx_core.model.audio_vae.model_configurator import VOCODER_COMFY_KEYS_FILTER, VocoderConfigurator + + return SingleGPUModelBuilder( + model_path=str(checkpoint_path), + model_class_configurator=VocoderConfigurator, + model_sd_ops=VOCODER_COMFY_KEYS_FILTER, + ).build(device=_to_torch_device(device), dtype=dtype) + + +def load_text_encoder( + checkpoint_path: str | Path, + gemma_model_path: str | Path, + device: Device = "cpu", + dtype: torch.dtype = torch.bfloat16, +) -> "AVGemmaTextEncoderModel": + """Load the Gemma text encoder. + + Args: + checkpoint_path: Path to the LTX-2 safetensors checkpoint file + gemma_model_path: Path to Gemma model directory + device: Device to load model on + dtype: Data type for model weights + + Returns: + Loaded AVGemmaTextEncoderModel + """ + from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder + from ltx_core.model.clip.gemma.encoders.av_encoder import ( + AV_GEMMA_TEXT_ENCODER_KEY_OPS, + AVGemmaTextEncoderModelConfigurator, + ) + from ltx_core.model.clip.gemma.encoders.base_encoder import module_ops_from_gemma_root + + if not Path(gemma_model_path).is_dir(): + raise ValueError(f"Gemma model path is not a directory: {gemma_model_path}") + + torch_device = _to_torch_device(device) + text_encoder = SingleGPUModelBuilder( + model_path=str(checkpoint_path), + model_class_configurator=AVGemmaTextEncoderModelConfigurator, + model_sd_ops=AV_GEMMA_TEXT_ENCODER_KEY_OPS, + module_ops=module_ops_from_gemma_root(str(gemma_model_path)), + ).build(device=torch_device, dtype=dtype) + + return text_encoder + + +# ============================================================================= +# Combined Component Loader +# ============================================================================= + + +@dataclass +class LtxModelComponents: + """Container for all LTX-2 model components.""" + + transformer: "LTXModel" + video_vae_encoder: "VideoVAEEncoder | None" = None + video_vae_decoder: "VideoVAEDecoder | None" = None + audio_vae_decoder: "AudioVAEDecoder | None" = None + vocoder: "Vocoder | None" = None + text_encoder: "AVGemmaTextEncoderModel | None" = None + scheduler: "LTX2Scheduler | None" = None + + +def load_model( + checkpoint_path: str | Path, + text_encoder_path: str | Path | None = None, + device: Device = "cpu", + dtype: torch.dtype = torch.bfloat16, + with_video_vae_encoder: bool = False, + with_video_vae_decoder: bool = True, + with_audio_vae_decoder: bool = True, + with_vocoder: bool = True, + with_text_encoder: bool = True, +) -> LtxModelComponents: + """ + Load LTX-2 model components from a safetensors checkpoint. + + This is a convenience function that loads multiple components at once. + For loading individual components, use the dedicated functions: + - load_transformer() + - load_video_vae_encoder() + - load_video_vae_decoder() + - load_audio_vae_decoder() + - load_vocoder() + - load_text_encoder() + + Args: + checkpoint_path: Path to the safetensors checkpoint file + text_encoder_path: Path to Gemma model directory (required if with_text_encoder=True) + device: Device to load models on ("cuda", "cpu", etc.) + dtype: Data type for model weights + with_video_vae_encoder: Whether to load the video VAE encoder (for preprocessing) + with_video_vae_decoder: Whether to load the video VAE decoder (for inference/validation) + with_audio_vae_decoder: Whether to load the audio VAE decoder + with_vocoder: Whether to load the vocoder + with_text_encoder: Whether to load the text encoder + + Returns: + LtxModelComponents containing all loaded model components + """ + from ltx_core.pipeline.components.schedulers import LTX2Scheduler + + checkpoint_path = Path(checkpoint_path) + + # Validate checkpoint exists + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + logger.info(f"Loading LTX-2 model from {checkpoint_path}") + + torch_device = _to_torch_device(device) + + # Load transformer + logger.debug("Loading transformer...") + transformer = load_transformer(checkpoint_path, torch_device, dtype) + + # Load video VAE encoder + video_vae_encoder = None + if with_video_vae_encoder: + logger.debug("Loading video VAE encoder...") + video_vae_encoder = load_video_vae_encoder(checkpoint_path, torch_device, dtype) + + # Load video VAE decoder + video_vae_decoder = None + if with_video_vae_decoder: + logger.debug("Loading video VAE decoder...") + video_vae_decoder = load_video_vae_decoder(checkpoint_path, torch_device, dtype) + + # Load audio VAE decoder + audio_vae_decoder = None + if with_audio_vae_decoder: + logger.debug("Loading audio VAE decoder...") + audio_vae_decoder = load_audio_vae_decoder(checkpoint_path, torch_device, dtype) + + # Load vocoder + vocoder = None + if with_vocoder: + logger.debug("Loading vocoder...") + vocoder = load_vocoder(checkpoint_path, torch_device, dtype) + + # Load text encoder + text_encoder = None + if with_text_encoder: + if text_encoder_path is None: + raise ValueError("text_encoder_path must be provided when with_text_encoder=True") + logger.debug("Loading Gemma text encoder...") + text_encoder = load_text_encoder(checkpoint_path, text_encoder_path, torch_device, dtype) + + # Create scheduler (stateless, no loading needed) + scheduler = LTX2Scheduler() + + return LtxModelComponents( + transformer=transformer, + video_vae_encoder=video_vae_encoder, + video_vae_decoder=video_vae_decoder, + audio_vae_decoder=audio_vae_decoder, + vocoder=vocoder, + text_encoder=text_encoder, + scheduler=scheduler, + ) diff --git a/packages/ltx-trainer/src/ltx_trainer/progress.py b/packages/ltx-trainer/src/ltx_trainer/progress.py new file mode 100644 index 0000000000000000000000000000000000000000..cbcbaf8681cfbb04b6687bcb120b1adc352039ab --- /dev/null +++ b/packages/ltx-trainer/src/ltx_trainer/progress.py @@ -0,0 +1,249 @@ +"""Progress tracking for LTX training. + +This module provides a unified progress display for training and validation sampling, +encapsulating all Rich progress bar logic in one place. +""" + +from rich.progress import ( + BarColumn, + Progress, + TaskID, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, +) + + +class SamplingContext: + """Context for validation sampling progress tracking. + + Provides a unified progress display showing current video and denoising step. + Display format: "Sampling X/Y [████████████] step Z/W" + The progress bar shows the denoising progress for the current video. + """ + + def __init__(self, progress: Progress | None, task: TaskID | None, num_prompts: int, num_steps: int): + self._progress = progress + self._task = task + self._num_prompts = num_prompts + self._num_steps = num_steps + + def start_video(self, video_idx: int) -> None: + """Start tracking a new video (resets step progress).""" + if self._progress is None or self._task is None: + return + # Reset task for new video: completed=0, total=num_steps + self._progress.reset(self._task, total=self._num_steps) + self._progress.update( + self._task, + completed=0, + video=f"{video_idx + 1}/{self._num_prompts}", + info=f"step 0/{self._num_steps}", + ) + + def advance_step(self) -> None: + """Advance the denoising step by one.""" + if self._progress is None or self._task is None: + return + self._progress.advance(self._task) + completed = int(self._progress.tasks[self._task].completed) + self._progress.update(self._task, info=f"step {completed}/{self._num_steps}") + + def cleanup(self) -> None: + """Hide sampling task when done.""" + if self._progress is None or self._task is None: + return + self._progress.update(self._task, visible=False) + + +class StandaloneSamplingProgress: + """Standalone progress display for inference scripts. + + Unlike SamplingContext (which integrates with TrainingProgress), this class + manages its own Rich Progress instance for use in standalone inference scripts. + + Usage: + with StandaloneSamplingProgress(num_steps=30) as ctx: + for step in range(30): + # ... denoising step ... + ctx.advance_step() + """ + + def __init__(self, num_steps: int, description: str = "Generating"): + """Initialize standalone sampling progress. + + Args: + num_steps: Total number of denoising steps + description: Description to show in progress bar + """ + self._num_steps = num_steps + self._description = description + self._progress: Progress | None = None + self._task: TaskID | None = None + + def __enter__(self) -> "StandaloneSamplingProgress": + """Start the progress display.""" + self._progress = Progress( + TextColumn("[progress.description]{task.description}"), + BarColumn(bar_width=40, style="blue"), + TextColumn("{task.fields[info]}", style="cyan"), + TimeElapsedColumn(), + TextColumn("ETA:"), + TimeRemainingColumn(compact=True), + ) + self._progress.__enter__() + self._task = self._progress.add_task( + self._description, + total=self._num_steps, + info=f"step 0/{self._num_steps}", + ) + return self + + def __exit__(self, *args) -> None: + """Stop the progress display.""" + if self._progress is not None: + self._progress.__exit__(*args) + + def advance_step(self) -> None: + """Advance the denoising step by one.""" + if self._progress is None or self._task is None: + return + self._progress.advance(self._task) + completed = int(self._progress.tasks[self._task].completed) + self._progress.update(self._task, info=f"step {completed}/{self._num_steps}") + + +class TrainingProgress: + """Manages Rich progress display for training and validation. + + This class encapsulates all progress bar logic, providing a clean interface + for the trainer to update progress without dealing with Rich internals. + + Usage: + with TrainingProgress(enabled=True, total_steps=1000) as progress: + for step in range(1000): + # ... training step ... + progress.update_training(loss=0.1, lr=1e-4, step_time=0.5) + + if should_validate: + sampling_ctx = progress.start_sampling(num_prompts=3, num_steps=30) + sampler = ValidationSampler(..., sampling_context=sampling_ctx) + for prompt_idx, prompt in enumerate(prompts): + sampling_ctx.start_video(prompt_idx) + sampler.generate(...) + sampling_ctx.cleanup() + """ + + def __init__(self, enabled: bool, total_steps: int): + """Initialize progress tracking. + + Args: + enabled: Whether to display progress bars (False for non-main processes) + total_steps: Total number of training steps + """ + self._enabled = enabled + self._total_steps = total_steps + self._train_task: TaskID | None = None + + if not enabled: + self._progress = None + return + + # Single Progress instance with flexible columns + self._progress = Progress( + TextColumn("[progress.description]{task.description}"), + TextColumn("{task.fields[video]}", style="magenta"), + BarColumn(bar_width=40, style="blue"), + TextColumn("{task.fields[info]}", style="cyan"), + TimeElapsedColumn(), + TextColumn("ETA:"), + TimeRemainingColumn(compact=True), + ) + + def __enter__(self) -> "TrainingProgress": + """Enter the progress context, starting the live display.""" + if self._progress is not None: + self._progress.__enter__() + self._train_task = self._progress.add_task( + "Training", + total=self._total_steps, + video=f"0/{self._total_steps}", + info="Starting...", + ) + return self + + def __exit__(self, *args) -> None: + """Exit the progress context, stopping the live display.""" + if self._progress is not None: + self._progress.__exit__(*args) + + @property + def enabled(self) -> bool: + """Whether progress display is enabled.""" + return self._enabled + + def update_training( + self, + *, + loss: float, + lr: float, + step_time: float, + advance: bool = True, + ) -> None: + """Update the training progress display. + + Args: + loss: Current training loss + lr: Current learning rate + step_time: Time taken for this step in seconds + advance: Whether to advance the progress by one step + """ + if self._progress is None or self._train_task is None: + return + + info = f"Loss: {loss:.4f} | LR: {lr:.2e} | {step_time:.2f}s/step" + self._progress.update( + self._train_task, + advance=1 if advance else 0, + info=info, + ) + # Update step count in video column + completed = int(self._progress.tasks[self._train_task].completed) + self._progress.update(self._train_task, video=f"{completed}/{self._total_steps}") + + def start_sampling(self, num_prompts: int, num_steps: int) -> SamplingContext: + """Start validation sampling progress tracking. + + Creates a task that shows current video and denoising step progress. + Format: "Sampling X/Y [████████████] step Z/W" + + Args: + num_prompts: Number of validation prompts to sample + num_steps: Number of denoising steps per sample + + Returns: + SamplingContext for tracking progress (no-op if progress is disabled) + """ + if self._progress is None: + # Return a no-op context when progress is disabled + return SamplingContext( + progress=None, + task=None, + num_prompts=num_prompts, + num_steps=num_steps, + ) + + task = self._progress.add_task( + "Sampling", + total=num_steps, + completed=0, + video=f"0/{num_prompts}", + info=f"step 0/{num_steps}", + ) + + return SamplingContext( + progress=self._progress, + task=task, + num_prompts=num_prompts, + num_steps=num_steps, + ) diff --git a/packages/ltx-trainer/src/ltx_trainer/quantization.py b/packages/ltx-trainer/src/ltx_trainer/quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..4028174eaa3b3bc25c7ca84b7445595e7da0f4be --- /dev/null +++ b/packages/ltx-trainer/src/ltx_trainer/quantization.py @@ -0,0 +1,92 @@ +# Adapted from: https://github.com/bghira/SimpleTuner/blob/main/helpers/training/quantisation/__init__.py +from typing import Literal + +import torch +from optimum.quanto import qtype + +from ltx_trainer import logger + +QuantizationOptions = Literal[ + "no_change", + "int8-quanto", + "int4-quanto", + "int2-quanto", + "fp8-quanto", + "fp8uz-quanto", +] + + +def quantize_model( + model: torch.nn.Module, + precision: QuantizationOptions, + quantize_activations: bool = False, +) -> torch.nn.Module: + """ + Quantize a model using the specified precision settings. + + Args: + model: The model to quantize. + precision: The precision level to quantize to (e.g. "int8-quanto", "fp8-quanto"). + quantize_activations: Whether to quantize activations in addition to weights. + + Returns: + The quantized model, or the original model if no quantization is performed. + """ + if precision is None or precision == "no_change": + return model + + from optimum.quanto import freeze, quantize # noqa: PLC0415 + + weight_quant = _quanto_type_map(precision) + extra_quanto_args = { + "exclude": [ + "proj_in", + "time_embed.*", + "caption_projection.*", + "rope", + "*norm*", + "proj_out", + ] + } + if quantize_activations: + logger.info("Freezing model weights and activations") + extra_quanto_args["activations"] = weight_quant + else: + logger.info("Freezing model weights only") + + quantize(model, weights=weight_quant, **extra_quanto_args) + freeze(model) + return model + + +def _quanto_type_map(precision: QuantizationOptions) -> torch.dtype | qtype | None: # noqa: PLR0911 + if precision == "no_change": + return None + + from optimum.quanto import ( # noqa: PLC0415 + qfloat8, + qfloat8_e4m3fnuz, + qint2, + qint4, + qint8, + ) + + if precision == "int2-quanto": + return qint2 + elif precision == "int4-quanto": + return qint4 + elif precision == "int8-quanto": + return qint8 + elif precision in ("fp8-quanto", "fp8uz-quanto"): + if torch.backends.mps.is_available(): + logger.warning( + "MPS doesn't support dtype float8. " + "you must select another precision level such as int2, int8, or int8.", + ) + return None + if precision == "fp8-quanto": + return qfloat8 + elif precision == "fp8uz-quanto": + return qfloat8_e4m3fnuz + + raise ValueError(f"Invalid quantisation level: {precision}") diff --git a/packages/ltx-trainer/src/ltx_trainer/timestep_samplers.py b/packages/ltx-trainer/src/ltx_trainer/timestep_samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..264e9e1770ea35abc470998f607265b75af654da --- /dev/null +++ b/packages/ltx-trainer/src/ltx_trainer/timestep_samplers.py @@ -0,0 +1,138 @@ +import torch + + +class TimestepSampler: + """Base class for timestep samplers. + + Timestep samplers are used to sample timesteps for diffusion models. + They should implement both sample() and sample_for() methods. + """ + + def sample(self, batch_size: int, seq_length: int | None = None, device: torch.device = None) -> torch.Tensor: + """Sample timesteps for a batch. + + Args: + batch_size: Number of timesteps to sample + seq_length: (optional) Length of the sequence being processed + device: Device to place the samples on + + Returns: + Tensor of shape (batch_size,) containing timesteps + """ + raise NotImplementedError + + def sample_for(self, batch: torch.Tensor) -> torch.Tensor: + """Sample timesteps for a specific batch tensor. + + Args: + batch: Input tensor of shape (batch_size, seq_length, ...) + + Returns: + Tensor of shape (batch_size,) containing timesteps + """ + raise NotImplementedError + + +class UniformTimestepSampler(TimestepSampler): + """Samples timesteps uniformly between min_value and max_value (default 0 and 1).""" + + def __init__(self, min_value: float = 0.0, max_value: float = 1.0): + self.min_value = min_value + self.max_value = max_value + + def sample(self, batch_size: int, seq_length: int | None = None, device: torch.device = None) -> torch.Tensor: # noqa: ARG002 + return torch.rand(batch_size, device=device) * (self.max_value - self.min_value) + self.min_value + + def sample_for(self, batch: torch.Tensor) -> torch.Tensor: + if batch.ndim != 3: + raise ValueError(f"Batch should have 3 dimensions, got {batch.ndim}") + + return self.sample(batch.shape[0], device=batch.device) + + +class ShiftedLogitNormalTimestepSampler: + """ + Samples timesteps from a shifted logit-normal distribution, + where the shift is determined by the sequence length. + """ + + def __init__(self, std: float = 1.0): + self.std = std + + def sample(self, batch_size: int, seq_length: int, device: torch.device = None) -> torch.Tensor: + """Sample timesteps for a batch from a shifted logit-normal distribution. + + Args: + batch_size: Number of timesteps to sample + seq_length: Length of the sequence being processed, used to determine the shift + device: Device to place the samples on + + Returns: + Tensor of shape (batch_size,) containing timesteps sampled from a shifted + logit-normal distribution, where the shift is determined by seq_length + """ + shift = self._get_shift_for_sequence_length(seq_length) + normal_samples = torch.randn((batch_size,), device=device) * self.std + shift + timesteps = torch.sigmoid(normal_samples) + return timesteps + + def sample_for(self, batch: torch.Tensor) -> torch.Tensor: + """Sample timesteps for a specific batch tensor. + + Args: + batch: Input tensor of shape (batch_size, seq_length, ...) + + Returns: + Tensor of shape (batch_size,) containing timesteps sampled from a shifted + logit-normal distribution, where the shift is determined by the sequence length + of the input batch + + Raises: + ValueError: If the input batch does not have 3 dimensions + """ + if batch.ndim != 3: + raise ValueError(f"Batch should have 3 dimensions, got {batch.ndim}") + + batch_size, seq_length, _ = batch.shape + return self.sample(batch_size, seq_length, device=batch.device) + + @staticmethod + def _get_shift_for_sequence_length( + seq_length: int, + min_tokens: int = 1024, + max_tokens: int = 4096, + min_shift: float = 0.95, + max_shift: float = 2.05, + ) -> float: + # Calculate the shift value for a given sequence length using linear interpolation + # between min_shift and max_shift based on sequence length. + m = (max_shift - min_shift) / (max_tokens - min_tokens) # Calculate slope + b = min_shift - m * min_tokens # Calculate y-intercept + shift = m * seq_length + b # Apply linear equation y = mx + b + return shift + + +SAMPLERS = { + "uniform": UniformTimestepSampler, + "shifted_logit_normal": ShiftedLogitNormalTimestepSampler, +} + + +def example() -> None: + # noinspection PyUnresolvedReferences + import matplotlib.pyplot as plt # noqa: PLC0415 + + sampler = ShiftedLogitNormalTimestepSampler() + for seq_length in [1024, 2048, 4096, 8192]: + samples = sampler.sample(batch_size=1_000_000, seq_length=seq_length) + + # plot the histogram of the samples + plt.hist(samples.numpy(), bins=100, density=True) + plt.title(f"Timestep Samples for Sequence Length {seq_length}") + plt.xlabel("Timestep") + plt.ylabel("Density") + plt.show() + + +if __name__ == "__main__": + example() diff --git a/packages/ltx-trainer/src/ltx_trainer/trainer.py b/packages/ltx-trainer/src/ltx_trainer/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..fc8fb709cb4daefc9fc52ed22bf4cef7127047eb --- /dev/null +++ b/packages/ltx-trainer/src/ltx_trainer/trainer.py @@ -0,0 +1,958 @@ +import os +import time +import warnings +from pathlib import Path +from typing import Callable + +import torch +import wandb +import yaml +from accelerate import Accelerator, DistributedType +from accelerate.utils import set_seed +from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, set_peft_model_state_dict +from peft.tuners.tuners_utils import BaseTunerLayer +from peft.utils import ModulesToSaveWrapper +from pydantic import BaseModel +from safetensors.torch import load_file, save_file +from torch import Tensor +from torch.optim import AdamW +from torch.optim.lr_scheduler import ( + CosineAnnealingLR, + CosineAnnealingWarmRestarts, + LinearLR, + LRScheduler, + PolynomialLR, + StepLR, +) +from torch.utils.data import DataLoader +from torchvision.transforms import functional as F # noqa: N812 + +from ltx_trainer import logger +from ltx_trainer.config import LtxTrainerConfig +from ltx_trainer.config_display import print_config +from ltx_trainer.datasets import PrecomputedDataset +from ltx_trainer.hf_hub_utils import push_to_hub +from ltx_trainer.model_loader import load_model as load_ltx_model +from ltx_trainer.model_loader import load_text_encoder +from ltx_trainer.progress import TrainingProgress +from ltx_trainer.quantization import quantize_model +from ltx_trainer.timestep_samplers import SAMPLERS +from ltx_trainer.training_strategies import get_training_strategy +from ltx_trainer.utils import get_gpu_memory_gb, open_image_as_srgb +from ltx_trainer.validation_sampler import CachedPromptEmbeddings, GenerationConfig, ValidationSampler +from ltx_trainer.video_utils import read_video, save_video + +# Disable irrelevant warnings from transformers +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +# Silence bitsandbytes warnings about casting +warnings.filterwarnings( + "ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization" +) + +# Disable progress bars if not main process +IS_MAIN_PROCESS = os.environ.get("LOCAL_RANK", "0") == "0" +if not IS_MAIN_PROCESS: + from transformers.utils.logging import disable_progress_bar + + disable_progress_bar() + +StepCallback = Callable[[int, int, list[Path]], None] # (step, total, list[sampled_video_path]) -> None + +MEMORY_CHECK_INTERVAL = 200 + + +class TrainingStats(BaseModel): + """Statistics collected during training""" + + total_time_seconds: float + steps_per_second: float + samples_per_second: float + peak_gpu_memory_gb: float + global_batch_size: int + num_processes: int + + +class LtxvTrainer: + def __init__(self, trainer_config: LtxTrainerConfig) -> None: + self._config = trainer_config + if IS_MAIN_PROCESS: + print_config(trainer_config) + self._training_strategy = get_training_strategy(self._config.training_strategy) + self._cached_validation_embeddings = self._load_text_encoder_and_cache_embeddings() + self._load_models() + self._setup_accelerator() + self._collect_trainable_params() + self._load_checkpoint() + self._prepare_models_for_training() + self._dataset = None + self._global_step = -1 + self._checkpoint_paths = [] + self._init_wandb() + + def train( # noqa: PLR0912, PLR0915 + self, + disable_progress_bars: bool = False, + step_callback: StepCallback | None = None, + ) -> tuple[Path, TrainingStats]: + """ + Start the training process. + Returns: + Tuple of (saved_model_path, training_stats) + """ + device = self._accelerator.device + cfg = self._config + start_mem = get_gpu_memory_gb(device) + + train_start_time = time.time() + + # Use the same seed for all processes and ensure deterministic operations + set_seed(cfg.seed) + logger.debug(f"Process {self._accelerator.process_index} using seed: {cfg.seed}") + + self._init_optimizer() + self._init_dataloader() + data_iter = iter(self._dataloader) + self._init_timestep_sampler() + + # Synchronize all processes after initialization + self._accelerator.wait_for_everyone() + + Path(cfg.output_dir).mkdir(parents=True, exist_ok=True) + + # Save the training configuration as YAML + self._save_config() + + logger.info("🚀 Starting training...") + + # Create progress tracking (disabled for non-main processes or when explicitly disabled) + progress_enabled = IS_MAIN_PROCESS and not disable_progress_bars + progress = TrainingProgress( + enabled=progress_enabled, + total_steps=cfg.optimization.steps, + ) + + if IS_MAIN_PROCESS and disable_progress_bars: + logger.warning("Progress bars disabled. Intermediate status messages will be logged instead.") + + self._transformer.train() + self._global_step = 0 + + peak_mem_during_training = start_mem + + sampled_videos_paths = None + + with progress: + # Initial validation before training starts + if cfg.validation.interval and not cfg.validation.skip_initial_validation: + sampled_videos_paths = self._sample_videos(progress) + if IS_MAIN_PROCESS and sampled_videos_paths and self._config.wandb.log_validation_videos: + self._log_validation_videos(sampled_videos_paths, cfg.validation.prompts) + + self._accelerator.wait_for_everyone() + + for step in range(cfg.optimization.steps * cfg.optimization.gradient_accumulation_steps): + # Get next batch, reset the dataloader if needed + try: + batch = next(data_iter) + except StopIteration: + data_iter = iter(self._dataloader) + batch = next(data_iter) + + step_start_time = time.time() + with self._accelerator.accumulate(self._transformer): + is_optimization_step = (step + 1) % cfg.optimization.gradient_accumulation_steps == 0 + if is_optimization_step: + self._global_step += 1 + + loss = self._training_step(batch) + self._accelerator.backward(loss) + + if self._accelerator.sync_gradients and cfg.optimization.max_grad_norm > 0: + self._accelerator.clip_grad_norm_( + self._trainable_params, + cfg.optimization.max_grad_norm, + ) + + self._optimizer.step() + self._optimizer.zero_grad() + + if self._lr_scheduler is not None: + self._lr_scheduler.step() + + # Run validation if needed + if ( + cfg.validation.interval + and self._global_step > 0 + and self._global_step % cfg.validation.interval == 0 + and is_optimization_step + ): + if self._accelerator.distributed_type == DistributedType.FSDP: + # FSDP: All processes must participate in validation + sampled_videos_paths = self._sample_videos(progress) + if IS_MAIN_PROCESS and sampled_videos_paths and self._config.wandb.log_validation_videos: + self._log_validation_videos(sampled_videos_paths, cfg.validation.prompts) + # DDP: Only main process runs validation + elif IS_MAIN_PROCESS: + sampled_videos_paths = self._sample_videos(progress) + if sampled_videos_paths and self._config.wandb.log_validation_videos: + self._log_validation_videos(sampled_videos_paths, cfg.validation.prompts) + + # Save checkpoint if needed + if ( + cfg.checkpoints.interval + and self._global_step > 0 + and self._global_step % cfg.checkpoints.interval == 0 + and is_optimization_step + ): + self._save_checkpoint() + + self._accelerator.wait_for_everyone() + + # Call step callback if provided + if step_callback and is_optimization_step: + step_callback(self._global_step, cfg.optimization.steps, sampled_videos_paths) + + self._accelerator.wait_for_everyone() + + # Update progress and log metrics + current_lr = self._optimizer.param_groups[0]["lr"] + step_time = (time.time() - step_start_time) * cfg.optimization.gradient_accumulation_steps + + progress.update_training( + loss=loss.item(), + lr=current_lr, + step_time=step_time, + advance=is_optimization_step, + ) + + # Log metrics to W&B (only on main process and optimization steps) + if IS_MAIN_PROCESS and is_optimization_step: + self._log_metrics( + { + "train/loss": loss.item(), + "train/learning_rate": current_lr, + "train/step_time": step_time, + "train/global_step": self._global_step, + } + ) + + # Fallback logging when progress bars are disabled + if disable_progress_bars and IS_MAIN_PROCESS and self._global_step % 20 == 0: + elapsed = time.time() - train_start_time + progress_percentage = self._global_step / cfg.optimization.steps + if progress_percentage > 0: + total_estimated = elapsed / progress_percentage + total_time = f"{total_estimated // 3600:.0f}h {(total_estimated % 3600) // 60:.0f}m" + else: + total_time = "calculating..." + logger.info( + f"Step {self._global_step}/{cfg.optimization.steps} - " + f"Loss: {loss.item():.4f}, LR: {current_lr:.2e}, " + f"Time/Step: {step_time:.2f}s, Total Time: {total_time}", + ) + + # Sample GPU memory periodically + if step % MEMORY_CHECK_INTERVAL == 0: + current_mem = get_gpu_memory_gb(device) + peak_mem_during_training = max(peak_mem_during_training, current_mem) + + # Collect final stats + train_end_time = time.time() + end_mem = get_gpu_memory_gb(device) + peak_mem = max(start_mem, end_mem, peak_mem_during_training) + + # Calculate steps/second over entire training + total_time_seconds = train_end_time - train_start_time + steps_per_second = cfg.optimization.steps / total_time_seconds + + samples_per_second = steps_per_second * self._accelerator.num_processes * cfg.optimization.batch_size + + stats = TrainingStats( + total_time_seconds=total_time_seconds, + steps_per_second=steps_per_second, + samples_per_second=samples_per_second, + peak_gpu_memory_gb=peak_mem, + num_processes=self._accelerator.num_processes, + global_batch_size=cfg.optimization.batch_size * self._accelerator.num_processes, + ) + + saved_path = self._save_checkpoint() + + if IS_MAIN_PROCESS: + # Log the training statistics + self._log_training_stats(stats) + + # Upload artifacts to hub if enabled + if cfg.hub.push_to_hub: + push_to_hub(saved_path, sampled_videos_paths, self._config) + + # Log final stats to W&B + if self._wandb_run is not None: + self._log_metrics( + { + "stats/total_time_minutes": stats.total_time_seconds / 60, + "stats/steps_per_second": stats.steps_per_second, + "stats/samples_per_second": stats.samples_per_second, + "stats/peak_gpu_memory_gb": stats.peak_gpu_memory_gb, + } + ) + self._wandb_run.finish() + + self._accelerator.wait_for_everyone() + self._accelerator.end_training() + + return saved_path, stats + + def _training_step(self, batch: dict[str, dict[str, Tensor]]) -> Tensor: + """Perform a single training step using the configured strategy.""" + # Apply embedding connectors to transform pre-computed text embeddings + conditions = batch["conditions"] + video_embeds, audio_embeds, attention_mask = self._text_encoder._run_connectors( + conditions["prompt_embeds"], conditions["prompt_attention_mask"] + ) + conditions["video_prompt_embeds"] = video_embeds + conditions["audio_prompt_embeds"] = audio_embeds + conditions["prompt_attention_mask"] = attention_mask + + # Use strategy to prepare training inputs (returns ModelInputs with Modality objects) + model_inputs = self._training_strategy.prepare_training_inputs(batch, self._timestep_sampler) + + # Run transformer forward pass with Modality-based interface + video_pred, audio_pred = self._transformer( + video=model_inputs.video, + audio=model_inputs.audio, + perturbations=None, + ) + + # Use strategy to compute loss + loss = self._training_strategy.compute_loss(video_pred, audio_pred, model_inputs) + + return loss + + def _load_text_encoder_and_cache_embeddings(self) -> list[CachedPromptEmbeddings] | None: + """Load text encoder, computes and returns validation embeddings.""" + + # This method: + # 1. Loads the text encoder on GPU + # 2. If validation prompts are configured, computes and caches their embeddings + # 3. Unloads the heavy Gemma model while keeping the lightweight embedding connectors + # The text encoder is kept (as self._text_encoder) but with model/tokenizer/feature_extractor + # set to None. Only the embedding connectors remain for use during training. + + # Load text encoder on GPU + logger.debug("Loading text encoder...") + if self._config.acceleration.load_text_encoder_in_8bit: + logger.warning( + "⚠️ load_text_encoder_in_8bit is set to True but 8-bit text encoder loading " + "is not currently implemented. The text encoder will be loaded in bfloat16 precision." + ) + + self._text_encoder = load_text_encoder( + checkpoint_path=self._config.model.model_path, + gemma_model_path=self._config.model.text_encoder_path, + device="cuda", + dtype=torch.bfloat16, + ) + + # Cache validation embeddings if prompts are configured + cached_embeddings = None + if self._config.validation.prompts: + logger.info(f"Pre-computing embeddings for {len(self._config.validation.prompts)} validation prompts...") + cached_embeddings = [] + with torch.inference_mode(): + for prompt in self._config.validation.prompts: + v_ctx_pos, a_ctx_pos, _ = self._text_encoder(prompt) + v_ctx_neg, a_ctx_neg, _ = self._text_encoder(self._config.validation.negative_prompt) + + cached_embeddings.append( + CachedPromptEmbeddings( + video_context_positive=v_ctx_pos.cpu(), + audio_context_positive=a_ctx_pos.cpu(), + video_context_negative=v_ctx_neg.cpu() if v_ctx_neg is not None else None, + audio_context_negative=a_ctx_neg.cpu() if a_ctx_neg is not None else None, + ) + ) + + # Unload heavy components to free VRAM, keeping only the embedding connectors + self._text_encoder.model = None + self._text_encoder.tokenizer = None + self._text_encoder.feature_extractor_linear = None + torch.cuda.empty_cache() + + logger.debug("Validation prompt embeddings cached. Gemma model unloaded") + return cached_embeddings + + def _load_models(self) -> None: + """Load the LTX-2 model components.""" + # Load audio components if: + # 1. Training strategy requires audio (training the audio branch), OR + # 2. Validation is configured to generate audio (even if not training audio) + load_audio = self._training_strategy.requires_audio or self._config.validation.generate_audio + + # Check if we need VAE encoder (for image or reference video conditioning) + need_vae_encoder = ( + self._config.validation.images is not None or self._config.validation.reference_videos is not None + ) + + # Load all model components (except text encoder - already handled) + components = load_ltx_model( + checkpoint_path=self._config.model.model_path, + device="cpu", + dtype=torch.bfloat16, + with_video_vae_encoder=need_vae_encoder, # Needed for image conditioning + with_video_vae_decoder=True, # Needed for validation sampling + with_audio_vae_decoder=load_audio, + with_vocoder=load_audio, + with_text_encoder=False, # Text encoder handled separately + ) + + # Extract components + self._transformer = components.transformer + self._vae_decoder = components.video_vae_decoder.to(dtype=torch.bfloat16) + self._vae_encoder = components.video_vae_encoder + if self._vae_encoder is not None: + self._vae_encoder = self._vae_encoder.to(dtype=torch.bfloat16) + self._scheduler = components.scheduler + self._audio_vae = components.audio_vae_decoder + self._vocoder = components.vocoder + # Note: self._text_encoder was set in _load_text_encoder_and_cache_embeddings + + # Determine initial dtype based on training mode. + # Note: For FSDP + LoRA, we'll cast to FP32 later in _prepare_models_for_training() + # after the accelerator is set up, and we can detect FSDP. + transformer_dtype = torch.bfloat16 if self._config.model.training_mode == "lora" else torch.float32 + self._transformer = self._transformer.to(dtype=transformer_dtype) + + if self._config.acceleration.quantization is not None: + if self._config.model.training_mode == "full": + raise ValueError("Quantization is not supported in full training mode.") + + logger.warning(f"Quantizing model with precision: {self._config.acceleration.quantization}") + self._transformer = quantize_model( + self._transformer, + precision=self._config.acceleration.quantization, + ) + + # Freeze all models. We later unfreeze the transformer based on training mode. + # Note: embedding_connectors are already frozen (they come from the frozen text encoder) + self._vae_decoder.requires_grad_(False) + if self._vae_encoder is not None: + self._vae_encoder.requires_grad_(False) + self._transformer.requires_grad_(False) + if self._audio_vae is not None: + self._audio_vae.requires_grad_(False) + if self._vocoder is not None: + self._vocoder.requires_grad_(False) + + def _collect_trainable_params(self) -> None: + """Collect trainable parameters based on training mode.""" + if self._config.model.training_mode == "lora": + # For LoRA training, first set up LoRA layers + self._setup_lora() + elif self._config.model.training_mode == "full": + # For full training, unfreeze all transformer parameters + self._transformer.requires_grad_(True) + else: + raise ValueError(f"Unknown training mode: {self._config.model.training_mode}") + + self._trainable_params = [p for p in self._transformer.parameters() if p.requires_grad] + logger.debug(f"Trainable params count: {sum(p.numel() for p in self._trainable_params):,}") + + def _init_timestep_sampler(self) -> None: + """Initialize the timestep sampler based on the config.""" + sampler_cls = SAMPLERS[self._config.flow_matching.timestep_sampling_mode] + self._timestep_sampler = sampler_cls(**self._config.flow_matching.timestep_sampling_params) + + def _setup_lora(self) -> None: + """Configure LoRA adapters for the transformer. Only called in LoRA training mode.""" + logger.debug(f"Adding LoRA adapter with rank {self._config.lora.rank}") + lora_config = LoraConfig( + r=self._config.lora.rank, + lora_alpha=self._config.lora.alpha, + target_modules=self._config.lora.target_modules, + lora_dropout=self._config.lora.dropout, + init_lora_weights=True, + ) + # Wrap the transformer with PEFT to add LoRA layers + # noinspection PyTypeChecker + self._transformer = get_peft_model(self._transformer, lora_config) + + def _load_checkpoint(self) -> None: + """Load checkpoint if specified in config.""" + if not self._config.model.load_checkpoint: + return + + checkpoint_path = self._find_checkpoint(self._config.model.load_checkpoint) + if not checkpoint_path: + logger.warning(f"⚠️ Could not find checkpoint at {self._config.model.load_checkpoint}") + return + + logger.info(f"📥 Loading checkpoint from {checkpoint_path}") + + if self._config.model.training_mode == "full": + self._load_full_checkpoint(checkpoint_path) + else: # LoRA mode + self._load_lora_checkpoint(checkpoint_path) + + def _load_full_checkpoint(self, checkpoint_path: Path) -> None: + """Load full model checkpoint.""" + state_dict = load_file(checkpoint_path) + self._transformer.load_state_dict(state_dict, strict=True) + + logger.info("✅ Full model checkpoint loaded successfully") + + def _load_lora_checkpoint(self, checkpoint_path: Path) -> None: + """Load LoRA checkpoint with DDP/FSDP compatibility.""" + state_dict = load_file(checkpoint_path) + + # Adjust layer names to match internal format. + # (Weights are saved in ComfyUI-compatible format, with "diffusion_model." prefix) + state_dict = {k.replace("diffusion_model.", "", 1): v for k, v in state_dict.items()} + + # Load LoRA weights and verify all weights were loaded + base_model = self._transformer.get_base_model() + set_peft_model_state_dict(base_model, state_dict) + + logger.info("✅ LoRA checkpoint loaded successfully") + + def _prepare_models_for_training(self) -> None: + """Prepare models for training with Accelerate.""" + + # For FSDP + LoRA: Cast entire model to FP32. + # FSDP requires uniform dtype across all parameters in wrapped modules. + # In LoRA mode, PEFT creates LoRA params in FP32 while base model is BF16. + # We cast the base model to FP32 to match the LoRA params. + if self._accelerator.distributed_type == DistributedType.FSDP and self._config.model.training_mode == "lora": + logger.debug("FSDP: casting transformer to FP32 for uniform dtype") + self._transformer = self._transformer.to(dtype=torch.float32) + + # Enable gradient checkpointing if requested + # For PeftModel, we need to access the underlying base model + transformer = ( + self._transformer.get_base_model() if hasattr(self._transformer, "get_base_model") else self._transformer + ) + + transformer.set_gradient_checkpointing(self._config.optimization.enable_gradient_checkpointing) + + # Keep frozen models on CPU for memory efficiency + self._vae_decoder = self._vae_decoder.to("cpu") + if self._vae_encoder is not None: + self._vae_encoder = self._vae_encoder.to("cpu") + + # Embedding connectors are already on GPU from _load_text_encoder_and_cache_embeddings + + # noinspection PyTypeChecker + self._transformer = self._accelerator.prepare(self._transformer) + + # Log GPU memory usage after model preparation + vram_usage_gb = torch.cuda.memory_allocated() / 1024**3 + logger.debug(f"GPU memory usage after models preparation: {vram_usage_gb:.2f} GB") + + @staticmethod + def _find_checkpoint(checkpoint_path: str | Path) -> Path | None: + """Find the checkpoint file to load, handling both file and directory paths.""" + checkpoint_path = Path(checkpoint_path) + + if checkpoint_path.is_file(): + if not checkpoint_path.suffix == ".safetensors": + raise ValueError(f"Checkpoint file must have a .safetensors extension: {checkpoint_path}") + return checkpoint_path + + if checkpoint_path.is_dir(): + # Look for checkpoint files in the directory + checkpoints = list(checkpoint_path.rglob("*step_*.safetensors")) + + if not checkpoints: + return None + + # Sort by step number and return the latest + def _get_step_num(p: Path) -> int: + try: + return int(p.stem.split("step_")[1]) + except (IndexError, ValueError): + return -1 + + latest = max(checkpoints, key=_get_step_num) + return latest + + else: + raise ValueError(f"Invalid checkpoint path: {checkpoint_path}. Must be a file or directory.") + + def _init_dataloader(self) -> None: + """Initialize the training data loader using the strategy's data sources.""" + if self._dataset is None: + # Get data sources from the training strategy + data_sources = self._training_strategy.get_data_sources() + + self._dataset = PrecomputedDataset(self._config.data.preprocessed_data_root, data_sources=data_sources) + logger.debug(f"Loaded dataset with {len(self._dataset):,} samples from sources: {list(data_sources)}") + + num_workers = self._config.data.num_dataloader_workers + dataloader = DataLoader( + self._dataset, + batch_size=self._config.optimization.batch_size, + shuffle=True, + drop_last=True, + num_workers=num_workers, + pin_memory=num_workers > 0, + persistent_workers=num_workers > 0, + ) + + self._dataloader = self._accelerator.prepare(dataloader) + + def _init_lora_weights(self) -> None: + """Initialize LoRA weights for the transformer.""" + logger.debug("Initializing LoRA weights...") + for _, module in self._transformer.named_modules(): + if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): + module.reset_lora_parameters(adapter_name="default", init_lora_weights=True) + + def _init_optimizer(self) -> None: + """Initialize the optimizer and learning rate scheduler.""" + opt_cfg = self._config.optimization + + lr = opt_cfg.learning_rate + if opt_cfg.optimizer_type == "adamw": + optimizer = AdamW(self._trainable_params, lr=lr) + elif opt_cfg.optimizer_type == "adamw8bit": + # noinspection PyUnresolvedReferences + from bitsandbytes.optim import AdamW8bit # noqa: PLC0415 + + optimizer = AdamW8bit(self._trainable_params, lr=lr) + else: + raise ValueError(f"Unknown optimizer type: {opt_cfg.optimizer_type}") + + # Add scheduler initialization + lr_scheduler = self._create_scheduler(optimizer) + + # noinspection PyTypeChecker + self._optimizer, self._lr_scheduler = self._accelerator.prepare(optimizer, lr_scheduler) + + def _create_scheduler(self, optimizer: torch.optim.Optimizer) -> LRScheduler | None: + """Create learning rate scheduler based on config.""" + scheduler_type = self._config.optimization.scheduler_type + steps = self._config.optimization.steps + params = self._config.optimization.scheduler_params or {} + + if scheduler_type is None: + return None + + if scheduler_type == "linear": + scheduler = LinearLR( + optimizer, + start_factor=params.pop("start_factor", 1.0), + end_factor=params.pop("end_factor", 0.1), + total_iters=steps, + **params, + ) + elif scheduler_type == "cosine": + scheduler = CosineAnnealingLR( + optimizer, + T_max=steps, + eta_min=params.pop("eta_min", 0), + **params, + ) + elif scheduler_type == "cosine_with_restarts": + scheduler = CosineAnnealingWarmRestarts( + optimizer, + T_0=params.pop("T_0", steps // 4), # First restart cycle length + T_mult=params.pop("T_mult", 1), # Multiplicative factor for cycle lengths + eta_min=params.pop("eta_min", 5e-5), + **params, + ) + elif scheduler_type == "polynomial": + scheduler = PolynomialLR( + optimizer, + total_iters=steps, + power=params.pop("power", 1.0), + **params, + ) + elif scheduler_type == "step": + scheduler = StepLR( + optimizer, + step_size=params.pop("step_size", steps // 2), + gamma=params.pop("gamma", 0.1), + **params, + ) + elif scheduler_type == "constant": + scheduler = None + else: + raise ValueError(f"Unknown scheduler type: {scheduler_type}") + + return scheduler + + def _setup_accelerator(self) -> None: + """Initialize the Accelerator with the appropriate settings.""" + + # All distributed setup (DDP/FSDP, number of processes, etc.) is controlled by + # the user's Accelerate configuration (accelerate config / accelerate launch). + self._accelerator = Accelerator( + mixed_precision=self._config.acceleration.mixed_precision_mode, + gradient_accumulation_steps=self._config.optimization.gradient_accumulation_steps, + ) + + if self._accelerator.num_processes > 1: + logger.info( + f"{self._accelerator.distributed_type.value} distributed training enabled " + f"with {self._accelerator.num_processes} processes" + ) + + local_batch = self._config.optimization.batch_size + global_batch = self._config.optimization.batch_size * self._accelerator.num_processes + logger.info(f"Local batch size: {local_batch}, global batch size: {global_batch}") + + # Log torch.compile status from Accelerate's dynamo plugin + is_compile_enabled = ( + hasattr(self._accelerator.state, "dynamo_plugin") and self._accelerator.state.dynamo_plugin.backend != "NO" + ) + if is_compile_enabled: + plugin = self._accelerator.state.dynamo_plugin + logger.info(f"🔥 torch.compile enabled via Accelerate: backend={plugin.backend}, mode={plugin.mode}") + + if self._accelerator.distributed_type == DistributedType.FSDP: + logger.warning( + "⚠️ FSDP + torch.compile is experimental and may hang on the first training iteration. " + "If this occurs, disable torch.compile by removing dynamo_config from your Accelerate config." + ) + + if self._accelerator.distributed_type == DistributedType.FSDP and self._config.acceleration.quantization: + logger.warning( + f"FSDP with quantization ({self._config.acceleration.quantization}) may have compatibility issues." + "Monitor training stability and consider disabling quantization if issues arise." + ) + + # Note: Use @torch.no_grad() instead of @torch.inference_mode() to avoid FSDP inplace update errors after validation + @torch.no_grad() + def _sample_videos(self, progress: TrainingProgress) -> list[Path] | None: + """Run validation by generating videos from validation prompts.""" + use_images = self._config.validation.images is not None + use_reference_videos = self._config.validation.reference_videos is not None + generate_audio = self._config.validation.generate_audio + inference_steps = self._config.validation.inference_steps + + # Free up GPU memory before validation sampling. + # Zero gradients and empty the cache to reclaim memory. + self._optimizer.zero_grad(set_to_none=True) + torch.cuda.empty_cache() + + # Start sampling progress tracking + sampling_ctx = progress.start_sampling( + num_prompts=len(self._config.validation.prompts), + num_steps=inference_steps, + ) + + # Create validation sampler with loaded models and progress tracking + sampler = ValidationSampler( + transformer=self._transformer, + vae_decoder=self._vae_decoder, + vae_encoder=self._vae_encoder, + text_encoder=None, + audio_decoder=self._audio_vae if generate_audio else None, + vocoder=self._vocoder if generate_audio else None, + sampling_context=sampling_ctx, + ) + + output_dir = Path(self._config.output_dir) / "samples" + output_dir.mkdir(exist_ok=True, parents=True) + + video_paths = [] + width, height, num_frames = self._config.validation.video_dims + + for prompt_idx, prompt in enumerate(self._config.validation.prompts): + # Update progress to show current video + sampling_ctx.start_video(prompt_idx) + + # Load conditioning image if provided + condition_image = None + if use_images: + image_path = self._config.validation.images[prompt_idx] + image = open_image_as_srgb(image_path) + # Convert PIL image to tensor [C, H, W] in [0, 1] + condition_image = F.to_tensor(image) + + # Load reference video if provided (for IC-LoRA) + reference_video = None + if use_reference_videos: + ref_video_path = self._config.validation.reference_videos[prompt_idx] + # read_video returns [F, C, H, W] in [0, 1] + reference_video, _ = read_video(ref_video_path, max_frames=num_frames) + + # Get cached embeddings for this prompt if available + cached_embeddings = ( + self._cached_validation_embeddings[prompt_idx] + if self._cached_validation_embeddings is not None + else None + ) + + # Create generation config + gen_config = GenerationConfig( + prompt=prompt, + negative_prompt=self._config.validation.negative_prompt, + height=height, + width=width, + num_frames=num_frames, + frame_rate=self._config.validation.frame_rate, + num_inference_steps=inference_steps, + guidance_scale=self._config.validation.guidance_scale, + seed=self._config.validation.seed, + condition_image=condition_image, + reference_video=reference_video, + generate_audio=generate_audio, + include_reference_in_output=self._config.validation.include_reference_in_output, + cached_embeddings=cached_embeddings, + stg_scale=self._config.validation.stg_scale, + stg_blocks=self._config.validation.stg_blocks, + stg_mode=self._config.validation.stg_mode, + ) + + # Generate sample + video, audio = sampler.generate( + config=gen_config, + device=self._accelerator.device, + ) + + # Save video + if IS_MAIN_PROCESS: + video_path = output_dir / f"step_{self._global_step:06d}_{prompt_idx + 1}.mp4" + save_video( + video_tensor=video, + output_path=video_path, + fps=self._config.validation.frame_rate, + audio=audio, + audio_sample_rate=self._vocoder.output_sample_rate if audio is not None else None, + ) + video_paths.append(video_path) + + # Clean up progress tasks + sampling_ctx.cleanup() + + # Clear GPU cache after validation + torch.cuda.empty_cache() + + rel_outputs_path = output_dir.relative_to(self._config.output_dir) + logger.info(f"🎥 Validation samples for step {self._global_step} saved in {rel_outputs_path}") + return video_paths + + @staticmethod + def _log_training_stats(stats: TrainingStats) -> None: + """Log training statistics.""" + stats_str = ( + "📊 Training Statistics:\n" + f" - Total time: {stats.total_time_seconds / 60:.1f} minutes\n" + f" - Training speed: {stats.steps_per_second:.2f} steps/second\n" + f" - Samples/second: {stats.samples_per_second:.2f}\n" + f" - Peak GPU memory: {stats.peak_gpu_memory_gb:.2f} GB" + ) + if stats.num_processes > 1: + stats_str += f"\n - Number of processes: {stats.num_processes}\n" + stats_str += f" - Global batch size: {stats.global_batch_size}" + logger.info(stats_str) + + def _save_checkpoint(self) -> Path | None: + """Save the model weights.""" + is_lora = self._config.model.training_mode == "lora" + is_fsdp = self._accelerator.distributed_type == DistributedType.FSDP + + # Prepare paths + save_dir = Path(self._config.output_dir) / "checkpoints" + prefix = "lora" if is_lora else "model" + filename = f"{prefix}_weights_step_{self._global_step:05d}.safetensors" + saved_weights_path = save_dir / filename + + # Get state dict (collective operation - all processes must participate) + self._accelerator.wait_for_everyone() + full_state_dict = self._accelerator.get_state_dict(self._transformer) + + if not IS_MAIN_PROCESS: + return None + + save_dir.mkdir(exist_ok=True, parents=True) + + # For LoRA: extract only adapter weights; for full: use as-is + if is_lora: + unwrapped = self._accelerator.unwrap_model(self._transformer, keep_torch_compile=False) + # For FSDP, pass full_state_dict since model params aren't directly accessible + state_dict = get_peft_model_state_dict(unwrapped, state_dict=full_state_dict if is_fsdp else None) + + # Remove "base_model.model." prefix added by PEFT + state_dict = {k.replace("base_model.model.", "", 1): v for k, v in state_dict.items()} + + # Convert to ComfyUI-compatible format (add "diffusion_model." prefix) + state_dict = {f"diffusion_model.{k}": v for k, v in state_dict.items()} + + # Save to disk + save_file(state_dict, saved_weights_path) + else: + # Save to disk + self._accelerator.save(full_state_dict, saved_weights_path) + + rel_path = saved_weights_path.relative_to(self._config.output_dir) + logger.info(f"💾 {prefix.capitalize()} weights for step {self._global_step} saved in {rel_path}") + + # Keep track of checkpoint paths, and cleanup old checkpoints if needed + self._checkpoint_paths.append(saved_weights_path) + self._cleanup_checkpoints() + return saved_weights_path + + def _cleanup_checkpoints(self) -> None: + """Clean up old checkpoints.""" + if 0 < self._config.checkpoints.keep_last_n < len(self._checkpoint_paths): + checkpoints_to_remove = self._checkpoint_paths[: -self._config.checkpoints.keep_last_n] + for old_checkpoint in checkpoints_to_remove: + if old_checkpoint.exists(): + old_checkpoint.unlink() + logger.info(f"Removed old checkpoints: {old_checkpoint}") + # Update the list to only contain kept checkpoints + self._checkpoint_paths = self._checkpoint_paths[-self._config.checkpoints.keep_last_n :] + + def _save_config(self) -> None: + """Save the training configuration as a YAML file in the output directory.""" + if not IS_MAIN_PROCESS: + return + + config_path = Path(self._config.output_dir) / "training_config.yaml" + with open(config_path, "w") as f: + yaml.dump(self._config.model_dump(), f, default_flow_style=False, indent=2) + + logger.info(f"💾 Training configuration saved to: {config_path.relative_to(self._config.output_dir)}") + + def _init_wandb(self) -> None: + """Initialize Weights & Biases run.""" + if not self._config.wandb.enabled or not IS_MAIN_PROCESS: + self._wandb_run = None + return + + wandb_config = self._config.wandb + run = wandb.init( + project=wandb_config.project, + entity=wandb_config.entity, + name=Path(self._config.output_dir).name, + tags=wandb_config.tags, + config=self._config.model_dump(), + ) + self._wandb_run = run + + def _log_metrics(self, metrics: dict[str, float]) -> None: + """Log metrics to Weights & Biases.""" + if self._wandb_run is not None: + self._wandb_run.log(metrics) + + def _log_validation_videos(self, video_paths: list[Path], prompts: list[str]) -> None: + """Log validation videos to Weights & Biases.""" + if not self._config.wandb.log_validation_videos or self._wandb_run is None: + return + + # Create lists of videos with their captions + validation_videos = [ + wandb.Video(str(video_path), caption=prompt, format="mp4") + for video_path, prompt in zip(video_paths, prompts, strict=False) + ] + + # Log all videos at once + self._wandb_run.log( + { + "validation_videos": validation_videos, + }, + step=self._global_step, + ) diff --git a/packages/ltx-trainer/src/ltx_trainer/training_strategies/__init__.py b/packages/ltx-trainer/src/ltx_trainer/training_strategies/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15157908305276d12217310a6b5004b18a222c51 --- /dev/null +++ b/packages/ltx-trainer/src/ltx_trainer/training_strategies/__init__.py @@ -0,0 +1,64 @@ +"""Training strategies for different conditioning modes. + +This package implements the Strategy Pattern to handle different training modes: +- Text-to-video training (standard generation, optionally with audio) +- Video-to-video training (IC-LoRA mode with reference videos) + +Each strategy encapsulates the specific logic for preparing model inputs and computing loss. +""" + +from ltx_trainer import logger +from ltx_trainer.training_strategies.base_strategy import ( + DEFAULT_FPS, + VIDEO_SCALE_FACTORS, + ModelInputs, + TrainingStrategy, + TrainingStrategyConfigBase, +) +from ltx_trainer.training_strategies.text_to_video import TextToVideoConfig, TextToVideoStrategy +from ltx_trainer.training_strategies.video_to_video import VideoToVideoConfig, VideoToVideoStrategy + +# Type alias for all strategy config types +TrainingStrategyConfig = TextToVideoConfig | VideoToVideoConfig + +__all__ = [ + "DEFAULT_FPS", + "VIDEO_SCALE_FACTORS", + "ModelInputs", + "TextToVideoConfig", + "TextToVideoStrategy", + "TrainingStrategy", + "TrainingStrategyConfig", + "TrainingStrategyConfigBase", + "VideoToVideoConfig", + "VideoToVideoStrategy", + "get_training_strategy", +] + + +def get_training_strategy(config: TrainingStrategyConfig) -> TrainingStrategy: + """Factory function to create the appropriate training strategy. + + The strategy is determined by the `name` field in the configuration. + + Args: + config: Strategy-specific configuration with a `name` field + + Returns: + The appropriate training strategy instance + + Raises: + ValueError: If strategy name is not supported + """ + + match config: + case TextToVideoConfig(): + strategy = TextToVideoStrategy(config) + case VideoToVideoConfig(): + strategy = VideoToVideoStrategy(config) + case _: + raise ValueError(f"Unknown training strategy config type: {type(config).__name__}") + + audio_mode = "(audio enabled 🔈)" if getattr(config, "with_audio", False) else "(audio disabled 🔇)" + logger.debug(f"🎯 Using {strategy.__class__.__name__} training strategy {audio_mode}") + return strategy diff --git a/packages/ltx-trainer/src/ltx_trainer/training_strategies/base_strategy.py b/packages/ltx-trainer/src/ltx_trainer/training_strategies/base_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..f42c46f1f2a70d5c29a22b7727da4c2f42cdba29 --- /dev/null +++ b/packages/ltx-trainer/src/ltx_trainer/training_strategies/base_strategy.py @@ -0,0 +1,274 @@ +"""Base class for training strategies. + +This module defines the abstract base class that all training strategies must implement, +along with the base configuration class. +""" + +import random +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Literal + +import torch +from pydantic import BaseModel, ConfigDict, Field +from torch import Tensor + +from ltx_core.model.transformer.modality import Modality +from ltx_core.pipeline.components.patchifiers import ( + AudioLatentShape, + AudioPatchifier, + VideoLatentPatchifier, + VideoLatentShape, + get_pixel_coords, +) +from ltx_trainer.timestep_samplers import TimestepSampler + +# Default frames per second for video missing in the FPS metadata +DEFAULT_FPS = 24 + +# VAE scale factors for LTX-2 +VIDEO_SCALE_FACTORS = (8, 32, 32) # (temporal, height, width) + + +class TrainingStrategyConfigBase(BaseModel): + """Base configuration class for training strategies. + + All strategy-specific configuration classes should inherit from this. + """ + + model_config = ConfigDict(extra="forbid") + + name: Literal["text_to_video", "video_to_video"] = Field( + description="Unique name identifying the training strategy type" + ) + + +@dataclass +class ModelInputs: + """Container for model inputs using the Modality-based interface.""" + + video: Modality + audio: Modality | None + + # Training targets (for loss computation) + video_targets: Tensor + audio_targets: Tensor | None + + # Masks for loss computation + video_loss_mask: Tensor # Boolean mask: True = compute loss for this token + audio_loss_mask: Tensor | None + + # Metadata needed for loss computation in some strategies + ref_seq_len: int | None = None # For IC-LoRA: length of reference sequence + + +class TrainingStrategy(ABC): + """Abstract base class for training strategies. + + Each strategy encapsulates the logic for a specific training mode, + handling input preparation and loss computation. + """ + + def __init__(self, config: TrainingStrategyConfigBase): + """Initialize strategy with configuration. + + Args: + config: Strategy-specific configuration + """ + self.config = config + self._video_patchifier = VideoLatentPatchifier(patch_size=1) + self._audio_patchifier = AudioPatchifier(patch_size=1) + + @property + def requires_audio(self) -> bool: + """Whether this training strategy requires audio components. + + Override this property in subclasses that support audio training. + The trainer uses this to determine whether to load audio VAE and vocoder. + + Returns: + True if audio components should be loaded, False otherwise. + """ + return False + + @abstractmethod + def get_data_sources(self) -> list[str] | dict[str, str]: + """Get the required data sources for this training strategy. + + Returns: + Either a list of data directory names (where output keys match directory names) + or a dictionary mapping data directory names to custom output keys for the dataset + """ + + @abstractmethod + def prepare_training_inputs( + self, + batch: dict[str, Any], + timestep_sampler: TimestepSampler, + ) -> ModelInputs: + """Prepare training inputs from a raw data batch. + + Args: + batch: Raw batch data from the dataset. Contains: + - "latents": Video latent data + - "conditions": Text embeddings with keys: + - "video_prompt_embeds": Already processed by embedding connectors + - "audio_prompt_embeds": Already processed by embedding connectors + - "prompt_attention_mask": Attention mask + - Additional keys depending on strategy (e.g., "ref_latents" for IC-LoRA) + timestep_sampler: Sampler for generating timesteps and noise + + Returns: + ModelInputs containing Modality objects and training targets + """ + + @abstractmethod + def compute_loss( + self, + video_pred: Tensor, + audio_pred: Tensor | None, + inputs: ModelInputs, + ) -> Tensor: + """Compute the training loss. + + Args: + video_pred: Video prediction from the transformer model + audio_pred: Audio prediction from the transformer model (None for video-only) + inputs: The prepared model inputs containing targets and masks + + Returns: + Scalar loss tensor + """ + + def _get_video_positions( + self, + num_frames: int, + height: int, + width: int, + batch_size: int, + fps: float, + device: torch.device, + dtype: torch.dtype, + ) -> Tensor: + """Generate video position embeddings using ltx_core's native implementation. + + Args: + num_frames: Number of latent frames + height: Latent height + width: Latent width + batch_size: Batch size + fps: Frames per second + device: Target device + dtype: Target dtype + + Returns: + Position tensor of shape [B, 3, seq_len, 2] + """ + latent_coords = self._video_patchifier.get_patch_grid_bounds( + output_shape=VideoLatentShape( + frames=num_frames, + height=height, + width=width, + batch=batch_size, + channels=128, # Video latent channels + ), + device=device, + ) + + # Convert latent coords to pixel coords with causal fix + pixel_coords = get_pixel_coords( + latent_coords=latent_coords, + scale_factors=VIDEO_SCALE_FACTORS, + causal_fix=True, + ).to(dtype) + + # Scale temporal dimension by 1/fps to get time in seconds + pixel_coords[:, 0, ...] = pixel_coords[:, 0, ...] / fps + + return pixel_coords + + def _get_audio_positions( + self, + num_time_steps: int, + batch_size: int, + device: torch.device, + dtype: torch.dtype, + ) -> Tensor: + """Generate audio position embeddings using ltx_core's native implementation. + + Args: + num_time_steps: Number of audio time steps (T, not T*mel_bins) + batch_size: Batch size + device: Target device + dtype: Target dtype + + Returns: + Position tensor of shape [B, 1, num_time_steps, 2] + + Note: + Audio latents should be in patchified format [B, T, C*F] = [B, T, 128] + where T is the number of time steps, C=8 channels, F=16 mel bins. + This matches the format produced by AudioPatchifier.patchify(). + """ + mel_bins = 16 + + latent_coords = self._audio_patchifier.get_patch_grid_bounds( + output_shape=AudioLatentShape( + frames=num_time_steps, + mel_bins=mel_bins, + batch=batch_size, + channels=8, # Audio latent channels + ), + device=device, + ) + + return latent_coords.to(dtype) + + @staticmethod + def _create_per_token_timesteps(conditioning_mask: Tensor, sampled_sigma: Tensor) -> Tensor: + """Create per-token timesteps based on conditioning mask. + + Args: + conditioning_mask: Boolean mask of shape (batch_size, sequence_length), + where True = conditioning token (timestep=0), False = target token (use sigma) + sampled_sigma: Sampled sigma values of shape (batch_size,) or (batch_size, 1, 1) + + Returns: + Timesteps tensor of shape [batch_size, sequence_length] + """ + # Expand to match conditioning mask shape [B, seq_len] + expanded_sigma = sampled_sigma.view(-1, 1).expand_as(conditioning_mask) + + # Conditioning tokens get 0, target tokens get the sampled sigma + return torch.where(conditioning_mask, torch.zeros_like(expanded_sigma), expanded_sigma) + + @staticmethod + def _create_first_frame_conditioning_mask( + batch_size: int, + sequence_length: int, + height: int, + width: int, + device: torch.device, + first_frame_conditioning_p: float = 0.0, + ) -> Tensor: + """Create conditioning mask for first frame conditioning. + + Args: + batch_size: Batch size + sequence_length: Total sequence length + height: Latent height + width: Latent width + device: Target device + first_frame_conditioning_p: Probability of conditioning on the first frame + + Returns: + Boolean mask where True indicates first frame tokens (if conditioning is enabled) + """ + conditioning_mask = torch.zeros(batch_size, sequence_length, dtype=torch.bool, device=device) + + if first_frame_conditioning_p > 0 and random.random() < first_frame_conditioning_p: + first_frame_end_idx = height * width + if first_frame_end_idx < sequence_length: + conditioning_mask[:, :first_frame_end_idx] = True + + return conditioning_mask diff --git a/packages/ltx-trainer/src/ltx_trainer/training_strategies/text_to_video.py b/packages/ltx-trainer/src/ltx_trainer/training_strategies/text_to_video.py new file mode 100644 index 0000000000000000000000000000000000000000..c9de82e09da2d92ead7cc50a699f1b13ab6f42df --- /dev/null +++ b/packages/ltx-trainer/src/ltx_trainer/training_strategies/text_to_video.py @@ -0,0 +1,294 @@ +"""Text-to-video training strategy. + +This strategy implements standard text-to-video generation training where: +- Only target latents are used (no reference videos) +- Standard noise application and loss computation +- Supports first frame conditioning +- Optionally supports joint audio-video training +""" + +from typing import Any, Literal + +import torch +from pydantic import Field +from torch import Tensor + +from ltx_core.model.transformer.modality import Modality +from ltx_trainer import logger +from ltx_trainer.timestep_samplers import TimestepSampler +from ltx_trainer.training_strategies.base_strategy import ( + DEFAULT_FPS, + ModelInputs, + TrainingStrategy, + TrainingStrategyConfigBase, +) + + +class TextToVideoConfig(TrainingStrategyConfigBase): + """Configuration for text-to-video training strategy.""" + + name: Literal["text_to_video"] = "text_to_video" + + first_frame_conditioning_p: float = Field( + default=0.1, + description="Probability of conditioning on the first frame during training", + ge=0.0, + le=1.0, + ) + + with_audio: bool = Field( + default=False, + description="Whether to include audio in training (joint audio-video generation)", + ) + + audio_latents_dir: str = Field( + default="audio_latents", + description="Directory name for audio latents when with_audio is True", + ) + + +class TextToVideoStrategy(TrainingStrategy): + """Text-to-video training strategy. + + This strategy implements regular video generation training where: + - Only target latents are used (no reference videos) + - Standard noise application and loss computation + - Supports first frame conditioning + - Optionally supports joint audio-video training when with_audio=True + """ + + config: TextToVideoConfig + + def __init__(self, config: TextToVideoConfig): + """Initialize strategy with configuration. + + Args: + config: Text-to-video configuration + """ + super().__init__(config) + + @property + def requires_audio(self) -> bool: + """Whether this training strategy requires audio components.""" + return self.config.with_audio + + def get_data_sources(self) -> list[str] | dict[str, str]: + """ + Text-to-video training requires latents and text conditions. + When with_audio is True, also requires audio latents. + """ + sources = { + "latents": "latents", + "conditions": "conditions", + } + + if self.config.with_audio: + sources[self.config.audio_latents_dir] = "audio_latents" + + return sources + + def prepare_training_inputs( + self, + batch: dict[str, Any], + timestep_sampler: TimestepSampler, + ) -> ModelInputs: + """Prepare inputs for text-to-video training.""" + # Get pre-encoded latents - dataset provides uniform non-patchified format [B, C, F, H, W] + latents = batch["latents"] + video_latents = latents["latents"] + + # Get video dimensions (assume same for all batch elements) + num_frames = latents["num_frames"][0].item() + height = latents["height"][0].item() + width = latents["width"][0].item() + + # Patchify latents: [B, C, F, H, W] -> [B, seq_len, C] + video_latents = self._video_patchifier.patchify(video_latents) + + # Handle FPS with backward compatibility + fps = latents.get("fps", None) + if fps is not None and not torch.all(fps == fps[0]): + logger.warning( + f"Different FPS values found in the batch. Found: {fps.tolist()}, using the first one: {fps[0].item()}" + ) + fps = fps[0].item() if fps is not None else DEFAULT_FPS + + # Get text embeddings (already processed by embedding connectors in trainer) + conditions = batch["conditions"] + video_prompt_embeds = conditions["video_prompt_embeds"] + audio_prompt_embeds = conditions["audio_prompt_embeds"] + prompt_attention_mask = conditions["prompt_attention_mask"] + + batch_size = video_latents.shape[0] + video_seq_len = video_latents.shape[1] + device = video_latents.device + dtype = video_latents.dtype + + # Create conditioning mask (first frame conditioning) + video_conditioning_mask = self._create_first_frame_conditioning_mask( + batch_size=batch_size, + sequence_length=video_seq_len, + height=height, + width=width, + device=device, + first_frame_conditioning_p=self.config.first_frame_conditioning_p, + ) + + # Sample noise and sigmas + sigmas = timestep_sampler.sample_for(video_latents) + video_noise = torch.randn_like(video_latents) + + # Apply noise: noisy = (1 - sigma) * clean + sigma * noise + sigmas_expanded = sigmas.view(-1, 1, 1) + noisy_video = (1 - sigmas_expanded) * video_latents + sigmas_expanded * video_noise + + # For conditioning tokens, use clean latents + conditioning_mask_expanded = video_conditioning_mask.unsqueeze(-1) + noisy_video = torch.where(conditioning_mask_expanded, video_latents, noisy_video) + + # Compute video targets (velocity prediction) + video_targets = video_noise - video_latents + + # Create per-token timesteps + video_timesteps = self._create_per_token_timesteps(video_conditioning_mask, sigmas.squeeze()) + + # Generate video positions using ltx_core's native implementation + video_positions = self._get_video_positions( + num_frames=num_frames, + height=height, + width=width, + batch_size=batch_size, + fps=fps, + device=device, + dtype=dtype, + ) + + # Create video Modality + video_modality = Modality( + enabled=True, + latent=noisy_video, + timesteps=video_timesteps, + positions=video_positions, + context=video_prompt_embeds, + context_mask=prompt_attention_mask, + ) + + # Video loss mask: True for tokens we want to compute loss on (non-conditioning tokens) + video_loss_mask = ~video_conditioning_mask + + # Handle audio if enabled + audio_modality = None + audio_targets = None + audio_loss_mask = None + + if self.config.with_audio: + audio_modality, audio_targets, audio_loss_mask = self._prepare_audio_inputs( + batch=batch, + sigmas=sigmas, + audio_prompt_embeds=audio_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + batch_size=batch_size, + device=device, + dtype=dtype, + ) + + return ModelInputs( + video=video_modality, + audio=audio_modality, + video_targets=video_targets, + audio_targets=audio_targets, + video_loss_mask=video_loss_mask, + audio_loss_mask=audio_loss_mask, + ) + + def _prepare_audio_inputs( + self, + batch: dict[str, Any], + sigmas: Tensor, + audio_prompt_embeds: Tensor, + prompt_attention_mask: Tensor, + batch_size: int, + device: torch.device, + dtype: torch.dtype, + ) -> tuple[Modality, Tensor, Tensor]: + """Prepare audio inputs for joint audio-video training. + + Args: + batch: Raw batch data containing audio_latents + sigmas: Sampled sigma values (same as video) + audio_prompt_embeds: Audio context embeddings + prompt_attention_mask: Attention mask for context + batch_size: Batch size + device: Target device + dtype: Target dtype + + Returns: + Tuple of (audio_modality, audio_targets, audio_loss_mask) + """ + # Get audio latents - dataset provides uniform non-patchified format [B, C, T, F] + audio_data = batch["audio_latents"] + audio_latents = audio_data["latents"] + + # Patchify audio latents: [B, C, T, F] -> [B, T, C*F] + audio_latents = self._audio_patchifier.patchify(audio_latents) + + audio_seq_len = audio_latents.shape[1] + + # Sample audio noise + audio_noise = torch.randn_like(audio_latents) + + # Apply noise to audio (same sigma as video) + sigmas_expanded = sigmas.view(-1, 1, 1) + noisy_audio = (1 - sigmas_expanded) * audio_latents + sigmas_expanded * audio_noise + + # Compute audio targets + audio_targets = audio_noise - audio_latents + + # Audio timesteps: all tokens use the sampled sigma (no conditioning mask) + audio_timesteps = sigmas.view(-1, 1).expand(-1, audio_seq_len) + + # Generate audio positions + audio_positions = self._get_audio_positions( + num_time_steps=audio_seq_len, + batch_size=batch_size, + device=device, + dtype=dtype, + ) + + # Create audio Modality + audio_modality = Modality( + enabled=True, + latent=noisy_audio, + timesteps=audio_timesteps, + positions=audio_positions, + context=audio_prompt_embeds, + context_mask=prompt_attention_mask, + ) + + # Audio loss mask: all tokens contribute to loss (no conditioning) + audio_loss_mask = torch.ones(batch_size, audio_seq_len, dtype=torch.bool, device=device) + + return audio_modality, audio_targets, audio_loss_mask + + def compute_loss( + self, + video_pred: Tensor, + audio_pred: Tensor | None, + inputs: ModelInputs, + ) -> Tensor: + """Compute masked MSE loss for video and optionally audio.""" + # Video loss + video_loss = (video_pred - inputs.video_targets).pow(2) + video_loss_mask = inputs.video_loss_mask.unsqueeze(-1).float() + video_loss = video_loss.mul(video_loss_mask).div(video_loss_mask.mean()) + video_loss = video_loss.mean() + + # If no audio, return video loss only + if not self.config.with_audio or audio_pred is None or inputs.audio_targets is None: + return video_loss + + # Audio loss (no conditioning mask) + audio_loss = (audio_pred - inputs.audio_targets).pow(2).mean() + + # Combined loss + return video_loss + audio_loss diff --git a/packages/ltx-trainer/src/ltx_trainer/training_strategies/video_to_video.py b/packages/ltx-trainer/src/ltx_trainer/training_strategies/video_to_video.py new file mode 100644 index 0000000000000000000000000000000000000000..1167795c018a0bbdf8723c2aa39f90204ed6c5c6 --- /dev/null +++ b/packages/ltx-trainer/src/ltx_trainer/training_strategies/video_to_video.py @@ -0,0 +1,226 @@ +"""Video-to-video training strategy for IC-LoRA. + +This strategy implements training with reference video conditioning where: +- Reference latents (clean) are concatenated with target latents (noised) +- Video coordinates handle both reference and target sequences +- Loss is computed only on the target portion +""" + +from typing import Any, Literal + +import torch +from pydantic import Field +from torch import Tensor + +from ltx_core.model.transformer.modality import Modality +from ltx_trainer import logger +from ltx_trainer.timestep_samplers import TimestepSampler +from ltx_trainer.training_strategies.base_strategy import ( + DEFAULT_FPS, + ModelInputs, + TrainingStrategy, + TrainingStrategyConfigBase, +) + + +class VideoToVideoConfig(TrainingStrategyConfigBase): + """Configuration for video-to-video (IC-LoRA) training strategy.""" + + name: Literal["video_to_video"] = "video_to_video" + + first_frame_conditioning_p: float = Field( + default=0.1, + description="Probability of conditioning on the first frame during training", + ge=0.0, + le=1.0, + ) + + reference_latents_dir: str = Field( + default="reference_latents", + description="Directory name for latents of reference videos", + ) + + +class VideoToVideoStrategy(TrainingStrategy): + """Video-to-video training strategy for IC-LoRA. + + This strategy implements training with reference video conditioning where: + - Reference latents (clean) are concatenated with target latents (noised) + - Video coordinates handle both reference and target sequences + - Loss is computed only on the target portion + """ + + config: VideoToVideoConfig + + def __init__(self, config: VideoToVideoConfig): + """Initialize strategy with configuration. + + Args: + config: Video-to-video configuration + """ + super().__init__(config) + + def get_data_sources(self) -> dict[str, str]: + """IC-LoRA training requires latents, conditions, and reference latents.""" + return { + "latents": "latents", + "conditions": "conditions", + self.config.reference_latents_dir: "ref_latents", + } + + def prepare_training_inputs( + self, + batch: dict[str, Any], + timestep_sampler: TimestepSampler, + ) -> ModelInputs: + """Prepare inputs for IC-LoRA training with reference videos.""" + # Get pre-encoded latents - dataset provides uniform non-patchified format [B, C, F, H, W] + latents = batch["latents"] + target_latents = latents["latents"] + ref_latents = batch["ref_latents"]["latents"] + + # Get dimensions + num_frames = latents["num_frames"][0].item() + height = latents["height"][0].item() + width = latents["width"][0].item() + + ref_latents_info = batch["ref_latents"] + ref_frames = ref_latents_info["num_frames"][0].item() + ref_height = ref_latents_info["height"][0].item() + ref_width = ref_latents_info["width"][0].item() + + # Patchify latents: [B, C, F, H, W] -> [B, seq_len, C] + target_latents = self._video_patchifier.patchify(target_latents) + ref_latents = self._video_patchifier.patchify(ref_latents) + + # Handle FPS + fps = latents.get("fps", None) + if fps is not None and not torch.all(fps == fps[0]): + logger.warning( + f"Different FPS values found in the batch. Found: {fps.tolist()}, using the first one: {fps[0].item()}" + ) + fps = fps[0].item() if fps is not None else DEFAULT_FPS + + # Get text embeddings (already processed by embedding connectors in trainer) + # Video-to-video uses only video embeddings + conditions = batch["conditions"] + prompt_embeds = conditions["video_prompt_embeds"] + prompt_attention_mask = conditions["prompt_attention_mask"] + + batch_size = target_latents.shape[0] + ref_seq_len = ref_latents.shape[1] + target_seq_len = target_latents.shape[1] + device = target_latents.device + dtype = target_latents.dtype + + # Create conditioning mask + # Reference tokens are always conditioning (timestep=0) + ref_conditioning_mask = torch.ones(batch_size, ref_seq_len, dtype=torch.bool, device=device) + + # Target tokens: check for first frame conditioning + target_conditioning_mask = self._create_first_frame_conditioning_mask( + batch_size=batch_size, + sequence_length=target_seq_len, + height=height, + width=width, + device=device, + first_frame_conditioning_p=self.config.first_frame_conditioning_p, + ) + + # Combined conditioning mask + conditioning_mask = torch.cat([ref_conditioning_mask, target_conditioning_mask], dim=1) + + # Sample noise and sigmas for target + sigmas = timestep_sampler.sample_for(target_latents) + noise = torch.randn_like(target_latents) + sigmas_expanded = sigmas.view(-1, 1, 1) + + # Apply noise to target + noisy_target = (1 - sigmas_expanded) * target_latents + sigmas_expanded * noise + + # For first frame conditioning in target, use clean latents + target_conditioning_mask_expanded = target_conditioning_mask.unsqueeze(-1) + noisy_target = torch.where(target_conditioning_mask_expanded, target_latents, noisy_target) + + # Targets for loss computation + targets = noise - target_latents + + # Concatenate reference (clean) and target (noisy) + combined_latents = torch.cat([ref_latents, noisy_target], dim=1) + + # Create per-token timesteps + timesteps = self._create_per_token_timesteps(conditioning_mask, sigmas.squeeze()) + + # Generate positions for reference and target separately, then concatenate + ref_positions = self._get_video_positions( + num_frames=ref_frames, + height=ref_height, + width=ref_width, + batch_size=batch_size, + fps=fps, + device=device, + dtype=dtype, + ) + + target_positions = self._get_video_positions( + num_frames=num_frames, + height=height, + width=width, + batch_size=batch_size, + fps=fps, + device=device, + dtype=dtype, + ) + + # Concatenate positions along sequence dimension + positions = torch.cat([ref_positions, target_positions], dim=2) + + # Create video Modality + video_modality = Modality( + enabled=True, + latent=combined_latents, + timesteps=timesteps, + positions=positions, + context=prompt_embeds, + context_mask=prompt_attention_mask, + ) + + # Loss mask: only compute loss on non-conditioning target tokens + # Reference tokens: all False (no loss) + # Target tokens: True where not conditioning + ref_loss_mask = torch.zeros(batch_size, ref_seq_len, dtype=torch.bool, device=device) + target_loss_mask = ~target_conditioning_mask + video_loss_mask = torch.cat([ref_loss_mask, target_loss_mask], dim=1) + + return ModelInputs( + video=video_modality, + audio=None, + video_targets=targets, + audio_targets=None, + video_loss_mask=video_loss_mask, + audio_loss_mask=None, + ref_seq_len=ref_seq_len, + ) + + def compute_loss( + self, + video_pred: Tensor, + _audio_pred: Tensor | None, + inputs: ModelInputs, + ) -> Tensor: + """Compute masked loss only on target portion.""" + # Extract target portion of prediction + ref_seq_len = inputs.ref_seq_len + target_pred = video_pred[:, ref_seq_len:, :] + + # Get target portion of loss mask + target_loss_mask = inputs.video_loss_mask[:, ref_seq_len:] + + # Compute loss + loss = (target_pred - inputs.video_targets).pow(2) + + # Apply loss mask + loss_mask = target_loss_mask.unsqueeze(-1).float() + loss = loss.mul(loss_mask).div(loss_mask.mean()) + + return loss.mean() diff --git a/packages/ltx-trainer/src/ltx_trainer/utils.py b/packages/ltx-trainer/src/ltx_trainer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c834330cea9a385987e924b36c89a377edc19994 --- /dev/null +++ b/packages/ltx-trainer/src/ltx_trainer/utils.py @@ -0,0 +1,69 @@ +import io +import subprocess +from pathlib import Path + +import torch +from PIL import ExifTags, Image, ImageCms, ImageOps +from PIL.Image import Image as PilImage + +from ltx_trainer import logger + + +def get_gpu_memory_gb(device: torch.device) -> float: + """Get current GPU memory usage in GB using nvidia-smi""" + try: + device_id = device.index if device.index is not None else 0 + result = subprocess.check_output( + [ + "nvidia-smi", + "--query-gpu=memory.used", + "--format=csv,nounits,noheader", + "-i", + str(device_id), + ], + encoding="utf-8", + ) + return float(result.strip()) / 1024 # Convert MB to GB + except (subprocess.CalledProcessError, FileNotFoundError, ValueError) as e: + logger.error(f"Failed to get GPU memory from nvidia-smi: {e}") + # Fallback to torch + return torch.cuda.memory_allocated(device) / 1024**3 + + +def open_image_as_srgb(image_path: str | Path | io.BytesIO) -> PilImage: + """ + Opens an image file, applies rotation (if it's set in metadata) and converts it + to the sRGB color space respecting the original image color space . + """ + exif_colorspace_srgb = 1 + + with Image.open(image_path) as img_raw: + img = ImageOps.exif_transpose(img_raw) + + input_icc_profile = img.info.get("icc_profile") + + # Try to convert to sRGB if the image has ICC profile metadata + srgb_profile = ImageCms.createProfile(colorSpace="sRGB") + if input_icc_profile is not None: + input_profile = ImageCms.ImageCmsProfile(io.BytesIO(input_icc_profile)) + srgb_img = ImageCms.profileToProfile(img, input_profile, srgb_profile, outputMode="RGB") + else: + # Try fall back to checking EXIF + exif_data = img.getexif() + if exif_data is not None: + # Assume sRGB if no ICC profile and EXIF has no ColorSpace tag + color_space_value = exif_data.get(ExifTags.Base.ColorSpace.value) + if color_space_value is not None and color_space_value != exif_colorspace_srgb: + raise ValueError( + "Image has colorspace tag in EXIF but it isn't set to sRGB," + " conversion is not supported." + f" EXIF ColorSpace tag value is {color_space_value}", + ) + + srgb_img = img.convert("RGB") + + # Set sRGB profile in metadata since now the image is assumed to be in sRGB. + srgb_profile_data = ImageCms.ImageCmsProfile(srgb_profile).tobytes() + srgb_img.info["icc_profile"] = srgb_profile_data + + return srgb_img diff --git a/packages/ltx-trainer/src/ltx_trainer/validation_sampler.py b/packages/ltx-trainer/src/ltx_trainer/validation_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..7278ce6503460fb04861d8c02a6890799a1c64c3 --- /dev/null +++ b/packages/ltx-trainer/src/ltx_trainer/validation_sampler.py @@ -0,0 +1,843 @@ +"""Validation sampling for LTX-2 training using ltx-core components. + +This module provides a simplified validation pipeline for generating samples during training, +using the new ltx-core components (VideoLatentTools, AudioLatentTools, LatentState, etc.). +""" + +from dataclasses import dataclass, replace +from typing import TYPE_CHECKING, Literal + +import torch +from einops import rearrange +from torch import Tensor + +from ltx_core.guidance.perturbations import ( + BatchedPerturbationConfig, + Perturbation, + PerturbationConfig, + PerturbationType, +) +from ltx_core.model.transformer.modality import Modality +from ltx_core.model.transformer.model import X0Model +from ltx_core.pipeline.components.diffusion_steps import EulerDiffusionStep +from ltx_core.pipeline.components.guiders import CFGGuider, STGGuider +from ltx_core.pipeline.components.noisers import GaussianNoiser +from ltx_core.pipeline.components.patchifiers import ( + AudioLatentShape, + AudioPatchifier, + VideoLatentPatchifier, + VideoLatentShape, + get_pixel_coords, +) +from ltx_core.pipeline.components.protocols import VideoPixelShape +from ltx_core.pipeline.components.schedulers import LTX2Scheduler +from ltx_core.pipeline.conditioning.tools import AudioLatentTools, LatentState, VideoLatentTools +from ltx_core.tiling import SpatialTilingConfig, TemporalTilingConfig, TilingConfig +from ltx_trainer.progress import SamplingContext + +if TYPE_CHECKING: + from ltx_core.model.audio_vae.audio_vae import Decoder as AudioDecoder + from ltx_core.model.audio_vae.vocoder import Vocoder + from ltx_core.model.clip.gemma.encoders.av_encoder import AVGemmaTextEncoderModel + from ltx_core.model.transformer.model import LTXModel + from ltx_core.model.video_vae.video_vae import Decoder as VideoDecoder + from ltx_core.model.video_vae.video_vae import Encoder as VideoEncoder + +# Video VAE scale factors (temporal, height, width) +VIDEO_SCALE_FACTORS = (8, 32, 32) + + +@dataclass +class CachedPromptEmbeddings: + """Pre-computed text embeddings for a validation prompt. + + These embeddings are computed once at training start and reused for all validation runs, + avoiding the need to load the full Gemma text encoder during validation. + """ + + video_context_positive: Tensor # [1, seq_len, hidden_dim] + audio_context_positive: Tensor # [1, seq_len, hidden_dim] + video_context_negative: Tensor | None = None + audio_context_negative: Tensor | None = None + + +@dataclass +class TiledDecodingConfig: + """Configuration for tiled video decoding to reduce VRAM usage. + + Tiled decoding splits the latent tensor into overlapping tiles, decodes each + tile individually, and blends them together. This significantly reduces peak + VRAM usage at the cost of slightly slower decoding. + + Defaults match the recommended values from ltx-core tests. + """ + + enabled: bool = True # Whether to use tiled decoding (enabled by default) + tile_size_pixels: int = 192 # Spatial tile size in pixels (must be ≥64 and divisible by 32) + tile_overlap_pixels: int = 64 # Spatial tile overlap in pixels (must be divisible by 32) + tile_size_frames: int = 48 # Temporal tile size in frames (must be ≥16 and divisible by 8) + tile_overlap_frames: int = 24 # Temporal tile overlap in frames (must be divisible by 8) + + +@dataclass +class GenerationConfig: + """Configuration for video/audio generation.""" + + prompt: str # Text prompt for generation + negative_prompt: str = "" # Negative prompt to avoid unwanted artifacts + height: int = 544 # Output video height in pixels + width: int = 960 # Output video width in pixels + num_frames: int = 97 # Number of frames to generate + frame_rate: float = 25.0 # Frame rate for temporal position scaling + num_inference_steps: int = 30 # Number of denoising steps + guidance_scale: float = 3.0 # CFG guidance scale + seed: int = 42 # Random seed for reproducibility + condition_image: Tensor | None = None # Optional first frame image for image-to-video + reference_video: Tensor | None = None # For IC-LoRA: [F, C, H, W] in [0, 1] + generate_audio: bool = True # Whether to generate audio alongside video + include_reference_in_output: bool = False # For IC-LoRA: concatenate original reference with generated output + cached_embeddings: CachedPromptEmbeddings | None = None # Pre-computed text embeddings (avoids loading Gemma) + stg_scale: float = 0.0 # STG strength (0.0 = disabled, recommended: 1.0) + stg_blocks: list[int] | None = None # Transformer blocks to perturb (None = all, recommended: [29]) + stg_mode: Literal["stg_av", "stg_v"] = "stg_av" # STG mode: "stg_av" (audio+video) or "stg_v" (video only) + # Tiled decoding config: None = use defaults (enabled), False = disable, or TiledDecodingConfig for custom settings + tiled_decoding: TiledDecodingConfig | Literal[False] | None = None + + def __post_init__(self) -> None: + """Apply default tiled decoding config if not provided.""" + if self.tiled_decoding is None: + # Use default config with tiling enabled + object.__setattr__(self, "tiled_decoding", TiledDecodingConfig()) + elif self.tiled_decoding is False: + # Explicitly disabled - use config with enabled=False + object.__setattr__(self, "tiled_decoding", TiledDecodingConfig(enabled=False)) + + +class ValidationSampler: + """Generates validation samples during training using ltx-core components. + + This class provides a simplified interface for generating video (and optionally audio) + samples during training validation. It supports: + - Text-to-video generation + - Image-to-video generation (first frame conditioning) + - Video-to-video generation (IC-LoRA reference video conditioning) + - Optional audio generation + + The implementation follows the patterns from ltx_pipelines.single_stage. + + Text embeddings can be provided either via: + - A full text_encoder (encodes prompts on-the-fly) + - Pre-computed cached_embeddings (avoids loading Gemma during validation) + """ + + def __init__( + self, + transformer: "LTXModel", + vae_decoder: "VideoDecoder", + vae_encoder: "VideoEncoder | None", + text_encoder: "AVGemmaTextEncoderModel | None" = None, + audio_decoder: "AudioDecoder | None" = None, + vocoder: "Vocoder | None" = None, + sampling_context: SamplingContext | None = None, + ): + """Initialize the validation sampler. + + Args: + transformer: LTX-2 transformer model + vae_decoder: Video VAE decoder + vae_encoder: Video VAE encoder (for image/video conditioning), can be None if not needed + text_encoder: Gemma text encoder with embeddings connector (optional if cached_embeddings in config) + audio_decoder: Optional audio VAE decoder (for audio generation) + vocoder: Optional vocoder (for audio generation) + sampling_context: Optional SamplingContext for progress display during denoising + """ + self._transformer = transformer + self._vae_decoder = vae_decoder + self._vae_encoder = vae_encoder + self._text_encoder = text_encoder + self._audio_decoder = audio_decoder + self._vocoder = vocoder + self._sampling_context = sampling_context + + # Patchifiers + self._video_patchifier = VideoLatentPatchifier(patch_size=1) + self._audio_patchifier = AudioPatchifier(patch_size=1) + + # Note: Use @torch.no_grad() instead of @torch.inference_mode() to avoid FSDP inplace update errors after validation + @torch.no_grad() + def generate( + self, + config: GenerationConfig, + device: torch.device | str = "cuda", + ) -> tuple[Tensor, Tensor | None]: + """Generate a video (and optionally audio) sample. + + Args: + config: Generation configuration + device: Device to run generation on + + Returns: + Tuple of: + - video: Video tensor [C, F, H, W] in [0, 1] (float32) + - audio: Audio waveform tensor [C, samples] or None + """ + device = torch.device(device) if isinstance(device, str) else device + self._validate_config(config) + + # Route to appropriate generation method + if config.reference_video is not None: + return self._generate_with_reference(config, device) + return self._generate_standard(config, device) + + def _generate_standard(self, config: GenerationConfig, device: torch.device) -> tuple[Tensor, Tensor | None]: + """Standard generation (text-to-video or image-to-video).""" + # Get prompt embeddings (from cache or encode on-the-fly) + v_ctx_pos, a_ctx_pos, v_ctx_neg, a_ctx_neg = self._get_prompt_embeddings(config, device) + + # Setup generator + generator = torch.Generator(device=device).manual_seed(config.seed) + + # Create latent tools + video_tools = self._create_video_latent_tools(config) + audio_tools = self._create_audio_latent_tools(config) if config.generate_audio else None + + # Create initial states + video_clean_state = video_tools.create_initial_state(device=device, dtype=torch.bfloat16) + audio_clean_state = ( + audio_tools.create_initial_state(device=device, dtype=torch.bfloat16) if audio_tools else None + ) + + # Apply image conditioning if provided + if config.condition_image is not None: + video_clean_state = self._apply_image_conditioning( + video_clean_state, config.condition_image, config, device + ) + + # Add noise + noiser = GaussianNoiser(generator=generator) + video_state = noiser(latent_state=video_clean_state, noise_scale=1.0) + audio_state = noiser(latent_state=audio_clean_state, noise_scale=1.0) if audio_clean_state else None + + # Run denoising loop + video_state, audio_state = self._run_denoising( + config=config, + video_state=video_state, + audio_state=audio_state, + video_clean_state=video_clean_state, + audio_clean_state=audio_clean_state, + v_ctx_pos=v_ctx_pos, + a_ctx_pos=a_ctx_pos, + v_ctx_neg=v_ctx_neg, + a_ctx_neg=a_ctx_neg, + device=device, + ) + + # Decode outputs + video_state = video_tools.clear_conditioning(video_state) + video_state = video_tools.unpatchify(video_state) + video_output = self._decode_video(video_state, device, config.tiled_decoding) + + audio_output = None + if audio_state is not None and audio_tools is not None: + audio_state = audio_tools.clear_conditioning(audio_state) + audio_state = audio_tools.unpatchify(audio_state) + audio_output = self._decode_audio(audio_state, device) + + return video_output, audio_output + + def _generate_with_reference(self, config: GenerationConfig, device: torch.device) -> tuple[Tensor, Tensor | None]: + """Generate with reference video conditioning (IC-LoRA style). + + For IC-LoRA: + - Reference video latents are concatenated with target latents + - Reference latents have timestep=0 (clean, not denoised) + - Target latents are denoised normally + - If condition_image is also provided, the first frame of the target is conditioned + - If include_reference_in_output is True, the preprocessed reference video + is concatenated side-by-side with the generated video + """ + # Get prompt embeddings (from cache or encode on-the-fly) + v_ctx_pos, a_ctx_pos, v_ctx_neg, a_ctx_neg = self._get_prompt_embeddings(config, device) + + # Setup generator + generator = torch.Generator(device=device).manual_seed(config.seed) + + # Preprocess and encode reference video + ref_video_preprocessed = self._preprocess_reference_video(config) + ref_latent, ref_positions = self._encode_video(ref_video_preprocessed, config.frame_rate, device) + ref_seq_len = ref_latent.shape[1] + + # Create target video state + video_tools = self._create_video_latent_tools(config) + target_clean_state = video_tools.create_initial_state(device=device, dtype=torch.bfloat16) + + # Apply first-frame image conditioning to target if provided + if config.condition_image is not None: + target_clean_state = self._apply_image_conditioning( + target_clean_state, config.condition_image, config, device + ) + + # Create combined state (reference + target) + # denoise_mask shape is [B, seq_len, 1] after patchification + ref_denoise_mask = torch.zeros(1, ref_seq_len, 1, device=device, dtype=torch.float32) + combined_clean_state = LatentState( + latent=torch.cat([ref_latent, target_clean_state.latent], dim=1), + denoise_mask=torch.cat([ref_denoise_mask, target_clean_state.denoise_mask], dim=1), + positions=torch.cat([ref_positions, target_clean_state.positions], dim=2), + clean_latent=torch.cat([ref_latent, target_clean_state.clean_latent], dim=1), + ) + + # Add noise (only to the target portion via denoise_mask) + noiser = GaussianNoiser(generator=generator) + combined_state = noiser(latent_state=combined_clean_state, noise_scale=1.0) + + # Create audio state if needed + audio_tools = self._create_audio_latent_tools(config) if config.generate_audio else None + audio_clean_state = ( + audio_tools.create_initial_state(device=device, dtype=torch.bfloat16) if audio_tools else None + ) + audio_state = noiser(latent_state=audio_clean_state, noise_scale=1.0) if audio_clean_state else None + + # Run denoising loop + combined_state, audio_state = self._run_denoising( + config=config, + video_state=combined_state, + audio_state=audio_state, + video_clean_state=combined_clean_state, + audio_clean_state=audio_clean_state, + v_ctx_pos=v_ctx_pos, + a_ctx_pos=a_ctx_pos, + v_ctx_neg=v_ctx_neg, + a_ctx_neg=a_ctx_neg, + device=device, + ) + + # Extract target portion and decode + target_latent = combined_state.latent[:, ref_seq_len:] + video_output = self._decode_video_latent(target_latent, config, device) + + # Optionally concatenate original reference video side-by-side + if config.include_reference_in_output: + # Use preprocessed reference (already resized/cropped, in pixel space) + # Convert from [B, C, F, H, W] to [C, F, H, W] + ref_video_pixels = ref_video_preprocessed[0].cpu() + # Normalize from [-1, 1] to [0, 1] + ref_video_pixels = ((ref_video_pixels + 1.0) / 2.0).clamp(0.0, 1.0) + video_output = self._concatenate_videos_side_by_side(ref_video_pixels, video_output) + + # Decode audio + audio_output = None + if audio_state is not None and audio_tools is not None: + audio_state = audio_tools.clear_conditioning(audio_state) + audio_state = audio_tools.unpatchify(audio_state) + audio_output = self._decode_audio(audio_state, device) + + return video_output, audio_output + + def _create_video_latent_tools(self, config: GenerationConfig) -> VideoLatentTools: + """Create video latent tools for the given configuration.""" + pixel_shape = VideoPixelShape( + batch=1, + frames=config.num_frames, + height=config.height, + width=config.width, + fps=config.frame_rate, + ) + return VideoLatentTools( + patchifier=self._video_patchifier, + target_shape=VideoLatentShape.from_pixel_shape(shape=pixel_shape), + fps=config.frame_rate, + scale_factors=VIDEO_SCALE_FACTORS, + causal_fix=True, + ) + + def _create_audio_latent_tools(self, config: GenerationConfig) -> AudioLatentTools: + """Create audio latent tools for the given configuration.""" + return AudioLatentTools( + patchifier=self._audio_patchifier, + target_shape=AudioLatentShape.from_duration(batch=1, duration=config.num_frames / config.frame_rate), + ) + + def _apply_image_conditioning( + self, video_state: LatentState, image: Tensor, config: GenerationConfig, device: torch.device + ) -> LatentState: + """Apply first-frame image conditioning to the video state.""" + # Encode the image + encoded_image = self._encode_conditioning_image(image, config.height, config.width, device) + + # Patchify the encoded image (single frame) + patchified_image = self._video_patchifier.patchify(encoded_image) # [1, 1, C] -> [1, num_patches, C] + num_image_tokens = patchified_image.shape[1] + + # Update the first frame tokens in the latent + new_latent = video_state.latent.clone() + new_latent[:, :num_image_tokens] = patchified_image.to(new_latent.dtype) + + # Update clean_latent as well (conditioning image is clean) + new_clean_latent = video_state.clean_latent.clone() + new_clean_latent[:, :num_image_tokens] = patchified_image.to(new_clean_latent.dtype) + + # Set denoise_mask to 0 for conditioned tokens (don't denoise them) + new_denoise_mask = video_state.denoise_mask.clone() + new_denoise_mask[:, :num_image_tokens] = 0.0 + + return LatentState( + latent=new_latent, + denoise_mask=new_denoise_mask, + positions=video_state.positions, + clean_latent=new_clean_latent, + ) + + @staticmethod + def _preprocess_reference_video(config: GenerationConfig) -> Tensor: + """Preprocess reference video: resize, crop, and convert to model input format. + + Args: + config: Generation configuration with reference_video + + Returns: + Preprocessed video tensor [B, C, F, H, W] in [-1, 1] range + """ + ref_video = config.reference_video # [F, C, H, W] in [0, 1] + target_height, target_width = config.height, config.width + current_height, current_width = ref_video.shape[2:] + + # Resize maintaining aspect ratio and center crop if needed + if current_height != target_height or current_width != target_width: + aspect_ratio = current_width / current_height + target_aspect_ratio = target_width / target_height + + if aspect_ratio > target_aspect_ratio: + resize_height, resize_width = target_height, int(target_height * aspect_ratio) + else: + resize_height, resize_width = int(target_width / aspect_ratio), target_width + + ref_video = torch.nn.functional.interpolate( + ref_video, size=(resize_height, resize_width), mode="bilinear", align_corners=False + ) + + # Center crop + h_start = (resize_height - target_height) // 2 + w_start = (resize_width - target_width) // 2 + ref_video = ref_video[:, :, h_start : h_start + target_height, w_start : w_start + target_width] + + # Convert to [B, C, F, H, W] and trim to valid frame count (k*8 + 1) + ref_video = rearrange(ref_video, "f c h w -> 1 c f h w") + valid_frames = (ref_video.shape[2] - 1) // 8 * 8 + 1 + ref_video = ref_video[:, :, :valid_frames] + + # Convert to [-1, 1] range + return ref_video * 2.0 - 1.0 + + def _encode_video(self, video: Tensor, fps: float, device: torch.device) -> tuple[Tensor, Tensor]: + """Encode video to patchified latents and compute positions. + + Args: + video: Video tensor [B, C, F, H, W] in [-1, 1] range + fps: Frame rate for temporal position scaling + device: Device to run encoding on + + Returns: + Tuple of (patchified_latents, positions) + """ + video = video.to(device=device, dtype=torch.float32) + + # Encode with VAE + self._vae_encoder.to(device) + with torch.autocast(device_type=str(device).split(":")[0], dtype=torch.bfloat16): + latents = self._vae_encoder(video) + self._vae_encoder.to("cpu") + + latents = latents.to(torch.bfloat16) + patchified = self._video_patchifier.patchify(latents) + + # Compute positions + latent_shape = VideoLatentShape( + batch=1, + channels=latents.shape[1], + frames=latents.shape[2], + height=latents.shape[3], + width=latents.shape[4], + ) + latent_coords = self._video_patchifier.get_patch_grid_bounds(output_shape=latent_shape, device=device) + positions = get_pixel_coords(latent_coords, scale_factors=VIDEO_SCALE_FACTORS, causal_fix=True) + positions = positions.to(torch.bfloat16) + positions[:, 0, ...] = positions[:, 0, ...] / fps + + return patchified, positions + + def _run_denoising( + self, + config: GenerationConfig, + video_state: LatentState, + audio_state: LatentState | None, + video_clean_state: LatentState, + audio_clean_state: LatentState | None, + v_ctx_pos: Tensor, + a_ctx_pos: Tensor, + v_ctx_neg: Tensor | None, + a_ctx_neg: Tensor | None, + device: torch.device, + ) -> tuple[LatentState, LatentState | None]: + """Run the denoising loop using X0 prediction with CFG and optional STG.""" + scheduler = LTX2Scheduler() + sigmas = scheduler.execute(steps=config.num_inference_steps).to(device).float() + stepper = EulerDiffusionStep() + cfg_guider = CFGGuider(config.guidance_scale) + stg_guider = STGGuider(config.stg_scale) + + # Build STG perturbation config if STG is enabled + stg_perturbation_config = self._build_stg_perturbation_config(config) if stg_guider.enabled() else None + + # Create initial modalities (will be updated each step via replace()) + video = Modality( + enabled=True, + latent=video_state.latent, + timesteps=video_state.denoise_mask, + positions=video_state.positions, + context=v_ctx_pos, + context_mask=None, + ) + + # Audio modality is None when not generating audio + audio: Modality | None = None + if audio_state is not None: + audio = Modality( + enabled=True, + latent=audio_state.latent, + timesteps=audio_state.denoise_mask, + positions=audio_state.positions, + context=a_ctx_pos, + context_mask=None, + ) + + # Wrap transformer with X0Model to convert velocity predictions to denoised outputs + self._transformer.to(device) + x0_model = X0Model(self._transformer) + + with torch.autocast(device_type=str(device).split(":")[0], dtype=torch.bfloat16): + for step_idx, sigma in enumerate(sigmas[:-1]): + # Update modalities with current state and timesteps + video = replace( + video, + latent=video_state.latent, + timesteps=sigma * video_state.denoise_mask, + positions=video_state.positions, + ) + + if audio is not None and audio_state is not None: + audio = replace( + audio, + latent=audio_state.latent, + timesteps=sigma * audio_state.denoise_mask, + positions=audio_state.positions, + ) + + # Run model (positive pass) - X0Model returns denoised outputs + pos_video, pos_audio = x0_model(video=video, audio=audio, perturbations=None) + denoised_video, denoised_audio = pos_video, pos_audio + + # Apply CFG if guidance_scale != 1.0 + if cfg_guider.enabled() and v_ctx_neg is not None: + video_neg = replace(video, context=v_ctx_neg) + audio_neg = replace(audio, context=a_ctx_neg) if audio is not None else None + neg_video, neg_audio = x0_model(video=video_neg, audio=audio_neg, perturbations=None) + + denoised_video = denoised_video + cfg_guider.delta(pos_video, neg_video) + if audio is not None and denoised_audio is not None: + denoised_audio = denoised_audio + cfg_guider.delta(pos_audio, neg_audio) + + # Apply STG if stg_scale != 0.0 + if stg_guider.enabled() and stg_perturbation_config is not None: + perturbed_video, perturbed_audio = x0_model( + video=video, audio=audio, perturbations=stg_perturbation_config + ) + denoised_video = denoised_video + stg_guider.delta(pos_video, perturbed_video) + if audio is not None and denoised_audio is not None and perturbed_audio is not None: + denoised_audio = denoised_audio + stg_guider.delta(pos_audio, perturbed_audio) + + # Apply conditioning mask (keep conditioned tokens clean) + denoised_video = denoised_video * video_state.denoise_mask + video_clean_state.latent.float() * ( + 1 - video_state.denoise_mask + ) + if audio is not None and audio_state is not None and audio_clean_state is not None: + denoised_audio = denoised_audio * audio_state.denoise_mask + audio_clean_state.latent.float() * ( + 1 - audio_state.denoise_mask + ) + + # Euler step + video_state = replace( + video_state, + latent=stepper.step( + sample=video.latent, denoised_sample=denoised_video, sigmas=sigmas, step_index=step_idx + ), + ) + if audio is not None and audio_state is not None: + audio_state = replace( + audio_state, + latent=stepper.step( + sample=audio.latent, denoised_sample=denoised_audio, sigmas=sigmas, step_index=step_idx + ), + ) + + # Update progress + if self._sampling_context is not None: + self._sampling_context.advance_step() + + return video_state, audio_state + + @staticmethod + def _build_stg_perturbation_config(config: GenerationConfig) -> BatchedPerturbationConfig: + """Build the perturbation config for STG based on the stg_mode.""" + # Always skip video self-attention for STG + perturbations: list[Perturbation] = [ + Perturbation(type=PerturbationType.SKIP_VIDEO_SELF_ATTN, blocks=config.stg_blocks) + ] + + # Optionally also skip audio self-attention (stg_av mode) + if config.stg_mode == "stg_av": + perturbations.append(Perturbation(type=PerturbationType.SKIP_AUDIO_SELF_ATTN, blocks=config.stg_blocks)) + + perturbation_config = PerturbationConfig(perturbations=perturbations) + # Batch size is 1 for validation + return BatchedPerturbationConfig(perturbations=[perturbation_config]) + + def _decode_video_latent(self, latent: Tensor, config: GenerationConfig, device: torch.device) -> Tensor: + """Decode patchified video latent to pixel space.""" + # Unpatchify + latent_frames = config.num_frames // VIDEO_SCALE_FACTORS[0] + 1 + latent_height = config.height // VIDEO_SCALE_FACTORS[1] + latent_width = config.width // VIDEO_SCALE_FACTORS[2] + + unpatchified = self._video_patchifier.unpatchify( + latent, + output_shape=VideoLatentShape( + height=latent_height, + width=latent_width, + frames=latent_frames, + batch=1, + channels=128, + ), + ) + + # Decode - ensure bfloat16 to match decoder weights + self._vae_decoder.to(device) + unpatchified = unpatchified.to(dtype=torch.bfloat16) + tiled_config = config.tiled_decoding + + if tiled_config is not None and tiled_config.enabled: + # Use tiled decoding for reduced VRAM + tiling_config = TilingConfig( + spatial_config=SpatialTilingConfig( + tile_size_in_pixels=tiled_config.tile_size_pixels, + tile_overlap_in_pixels=tiled_config.tile_overlap_pixels, + ), + temporal_config=TemporalTilingConfig( + tile_size_in_frames=tiled_config.tile_size_frames, + tile_overlap_in_frames=tiled_config.tile_overlap_frames, + ), + ) + chunks = [] + for video_chunk, _ in self._vae_decoder.tiled_decode( + unpatchified, + tiling_config=tiling_config, + ): + chunks.append(video_chunk) + decoded_video = torch.cat(chunks, dim=2) + else: + # Standard full decoding + decoded_video = self._vae_decoder(unpatchified) + + decoded_video = ((decoded_video + 1.0) / 2.0).clamp(0.0, 1.0) + self._vae_decoder.to("cpu") + + return decoded_video[0].float().cpu() + + def _validate_config(self, config: GenerationConfig) -> None: + """Validate generation configuration.""" + if config.height % 32 != 0 or config.width % 32 != 0: + raise ValueError(f"height and width must be divisible by 32, got {config.height}x{config.width}") + if config.num_frames % 8 != 1: + raise ValueError(f"num_frames must satisfy num_frames % 8 == 1, got {config.num_frames}") + if config.generate_audio and (self._audio_decoder is None or self._vocoder is None): + raise ValueError("Audio generation requires audio_decoder and vocoder") + if config.condition_image is not None and self._vae_encoder is None: + raise ValueError("Image conditioning requires vae_encoder") + if config.reference_video is not None and self._vae_encoder is None: + raise ValueError("Reference video conditioning requires vae_encoder") + + # Validate prompt embedding source + if config.cached_embeddings is None and self._text_encoder is None: + raise ValueError("Either text_encoder or config.cached_embeddings must be provided") + + def _get_prompt_embeddings( + self, config: GenerationConfig, device: torch.device + ) -> tuple[Tensor, Tensor, Tensor | None, Tensor | None]: + """Get prompt embeddings from config cache or encode on-the-fly.""" + if config.cached_embeddings is not None: + # Use pre-computed embeddings from config + cached = config.cached_embeddings + v_ctx_pos = cached.video_context_positive.to(device) + a_ctx_pos = cached.audio_context_positive.to(device) + v_ctx_neg = cached.video_context_negative.to(device) if cached.video_context_negative is not None else None + a_ctx_neg = cached.audio_context_negative.to(device) if cached.audio_context_negative is not None else None + return v_ctx_pos, a_ctx_pos, v_ctx_neg, a_ctx_neg + + # Fall back to encoding on-the-fly + return self._encode_prompts(config, device) + + def _encode_prompts( + self, config: GenerationConfig, device: torch.device + ) -> tuple[Tensor, Tensor, Tensor | None, Tensor | None]: + """Encode positive and negative prompts using the text encoder.""" + self._text_encoder.to(device) + v_ctx_pos, a_ctx_pos, _ = self._text_encoder(config.prompt) + v_ctx_neg, a_ctx_neg = None, None + if config.guidance_scale != 1.0: + v_ctx_neg, a_ctx_neg, _ = self._text_encoder(config.negative_prompt) + + # Move the base Gemma model to CPU but keep embeddings connectors on GPU + # as this module is also used during training + self._text_encoder.model.to("cpu") + self._text_encoder.feature_extractor_linear.to("cpu") + + return v_ctx_pos, a_ctx_pos, v_ctx_neg, a_ctx_neg + + def _decode_video( + self, video_state: LatentState, device: torch.device, tiled_config: TiledDecodingConfig | None = None + ) -> Tensor: + """Decode video latents to pixel space. + + Args: + video_state: Video latent state to decode + device: Device to run decoding on + tiled_config: Optional tiled decoding configuration for reduced VRAM usage + + Returns: + Decoded video tensor [C, F, H, W] in [0, 1] range + """ + self._vae_decoder.to(device) + # Ensure latent is bfloat16 to match decoder weights + latent = video_state.latent.to(dtype=torch.bfloat16) + + if tiled_config is not None and tiled_config.enabled: + # Use tiled decoding for reduced VRAM + tiling_config = TilingConfig( + spatial_config=SpatialTilingConfig( + tile_size_in_pixels=tiled_config.tile_size_pixels, + tile_overlap_in_pixels=tiled_config.tile_overlap_pixels, + ), + temporal_config=TemporalTilingConfig( + tile_size_in_frames=tiled_config.tile_size_frames, + tile_overlap_in_frames=tiled_config.tile_overlap_frames, + ), + ) + chunks = [] + for video_chunk, _ in self._vae_decoder.tiled_decode( + latent, + tiling_config=tiling_config, + ): + chunks.append(video_chunk) + decoded_video = torch.cat(chunks, dim=2) + else: + # Standard full decoding + decoded_video = self._vae_decoder(latent) + + decoded_video = ((decoded_video + 1.0) / 2.0).clamp(0.0, 1.0) + self._vae_decoder.to("cpu") + return decoded_video[0].float().cpu() + + def _decode_audio(self, audio_state: LatentState, device: torch.device) -> Tensor: + """Decode audio latents to waveform.""" + self._audio_decoder.to(device) + # Ensure latent is bfloat16 to match decoder weights + latent = audio_state.latent.to(dtype=torch.bfloat16) + decoded_audio = self._audio_decoder(latent) + self._audio_decoder.to("cpu") + + self._vocoder.to(device) + audio_waveform = self._vocoder(decoded_audio) + self._vocoder.to("cpu") + + return audio_waveform.squeeze(0).float().cpu() + + @staticmethod + def _concatenate_videos_side_by_side(left_video: Tensor, right_video: Tensor) -> Tensor: + """Concatenate two videos side-by-side (horizontally). + + If the videos have different frame counts, the shorter one is padded with + its last frame repeated. + + Args: + left_video: Left video tensor [C, F1, H, W] in [0, 1] + right_video: Right video tensor [C, F2, H, W] in [0, 1] + + Returns: + Concatenated video tensor [C, max(F1,F2), H, W*2] in [0, 1] + """ + left_frames = left_video.shape[1] + right_frames = right_video.shape[1] + + # Pad shorter video by repeating last frame + if left_frames < right_frames: + padding = left_video[:, -1:, :, :].expand(-1, right_frames - left_frames, -1, -1) + left_video = torch.cat([left_video, padding], dim=1) + elif right_frames < left_frames: + padding = right_video[:, -1:, :, :].expand(-1, left_frames - right_frames, -1, -1) + right_video = torch.cat([right_video, padding], dim=1) + + # Concatenate along width dimension + return torch.cat([left_video, right_video], dim=3) + + def _encode_conditioning_image( + self, + image: Tensor, + target_height: int, + target_width: int, + device: torch.device, + ) -> Tensor: + """Encode a conditioning image to latent space. + + The image is resized to cover the target dimensions while preserving aspect ratio, + then center-cropped to exactly match the target size. + """ + # image is [C, H, W] in [0, 1] # noqa: ERA001 + current_height, current_width = image.shape[1:] + + # Resize maintaining aspect ratio (cover target, then center crop) + if current_height != target_height or current_width != target_width: + aspect_ratio = current_width / current_height + target_aspect_ratio = target_width / target_height + + if aspect_ratio > target_aspect_ratio: + # Image is wider than target - resize to match height, crop width + resize_height = target_height + resize_width = int(target_height * aspect_ratio) + else: + # Image is taller than target - resize to match width, crop height + resize_height = int(target_width / aspect_ratio) + resize_width = target_width + + image = rearrange(image, "c h w -> 1 c h w") + image = torch.nn.functional.interpolate( + image, size=(resize_height, resize_width), mode="bilinear", align_corners=False + ) + + # Center crop to target dimensions + h_start = (resize_height - target_height) // 2 + w_start = (resize_width - target_width) // 2 + image = image[:, :, h_start : h_start + target_height, w_start : w_start + target_width] + else: + image = rearrange(image, "c h w -> 1 c h w") + + # Add frame dimension and convert to [-1, 1] + image = rearrange(image, "b c h w -> b c 1 h w") + image = (image * 2.0 - 1.0).to(device=device, dtype=torch.float32) + + # Encode + self._vae_encoder.to(device) + with torch.autocast(device_type=str(device).split(":")[0], dtype=torch.bfloat16): + encoded = self._vae_encoder(image) + self._vae_encoder.to("cpu") + + return encoded diff --git a/packages/ltx-trainer/src/ltx_trainer/video_utils.py b/packages/ltx-trainer/src/ltx_trainer/video_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cbfc481cdabb1de1780b189b14007a4cf9f76f4e --- /dev/null +++ b/packages/ltx-trainer/src/ltx_trainer/video_utils.py @@ -0,0 +1,165 @@ +"""Video I/O utilities using PyAV. + +This module provides functions for reading and writing video files using PyAV, +with optional audio support. +""" + +from fractions import Fraction +from pathlib import Path + +import av +import numpy as np +import torch +from torch import Tensor + + +def get_video_frame_count(video_path: str | Path) -> int: + """Get the number of frames in a video file. + + Args: + video_path: Path to the video file + + Returns: + Number of frames in the video + """ + with av.open(str(video_path)) as container: + video_stream = container.streams.video[0] + frame_count = video_stream.frames + if frame_count == 0: + # Fallback: count frames by decoding + frame_count = sum(1 for _ in container.decode(video=0)) + return frame_count + + +def read_video(video_path: str | Path, max_frames: int | None = None) -> tuple[Tensor, float]: + """Load frames from a video file using PyAV. + + Args: + video_path: Path to the video file + max_frames: Maximum number of frames to read. If None, reads all frames. + + Returns: + Video tensor with shape [F, C, H, W] in range [0, 1] and frames per second (fps). + """ + with av.open(str(video_path)) as container: + video_stream = container.streams.video[0] + fps = float(video_stream.average_rate or video_stream.base_rate or 24) + + frames = [] + for frame in container.decode(video=0): + if max_frames is not None and len(frames) >= max_frames: + break + frames.append(frame.to_ndarray(format="rgb24")) + + frames_np = np.stack(frames, axis=0) # [F, H, W, C] + video = torch.from_numpy(frames_np).float().div(255.0) # [F, H, W, C] in [0, 1] + return video.permute(0, 3, 1, 2), fps # [F, C, H, W] + + +def save_video( + video_tensor: torch.Tensor, + output_path: Path | str, + fps: float = 24.0, + audio: torch.Tensor | None = None, + audio_sample_rate: int | None = None, +) -> None: + """Save a video tensor to a file using PyAV, optionally with audio. + + Args: + video_tensor: Video tensor of shape [C, F, H, W] or [F, C, H, W] in range [0, 1] or [0, 255] + output_path: Path to save the video + fps: Frames per second for the output video + audio: Optional audio tensor of shape [C, samples] or [samples, C] in range [-1, 1] + audio_sample_rate: Sample rate for the audio (required if audio is provided) + """ + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Normalize to [F, H, W, C] uint8 numpy array + video_np = _prepare_video_array(video_tensor) + _, height, width, _ = video_np.shape + + with av.open(str(output_path), mode="w") as container: + # Setup video stream + video_stream = container.add_stream("libx264", rate=int(fps)) + video_stream.width = width + video_stream.height = height + video_stream.pix_fmt = "yuv420p" + video_stream.options = {"crf": "18"} + + # Setup audio stream if needed + if audio is not None: + if audio_sample_rate is None: + raise ValueError("audio_sample_rate must be provided when audio is given") + audio_stream = container.add_stream("aac", rate=audio_sample_rate) + audio_stream.layout = "stereo" + audio_stream.time_base = Fraction(1, audio_sample_rate) + + # Write video frames + for frame_array in video_np: + frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") + for packet in video_stream.encode(frame): + container.mux(packet) + for packet in video_stream.encode(): + container.mux(packet) + + # Write audio if provided + if audio is not None: + _write_audio(container, audio_stream, audio, audio_sample_rate) + + +def _prepare_video_array(video_tensor: torch.Tensor) -> np.ndarray: + """Convert video tensor to [F, H, W, C] uint8 numpy array.""" + # Handle [C, F, H, W] vs [F, C, H, W] format + if video_tensor.shape[0] == 3 and video_tensor.shape[1] > 3: + video_tensor = video_tensor.permute(1, 0, 2, 3) # [C, F, H, W] -> [F, C, H, W] + + # Normalize to [0, 255] uint8 + if video_tensor.max() <= 1.0: + video_tensor = video_tensor * 255 + + # [F, C, H, W] -> [F, H, W, C] + return video_tensor.permute(0, 2, 3, 1).to(torch.uint8).cpu().numpy() + + +def _write_audio( + container: av.container.Container, + audio_stream: av.audio.AudioStream, + audio: torch.Tensor, + sample_rate: int, +) -> None: + """Write audio tensor to container as stereo AAC.""" + audio = audio.cpu().float() + + # Normalize to [samples, 2] stereo format + if audio.ndim == 1: + audio = audio.unsqueeze(1).repeat(1, 2) # Mono -> stereo + elif audio.shape[0] == 2 and audio.shape[1] != 2: + audio = audio.T # [2, samples] -> [samples, 2] + if audio.shape[1] == 1: + audio = audio.repeat(1, 2) # Mono -> stereo + + # Convert to int16 interleaved: [samples, 2] -> [1, samples*2] + audio_int16 = (audio.clamp(-1, 1) * 32767).to(torch.int16) + audio_interleaved = audio_int16.contiguous().view(1, -1).numpy() + + # Create audio frame + frame = av.AudioFrame.from_ndarray(audio_interleaved, format="s16", layout="stereo") + frame.sample_rate = sample_rate + + # Resample to encoder format and write + resampler = av.audio.resampler.AudioResampler( + format=audio_stream.codec_context.format, + layout=audio_stream.codec_context.layout, + rate=sample_rate, + ) + + pts = 0 + for resampled_frame in resampler.resample(frame): + resampled_frame.pts = pts + pts += resampled_frame.samples + for packet in audio_stream.encode(resampled_frame): + container.mux(packet) + + for packet in audio_stream.encode(): + container.mux(packet) diff --git a/packages/ltx-trainer/templates/model_card.md b/packages/ltx-trainer/templates/model_card.md new file mode 100644 index 0000000000000000000000000000000000000000..d8fb0d6d1d0a961412aab716f6babfa70b60c447 --- /dev/null +++ b/packages/ltx-trainer/templates/model_card.md @@ -0,0 +1,56 @@ +--- +tags: +- ltx-2 +- ltx-video +- text-to-video +- audio-video +pinned: true +language: +- en +license: other +pipeline_tag: text-to-video +library_name: diffusers +--- + +# {model_name} + +This is a fine-tuned version of [`{base_model}`]({base_model_link}) trained on custom data. + +## Model Details + +- **Base Model:** [`{base_model}`]({base_model_link}) +- **Training Type:** {training_type} +- **Training Steps:** {training_steps} +- **Learning Rate:** {learning_rate} +- **Batch Size:** {batch_size} + +## Sample Outputs + +| | | | | +|:---:|:---:|:---:|:---:| +{sample_grid} + +## Usage + +This model is designed to be used with the LTX-2 (Lightricks Audio-Video) pipeline. + +### 🔌 Using Trained LoRAs in ComfyUI +In order to use the trained LoRA in ComfyUI, follow these steps: +1. Copy your trained LoRA checkpoint (`.safetensors` file) to the `models/loras` folder in your ComfyUI installation. +2. In your ComfyUI workflow: + - Add the "Load LoRA" node to choose your LoRA file + - Connect it to the "Load Checkpoint" node to apply the LoRA to the base model + +You can find reference Text-to-Video (T2V) and Image-to-Video (I2V) workflows in the official [LTX-2 repository](https://github.com/Lightricks/LTX-Video). + +### Example Prompts + +{validation_prompts} + + +This model inherits the license of the base model ([`{base_model}`]({base_model_link})). + +## Acknowledgments + +- Base model by [Lightricks](https://huggingface.co/Lightricks) +- Trainer: [LTX-2](https://github.com/Lightricks/LTX-Video) diff --git a/packages/ltx-trainer/tests/__init__.py b/packages/ltx-trainer/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/packages/ltx-trainer/tests/test_configs.py b/packages/ltx-trainer/tests/test_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..e9281eecfb254cac5bfffae6a9360d2a14bb2235 --- /dev/null +++ b/packages/ltx-trainer/tests/test_configs.py @@ -0,0 +1,84 @@ +"""Test that all configuration files are valid and can be loaded. + +This test automatically discovers all YAML files in the configs/ directory +and validates that they can be deserialized into LtxTrainerConfig objects. +""" + +from pathlib import Path + +import pytest +import yaml + +from ltx_trainer.config import LtxTrainerConfig + + +def get_config_files() -> list[Path]: + """Discover all YAML config files in the configs directory. + + Returns: + List of paths to YAML config files (excluding accelerate configs) + """ + configs_dir = Path(__file__).parent.parent / "configs" + + # Find all .yaml and .yml files, excluding accelerate subfolder + config_files = [] + for pattern in ["*.yaml", "*.yml"]: + config_files.extend(configs_dir.glob(pattern)) + + return sorted(config_files) + + +@pytest.mark.parametrize("config_file", get_config_files(), ids=lambda p: p.name) +def test_config_file_valid(config_file: Path, tmp_path: Path) -> None: + """Test that a config file can be loaded and validated. + + This test parses the YAML and validates it against LtxTrainerConfig schema. + Pydantic handles all validation - if deserialization succeeds, the config is valid. + + Note: We create actual dummy files since Pydantic validators check path existence. + """ + # Load YAML + with open(config_file) as f: + config_dict = yaml.safe_load(f) + + # Create dummy files that validators will check for existence + dummy_model = tmp_path / "dummy_model.safetensors" + dummy_model.touch() + + dummy_gemma = tmp_path / "dummy_gemma" + dummy_gemma.mkdir() + + dummy_video = tmp_path / "dummy_video.mp4" + dummy_video.touch() + + dummy_image = tmp_path / "dummy_image.png" + dummy_image.touch() + + # Replace file paths with dummy paths that actually exist + if "model" in config_dict: + if "model_path" in config_dict["model"]: + config_dict["model"]["model_path"] = str(dummy_model) + if "text_encoder_path" in config_dict["model"]: + config_dict["model"]["text_encoder_path"] = str(dummy_gemma) + + if "data" in config_dict and "preprocessed_data_root" in config_dict["data"]: + config_dict["data"]["preprocessed_data_root"] = str(tmp_path) + + if "validation" in config_dict: + # Replace validation paths with dummy paths + if "images" in config_dict["validation"] and config_dict["validation"]["images"]: + # Provide dummy image paths (one per prompt) + num_prompts = len(config_dict["validation"].get("prompts", [])) + config_dict["validation"]["images"] = [str(dummy_image)] * num_prompts + + if "reference_videos" in config_dict["validation"] and config_dict["validation"]["reference_videos"]: + # Provide dummy video paths (one per prompt) + num_prompts = len(config_dict["validation"].get("prompts", [])) + config_dict["validation"]["reference_videos"] = [str(dummy_video)] * num_prompts + + # Validate config - Pydantic does all the work! + # If this doesn't raise ValidationError, the config is valid + config = LtxTrainerConfig.model_validate(config_dict) + + # Basic sanity check that we got a valid config object + assert config is not None