|
|
|
|
|
|
|
|
""" |
|
|
Model loader for LTX-2 trainer using the new ltx-core package. |
|
|
|
|
|
This module provides a unified interface for loading LTX-2 model components |
|
|
for training, using SingleGPUModelBuilder from ltx-core. |
|
|
|
|
|
Example usage: |
|
|
# Load individual components |
|
|
vae_encoder = load_video_vae_encoder("/path/to/checkpoint.safetensors", device="cuda") |
|
|
vae_decoder = load_video_vae_decoder("/path/to/checkpoint.safetensors", device="cuda") |
|
|
text_encoder = load_text_encoder("/path/to/checkpoint.safetensors", "/path/to/gemma", device="cuda") |
|
|
|
|
|
# Load all components at once |
|
|
components = load_model("/path/to/checkpoint.safetensors", text_encoder_path="/path/to/gemma") |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
from typing import TYPE_CHECKING |
|
|
|
|
|
import torch |
|
|
|
|
|
from ltx_trainer import logger |
|
|
|
|
|
|
|
|
Device = str | torch.device |
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from ltx_core.model.audio_vae.audio_vae import Decoder as AudioVAEDecoder |
|
|
from ltx_core.model.audio_vae.audio_vae import Encoder as AudioVAEEncoder |
|
|
from ltx_core.model.audio_vae.vocoder import Vocoder |
|
|
from ltx_core.model.clip.gemma.encoders.av_encoder import AVGemmaTextEncoderModel |
|
|
from ltx_core.model.transformer.model import LTXModel |
|
|
from ltx_core.model.video_vae.video_vae import Decoder as VideoVAEDecoder |
|
|
from ltx_core.model.video_vae.video_vae import Encoder as VideoVAEEncoder |
|
|
from ltx_core.pipeline.components.schedulers import LTX2Scheduler |
|
|
|
|
|
|
|
|
def _to_torch_device(device: Device) -> torch.device: |
|
|
"""Convert device specification to torch.device.""" |
|
|
return torch.device(device) if isinstance(device, str) else device |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_transformer( |
|
|
checkpoint_path: str | Path, |
|
|
device: Device = "cpu", |
|
|
dtype: torch.dtype = torch.bfloat16, |
|
|
) -> "LTXModel": |
|
|
"""Load the LTX transformer model. |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to the safetensors checkpoint file |
|
|
device: Device to load model on |
|
|
dtype: Data type for model weights |
|
|
|
|
|
Returns: |
|
|
Loaded LTXModel transformer |
|
|
""" |
|
|
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder |
|
|
from ltx_core.model.transformer.model_configurator import ( |
|
|
LTXV_MODEL_COMFY_RENAMING_MAP, |
|
|
LTXModelConfigurator, |
|
|
) |
|
|
|
|
|
return SingleGPUModelBuilder( |
|
|
model_path=str(checkpoint_path), |
|
|
model_class_configurator=LTXModelConfigurator, |
|
|
model_sd_ops=LTXV_MODEL_COMFY_RENAMING_MAP, |
|
|
).build(device=_to_torch_device(device), dtype=dtype) |
|
|
|
|
|
|
|
|
def load_video_vae_encoder( |
|
|
checkpoint_path: str | Path, |
|
|
device: Device = "cpu", |
|
|
dtype: torch.dtype = torch.bfloat16, |
|
|
) -> "VideoVAEEncoder": |
|
|
"""Load the video VAE encoder (for preprocessing). |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to the safetensors checkpoint file |
|
|
device: Device to load model on |
|
|
dtype: Data type for model weights |
|
|
|
|
|
Returns: |
|
|
Loaded VideoVAEEncoder |
|
|
""" |
|
|
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder |
|
|
from ltx_core.model.video_vae.model_configurator import VAE_ENCODER_COMFY_KEYS_FILTER |
|
|
from ltx_core.model.video_vae.model_configurator import ( |
|
|
VAEEncoderConfigurator as VideoVAEEncoderConfigurator, |
|
|
) |
|
|
|
|
|
return SingleGPUModelBuilder( |
|
|
model_path=str(checkpoint_path), |
|
|
model_class_configurator=VideoVAEEncoderConfigurator, |
|
|
model_sd_ops=VAE_ENCODER_COMFY_KEYS_FILTER, |
|
|
).build(device=_to_torch_device(device), dtype=dtype) |
|
|
|
|
|
|
|
|
def load_video_vae_decoder( |
|
|
checkpoint_path: str | Path, |
|
|
device: Device = "cpu", |
|
|
dtype: torch.dtype = torch.bfloat16, |
|
|
) -> "VideoVAEDecoder": |
|
|
"""Load the video VAE decoder (for inference/validation). |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to the safetensors checkpoint file |
|
|
device: Device to load model on |
|
|
dtype: Data type for model weights |
|
|
|
|
|
Returns: |
|
|
Loaded VideoVAEDecoder |
|
|
""" |
|
|
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder |
|
|
from ltx_core.model.video_vae.model_configurator import VAE_DECODER_COMFY_KEYS_FILTER |
|
|
from ltx_core.model.video_vae.model_configurator import ( |
|
|
VAEDecoderConfigurator as VideoVAEDecoderConfigurator, |
|
|
) |
|
|
|
|
|
return SingleGPUModelBuilder( |
|
|
model_path=str(checkpoint_path), |
|
|
model_class_configurator=VideoVAEDecoderConfigurator, |
|
|
model_sd_ops=VAE_DECODER_COMFY_KEYS_FILTER, |
|
|
).build(device=_to_torch_device(device), dtype=dtype) |
|
|
|
|
|
|
|
|
def load_audio_vae_encoder( |
|
|
checkpoint_path: str | Path, |
|
|
device: Device = "cpu", |
|
|
dtype: torch.dtype = torch.bfloat16, |
|
|
) -> "AudioVAEEncoder": |
|
|
"""Load the audio VAE encoder (for preprocessing). |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to the safetensors checkpoint file |
|
|
device: Device to load model on |
|
|
dtype: Data type for model weights (default bfloat16, but float32 recommended for quality) |
|
|
|
|
|
Returns: |
|
|
Loaded AudioVAEEncoder |
|
|
""" |
|
|
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder |
|
|
from ltx_core.model.audio_vae.model_configurator import AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER |
|
|
from ltx_core.model.audio_vae.model_configurator import ( |
|
|
VAEEncoderConfigurator as AudioVAEEncoderConfigurator, |
|
|
) |
|
|
|
|
|
return SingleGPUModelBuilder( |
|
|
model_path=str(checkpoint_path), |
|
|
model_class_configurator=AudioVAEEncoderConfigurator, |
|
|
model_sd_ops=AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER, |
|
|
).build(device=_to_torch_device(device), dtype=dtype) |
|
|
|
|
|
|
|
|
def load_audio_vae_decoder( |
|
|
checkpoint_path: str | Path, |
|
|
device: Device = "cpu", |
|
|
dtype: torch.dtype = torch.bfloat16, |
|
|
) -> "AudioVAEDecoder": |
|
|
"""Load the audio VAE decoder. |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to the safetensors checkpoint file |
|
|
device: Device to load model on |
|
|
dtype: Data type for model weights |
|
|
|
|
|
Returns: |
|
|
Loaded AudioVAEDecoder |
|
|
""" |
|
|
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder |
|
|
from ltx_core.model.audio_vae.model_configurator import AUDIO_VAE_DECODER_COMFY_KEYS_FILTER |
|
|
from ltx_core.model.audio_vae.model_configurator import ( |
|
|
VAEDecoderConfigurator as AudioVAEDecoderConfigurator, |
|
|
) |
|
|
|
|
|
return SingleGPUModelBuilder( |
|
|
model_path=str(checkpoint_path), |
|
|
model_class_configurator=AudioVAEDecoderConfigurator, |
|
|
model_sd_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER, |
|
|
).build(device=_to_torch_device(device), dtype=dtype) |
|
|
|
|
|
|
|
|
def load_vocoder( |
|
|
checkpoint_path: str | Path, |
|
|
device: Device = "cpu", |
|
|
dtype: torch.dtype = torch.bfloat16, |
|
|
) -> "Vocoder": |
|
|
"""Load the vocoder (for audio waveform generation). |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to the safetensors checkpoint file |
|
|
device: Device to load model on |
|
|
dtype: Data type for model weights |
|
|
|
|
|
Returns: |
|
|
Loaded Vocoder |
|
|
""" |
|
|
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder |
|
|
from ltx_core.model.audio_vae.model_configurator import VOCODER_COMFY_KEYS_FILTER, VocoderConfigurator |
|
|
|
|
|
return SingleGPUModelBuilder( |
|
|
model_path=str(checkpoint_path), |
|
|
model_class_configurator=VocoderConfigurator, |
|
|
model_sd_ops=VOCODER_COMFY_KEYS_FILTER, |
|
|
).build(device=_to_torch_device(device), dtype=dtype) |
|
|
|
|
|
|
|
|
def load_text_encoder( |
|
|
checkpoint_path: str | Path, |
|
|
gemma_model_path: str | Path, |
|
|
device: Device = "cpu", |
|
|
dtype: torch.dtype = torch.bfloat16, |
|
|
) -> "AVGemmaTextEncoderModel": |
|
|
"""Load the Gemma text encoder. |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to the LTX-2 safetensors checkpoint file |
|
|
gemma_model_path: Path to Gemma model directory |
|
|
device: Device to load model on |
|
|
dtype: Data type for model weights |
|
|
|
|
|
Returns: |
|
|
Loaded AVGemmaTextEncoderModel |
|
|
""" |
|
|
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder |
|
|
from ltx_core.model.clip.gemma.encoders.av_encoder import ( |
|
|
AV_GEMMA_TEXT_ENCODER_KEY_OPS, |
|
|
AVGemmaTextEncoderModelConfigurator, |
|
|
) |
|
|
from ltx_core.model.clip.gemma.encoders.base_encoder import module_ops_from_gemma_root |
|
|
|
|
|
if not Path(gemma_model_path).is_dir(): |
|
|
raise ValueError(f"Gemma model path is not a directory: {gemma_model_path}") |
|
|
|
|
|
torch_device = _to_torch_device(device) |
|
|
text_encoder = SingleGPUModelBuilder( |
|
|
model_path=str(checkpoint_path), |
|
|
model_class_configurator=AVGemmaTextEncoderModelConfigurator, |
|
|
model_sd_ops=AV_GEMMA_TEXT_ENCODER_KEY_OPS, |
|
|
module_ops=module_ops_from_gemma_root(str(gemma_model_path)), |
|
|
).build(device=torch_device, dtype=dtype) |
|
|
|
|
|
return text_encoder |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class LtxModelComponents: |
|
|
"""Container for all LTX-2 model components.""" |
|
|
|
|
|
transformer: "LTXModel" |
|
|
video_vae_encoder: "VideoVAEEncoder | None" = None |
|
|
video_vae_decoder: "VideoVAEDecoder | None" = None |
|
|
audio_vae_decoder: "AudioVAEDecoder | None" = None |
|
|
vocoder: "Vocoder | None" = None |
|
|
text_encoder: "AVGemmaTextEncoderModel | None" = None |
|
|
scheduler: "LTX2Scheduler | None" = None |
|
|
|
|
|
|
|
|
def load_model( |
|
|
checkpoint_path: str | Path, |
|
|
text_encoder_path: str | Path | None = None, |
|
|
device: Device = "cpu", |
|
|
dtype: torch.dtype = torch.bfloat16, |
|
|
with_video_vae_encoder: bool = False, |
|
|
with_video_vae_decoder: bool = True, |
|
|
with_audio_vae_decoder: bool = True, |
|
|
with_vocoder: bool = True, |
|
|
with_text_encoder: bool = True, |
|
|
) -> LtxModelComponents: |
|
|
""" |
|
|
Load LTX-2 model components from a safetensors checkpoint. |
|
|
|
|
|
This is a convenience function that loads multiple components at once. |
|
|
For loading individual components, use the dedicated functions: |
|
|
- load_transformer() |
|
|
- load_video_vae_encoder() |
|
|
- load_video_vae_decoder() |
|
|
- load_audio_vae_decoder() |
|
|
- load_vocoder() |
|
|
- load_text_encoder() |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to the safetensors checkpoint file |
|
|
text_encoder_path: Path to Gemma model directory (required if with_text_encoder=True) |
|
|
device: Device to load models on ("cuda", "cpu", etc.) |
|
|
dtype: Data type for model weights |
|
|
with_video_vae_encoder: Whether to load the video VAE encoder (for preprocessing) |
|
|
with_video_vae_decoder: Whether to load the video VAE decoder (for inference/validation) |
|
|
with_audio_vae_decoder: Whether to load the audio VAE decoder |
|
|
with_vocoder: Whether to load the vocoder |
|
|
with_text_encoder: Whether to load the text encoder |
|
|
|
|
|
Returns: |
|
|
LtxModelComponents containing all loaded model components |
|
|
""" |
|
|
from ltx_core.pipeline.components.schedulers import LTX2Scheduler |
|
|
|
|
|
checkpoint_path = Path(checkpoint_path) |
|
|
|
|
|
|
|
|
if not checkpoint_path.exists(): |
|
|
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") |
|
|
|
|
|
logger.info(f"Loading LTX-2 model from {checkpoint_path}") |
|
|
|
|
|
torch_device = _to_torch_device(device) |
|
|
|
|
|
|
|
|
logger.debug("Loading transformer...") |
|
|
transformer = load_transformer(checkpoint_path, torch_device, dtype) |
|
|
|
|
|
|
|
|
video_vae_encoder = None |
|
|
if with_video_vae_encoder: |
|
|
logger.debug("Loading video VAE encoder...") |
|
|
video_vae_encoder = load_video_vae_encoder(checkpoint_path, torch_device, dtype) |
|
|
|
|
|
|
|
|
video_vae_decoder = None |
|
|
if with_video_vae_decoder: |
|
|
logger.debug("Loading video VAE decoder...") |
|
|
video_vae_decoder = load_video_vae_decoder(checkpoint_path, torch_device, dtype) |
|
|
|
|
|
|
|
|
audio_vae_decoder = None |
|
|
if with_audio_vae_decoder: |
|
|
logger.debug("Loading audio VAE decoder...") |
|
|
audio_vae_decoder = load_audio_vae_decoder(checkpoint_path, torch_device, dtype) |
|
|
|
|
|
|
|
|
vocoder = None |
|
|
if with_vocoder: |
|
|
logger.debug("Loading vocoder...") |
|
|
vocoder = load_vocoder(checkpoint_path, torch_device, dtype) |
|
|
|
|
|
|
|
|
text_encoder = None |
|
|
if with_text_encoder: |
|
|
if text_encoder_path is None: |
|
|
raise ValueError("text_encoder_path must be provided when with_text_encoder=True") |
|
|
logger.debug("Loading Gemma text encoder...") |
|
|
text_encoder = load_text_encoder(checkpoint_path, text_encoder_path, torch_device, dtype) |
|
|
|
|
|
|
|
|
scheduler = LTX2Scheduler() |
|
|
|
|
|
return LtxModelComponents( |
|
|
transformer=transformer, |
|
|
video_vae_encoder=video_vae_encoder, |
|
|
video_vae_decoder=video_vae_decoder, |
|
|
audio_vae_decoder=audio_vae_decoder, |
|
|
vocoder=vocoder, |
|
|
text_encoder=text_encoder, |
|
|
scheduler=scheduler, |
|
|
) |
|
|
|