File size: 13,023 Bytes
ebfc6b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
#!/usr/bin/env python3

"""
Decode precomputed video latents back into videos using the VAE.

This script loads latent files saved during preprocessing and decodes them
back into video clips using the same VAE model.

Basic usage:
    python scripts/decode_latents.py /path/to/latents/dir /path/to/output \
        --model-source /path/to/ltx2.safetensors
"""

from pathlib import Path

import torch
import torchaudio
import torchvision.utils
import typer
from rich.console import Console
from rich.progress import (
    BarColumn,
    MofNCompleteColumn,
    Progress,
    SpinnerColumn,
    TextColumn,
    TimeElapsedColumn,
    TimeRemainingColumn,
)
from transformers.utils.logging import disable_progress_bar

from ltx_trainer import logger
from ltx_trainer.model_loader import load_audio_vae_decoder, load_video_vae_decoder, load_vocoder
from ltx_trainer.video_utils import save_video

disable_progress_bar()
console = Console()
app = typer.Typer(
    pretty_exceptions_enable=False,
    no_args_is_help=True,
    help="Decode precomputed video latents back into videos using the VAE.",
)


class LatentsDecoder:
    def __init__(
        self,
        model_path: str,
        device: str = "cuda",
        vae_tiling: bool = False,
        with_audio: bool = False,
    ):
        """Initialize the decoder with model configuration.

        Args:
            model_path: Path to LTX-2 checkpoint (.safetensors)
            device: Device to use for computation
            vae_tiling: Whether to enable VAE tiling for larger video resolutions
            with_audio: Whether to load audio VAE for audio decoding
        """
        self.device = torch.device(device)
        self.model_path = model_path
        self.vae = None
        self.audio_vae = None
        self.vocoder = None
        self._load_model(model_path, vae_tiling, with_audio)

    def _load_model(self, model_path: str, vae_tiling: bool, with_audio: bool = False) -> None:
        """Initialize and load the VAE model(s)."""
        with console.status(f"[bold]Loading video VAE decoder from {model_path}...", spinner="dots"):
            self.vae = load_video_vae_decoder(model_path, device=self.device, dtype=torch.bfloat16)

            if vae_tiling:
                self.vae.enable_tiling()

        if with_audio:
            with console.status(f"[bold]Loading audio VAE decoder from {model_path}...", spinner="dots"):
                self.audio_vae = load_audio_vae_decoder(model_path, device=self.device, dtype=torch.bfloat16)

            with console.status(f"[bold]Loading vocoder from {model_path}...", spinner="dots"):
                self.vocoder = load_vocoder(model_path, device=self.device)

    @torch.inference_mode()
    def decode(self, latents_dir: Path, output_dir: Path, seed: int | None = None) -> None:
        """Decode all latent files in the directory recursively.

        Args:
            latents_dir: Directory containing latent files (.pt)
            output_dir: Directory to save decoded videos
            seed: Optional random seed for noise generation
        """
        # Find all .pt files recursively
        latent_files = list(latents_dir.rglob("*.pt"))

        if not latent_files:
            logger.warning(f"No .pt files found in {latents_dir}")
            return

        logger.info(f"Found {len(latent_files):,} latent files to decode")

        # Process files with progress bar
        with Progress(
            SpinnerColumn(),
            TextColumn("[progress.description]{task.description}"),
            BarColumn(),
            MofNCompleteColumn(),
            TimeElapsedColumn(),
            TimeRemainingColumn(),
            console=console,
        ) as progress:
            task = progress.add_task("Decoding latents", total=len(latent_files))

            for latent_file in latent_files:
                # Calculate relative path to maintain directory structure
                rel_path = latent_file.relative_to(latents_dir)
                output_subdir = output_dir / rel_path.parent
                output_subdir.mkdir(parents=True, exist_ok=True)

                try:
                    self._process_file(latent_file, output_subdir, seed)
                except Exception as e:
                    logger.error(f"Error processing {latent_file}: {e}")
                    continue

                progress.advance(task)

        logger.info(f"Decoding complete! Videos saved to {output_dir}")

    def _process_file(self, latent_file: Path, output_dir: Path, seed: int | None) -> None:
        """Process a single latent file."""
        # Load the latent data
        data = torch.load(latent_file, map_location=self.device, weights_only=False)

        # Get latents - handle both old patchified [seq_len, C] and new [C, F, H, W] formats
        latents = data["latents"]
        num_frames = data["num_frames"]
        height = data["height"]
        width = data["width"]

        # Check if latents need reshaping (old patchified format)
        if latents.dim() == 2:
            # Old format: [seq_len, C] -> reshape to [C, F, H, W]
            _seq_len, channels = latents.shape
            latents = latents.reshape(num_frames, height, width, channels)
            latents = latents.permute(3, 0, 1, 2)  # [F, H, W, C] -> [C, F, H, W]

        # Add batch dimension: [C, F, H, W] -> [1, C, F, H, W]
        latents = latents.unsqueeze(0).to(device=self.device, dtype=torch.bfloat16)

        # Create generator only if seed is provided
        generator = None
        if seed is not None:
            generator = torch.Generator(device=self.device)
            generator.manual_seed(seed)

        # Decode the video (VAE decoder uses forward/call, not decode method)
        video = self.vae(latents)  # [B, C, F, H, W]

        # Convert to [F, C, H, W] format and normalize to [0, 1]
        video = video[0]  # Remove batch dimension -> [C, F, H, W]
        video = video.permute(1, 0, 2, 3)  # [C, F, H, W] -> [F, C, H, W]
        video = (video + 1) / 2  # Denormalize from [-1, 1] to [0, 1]
        video = video.clamp(0, 1)

        # Determine output format and save
        is_image = video.shape[0] == 1
        if is_image:
            # Save as PNG for single frame
            output_path = output_dir / f"{latent_file.stem}.png"
            torchvision.utils.save_image(
                video[0],  # [C, H, W] in [0, 1]
                str(output_path),
            )
        else:
            # Save as MP4 for video using PyAV-based save_video
            output_path = output_dir / f"{latent_file.stem}.mp4"
            fps = data.get("fps", 24)  # Use stored FPS or default to 24
            save_video(
                video_tensor=video,  # [F, C, H, W] in [0, 1]
                output_path=output_path,
                fps=fps,
            )

    @torch.inference_mode()
    def decode_audio(self, latents_dir: Path, output_dir: Path) -> None:
        """Decode all audio latent files in the directory recursively.

        Args:
            latents_dir: Directory containing audio latent files (.pt)
            output_dir: Directory to save decoded audio files
        """
        # Check if audio VAE is loaded
        if self.audio_vae is None or self.vocoder is None:
            logger.warning("Audio VAE or vocoder not loaded. Skipping audio decoding.")
            return

        # Find all .pt files recursively
        latent_files = list(latents_dir.rglob("*.pt"))

        if not latent_files:
            logger.warning(f"No .pt files found in {latents_dir}")
            return

        logger.info(f"Found {len(latent_files):,} audio latent files to decode")

        # Process files with progress bar
        with Progress(
            SpinnerColumn(),
            TextColumn("[progress.description]{task.description}"),
            BarColumn(),
            MofNCompleteColumn(),
            TimeElapsedColumn(),
            TimeRemainingColumn(),
            console=console,
        ) as progress:
            task = progress.add_task("Decoding audio latents", total=len(latent_files))

            for latent_file in latent_files:
                # Calculate relative path to maintain directory structure
                rel_path = latent_file.relative_to(latents_dir)
                output_subdir = output_dir / rel_path.parent
                output_subdir.mkdir(parents=True, exist_ok=True)

                try:
                    self._process_audio_file(latent_file, output_subdir)
                except Exception as e:
                    logger.error(f"Error processing audio {latent_file}: {e}")
                    continue

                progress.advance(task)

        logger.info(f"Audio decoding complete! Audio files saved to {output_dir}")

    def _process_audio_file(self, latent_file: Path, output_dir: Path) -> None:
        """Process a single audio latent file."""
        # Load the latent data
        data = torch.load(latent_file, map_location=self.device, weights_only=False)

        latents = data["latents"].to(device=self.device, dtype=torch.float32)
        num_time_steps = data["num_time_steps"]
        freq_bins = data["frequency_bins"]

        # Handle both old patchified [seq_len, C] and new [C, T, F] formats
        if latents.dim() == 2:
            # Old format: [seq_len, channels] where seq_len = time * freq
            # Reshape to [C, T, F]
            latents = latents.reshape(num_time_steps, freq_bins, -1)  # [T, F, C]
            latents = latents.permute(2, 0, 1)  # [T, F, C] -> [C, T, F]

        # Add batch dimension: [C, T, F] -> [1, C, T, F]
        latents = latents.unsqueeze(0)

        # Set correct dtype for audio VAE
        latents = latents.to(dtype=torch.bfloat16)

        # Decode audio using audio VAE decoder (produces mel spectrogram)
        mel_spectrogram = self.audio_vae(latents)

        # Convert mel spectrogram to waveform using vocoder
        waveform = self.vocoder(mel_spectrogram)

        # Save as WAV
        output_path = output_dir / f"{latent_file.stem}.wav"
        sample_rate = self.vocoder.output_sample_rate
        torchaudio.save(str(output_path), waveform[0].cpu(), sample_rate)


@app.command()
def main(
    latents_dir: str = typer.Argument(
        ...,
        help="Directory containing the precomputed latent files (searched recursively)",
    ),
    output_dir: str = typer.Argument(
        ...,
        help="Directory to save the decoded videos (maintains same folder hierarchy as input)",
    ),
    model_path: str = typer.Option(
        ...,
        help="Path to LTX-2 checkpoint (.safetensors file)",
    ),
    device: str = typer.Option(
        default="cuda",
        help="Device to use for computation",
    ),
    vae_tiling: bool = typer.Option(
        default=False,
        help="Enable VAE tiling for larger video resolutions",
    ),
    seed: int | None = typer.Option(
        default=None,
        help="Random seed for noise generation during decoding",
    ),
    with_audio: bool = typer.Option(
        default=False,
        help="Also decode audio latents (requires audio_latents directory)",
    ),
    audio_latents_dir: str | None = typer.Option(
        default=None,
        help="Directory containing audio latent files (defaults to 'audio_latents' sibling of latents_dir)",
    ),
) -> None:
    """Decode precomputed video latents back into videos using the VAE.

    This script recursively searches for .pt latent files in the input directory
    and decodes them to videos, maintaining the same folder hierarchy in the output.

    Examples:
        # Basic usage
        python scripts/decode_latents.py /path/to/latents /path/to/videos \\
            --model-path /path/to/ltx2.safetensors

        # With VAE tiling for large videos
        python scripts/decode_latents.py /path/to/latents /path/to/videos \\
            --model-path /path/to/ltx2.safetensors --vae-tiling

        # With audio decoding
        python scripts/decode_latents.py /path/to/latents /path/to/videos \\
            --model-path /path/to/ltx2.safetensors --with-audio
    """
    latents_path = Path(latents_dir)
    output_path = Path(output_dir)

    if not latents_path.exists() or not latents_path.is_dir():
        raise typer.BadParameter(f"Latents directory does not exist: {latents_path}")

    decoder = LatentsDecoder(
        model_path=model_path,
        device=device,
        vae_tiling=vae_tiling,
        with_audio=with_audio,
    )
    decoder.decode(latents_path, output_path, seed=seed)

    # Decode audio if requested
    if with_audio:
        audio_path = Path(audio_latents_dir) if audio_latents_dir else latents_path.parent / "audio_latents"

        if audio_path.exists():
            audio_output_path = output_path.parent / "decoded_audio"
            decoder.decode_audio(audio_path, audio_output_path)
        else:
            logger.warning(f"Audio latents directory not found: {audio_path}")


if __name__ == "__main__":
    app()