| | |
| |
|
| | """ |
| | Decode precomputed video latents back into videos using the VAE. |
| | |
| | This script loads latent files saved during preprocessing and decodes them |
| | back into video clips using the same VAE model. |
| | |
| | Basic usage: |
| | python scripts/decode_latents.py /path/to/latents/dir /path/to/output \ |
| | --model-source /path/to/ltx2.safetensors |
| | """ |
| |
|
| | from pathlib import Path |
| |
|
| | import torch |
| | import torchaudio |
| | import torchvision.utils |
| | import typer |
| | from rich.console import Console |
| | from rich.progress import ( |
| | BarColumn, |
| | MofNCompleteColumn, |
| | Progress, |
| | SpinnerColumn, |
| | TextColumn, |
| | TimeElapsedColumn, |
| | TimeRemainingColumn, |
| | ) |
| | from transformers.utils.logging import disable_progress_bar |
| |
|
| | from ltx_trainer import logger |
| | from ltx_trainer.model_loader import load_audio_vae_decoder, load_video_vae_decoder, load_vocoder |
| | from ltx_trainer.video_utils import save_video |
| |
|
| | disable_progress_bar() |
| | console = Console() |
| | app = typer.Typer( |
| | pretty_exceptions_enable=False, |
| | no_args_is_help=True, |
| | help="Decode precomputed video latents back into videos using the VAE.", |
| | ) |
| |
|
| |
|
| | class LatentsDecoder: |
| | def __init__( |
| | self, |
| | model_path: str, |
| | device: str = "cuda", |
| | vae_tiling: bool = False, |
| | with_audio: bool = False, |
| | ): |
| | """Initialize the decoder with model configuration. |
| | |
| | Args: |
| | model_path: Path to LTX-2 checkpoint (.safetensors) |
| | device: Device to use for computation |
| | vae_tiling: Whether to enable VAE tiling for larger video resolutions |
| | with_audio: Whether to load audio VAE for audio decoding |
| | """ |
| | self.device = torch.device(device) |
| | self.model_path = model_path |
| | self.vae = None |
| | self.audio_vae = None |
| | self.vocoder = None |
| | self._load_model(model_path, vae_tiling, with_audio) |
| |
|
| | def _load_model(self, model_path: str, vae_tiling: bool, with_audio: bool = False) -> None: |
| | """Initialize and load the VAE model(s).""" |
| | with console.status(f"[bold]Loading video VAE decoder from {model_path}...", spinner="dots"): |
| | self.vae = load_video_vae_decoder(model_path, device=self.device, dtype=torch.bfloat16) |
| |
|
| | if vae_tiling: |
| | self.vae.enable_tiling() |
| |
|
| | if with_audio: |
| | with console.status(f"[bold]Loading audio VAE decoder from {model_path}...", spinner="dots"): |
| | self.audio_vae = load_audio_vae_decoder(model_path, device=self.device, dtype=torch.bfloat16) |
| |
|
| | with console.status(f"[bold]Loading vocoder from {model_path}...", spinner="dots"): |
| | self.vocoder = load_vocoder(model_path, device=self.device) |
| |
|
| | @torch.inference_mode() |
| | def decode(self, latents_dir: Path, output_dir: Path, seed: int | None = None) -> None: |
| | """Decode all latent files in the directory recursively. |
| | |
| | Args: |
| | latents_dir: Directory containing latent files (.pt) |
| | output_dir: Directory to save decoded videos |
| | seed: Optional random seed for noise generation |
| | """ |
| | |
| | latent_files = list(latents_dir.rglob("*.pt")) |
| |
|
| | if not latent_files: |
| | logger.warning(f"No .pt files found in {latents_dir}") |
| | return |
| |
|
| | logger.info(f"Found {len(latent_files):,} latent files to decode") |
| |
|
| | |
| | with Progress( |
| | SpinnerColumn(), |
| | TextColumn("[progress.description]{task.description}"), |
| | BarColumn(), |
| | MofNCompleteColumn(), |
| | TimeElapsedColumn(), |
| | TimeRemainingColumn(), |
| | console=console, |
| | ) as progress: |
| | task = progress.add_task("Decoding latents", total=len(latent_files)) |
| |
|
| | for latent_file in latent_files: |
| | |
| | rel_path = latent_file.relative_to(latents_dir) |
| | output_subdir = output_dir / rel_path.parent |
| | output_subdir.mkdir(parents=True, exist_ok=True) |
| |
|
| | try: |
| | self._process_file(latent_file, output_subdir, seed) |
| | except Exception as e: |
| | logger.error(f"Error processing {latent_file}: {e}") |
| | continue |
| |
|
| | progress.advance(task) |
| |
|
| | logger.info(f"Decoding complete! Videos saved to {output_dir}") |
| |
|
| | def _process_file(self, latent_file: Path, output_dir: Path, seed: int | None) -> None: |
| | """Process a single latent file.""" |
| | |
| | data = torch.load(latent_file, map_location=self.device, weights_only=False) |
| |
|
| | |
| | latents = data["latents"] |
| | num_frames = data["num_frames"] |
| | height = data["height"] |
| | width = data["width"] |
| |
|
| | |
| | if latents.dim() == 2: |
| | |
| | _seq_len, channels = latents.shape |
| | latents = latents.reshape(num_frames, height, width, channels) |
| | latents = latents.permute(3, 0, 1, 2) |
| |
|
| | |
| | latents = latents.unsqueeze(0).to(device=self.device, dtype=torch.bfloat16) |
| |
|
| | |
| | generator = None |
| | if seed is not None: |
| | generator = torch.Generator(device=self.device) |
| | generator.manual_seed(seed) |
| |
|
| | |
| | video = self.vae(latents) |
| |
|
| | |
| | video = video[0] |
| | video = video.permute(1, 0, 2, 3) |
| | video = (video + 1) / 2 |
| | video = video.clamp(0, 1) |
| |
|
| | |
| | is_image = video.shape[0] == 1 |
| | if is_image: |
| | |
| | output_path = output_dir / f"{latent_file.stem}.png" |
| | torchvision.utils.save_image( |
| | video[0], |
| | str(output_path), |
| | ) |
| | else: |
| | |
| | output_path = output_dir / f"{latent_file.stem}.mp4" |
| | fps = data.get("fps", 24) |
| | save_video( |
| | video_tensor=video, |
| | output_path=output_path, |
| | fps=fps, |
| | ) |
| |
|
| | @torch.inference_mode() |
| | def decode_audio(self, latents_dir: Path, output_dir: Path) -> None: |
| | """Decode all audio latent files in the directory recursively. |
| | |
| | Args: |
| | latents_dir: Directory containing audio latent files (.pt) |
| | output_dir: Directory to save decoded audio files |
| | """ |
| | |
| | if self.audio_vae is None or self.vocoder is None: |
| | logger.warning("Audio VAE or vocoder not loaded. Skipping audio decoding.") |
| | return |
| |
|
| | |
| | latent_files = list(latents_dir.rglob("*.pt")) |
| |
|
| | if not latent_files: |
| | logger.warning(f"No .pt files found in {latents_dir}") |
| | return |
| |
|
| | logger.info(f"Found {len(latent_files):,} audio latent files to decode") |
| |
|
| | |
| | with Progress( |
| | SpinnerColumn(), |
| | TextColumn("[progress.description]{task.description}"), |
| | BarColumn(), |
| | MofNCompleteColumn(), |
| | TimeElapsedColumn(), |
| | TimeRemainingColumn(), |
| | console=console, |
| | ) as progress: |
| | task = progress.add_task("Decoding audio latents", total=len(latent_files)) |
| |
|
| | for latent_file in latent_files: |
| | |
| | rel_path = latent_file.relative_to(latents_dir) |
| | output_subdir = output_dir / rel_path.parent |
| | output_subdir.mkdir(parents=True, exist_ok=True) |
| |
|
| | try: |
| | self._process_audio_file(latent_file, output_subdir) |
| | except Exception as e: |
| | logger.error(f"Error processing audio {latent_file}: {e}") |
| | continue |
| |
|
| | progress.advance(task) |
| |
|
| | logger.info(f"Audio decoding complete! Audio files saved to {output_dir}") |
| |
|
| | def _process_audio_file(self, latent_file: Path, output_dir: Path) -> None: |
| | """Process a single audio latent file.""" |
| | |
| | data = torch.load(latent_file, map_location=self.device, weights_only=False) |
| |
|
| | latents = data["latents"].to(device=self.device, dtype=torch.float32) |
| | num_time_steps = data["num_time_steps"] |
| | freq_bins = data["frequency_bins"] |
| |
|
| | |
| | if latents.dim() == 2: |
| | |
| | |
| | latents = latents.reshape(num_time_steps, freq_bins, -1) |
| | latents = latents.permute(2, 0, 1) |
| |
|
| | |
| | latents = latents.unsqueeze(0) |
| |
|
| | |
| | latents = latents.to(dtype=torch.bfloat16) |
| |
|
| | |
| | mel_spectrogram = self.audio_vae(latents) |
| |
|
| | |
| | waveform = self.vocoder(mel_spectrogram) |
| |
|
| | |
| | output_path = output_dir / f"{latent_file.stem}.wav" |
| | sample_rate = self.vocoder.output_sample_rate |
| | torchaudio.save(str(output_path), waveform[0].cpu(), sample_rate) |
| |
|
| |
|
| | @app.command() |
| | def main( |
| | latents_dir: str = typer.Argument( |
| | ..., |
| | help="Directory containing the precomputed latent files (searched recursively)", |
| | ), |
| | output_dir: str = typer.Argument( |
| | ..., |
| | help="Directory to save the decoded videos (maintains same folder hierarchy as input)", |
| | ), |
| | model_path: str = typer.Option( |
| | ..., |
| | help="Path to LTX-2 checkpoint (.safetensors file)", |
| | ), |
| | device: str = typer.Option( |
| | default="cuda", |
| | help="Device to use for computation", |
| | ), |
| | vae_tiling: bool = typer.Option( |
| | default=False, |
| | help="Enable VAE tiling for larger video resolutions", |
| | ), |
| | seed: int | None = typer.Option( |
| | default=None, |
| | help="Random seed for noise generation during decoding", |
| | ), |
| | with_audio: bool = typer.Option( |
| | default=False, |
| | help="Also decode audio latents (requires audio_latents directory)", |
| | ), |
| | audio_latents_dir: str | None = typer.Option( |
| | default=None, |
| | help="Directory containing audio latent files (defaults to 'audio_latents' sibling of latents_dir)", |
| | ), |
| | ) -> None: |
| | """Decode precomputed video latents back into videos using the VAE. |
| | |
| | This script recursively searches for .pt latent files in the input directory |
| | and decodes them to videos, maintaining the same folder hierarchy in the output. |
| | |
| | Examples: |
| | # Basic usage |
| | python scripts/decode_latents.py /path/to/latents /path/to/videos \\ |
| | --model-path /path/to/ltx2.safetensors |
| | |
| | # With VAE tiling for large videos |
| | python scripts/decode_latents.py /path/to/latents /path/to/videos \\ |
| | --model-path /path/to/ltx2.safetensors --vae-tiling |
| | |
| | # With audio decoding |
| | python scripts/decode_latents.py /path/to/latents /path/to/videos \\ |
| | --model-path /path/to/ltx2.safetensors --with-audio |
| | """ |
| | latents_path = Path(latents_dir) |
| | output_path = Path(output_dir) |
| |
|
| | if not latents_path.exists() or not latents_path.is_dir(): |
| | raise typer.BadParameter(f"Latents directory does not exist: {latents_path}") |
| |
|
| | decoder = LatentsDecoder( |
| | model_path=model_path, |
| | device=device, |
| | vae_tiling=vae_tiling, |
| | with_audio=with_audio, |
| | ) |
| | decoder.decode(latents_path, output_path, seed=seed) |
| |
|
| | |
| | if with_audio: |
| | audio_path = Path(audio_latents_dir) if audio_latents_dir else latents_path.parent / "audio_latents" |
| |
|
| | if audio_path.exists(): |
| | audio_output_path = output_path.parent / "decoded_audio" |
| | decoder.decode_audio(audio_path, audio_output_path) |
| | else: |
| | logger.warning(f"Audio latents directory not found: {audio_path}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | app() |
| |
|