# ruff: noqa: PLC0415 """ Model loader for LTX-2 trainer using the new ltx-core package. This module provides a unified interface for loading LTX-2 model components for training, using SingleGPUModelBuilder from ltx-core. Example usage: # Load individual components vae_encoder = load_video_vae_encoder("/path/to/checkpoint.safetensors", device="cuda") vae_decoder = load_video_vae_decoder("/path/to/checkpoint.safetensors", device="cuda") text_encoder = load_text_encoder("/path/to/checkpoint.safetensors", "/path/to/gemma", device="cuda") # Load all components at once components = load_model("/path/to/checkpoint.safetensors", text_encoder_path="/path/to/gemma") """ from __future__ import annotations from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING import torch from ltx_trainer import logger # Type alias for device specification Device = str | torch.device # Type checking imports (not loaded at runtime) if TYPE_CHECKING: from ltx_core.model.audio_vae.audio_vae import Decoder as AudioVAEDecoder from ltx_core.model.audio_vae.audio_vae import Encoder as AudioVAEEncoder from ltx_core.model.audio_vae.vocoder import Vocoder from ltx_core.model.clip.gemma.encoders.av_encoder import AVGemmaTextEncoderModel from ltx_core.model.transformer.model import LTXModel from ltx_core.model.video_vae.video_vae import Decoder as VideoVAEDecoder from ltx_core.model.video_vae.video_vae import Encoder as VideoVAEEncoder from ltx_core.pipeline.components.schedulers import LTX2Scheduler def _to_torch_device(device: Device) -> torch.device: """Convert device specification to torch.device.""" return torch.device(device) if isinstance(device, str) else device # ============================================================================= # Individual Component Loaders # ============================================================================= def load_transformer( checkpoint_path: str | Path, device: Device = "cpu", dtype: torch.dtype = torch.bfloat16, ) -> "LTXModel": """Load the LTX transformer model. Args: checkpoint_path: Path to the safetensors checkpoint file device: Device to load model on dtype: Data type for model weights Returns: Loaded LTXModel transformer """ from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder from ltx_core.model.transformer.model_configurator import ( LTXV_MODEL_COMFY_RENAMING_MAP, LTXModelConfigurator, ) return SingleGPUModelBuilder( model_path=str(checkpoint_path), model_class_configurator=LTXModelConfigurator, model_sd_ops=LTXV_MODEL_COMFY_RENAMING_MAP, ).build(device=_to_torch_device(device), dtype=dtype) def load_video_vae_encoder( checkpoint_path: str | Path, device: Device = "cpu", dtype: torch.dtype = torch.bfloat16, ) -> "VideoVAEEncoder": """Load the video VAE encoder (for preprocessing). Args: checkpoint_path: Path to the safetensors checkpoint file device: Device to load model on dtype: Data type for model weights Returns: Loaded VideoVAEEncoder """ from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder from ltx_core.model.video_vae.model_configurator import VAE_ENCODER_COMFY_KEYS_FILTER from ltx_core.model.video_vae.model_configurator import ( VAEEncoderConfigurator as VideoVAEEncoderConfigurator, ) return SingleGPUModelBuilder( model_path=str(checkpoint_path), model_class_configurator=VideoVAEEncoderConfigurator, model_sd_ops=VAE_ENCODER_COMFY_KEYS_FILTER, ).build(device=_to_torch_device(device), dtype=dtype) def load_video_vae_decoder( checkpoint_path: str | Path, device: Device = "cpu", dtype: torch.dtype = torch.bfloat16, ) -> "VideoVAEDecoder": """Load the video VAE decoder (for inference/validation). Args: checkpoint_path: Path to the safetensors checkpoint file device: Device to load model on dtype: Data type for model weights Returns: Loaded VideoVAEDecoder """ from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder from ltx_core.model.video_vae.model_configurator import VAE_DECODER_COMFY_KEYS_FILTER from ltx_core.model.video_vae.model_configurator import ( VAEDecoderConfigurator as VideoVAEDecoderConfigurator, ) return SingleGPUModelBuilder( model_path=str(checkpoint_path), model_class_configurator=VideoVAEDecoderConfigurator, model_sd_ops=VAE_DECODER_COMFY_KEYS_FILTER, ).build(device=_to_torch_device(device), dtype=dtype) def load_audio_vae_encoder( checkpoint_path: str | Path, device: Device = "cpu", dtype: torch.dtype = torch.bfloat16, ) -> "AudioVAEEncoder": """Load the audio VAE encoder (for preprocessing). Args: checkpoint_path: Path to the safetensors checkpoint file device: Device to load model on dtype: Data type for model weights (default bfloat16, but float32 recommended for quality) Returns: Loaded AudioVAEEncoder """ from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder from ltx_core.model.audio_vae.model_configurator import AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER from ltx_core.model.audio_vae.model_configurator import ( VAEEncoderConfigurator as AudioVAEEncoderConfigurator, ) return SingleGPUModelBuilder( model_path=str(checkpoint_path), model_class_configurator=AudioVAEEncoderConfigurator, model_sd_ops=AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER, ).build(device=_to_torch_device(device), dtype=dtype) def load_audio_vae_decoder( checkpoint_path: str | Path, device: Device = "cpu", dtype: torch.dtype = torch.bfloat16, ) -> "AudioVAEDecoder": """Load the audio VAE decoder. Args: checkpoint_path: Path to the safetensors checkpoint file device: Device to load model on dtype: Data type for model weights Returns: Loaded AudioVAEDecoder """ from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder from ltx_core.model.audio_vae.model_configurator import AUDIO_VAE_DECODER_COMFY_KEYS_FILTER from ltx_core.model.audio_vae.model_configurator import ( VAEDecoderConfigurator as AudioVAEDecoderConfigurator, ) return SingleGPUModelBuilder( model_path=str(checkpoint_path), model_class_configurator=AudioVAEDecoderConfigurator, model_sd_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER, ).build(device=_to_torch_device(device), dtype=dtype) def load_vocoder( checkpoint_path: str | Path, device: Device = "cpu", dtype: torch.dtype = torch.bfloat16, ) -> "Vocoder": """Load the vocoder (for audio waveform generation). Args: checkpoint_path: Path to the safetensors checkpoint file device: Device to load model on dtype: Data type for model weights Returns: Loaded Vocoder """ from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder from ltx_core.model.audio_vae.model_configurator import VOCODER_COMFY_KEYS_FILTER, VocoderConfigurator return SingleGPUModelBuilder( model_path=str(checkpoint_path), model_class_configurator=VocoderConfigurator, model_sd_ops=VOCODER_COMFY_KEYS_FILTER, ).build(device=_to_torch_device(device), dtype=dtype) def load_text_encoder( checkpoint_path: str | Path, gemma_model_path: str | Path, device: Device = "cpu", dtype: torch.dtype = torch.bfloat16, ) -> "AVGemmaTextEncoderModel": """Load the Gemma text encoder. Args: checkpoint_path: Path to the LTX-2 safetensors checkpoint file gemma_model_path: Path to Gemma model directory device: Device to load model on dtype: Data type for model weights Returns: Loaded AVGemmaTextEncoderModel """ from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder from ltx_core.model.clip.gemma.encoders.av_encoder import ( AV_GEMMA_TEXT_ENCODER_KEY_OPS, AVGemmaTextEncoderModelConfigurator, ) from ltx_core.model.clip.gemma.encoders.base_encoder import module_ops_from_gemma_root if not Path(gemma_model_path).is_dir(): raise ValueError(f"Gemma model path is not a directory: {gemma_model_path}") torch_device = _to_torch_device(device) text_encoder = SingleGPUModelBuilder( model_path=str(checkpoint_path), model_class_configurator=AVGemmaTextEncoderModelConfigurator, model_sd_ops=AV_GEMMA_TEXT_ENCODER_KEY_OPS, module_ops=module_ops_from_gemma_root(str(gemma_model_path)), ).build(device=torch_device, dtype=dtype) return text_encoder # ============================================================================= # Combined Component Loader # ============================================================================= @dataclass class LtxModelComponents: """Container for all LTX-2 model components.""" transformer: "LTXModel" video_vae_encoder: "VideoVAEEncoder | None" = None video_vae_decoder: "VideoVAEDecoder | None" = None audio_vae_decoder: "AudioVAEDecoder | None" = None vocoder: "Vocoder | None" = None text_encoder: "AVGemmaTextEncoderModel | None" = None scheduler: "LTX2Scheduler | None" = None def load_model( checkpoint_path: str | Path, text_encoder_path: str | Path | None = None, device: Device = "cpu", dtype: torch.dtype = torch.bfloat16, with_video_vae_encoder: bool = False, with_video_vae_decoder: bool = True, with_audio_vae_decoder: bool = True, with_vocoder: bool = True, with_text_encoder: bool = True, ) -> LtxModelComponents: """ Load LTX-2 model components from a safetensors checkpoint. This is a convenience function that loads multiple components at once. For loading individual components, use the dedicated functions: - load_transformer() - load_video_vae_encoder() - load_video_vae_decoder() - load_audio_vae_decoder() - load_vocoder() - load_text_encoder() Args: checkpoint_path: Path to the safetensors checkpoint file text_encoder_path: Path to Gemma model directory (required if with_text_encoder=True) device: Device to load models on ("cuda", "cpu", etc.) dtype: Data type for model weights with_video_vae_encoder: Whether to load the video VAE encoder (for preprocessing) with_video_vae_decoder: Whether to load the video VAE decoder (for inference/validation) with_audio_vae_decoder: Whether to load the audio VAE decoder with_vocoder: Whether to load the vocoder with_text_encoder: Whether to load the text encoder Returns: LtxModelComponents containing all loaded model components """ from ltx_core.pipeline.components.schedulers import LTX2Scheduler checkpoint_path = Path(checkpoint_path) # Validate checkpoint exists if not checkpoint_path.exists(): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") logger.info(f"Loading LTX-2 model from {checkpoint_path}") torch_device = _to_torch_device(device) # Load transformer logger.debug("Loading transformer...") transformer = load_transformer(checkpoint_path, torch_device, dtype) # Load video VAE encoder video_vae_encoder = None if with_video_vae_encoder: logger.debug("Loading video VAE encoder...") video_vae_encoder = load_video_vae_encoder(checkpoint_path, torch_device, dtype) # Load video VAE decoder video_vae_decoder = None if with_video_vae_decoder: logger.debug("Loading video VAE decoder...") video_vae_decoder = load_video_vae_decoder(checkpoint_path, torch_device, dtype) # Load audio VAE decoder audio_vae_decoder = None if with_audio_vae_decoder: logger.debug("Loading audio VAE decoder...") audio_vae_decoder = load_audio_vae_decoder(checkpoint_path, torch_device, dtype) # Load vocoder vocoder = None if with_vocoder: logger.debug("Loading vocoder...") vocoder = load_vocoder(checkpoint_path, torch_device, dtype) # Load text encoder text_encoder = None if with_text_encoder: if text_encoder_path is None: raise ValueError("text_encoder_path must be provided when with_text_encoder=True") logger.debug("Loading Gemma text encoder...") text_encoder = load_text_encoder(checkpoint_path, text_encoder_path, torch_device, dtype) # Create scheduler (stateless, no loading needed) scheduler = LTX2Scheduler() return LtxModelComponents( transformer=transformer, video_vae_encoder=video_vae_encoder, video_vae_decoder=video_vae_decoder, audio_vae_decoder=audio_vae_decoder, vocoder=vocoder, text_encoder=text_encoder, scheduler=scheduler, )