Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| """ | |
| Decode precomputed video latents back into videos using the VAE. | |
| This script loads latent files saved during preprocessing and decodes them | |
| back into video clips using the same VAE model. | |
| Basic usage: | |
| python scripts/decode_latents.py /path/to/latents/dir /path/to/output \ | |
| --model-source /path/to/ltx2.safetensors | |
| """ | |
| from pathlib import Path | |
| import torch | |
| import torchaudio | |
| import torchvision.utils | |
| import typer | |
| from rich.console import Console | |
| from rich.progress import ( | |
| BarColumn, | |
| MofNCompleteColumn, | |
| Progress, | |
| SpinnerColumn, | |
| TextColumn, | |
| TimeElapsedColumn, | |
| TimeRemainingColumn, | |
| ) | |
| from transformers.utils.logging import disable_progress_bar | |
| from ltx_trainer import logger | |
| from ltx_trainer.model_loader import load_audio_vae_decoder, load_video_vae_decoder, load_vocoder | |
| from ltx_trainer.video_utils import save_video | |
| disable_progress_bar() | |
| console = Console() | |
| app = typer.Typer( | |
| pretty_exceptions_enable=False, | |
| no_args_is_help=True, | |
| help="Decode precomputed video latents back into videos using the VAE.", | |
| ) | |
| class LatentsDecoder: | |
| def __init__( | |
| self, | |
| model_path: str, | |
| device: str = "cuda", | |
| vae_tiling: bool = False, | |
| with_audio: bool = False, | |
| ): | |
| """Initialize the decoder with model configuration. | |
| Args: | |
| model_path: Path to LTX-2 checkpoint (.safetensors) | |
| device: Device to use for computation | |
| vae_tiling: Whether to enable VAE tiling for larger video resolutions | |
| with_audio: Whether to load audio VAE for audio decoding | |
| """ | |
| self.device = torch.device(device) | |
| self.model_path = model_path | |
| self.vae = None | |
| self.audio_vae = None | |
| self.vocoder = None | |
| self._load_model(model_path, vae_tiling, with_audio) | |
| def _load_model(self, model_path: str, vae_tiling: bool, with_audio: bool = False) -> None: | |
| """Initialize and load the VAE model(s).""" | |
| with console.status(f"[bold]Loading video VAE decoder from {model_path}...", spinner="dots"): | |
| self.vae = load_video_vae_decoder(model_path, device=self.device, dtype=torch.bfloat16) | |
| if vae_tiling: | |
| self.vae.enable_tiling() | |
| if with_audio: | |
| with console.status(f"[bold]Loading audio VAE decoder from {model_path}...", spinner="dots"): | |
| self.audio_vae = load_audio_vae_decoder(model_path, device=self.device, dtype=torch.bfloat16) | |
| with console.status(f"[bold]Loading vocoder from {model_path}...", spinner="dots"): | |
| self.vocoder = load_vocoder(model_path, device=self.device) | |
| def decode(self, latents_dir: Path, output_dir: Path, seed: int | None = None) -> None: | |
| """Decode all latent files in the directory recursively. | |
| Args: | |
| latents_dir: Directory containing latent files (.pt) | |
| output_dir: Directory to save decoded videos | |
| seed: Optional random seed for noise generation | |
| """ | |
| # Find all .pt files recursively | |
| latent_files = list(latents_dir.rglob("*.pt")) | |
| if not latent_files: | |
| logger.warning(f"No .pt files found in {latents_dir}") | |
| return | |
| logger.info(f"Found {len(latent_files):,} latent files to decode") | |
| # Process files with progress bar | |
| with Progress( | |
| SpinnerColumn(), | |
| TextColumn("[progress.description]{task.description}"), | |
| BarColumn(), | |
| MofNCompleteColumn(), | |
| TimeElapsedColumn(), | |
| TimeRemainingColumn(), | |
| console=console, | |
| ) as progress: | |
| task = progress.add_task("Decoding latents", total=len(latent_files)) | |
| for latent_file in latent_files: | |
| # Calculate relative path to maintain directory structure | |
| rel_path = latent_file.relative_to(latents_dir) | |
| output_subdir = output_dir / rel_path.parent | |
| output_subdir.mkdir(parents=True, exist_ok=True) | |
| try: | |
| self._process_file(latent_file, output_subdir, seed) | |
| except Exception as e: | |
| logger.error(f"Error processing {latent_file}: {e}") | |
| continue | |
| progress.advance(task) | |
| logger.info(f"Decoding complete! Videos saved to {output_dir}") | |
| def _process_file(self, latent_file: Path, output_dir: Path, seed: int | None) -> None: | |
| """Process a single latent file.""" | |
| # Load the latent data | |
| data = torch.load(latent_file, map_location=self.device, weights_only=False) | |
| # Get latents - handle both old patchified [seq_len, C] and new [C, F, H, W] formats | |
| latents = data["latents"] | |
| num_frames = data["num_frames"] | |
| height = data["height"] | |
| width = data["width"] | |
| # Check if latents need reshaping (old patchified format) | |
| if latents.dim() == 2: | |
| # Old format: [seq_len, C] -> reshape to [C, F, H, W] | |
| _seq_len, channels = latents.shape | |
| latents = latents.reshape(num_frames, height, width, channels) | |
| latents = latents.permute(3, 0, 1, 2) # [F, H, W, C] -> [C, F, H, W] | |
| # Add batch dimension: [C, F, H, W] -> [1, C, F, H, W] | |
| latents = latents.unsqueeze(0).to(device=self.device, dtype=torch.bfloat16) | |
| # Create generator only if seed is provided | |
| generator = None | |
| if seed is not None: | |
| generator = torch.Generator(device=self.device) | |
| generator.manual_seed(seed) | |
| # Decode the video (VAE decoder uses forward/call, not decode method) | |
| video = self.vae(latents) # [B, C, F, H, W] | |
| # Convert to [F, C, H, W] format and normalize to [0, 1] | |
| video = video[0] # Remove batch dimension -> [C, F, H, W] | |
| video = video.permute(1, 0, 2, 3) # [C, F, H, W] -> [F, C, H, W] | |
| video = (video + 1) / 2 # Denormalize from [-1, 1] to [0, 1] | |
| video = video.clamp(0, 1) | |
| # Determine output format and save | |
| is_image = video.shape[0] == 1 | |
| if is_image: | |
| # Save as PNG for single frame | |
| output_path = output_dir / f"{latent_file.stem}.png" | |
| torchvision.utils.save_image( | |
| video[0], # [C, H, W] in [0, 1] | |
| str(output_path), | |
| ) | |
| else: | |
| # Save as MP4 for video using PyAV-based save_video | |
| output_path = output_dir / f"{latent_file.stem}.mp4" | |
| fps = data.get("fps", 24) # Use stored FPS or default to 24 | |
| save_video( | |
| video_tensor=video, # [F, C, H, W] in [0, 1] | |
| output_path=output_path, | |
| fps=fps, | |
| ) | |
| def decode_audio(self, latents_dir: Path, output_dir: Path) -> None: | |
| """Decode all audio latent files in the directory recursively. | |
| Args: | |
| latents_dir: Directory containing audio latent files (.pt) | |
| output_dir: Directory to save decoded audio files | |
| """ | |
| # Check if audio VAE is loaded | |
| if self.audio_vae is None or self.vocoder is None: | |
| logger.warning("Audio VAE or vocoder not loaded. Skipping audio decoding.") | |
| return | |
| # Find all .pt files recursively | |
| latent_files = list(latents_dir.rglob("*.pt")) | |
| if not latent_files: | |
| logger.warning(f"No .pt files found in {latents_dir}") | |
| return | |
| logger.info(f"Found {len(latent_files):,} audio latent files to decode") | |
| # Process files with progress bar | |
| with Progress( | |
| SpinnerColumn(), | |
| TextColumn("[progress.description]{task.description}"), | |
| BarColumn(), | |
| MofNCompleteColumn(), | |
| TimeElapsedColumn(), | |
| TimeRemainingColumn(), | |
| console=console, | |
| ) as progress: | |
| task = progress.add_task("Decoding audio latents", total=len(latent_files)) | |
| for latent_file in latent_files: | |
| # Calculate relative path to maintain directory structure | |
| rel_path = latent_file.relative_to(latents_dir) | |
| output_subdir = output_dir / rel_path.parent | |
| output_subdir.mkdir(parents=True, exist_ok=True) | |
| try: | |
| self._process_audio_file(latent_file, output_subdir) | |
| except Exception as e: | |
| logger.error(f"Error processing audio {latent_file}: {e}") | |
| continue | |
| progress.advance(task) | |
| logger.info(f"Audio decoding complete! Audio files saved to {output_dir}") | |
| def _process_audio_file(self, latent_file: Path, output_dir: Path) -> None: | |
| """Process a single audio latent file.""" | |
| # Load the latent data | |
| data = torch.load(latent_file, map_location=self.device, weights_only=False) | |
| latents = data["latents"].to(device=self.device, dtype=torch.float32) | |
| num_time_steps = data["num_time_steps"] | |
| freq_bins = data["frequency_bins"] | |
| # Handle both old patchified [seq_len, C] and new [C, T, F] formats | |
| if latents.dim() == 2: | |
| # Old format: [seq_len, channels] where seq_len = time * freq | |
| # Reshape to [C, T, F] | |
| latents = latents.reshape(num_time_steps, freq_bins, -1) # [T, F, C] | |
| latents = latents.permute(2, 0, 1) # [T, F, C] -> [C, T, F] | |
| # Add batch dimension: [C, T, F] -> [1, C, T, F] | |
| latents = latents.unsqueeze(0) | |
| # Set correct dtype for audio VAE | |
| latents = latents.to(dtype=torch.bfloat16) | |
| # Decode audio using audio VAE decoder (produces mel spectrogram) | |
| mel_spectrogram = self.audio_vae(latents) | |
| # Convert mel spectrogram to waveform using vocoder | |
| waveform = self.vocoder(mel_spectrogram) | |
| # Save as WAV | |
| output_path = output_dir / f"{latent_file.stem}.wav" | |
| sample_rate = self.vocoder.output_sample_rate | |
| torchaudio.save(str(output_path), waveform[0].cpu(), sample_rate) | |
| def main( | |
| latents_dir: str = typer.Argument( | |
| ..., | |
| help="Directory containing the precomputed latent files (searched recursively)", | |
| ), | |
| output_dir: str = typer.Argument( | |
| ..., | |
| help="Directory to save the decoded videos (maintains same folder hierarchy as input)", | |
| ), | |
| model_path: str = typer.Option( | |
| ..., | |
| help="Path to LTX-2 checkpoint (.safetensors file)", | |
| ), | |
| device: str = typer.Option( | |
| default="cuda", | |
| help="Device to use for computation", | |
| ), | |
| vae_tiling: bool = typer.Option( | |
| default=False, | |
| help="Enable VAE tiling for larger video resolutions", | |
| ), | |
| seed: int | None = typer.Option( | |
| default=None, | |
| help="Random seed for noise generation during decoding", | |
| ), | |
| with_audio: bool = typer.Option( | |
| default=False, | |
| help="Also decode audio latents (requires audio_latents directory)", | |
| ), | |
| audio_latents_dir: str | None = typer.Option( | |
| default=None, | |
| help="Directory containing audio latent files (defaults to 'audio_latents' sibling of latents_dir)", | |
| ), | |
| ) -> None: | |
| """Decode precomputed video latents back into videos using the VAE. | |
| This script recursively searches for .pt latent files in the input directory | |
| and decodes them to videos, maintaining the same folder hierarchy in the output. | |
| Examples: | |
| # Basic usage | |
| python scripts/decode_latents.py /path/to/latents /path/to/videos \\ | |
| --model-path /path/to/ltx2.safetensors | |
| # With VAE tiling for large videos | |
| python scripts/decode_latents.py /path/to/latents /path/to/videos \\ | |
| --model-path /path/to/ltx2.safetensors --vae-tiling | |
| # With audio decoding | |
| python scripts/decode_latents.py /path/to/latents /path/to/videos \\ | |
| --model-path /path/to/ltx2.safetensors --with-audio | |
| """ | |
| latents_path = Path(latents_dir) | |
| output_path = Path(output_dir) | |
| if not latents_path.exists() or not latents_path.is_dir(): | |
| raise typer.BadParameter(f"Latents directory does not exist: {latents_path}") | |
| decoder = LatentsDecoder( | |
| model_path=model_path, | |
| device=device, | |
| vae_tiling=vae_tiling, | |
| with_audio=with_audio, | |
| ) | |
| decoder.decode(latents_path, output_path, seed=seed) | |
| # Decode audio if requested | |
| if with_audio: | |
| audio_path = Path(audio_latents_dir) if audio_latents_dir else latents_path.parent / "audio_latents" | |
| if audio_path.exists(): | |
| audio_output_path = output_path.parent / "decoded_audio" | |
| decoder.decode_audio(audio_path, audio_output_path) | |
| else: | |
| logger.warning(f"Audio latents directory not found: {audio_path}") | |
| if __name__ == "__main__": | |
| app() | |