linoy
inital commit
ebfc6b3
# ruff: noqa: PLC0415
"""
Model loader for LTX-2 trainer using the new ltx-core package.
This module provides a unified interface for loading LTX-2 model components
for training, using SingleGPUModelBuilder from ltx-core.
Example usage:
# Load individual components
vae_encoder = load_video_vae_encoder("/path/to/checkpoint.safetensors", device="cuda")
vae_decoder = load_video_vae_decoder("/path/to/checkpoint.safetensors", device="cuda")
text_encoder = load_text_encoder("/path/to/checkpoint.safetensors", "/path/to/gemma", device="cuda")
# Load all components at once
components = load_model("/path/to/checkpoint.safetensors", text_encoder_path="/path/to/gemma")
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
import torch
from ltx_trainer import logger
# Type alias for device specification
Device = str | torch.device
# Type checking imports (not loaded at runtime)
if TYPE_CHECKING:
from ltx_core.model.audio_vae.audio_vae import Decoder as AudioVAEDecoder
from ltx_core.model.audio_vae.audio_vae import Encoder as AudioVAEEncoder
from ltx_core.model.audio_vae.vocoder import Vocoder
from ltx_core.model.clip.gemma.encoders.av_encoder import AVGemmaTextEncoderModel
from ltx_core.model.transformer.model import LTXModel
from ltx_core.model.video_vae.video_vae import Decoder as VideoVAEDecoder
from ltx_core.model.video_vae.video_vae import Encoder as VideoVAEEncoder
from ltx_core.pipeline.components.schedulers import LTX2Scheduler
def _to_torch_device(device: Device) -> torch.device:
"""Convert device specification to torch.device."""
return torch.device(device) if isinstance(device, str) else device
# =============================================================================
# Individual Component Loaders
# =============================================================================
def load_transformer(
checkpoint_path: str | Path,
device: Device = "cpu",
dtype: torch.dtype = torch.bfloat16,
) -> "LTXModel":
"""Load the LTX transformer model.
Args:
checkpoint_path: Path to the safetensors checkpoint file
device: Device to load model on
dtype: Data type for model weights
Returns:
Loaded LTXModel transformer
"""
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder
from ltx_core.model.transformer.model_configurator import (
LTXV_MODEL_COMFY_RENAMING_MAP,
LTXModelConfigurator,
)
return SingleGPUModelBuilder(
model_path=str(checkpoint_path),
model_class_configurator=LTXModelConfigurator,
model_sd_ops=LTXV_MODEL_COMFY_RENAMING_MAP,
).build(device=_to_torch_device(device), dtype=dtype)
def load_video_vae_encoder(
checkpoint_path: str | Path,
device: Device = "cpu",
dtype: torch.dtype = torch.bfloat16,
) -> "VideoVAEEncoder":
"""Load the video VAE encoder (for preprocessing).
Args:
checkpoint_path: Path to the safetensors checkpoint file
device: Device to load model on
dtype: Data type for model weights
Returns:
Loaded VideoVAEEncoder
"""
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder
from ltx_core.model.video_vae.model_configurator import VAE_ENCODER_COMFY_KEYS_FILTER
from ltx_core.model.video_vae.model_configurator import (
VAEEncoderConfigurator as VideoVAEEncoderConfigurator,
)
return SingleGPUModelBuilder(
model_path=str(checkpoint_path),
model_class_configurator=VideoVAEEncoderConfigurator,
model_sd_ops=VAE_ENCODER_COMFY_KEYS_FILTER,
).build(device=_to_torch_device(device), dtype=dtype)
def load_video_vae_decoder(
checkpoint_path: str | Path,
device: Device = "cpu",
dtype: torch.dtype = torch.bfloat16,
) -> "VideoVAEDecoder":
"""Load the video VAE decoder (for inference/validation).
Args:
checkpoint_path: Path to the safetensors checkpoint file
device: Device to load model on
dtype: Data type for model weights
Returns:
Loaded VideoVAEDecoder
"""
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder
from ltx_core.model.video_vae.model_configurator import VAE_DECODER_COMFY_KEYS_FILTER
from ltx_core.model.video_vae.model_configurator import (
VAEDecoderConfigurator as VideoVAEDecoderConfigurator,
)
return SingleGPUModelBuilder(
model_path=str(checkpoint_path),
model_class_configurator=VideoVAEDecoderConfigurator,
model_sd_ops=VAE_DECODER_COMFY_KEYS_FILTER,
).build(device=_to_torch_device(device), dtype=dtype)
def load_audio_vae_encoder(
checkpoint_path: str | Path,
device: Device = "cpu",
dtype: torch.dtype = torch.bfloat16,
) -> "AudioVAEEncoder":
"""Load the audio VAE encoder (for preprocessing).
Args:
checkpoint_path: Path to the safetensors checkpoint file
device: Device to load model on
dtype: Data type for model weights (default bfloat16, but float32 recommended for quality)
Returns:
Loaded AudioVAEEncoder
"""
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder
from ltx_core.model.audio_vae.model_configurator import AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER
from ltx_core.model.audio_vae.model_configurator import (
VAEEncoderConfigurator as AudioVAEEncoderConfigurator,
)
return SingleGPUModelBuilder(
model_path=str(checkpoint_path),
model_class_configurator=AudioVAEEncoderConfigurator,
model_sd_ops=AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER,
).build(device=_to_torch_device(device), dtype=dtype)
def load_audio_vae_decoder(
checkpoint_path: str | Path,
device: Device = "cpu",
dtype: torch.dtype = torch.bfloat16,
) -> "AudioVAEDecoder":
"""Load the audio VAE decoder.
Args:
checkpoint_path: Path to the safetensors checkpoint file
device: Device to load model on
dtype: Data type for model weights
Returns:
Loaded AudioVAEDecoder
"""
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder
from ltx_core.model.audio_vae.model_configurator import AUDIO_VAE_DECODER_COMFY_KEYS_FILTER
from ltx_core.model.audio_vae.model_configurator import (
VAEDecoderConfigurator as AudioVAEDecoderConfigurator,
)
return SingleGPUModelBuilder(
model_path=str(checkpoint_path),
model_class_configurator=AudioVAEDecoderConfigurator,
model_sd_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
).build(device=_to_torch_device(device), dtype=dtype)
def load_vocoder(
checkpoint_path: str | Path,
device: Device = "cpu",
dtype: torch.dtype = torch.bfloat16,
) -> "Vocoder":
"""Load the vocoder (for audio waveform generation).
Args:
checkpoint_path: Path to the safetensors checkpoint file
device: Device to load model on
dtype: Data type for model weights
Returns:
Loaded Vocoder
"""
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder
from ltx_core.model.audio_vae.model_configurator import VOCODER_COMFY_KEYS_FILTER, VocoderConfigurator
return SingleGPUModelBuilder(
model_path=str(checkpoint_path),
model_class_configurator=VocoderConfigurator,
model_sd_ops=VOCODER_COMFY_KEYS_FILTER,
).build(device=_to_torch_device(device), dtype=dtype)
def load_text_encoder(
checkpoint_path: str | Path,
gemma_model_path: str | Path,
device: Device = "cpu",
dtype: torch.dtype = torch.bfloat16,
) -> "AVGemmaTextEncoderModel":
"""Load the Gemma text encoder.
Args:
checkpoint_path: Path to the LTX-2 safetensors checkpoint file
gemma_model_path: Path to Gemma model directory
device: Device to load model on
dtype: Data type for model weights
Returns:
Loaded AVGemmaTextEncoderModel
"""
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder
from ltx_core.model.clip.gemma.encoders.av_encoder import (
AV_GEMMA_TEXT_ENCODER_KEY_OPS,
AVGemmaTextEncoderModelConfigurator,
)
from ltx_core.model.clip.gemma.encoders.base_encoder import module_ops_from_gemma_root
if not Path(gemma_model_path).is_dir():
raise ValueError(f"Gemma model path is not a directory: {gemma_model_path}")
torch_device = _to_torch_device(device)
text_encoder = SingleGPUModelBuilder(
model_path=str(checkpoint_path),
model_class_configurator=AVGemmaTextEncoderModelConfigurator,
model_sd_ops=AV_GEMMA_TEXT_ENCODER_KEY_OPS,
module_ops=module_ops_from_gemma_root(str(gemma_model_path)),
).build(device=torch_device, dtype=dtype)
return text_encoder
# =============================================================================
# Combined Component Loader
# =============================================================================
@dataclass
class LtxModelComponents:
"""Container for all LTX-2 model components."""
transformer: "LTXModel"
video_vae_encoder: "VideoVAEEncoder | None" = None
video_vae_decoder: "VideoVAEDecoder | None" = None
audio_vae_decoder: "AudioVAEDecoder | None" = None
vocoder: "Vocoder | None" = None
text_encoder: "AVGemmaTextEncoderModel | None" = None
scheduler: "LTX2Scheduler | None" = None
def load_model(
checkpoint_path: str | Path,
text_encoder_path: str | Path | None = None,
device: Device = "cpu",
dtype: torch.dtype = torch.bfloat16,
with_video_vae_encoder: bool = False,
with_video_vae_decoder: bool = True,
with_audio_vae_decoder: bool = True,
with_vocoder: bool = True,
with_text_encoder: bool = True,
) -> LtxModelComponents:
"""
Load LTX-2 model components from a safetensors checkpoint.
This is a convenience function that loads multiple components at once.
For loading individual components, use the dedicated functions:
- load_transformer()
- load_video_vae_encoder()
- load_video_vae_decoder()
- load_audio_vae_decoder()
- load_vocoder()
- load_text_encoder()
Args:
checkpoint_path: Path to the safetensors checkpoint file
text_encoder_path: Path to Gemma model directory (required if with_text_encoder=True)
device: Device to load models on ("cuda", "cpu", etc.)
dtype: Data type for model weights
with_video_vae_encoder: Whether to load the video VAE encoder (for preprocessing)
with_video_vae_decoder: Whether to load the video VAE decoder (for inference/validation)
with_audio_vae_decoder: Whether to load the audio VAE decoder
with_vocoder: Whether to load the vocoder
with_text_encoder: Whether to load the text encoder
Returns:
LtxModelComponents containing all loaded model components
"""
from ltx_core.pipeline.components.schedulers import LTX2Scheduler
checkpoint_path = Path(checkpoint_path)
# Validate checkpoint exists
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
logger.info(f"Loading LTX-2 model from {checkpoint_path}")
torch_device = _to_torch_device(device)
# Load transformer
logger.debug("Loading transformer...")
transformer = load_transformer(checkpoint_path, torch_device, dtype)
# Load video VAE encoder
video_vae_encoder = None
if with_video_vae_encoder:
logger.debug("Loading video VAE encoder...")
video_vae_encoder = load_video_vae_encoder(checkpoint_path, torch_device, dtype)
# Load video VAE decoder
video_vae_decoder = None
if with_video_vae_decoder:
logger.debug("Loading video VAE decoder...")
video_vae_decoder = load_video_vae_decoder(checkpoint_path, torch_device, dtype)
# Load audio VAE decoder
audio_vae_decoder = None
if with_audio_vae_decoder:
logger.debug("Loading audio VAE decoder...")
audio_vae_decoder = load_audio_vae_decoder(checkpoint_path, torch_device, dtype)
# Load vocoder
vocoder = None
if with_vocoder:
logger.debug("Loading vocoder...")
vocoder = load_vocoder(checkpoint_path, torch_device, dtype)
# Load text encoder
text_encoder = None
if with_text_encoder:
if text_encoder_path is None:
raise ValueError("text_encoder_path must be provided when with_text_encoder=True")
logger.debug("Loading Gemma text encoder...")
text_encoder = load_text_encoder(checkpoint_path, text_encoder_path, torch_device, dtype)
# Create scheduler (stateless, no loading needed)
scheduler = LTX2Scheduler()
return LtxModelComponents(
transformer=transformer,
video_vae_encoder=video_vae_encoder,
video_vae_decoder=video_vae_decoder,
audio_vae_decoder=audio_vae_decoder,
vocoder=vocoder,
text_encoder=text_encoder,
scheduler=scheduler,
)