| |
|
|
| """ |
| Decode precomputed video latents back into videos using the VAE. |
| This script loads latent files saved during preprocessing and decodes them |
| back into video clips using the same VAE model. |
| Basic usage: |
| python scripts/decode_latents.py /path/to/latents/dir /path/to/output \ |
| --model-source /path/to/ltx2.safetensors |
| """ |
|
|
| from pathlib import Path |
|
|
| import torch |
| import torchaudio |
| import torchvision.utils |
| import typer |
| from einops import rearrange |
| from rich.console import Console |
| from rich.progress import ( |
| BarColumn, |
| MofNCompleteColumn, |
| Progress, |
| SpinnerColumn, |
| TextColumn, |
| TimeElapsedColumn, |
| TimeRemainingColumn, |
| ) |
| from transformers.utils.logging import disable_progress_bar |
|
|
| from ltx_core.model.video_vae import SpatialTilingConfig, TemporalTilingConfig, TilingConfig |
| from ltx_trainer import logger |
| from ltx_trainer.model_loader import load_audio_vae_decoder, load_video_vae_decoder, load_vocoder |
| from ltx_trainer.video_utils import save_video |
|
|
| DEFAULT_TILE_SIZE_PIXELS = 512 |
| DEFAULT_TILE_OVERLAP_PIXELS = 128 |
| DEFAULT_TILE_SIZE_FRAMES = 128 |
| DEFAULT_TILE_OVERLAP_FRAMES = 24 |
|
|
| disable_progress_bar() |
| console = Console() |
| app = typer.Typer( |
| pretty_exceptions_enable=False, |
| no_args_is_help=True, |
| help="Decode precomputed video latents back into videos using the VAE.", |
| ) |
|
|
|
|
| class LatentsDecoder: |
| def __init__( |
| self, |
| model_path: str, |
| device: str = "cuda", |
| vae_tiling: bool = False, |
| with_audio: bool = False, |
| ): |
| """Initialize the decoder with model configuration. |
| Args: |
| model_path: Path to LTX-2 checkpoint (.safetensors) |
| device: Device to use for computation |
| vae_tiling: Whether to enable VAE tiling for larger video resolutions |
| with_audio: Whether to load audio VAE for audio decoding |
| """ |
| self.device = torch.device(device) |
| self.model_path = model_path |
| self.vae = None |
| self.audio_vae = None |
| self.vocoder = None |
| self.vae_tiling = vae_tiling |
|
|
| self._load_model(model_path, with_audio) |
|
|
| def _load_model(self, model_path: str, with_audio: bool = False) -> None: |
| """Initialize and load the VAE model(s).""" |
| with console.status(f"[bold]Loading video VAE decoder from {model_path}...", spinner="dots"): |
| self.vae = load_video_vae_decoder(model_path, device=self.device, dtype=torch.bfloat16) |
|
|
| if with_audio: |
| with console.status(f"[bold]Loading audio VAE decoder from {model_path}...", spinner="dots"): |
| self.audio_vae = load_audio_vae_decoder(model_path, device=self.device, dtype=torch.bfloat16) |
|
|
| with console.status(f"[bold]Loading vocoder from {model_path}...", spinner="dots"): |
| self.vocoder = load_vocoder(model_path, device=self.device) |
|
|
| @torch.inference_mode() |
| def decode(self, latents_dir: Path, output_dir: Path, seed: int | None = None) -> None: |
| """Decode all latent files in the directory recursively. |
| Args: |
| latents_dir: Directory containing latent files (.pt) |
| output_dir: Directory to save decoded videos |
| seed: Optional random seed for noise generation |
| """ |
| |
| latent_files = list(latents_dir.rglob("*.pt")) |
|
|
| if not latent_files: |
| logger.warning(f"No .pt files found in {latents_dir}") |
| return |
|
|
| logger.info(f"Found {len(latent_files):,} latent files to decode") |
|
|
| |
| with Progress( |
| SpinnerColumn(), |
| TextColumn("[progress.description]{task.description}"), |
| BarColumn(), |
| MofNCompleteColumn(), |
| TimeElapsedColumn(), |
| TimeRemainingColumn(), |
| console=console, |
| ) as progress: |
| task = progress.add_task("Decoding latents", total=len(latent_files)) |
|
|
| for latent_file in latent_files: |
| |
| rel_path = latent_file.relative_to(latents_dir) |
| output_subdir = output_dir / rel_path.parent |
| output_subdir.mkdir(parents=True, exist_ok=True) |
|
|
| try: |
| self._process_file(latent_file, output_subdir, seed) |
| except Exception as e: |
| logger.error(f"Error processing {latent_file}: {e}") |
| continue |
|
|
| progress.advance(task) |
|
|
| logger.info(f"Decoding complete! Videos saved to {output_dir}") |
|
|
| @torch.inference_mode() |
| def decode_audio(self, latents_dir: Path, output_dir: Path) -> None: |
| """Decode all audio latent files in the directory recursively. |
| Args: |
| latents_dir: Directory containing audio latent files (.pt) |
| output_dir: Directory to save decoded audio files |
| """ |
| |
| if self.audio_vae is None or self.vocoder is None: |
| logger.warning("Audio VAE or vocoder not loaded. Skipping audio decoding.") |
| return |
|
|
| |
| latent_files = list(latents_dir.rglob("*.pt")) |
|
|
| if not latent_files: |
| logger.warning(f"No .pt files found in {latents_dir}") |
| return |
|
|
| logger.info(f"Found {len(latent_files):,} audio latent files to decode") |
|
|
| |
| with Progress( |
| SpinnerColumn(), |
| TextColumn("[progress.description]{task.description}"), |
| BarColumn(), |
| MofNCompleteColumn(), |
| TimeElapsedColumn(), |
| TimeRemainingColumn(), |
| console=console, |
| ) as progress: |
| task = progress.add_task("Decoding audio latents", total=len(latent_files)) |
|
|
| for latent_file in latent_files: |
| |
| rel_path = latent_file.relative_to(latents_dir) |
| output_subdir = output_dir / rel_path.parent |
| output_subdir.mkdir(parents=True, exist_ok=True) |
|
|
| try: |
| self._process_audio_file(latent_file, output_subdir) |
| except Exception as e: |
| logger.error(f"Error processing audio {latent_file}: {e}") |
| continue |
|
|
| progress.advance(task) |
|
|
| logger.info(f"Audio decoding complete! Audio files saved to {output_dir}") |
|
|
| def _process_file(self, latent_file: Path, output_dir: Path, seed: int | None) -> None: |
| """Process a single latent file.""" |
| |
| data = torch.load(latent_file, map_location=self.device, weights_only=False) |
|
|
| |
| latents = data["latents"] |
| num_frames = data["num_frames"] |
| height = data["height"] |
| width = data["width"] |
|
|
| |
| if latents.dim() == 2: |
| |
| latents = rearrange(latents, "(f h w) c -> c f h w", f=num_frames, h=height, w=width) |
|
|
| |
| latents = latents.unsqueeze(0).to(device=self.device, dtype=torch.bfloat16) |
|
|
| |
| generator = None |
| if seed is not None: |
| generator = torch.Generator(device=self.device) |
| generator.manual_seed(seed) |
|
|
| |
| video = self._decode_video(latents, generator) |
|
|
| |
| is_image = video.shape[0] == 1 |
| if is_image: |
| |
| output_path = output_dir / f"{latent_file.stem}.png" |
| torchvision.utils.save_image( |
| video[0], |
| str(output_path), |
| ) |
| else: |
| |
| output_path = output_dir / f"{latent_file.stem}.mp4" |
| fps = data.get("fps", 24) |
| save_video( |
| video_tensor=video, |
| output_path=output_path, |
| fps=fps, |
| ) |
|
|
| def _decode_video(self, latents: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor: |
| """Decode latents to video frames.""" |
| if self.vae_tiling: |
| |
| tiling_config = TilingConfig( |
| spatial_config=SpatialTilingConfig( |
| tile_size_in_pixels=DEFAULT_TILE_SIZE_PIXELS, |
| tile_overlap_in_pixels=DEFAULT_TILE_OVERLAP_PIXELS, |
| ), |
| temporal_config=TemporalTilingConfig( |
| tile_size_in_frames=DEFAULT_TILE_SIZE_FRAMES, |
| tile_overlap_in_frames=DEFAULT_TILE_OVERLAP_FRAMES, |
| ), |
| ) |
| chunks = list( |
| self.vae.tiled_decode( |
| latents, |
| tiling_config=tiling_config, |
| generator=generator, |
| ) |
| ) |
| |
| video = torch.cat(chunks, dim=2) |
| else: |
| |
| video = self.vae(latents, generator=generator) |
|
|
| |
| video = rearrange(video, "1 c f h w -> f c h w") |
| video = (video + 1) / 2 |
| video = video.clamp(0, 1) |
|
|
| return video |
|
|
| def _process_audio_file(self, latent_file: Path, output_dir: Path) -> None: |
| """Process a single audio latent file.""" |
| |
| data = torch.load(latent_file, map_location=self.device, weights_only=False) |
|
|
| latents = data["latents"].to(device=self.device, dtype=torch.float32) |
| num_time_steps = data["num_time_steps"] |
| freq_bins = data["frequency_bins"] |
|
|
| |
| if latents.dim() == 2: |
| |
| |
| latents = rearrange(latents, "(t f) c -> c t f", t=num_time_steps, f=freq_bins) |
|
|
| |
| latents = latents.unsqueeze(0) |
|
|
| |
| latents = latents.to(dtype=torch.bfloat16) |
|
|
| |
| mel_spectrogram = self.audio_vae(latents) |
|
|
| |
| waveform = self.vocoder(mel_spectrogram) |
|
|
| |
| output_path = output_dir / f"{latent_file.stem}.wav" |
| sample_rate = self.vocoder.output_sampling_rate |
| torchaudio.save(str(output_path), waveform[0].cpu(), sample_rate) |
|
|
|
|
| @app.command() |
| def main( |
| latents_dir: str = typer.Argument( |
| ..., |
| help="Directory containing the precomputed latent files (searched recursively)", |
| ), |
| output_dir: str = typer.Argument( |
| ..., |
| help="Directory to save the decoded videos (maintains same folder hierarchy as input)", |
| ), |
| model_path: str = typer.Option( |
| ..., |
| help="Path to LTX-2 checkpoint (.safetensors file)", |
| ), |
| device: str = typer.Option( |
| default="cuda", |
| help="Device to use for computation", |
| ), |
| vae_tiling: bool = typer.Option( |
| default=False, |
| help="Enable VAE tiling for larger video resolutions", |
| ), |
| seed: int | None = typer.Option( |
| default=None, |
| help="Random seed for noise generation during decoding", |
| ), |
| with_audio: bool = typer.Option( |
| default=False, |
| help="Also decode audio latents (requires audio_latents directory)", |
| ), |
| audio_latents_dir: str | None = typer.Option( |
| default=None, |
| help="Directory containing audio latent files (defaults to 'audio_latents' sibling of latents_dir)", |
| ), |
| ) -> None: |
| """Decode precomputed video latents back into videos using the VAE. |
| This script recursively searches for .pt latent files in the input directory |
| and decodes them to videos, maintaining the same folder hierarchy in the output. |
| Examples: |
| # Basic usage |
| python scripts/decode_latents.py /path/to/latents /path/to/videos \\ |
| --model-path /path/to/ltx2.safetensors |
| # With VAE tiling for large videos |
| python scripts/decode_latents.py /path/to/latents /path/to/videos \\ |
| --model-path /path/to/ltx2.safetensors --vae-tiling |
| # With audio decoding |
| python scripts/decode_latents.py /path/to/latents /path/to/videos \\ |
| --model-path /path/to/ltx2.safetensors --with-audio |
| """ |
| latents_path = Path(latents_dir) |
| output_path = Path(output_dir) |
|
|
| if not latents_path.exists() or not latents_path.is_dir(): |
| raise typer.BadParameter(f"Latents directory does not exist: {latents_path}") |
|
|
| decoder = LatentsDecoder( |
| model_path=model_path, |
| device=device, |
| vae_tiling=vae_tiling, |
| with_audio=with_audio, |
| ) |
| decoder.decode(latents_path, output_path, seed=seed) |
|
|
| |
| if with_audio: |
| audio_path = Path(audio_latents_dir) if audio_latents_dir else latents_path.parent / "audio_latents" |
|
|
| if audio_path.exists(): |
| audio_output_path = output_path.parent / "decoded_audio" |
| decoder.decode_audio(audio_path, audio_output_path) |
| else: |
| logger.warning(f"Audio latents directory not found: {audio_path}") |
|
|
|
|
| if __name__ == "__main__": |
| app() |
|
|