Spaces:
Running
on
Zero
Running
on
Zero
| # ruff: noqa: PLC0415 | |
| """ | |
| Model loader for LTX-2 trainer using the new ltx-core package. | |
| This module provides a unified interface for loading LTX-2 model components | |
| for training, using SingleGPUModelBuilder from ltx-core. | |
| Example usage: | |
| # Load individual components | |
| vae_encoder = load_video_vae_encoder("/path/to/checkpoint.safetensors", device="cuda") | |
| vae_decoder = load_video_vae_decoder("/path/to/checkpoint.safetensors", device="cuda") | |
| text_encoder = load_text_encoder("/path/to/checkpoint.safetensors", "/path/to/gemma", device="cuda") | |
| # Load all components at once | |
| components = load_model("/path/to/checkpoint.safetensors", text_encoder_path="/path/to/gemma") | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import TYPE_CHECKING | |
| import torch | |
| from ltx_trainer import logger | |
| # Type alias for device specification | |
| Device = str | torch.device | |
| # Type checking imports (not loaded at runtime) | |
| if TYPE_CHECKING: | |
| from ltx_core.model.audio_vae.audio_vae import Decoder as AudioVAEDecoder | |
| from ltx_core.model.audio_vae.audio_vae import Encoder as AudioVAEEncoder | |
| from ltx_core.model.audio_vae.vocoder import Vocoder | |
| from ltx_core.model.clip.gemma.encoders.av_encoder import AVGemmaTextEncoderModel | |
| from ltx_core.model.transformer.model import LTXModel | |
| from ltx_core.model.video_vae.video_vae import Decoder as VideoVAEDecoder | |
| from ltx_core.model.video_vae.video_vae import Encoder as VideoVAEEncoder | |
| from ltx_core.pipeline.components.schedulers import LTX2Scheduler | |
| def _to_torch_device(device: Device) -> torch.device: | |
| """Convert device specification to torch.device.""" | |
| return torch.device(device) if isinstance(device, str) else device | |
| # ============================================================================= | |
| # Individual Component Loaders | |
| # ============================================================================= | |
| def load_transformer( | |
| checkpoint_path: str | Path, | |
| device: Device = "cpu", | |
| dtype: torch.dtype = torch.bfloat16, | |
| ) -> "LTXModel": | |
| """Load the LTX transformer model. | |
| Args: | |
| checkpoint_path: Path to the safetensors checkpoint file | |
| device: Device to load model on | |
| dtype: Data type for model weights | |
| Returns: | |
| Loaded LTXModel transformer | |
| """ | |
| from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder | |
| from ltx_core.model.transformer.model_configurator import ( | |
| LTXV_MODEL_COMFY_RENAMING_MAP, | |
| LTXModelConfigurator, | |
| ) | |
| return SingleGPUModelBuilder( | |
| model_path=str(checkpoint_path), | |
| model_class_configurator=LTXModelConfigurator, | |
| model_sd_ops=LTXV_MODEL_COMFY_RENAMING_MAP, | |
| ).build(device=_to_torch_device(device), dtype=dtype) | |
| def load_video_vae_encoder( | |
| checkpoint_path: str | Path, | |
| device: Device = "cpu", | |
| dtype: torch.dtype = torch.bfloat16, | |
| ) -> "VideoVAEEncoder": | |
| """Load the video VAE encoder (for preprocessing). | |
| Args: | |
| checkpoint_path: Path to the safetensors checkpoint file | |
| device: Device to load model on | |
| dtype: Data type for model weights | |
| Returns: | |
| Loaded VideoVAEEncoder | |
| """ | |
| from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder | |
| from ltx_core.model.video_vae.model_configurator import VAE_ENCODER_COMFY_KEYS_FILTER | |
| from ltx_core.model.video_vae.model_configurator import ( | |
| VAEEncoderConfigurator as VideoVAEEncoderConfigurator, | |
| ) | |
| return SingleGPUModelBuilder( | |
| model_path=str(checkpoint_path), | |
| model_class_configurator=VideoVAEEncoderConfigurator, | |
| model_sd_ops=VAE_ENCODER_COMFY_KEYS_FILTER, | |
| ).build(device=_to_torch_device(device), dtype=dtype) | |
| def load_video_vae_decoder( | |
| checkpoint_path: str | Path, | |
| device: Device = "cpu", | |
| dtype: torch.dtype = torch.bfloat16, | |
| ) -> "VideoVAEDecoder": | |
| """Load the video VAE decoder (for inference/validation). | |
| Args: | |
| checkpoint_path: Path to the safetensors checkpoint file | |
| device: Device to load model on | |
| dtype: Data type for model weights | |
| Returns: | |
| Loaded VideoVAEDecoder | |
| """ | |
| from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder | |
| from ltx_core.model.video_vae.model_configurator import VAE_DECODER_COMFY_KEYS_FILTER | |
| from ltx_core.model.video_vae.model_configurator import ( | |
| VAEDecoderConfigurator as VideoVAEDecoderConfigurator, | |
| ) | |
| return SingleGPUModelBuilder( | |
| model_path=str(checkpoint_path), | |
| model_class_configurator=VideoVAEDecoderConfigurator, | |
| model_sd_ops=VAE_DECODER_COMFY_KEYS_FILTER, | |
| ).build(device=_to_torch_device(device), dtype=dtype) | |
| def load_audio_vae_encoder( | |
| checkpoint_path: str | Path, | |
| device: Device = "cpu", | |
| dtype: torch.dtype = torch.bfloat16, | |
| ) -> "AudioVAEEncoder": | |
| """Load the audio VAE encoder (for preprocessing). | |
| Args: | |
| checkpoint_path: Path to the safetensors checkpoint file | |
| device: Device to load model on | |
| dtype: Data type for model weights (default bfloat16, but float32 recommended for quality) | |
| Returns: | |
| Loaded AudioVAEEncoder | |
| """ | |
| from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder | |
| from ltx_core.model.audio_vae.model_configurator import AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER | |
| from ltx_core.model.audio_vae.model_configurator import ( | |
| VAEEncoderConfigurator as AudioVAEEncoderConfigurator, | |
| ) | |
| return SingleGPUModelBuilder( | |
| model_path=str(checkpoint_path), | |
| model_class_configurator=AudioVAEEncoderConfigurator, | |
| model_sd_ops=AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER, | |
| ).build(device=_to_torch_device(device), dtype=dtype) | |
| def load_audio_vae_decoder( | |
| checkpoint_path: str | Path, | |
| device: Device = "cpu", | |
| dtype: torch.dtype = torch.bfloat16, | |
| ) -> "AudioVAEDecoder": | |
| """Load the audio VAE decoder. | |
| Args: | |
| checkpoint_path: Path to the safetensors checkpoint file | |
| device: Device to load model on | |
| dtype: Data type for model weights | |
| Returns: | |
| Loaded AudioVAEDecoder | |
| """ | |
| from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder | |
| from ltx_core.model.audio_vae.model_configurator import AUDIO_VAE_DECODER_COMFY_KEYS_FILTER | |
| from ltx_core.model.audio_vae.model_configurator import ( | |
| VAEDecoderConfigurator as AudioVAEDecoderConfigurator, | |
| ) | |
| return SingleGPUModelBuilder( | |
| model_path=str(checkpoint_path), | |
| model_class_configurator=AudioVAEDecoderConfigurator, | |
| model_sd_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER, | |
| ).build(device=_to_torch_device(device), dtype=dtype) | |
| def load_vocoder( | |
| checkpoint_path: str | Path, | |
| device: Device = "cpu", | |
| dtype: torch.dtype = torch.bfloat16, | |
| ) -> "Vocoder": | |
| """Load the vocoder (for audio waveform generation). | |
| Args: | |
| checkpoint_path: Path to the safetensors checkpoint file | |
| device: Device to load model on | |
| dtype: Data type for model weights | |
| Returns: | |
| Loaded Vocoder | |
| """ | |
| from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder | |
| from ltx_core.model.audio_vae.model_configurator import VOCODER_COMFY_KEYS_FILTER, VocoderConfigurator | |
| return SingleGPUModelBuilder( | |
| model_path=str(checkpoint_path), | |
| model_class_configurator=VocoderConfigurator, | |
| model_sd_ops=VOCODER_COMFY_KEYS_FILTER, | |
| ).build(device=_to_torch_device(device), dtype=dtype) | |
| def load_text_encoder( | |
| checkpoint_path: str | Path, | |
| gemma_model_path: str | Path, | |
| device: Device = "cpu", | |
| dtype: torch.dtype = torch.bfloat16, | |
| ) -> "AVGemmaTextEncoderModel": | |
| """Load the Gemma text encoder. | |
| Args: | |
| checkpoint_path: Path to the LTX-2 safetensors checkpoint file | |
| gemma_model_path: Path to Gemma model directory | |
| device: Device to load model on | |
| dtype: Data type for model weights | |
| Returns: | |
| Loaded AVGemmaTextEncoderModel | |
| """ | |
| from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder | |
| from ltx_core.model.clip.gemma.encoders.av_encoder import ( | |
| AV_GEMMA_TEXT_ENCODER_KEY_OPS, | |
| AVGemmaTextEncoderModelConfigurator, | |
| ) | |
| from ltx_core.model.clip.gemma.encoders.base_encoder import module_ops_from_gemma_root | |
| if not Path(gemma_model_path).is_dir(): | |
| raise ValueError(f"Gemma model path is not a directory: {gemma_model_path}") | |
| torch_device = _to_torch_device(device) | |
| text_encoder = SingleGPUModelBuilder( | |
| model_path=str(checkpoint_path), | |
| model_class_configurator=AVGemmaTextEncoderModelConfigurator, | |
| model_sd_ops=AV_GEMMA_TEXT_ENCODER_KEY_OPS, | |
| module_ops=module_ops_from_gemma_root(str(gemma_model_path)), | |
| ).build(device=torch_device, dtype=dtype) | |
| return text_encoder | |
| # ============================================================================= | |
| # Combined Component Loader | |
| # ============================================================================= | |
| class LtxModelComponents: | |
| """Container for all LTX-2 model components.""" | |
| transformer: "LTXModel" | |
| video_vae_encoder: "VideoVAEEncoder | None" = None | |
| video_vae_decoder: "VideoVAEDecoder | None" = None | |
| audio_vae_decoder: "AudioVAEDecoder | None" = None | |
| vocoder: "Vocoder | None" = None | |
| text_encoder: "AVGemmaTextEncoderModel | None" = None | |
| scheduler: "LTX2Scheduler | None" = None | |
| def load_model( | |
| checkpoint_path: str | Path, | |
| text_encoder_path: str | Path | None = None, | |
| device: Device = "cpu", | |
| dtype: torch.dtype = torch.bfloat16, | |
| with_video_vae_encoder: bool = False, | |
| with_video_vae_decoder: bool = True, | |
| with_audio_vae_decoder: bool = True, | |
| with_vocoder: bool = True, | |
| with_text_encoder: bool = True, | |
| ) -> LtxModelComponents: | |
| """ | |
| Load LTX-2 model components from a safetensors checkpoint. | |
| This is a convenience function that loads multiple components at once. | |
| For loading individual components, use the dedicated functions: | |
| - load_transformer() | |
| - load_video_vae_encoder() | |
| - load_video_vae_decoder() | |
| - load_audio_vae_decoder() | |
| - load_vocoder() | |
| - load_text_encoder() | |
| Args: | |
| checkpoint_path: Path to the safetensors checkpoint file | |
| text_encoder_path: Path to Gemma model directory (required if with_text_encoder=True) | |
| device: Device to load models on ("cuda", "cpu", etc.) | |
| dtype: Data type for model weights | |
| with_video_vae_encoder: Whether to load the video VAE encoder (for preprocessing) | |
| with_video_vae_decoder: Whether to load the video VAE decoder (for inference/validation) | |
| with_audio_vae_decoder: Whether to load the audio VAE decoder | |
| with_vocoder: Whether to load the vocoder | |
| with_text_encoder: Whether to load the text encoder | |
| Returns: | |
| LtxModelComponents containing all loaded model components | |
| """ | |
| from ltx_core.pipeline.components.schedulers import LTX2Scheduler | |
| checkpoint_path = Path(checkpoint_path) | |
| # Validate checkpoint exists | |
| if not checkpoint_path.exists(): | |
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") | |
| logger.info(f"Loading LTX-2 model from {checkpoint_path}") | |
| torch_device = _to_torch_device(device) | |
| # Load transformer | |
| logger.debug("Loading transformer...") | |
| transformer = load_transformer(checkpoint_path, torch_device, dtype) | |
| # Load video VAE encoder | |
| video_vae_encoder = None | |
| if with_video_vae_encoder: | |
| logger.debug("Loading video VAE encoder...") | |
| video_vae_encoder = load_video_vae_encoder(checkpoint_path, torch_device, dtype) | |
| # Load video VAE decoder | |
| video_vae_decoder = None | |
| if with_video_vae_decoder: | |
| logger.debug("Loading video VAE decoder...") | |
| video_vae_decoder = load_video_vae_decoder(checkpoint_path, torch_device, dtype) | |
| # Load audio VAE decoder | |
| audio_vae_decoder = None | |
| if with_audio_vae_decoder: | |
| logger.debug("Loading audio VAE decoder...") | |
| audio_vae_decoder = load_audio_vae_decoder(checkpoint_path, torch_device, dtype) | |
| # Load vocoder | |
| vocoder = None | |
| if with_vocoder: | |
| logger.debug("Loading vocoder...") | |
| vocoder = load_vocoder(checkpoint_path, torch_device, dtype) | |
| # Load text encoder | |
| text_encoder = None | |
| if with_text_encoder: | |
| if text_encoder_path is None: | |
| raise ValueError("text_encoder_path must be provided when with_text_encoder=True") | |
| logger.debug("Loading Gemma text encoder...") | |
| text_encoder = load_text_encoder(checkpoint_path, text_encoder_path, torch_device, dtype) | |
| # Create scheduler (stateless, no loading needed) | |
| scheduler = LTX2Scheduler() | |
| return LtxModelComponents( | |
| transformer=transformer, | |
| video_vae_encoder=video_vae_encoder, | |
| video_vae_decoder=video_vae_decoder, | |
| audio_vae_decoder=audio_vae_decoder, | |
| vocoder=vocoder, | |
| text_encoder=text_encoder, | |
| scheduler=scheduler, | |
| ) | |