Spaces:
Running
on
Zero
Running
on
Zero
File size: 13,023 Bytes
ebfc6b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 |
#!/usr/bin/env python3
"""
Decode precomputed video latents back into videos using the VAE.
This script loads latent files saved during preprocessing and decodes them
back into video clips using the same VAE model.
Basic usage:
python scripts/decode_latents.py /path/to/latents/dir /path/to/output \
--model-source /path/to/ltx2.safetensors
"""
from pathlib import Path
import torch
import torchaudio
import torchvision.utils
import typer
from rich.console import Console
from rich.progress import (
BarColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
from transformers.utils.logging import disable_progress_bar
from ltx_trainer import logger
from ltx_trainer.model_loader import load_audio_vae_decoder, load_video_vae_decoder, load_vocoder
from ltx_trainer.video_utils import save_video
disable_progress_bar()
console = Console()
app = typer.Typer(
pretty_exceptions_enable=False,
no_args_is_help=True,
help="Decode precomputed video latents back into videos using the VAE.",
)
class LatentsDecoder:
def __init__(
self,
model_path: str,
device: str = "cuda",
vae_tiling: bool = False,
with_audio: bool = False,
):
"""Initialize the decoder with model configuration.
Args:
model_path: Path to LTX-2 checkpoint (.safetensors)
device: Device to use for computation
vae_tiling: Whether to enable VAE tiling for larger video resolutions
with_audio: Whether to load audio VAE for audio decoding
"""
self.device = torch.device(device)
self.model_path = model_path
self.vae = None
self.audio_vae = None
self.vocoder = None
self._load_model(model_path, vae_tiling, with_audio)
def _load_model(self, model_path: str, vae_tiling: bool, with_audio: bool = False) -> None:
"""Initialize and load the VAE model(s)."""
with console.status(f"[bold]Loading video VAE decoder from {model_path}...", spinner="dots"):
self.vae = load_video_vae_decoder(model_path, device=self.device, dtype=torch.bfloat16)
if vae_tiling:
self.vae.enable_tiling()
if with_audio:
with console.status(f"[bold]Loading audio VAE decoder from {model_path}...", spinner="dots"):
self.audio_vae = load_audio_vae_decoder(model_path, device=self.device, dtype=torch.bfloat16)
with console.status(f"[bold]Loading vocoder from {model_path}...", spinner="dots"):
self.vocoder = load_vocoder(model_path, device=self.device)
@torch.inference_mode()
def decode(self, latents_dir: Path, output_dir: Path, seed: int | None = None) -> None:
"""Decode all latent files in the directory recursively.
Args:
latents_dir: Directory containing latent files (.pt)
output_dir: Directory to save decoded videos
seed: Optional random seed for noise generation
"""
# Find all .pt files recursively
latent_files = list(latents_dir.rglob("*.pt"))
if not latent_files:
logger.warning(f"No .pt files found in {latents_dir}")
return
logger.info(f"Found {len(latent_files):,} latent files to decode")
# Process files with progress bar
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
MofNCompleteColumn(),
TimeElapsedColumn(),
TimeRemainingColumn(),
console=console,
) as progress:
task = progress.add_task("Decoding latents", total=len(latent_files))
for latent_file in latent_files:
# Calculate relative path to maintain directory structure
rel_path = latent_file.relative_to(latents_dir)
output_subdir = output_dir / rel_path.parent
output_subdir.mkdir(parents=True, exist_ok=True)
try:
self._process_file(latent_file, output_subdir, seed)
except Exception as e:
logger.error(f"Error processing {latent_file}: {e}")
continue
progress.advance(task)
logger.info(f"Decoding complete! Videos saved to {output_dir}")
def _process_file(self, latent_file: Path, output_dir: Path, seed: int | None) -> None:
"""Process a single latent file."""
# Load the latent data
data = torch.load(latent_file, map_location=self.device, weights_only=False)
# Get latents - handle both old patchified [seq_len, C] and new [C, F, H, W] formats
latents = data["latents"]
num_frames = data["num_frames"]
height = data["height"]
width = data["width"]
# Check if latents need reshaping (old patchified format)
if latents.dim() == 2:
# Old format: [seq_len, C] -> reshape to [C, F, H, W]
_seq_len, channels = latents.shape
latents = latents.reshape(num_frames, height, width, channels)
latents = latents.permute(3, 0, 1, 2) # [F, H, W, C] -> [C, F, H, W]
# Add batch dimension: [C, F, H, W] -> [1, C, F, H, W]
latents = latents.unsqueeze(0).to(device=self.device, dtype=torch.bfloat16)
# Create generator only if seed is provided
generator = None
if seed is not None:
generator = torch.Generator(device=self.device)
generator.manual_seed(seed)
# Decode the video (VAE decoder uses forward/call, not decode method)
video = self.vae(latents) # [B, C, F, H, W]
# Convert to [F, C, H, W] format and normalize to [0, 1]
video = video[0] # Remove batch dimension -> [C, F, H, W]
video = video.permute(1, 0, 2, 3) # [C, F, H, W] -> [F, C, H, W]
video = (video + 1) / 2 # Denormalize from [-1, 1] to [0, 1]
video = video.clamp(0, 1)
# Determine output format and save
is_image = video.shape[0] == 1
if is_image:
# Save as PNG for single frame
output_path = output_dir / f"{latent_file.stem}.png"
torchvision.utils.save_image(
video[0], # [C, H, W] in [0, 1]
str(output_path),
)
else:
# Save as MP4 for video using PyAV-based save_video
output_path = output_dir / f"{latent_file.stem}.mp4"
fps = data.get("fps", 24) # Use stored FPS or default to 24
save_video(
video_tensor=video, # [F, C, H, W] in [0, 1]
output_path=output_path,
fps=fps,
)
@torch.inference_mode()
def decode_audio(self, latents_dir: Path, output_dir: Path) -> None:
"""Decode all audio latent files in the directory recursively.
Args:
latents_dir: Directory containing audio latent files (.pt)
output_dir: Directory to save decoded audio files
"""
# Check if audio VAE is loaded
if self.audio_vae is None or self.vocoder is None:
logger.warning("Audio VAE or vocoder not loaded. Skipping audio decoding.")
return
# Find all .pt files recursively
latent_files = list(latents_dir.rglob("*.pt"))
if not latent_files:
logger.warning(f"No .pt files found in {latents_dir}")
return
logger.info(f"Found {len(latent_files):,} audio latent files to decode")
# Process files with progress bar
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
MofNCompleteColumn(),
TimeElapsedColumn(),
TimeRemainingColumn(),
console=console,
) as progress:
task = progress.add_task("Decoding audio latents", total=len(latent_files))
for latent_file in latent_files:
# Calculate relative path to maintain directory structure
rel_path = latent_file.relative_to(latents_dir)
output_subdir = output_dir / rel_path.parent
output_subdir.mkdir(parents=True, exist_ok=True)
try:
self._process_audio_file(latent_file, output_subdir)
except Exception as e:
logger.error(f"Error processing audio {latent_file}: {e}")
continue
progress.advance(task)
logger.info(f"Audio decoding complete! Audio files saved to {output_dir}")
def _process_audio_file(self, latent_file: Path, output_dir: Path) -> None:
"""Process a single audio latent file."""
# Load the latent data
data = torch.load(latent_file, map_location=self.device, weights_only=False)
latents = data["latents"].to(device=self.device, dtype=torch.float32)
num_time_steps = data["num_time_steps"]
freq_bins = data["frequency_bins"]
# Handle both old patchified [seq_len, C] and new [C, T, F] formats
if latents.dim() == 2:
# Old format: [seq_len, channels] where seq_len = time * freq
# Reshape to [C, T, F]
latents = latents.reshape(num_time_steps, freq_bins, -1) # [T, F, C]
latents = latents.permute(2, 0, 1) # [T, F, C] -> [C, T, F]
# Add batch dimension: [C, T, F] -> [1, C, T, F]
latents = latents.unsqueeze(0)
# Set correct dtype for audio VAE
latents = latents.to(dtype=torch.bfloat16)
# Decode audio using audio VAE decoder (produces mel spectrogram)
mel_spectrogram = self.audio_vae(latents)
# Convert mel spectrogram to waveform using vocoder
waveform = self.vocoder(mel_spectrogram)
# Save as WAV
output_path = output_dir / f"{latent_file.stem}.wav"
sample_rate = self.vocoder.output_sample_rate
torchaudio.save(str(output_path), waveform[0].cpu(), sample_rate)
@app.command()
def main(
latents_dir: str = typer.Argument(
...,
help="Directory containing the precomputed latent files (searched recursively)",
),
output_dir: str = typer.Argument(
...,
help="Directory to save the decoded videos (maintains same folder hierarchy as input)",
),
model_path: str = typer.Option(
...,
help="Path to LTX-2 checkpoint (.safetensors file)",
),
device: str = typer.Option(
default="cuda",
help="Device to use for computation",
),
vae_tiling: bool = typer.Option(
default=False,
help="Enable VAE tiling for larger video resolutions",
),
seed: int | None = typer.Option(
default=None,
help="Random seed for noise generation during decoding",
),
with_audio: bool = typer.Option(
default=False,
help="Also decode audio latents (requires audio_latents directory)",
),
audio_latents_dir: str | None = typer.Option(
default=None,
help="Directory containing audio latent files (defaults to 'audio_latents' sibling of latents_dir)",
),
) -> None:
"""Decode precomputed video latents back into videos using the VAE.
This script recursively searches for .pt latent files in the input directory
and decodes them to videos, maintaining the same folder hierarchy in the output.
Examples:
# Basic usage
python scripts/decode_latents.py /path/to/latents /path/to/videos \\
--model-path /path/to/ltx2.safetensors
# With VAE tiling for large videos
python scripts/decode_latents.py /path/to/latents /path/to/videos \\
--model-path /path/to/ltx2.safetensors --vae-tiling
# With audio decoding
python scripts/decode_latents.py /path/to/latents /path/to/videos \\
--model-path /path/to/ltx2.safetensors --with-audio
"""
latents_path = Path(latents_dir)
output_path = Path(output_dir)
if not latents_path.exists() or not latents_path.is_dir():
raise typer.BadParameter(f"Latents directory does not exist: {latents_path}")
decoder = LatentsDecoder(
model_path=model_path,
device=device,
vae_tiling=vae_tiling,
with_audio=with_audio,
)
decoder.decode(latents_path, output_path, seed=seed)
# Decode audio if requested
if with_audio:
audio_path = Path(audio_latents_dir) if audio_latents_dir else latents_path.parent / "audio_latents"
if audio_path.exists():
audio_output_path = output_path.parent / "decoded_audio"
decoder.decode_audio(audio_path, audio_output_path)
else:
logger.warning(f"Audio latents directory not found: {audio_path}")
if __name__ == "__main__":
app()
|