| | |
| |
|
| | """ |
| | Compute latent representations for video generation training. |
| | |
| | This module provides functionality for processing video and image files, including: |
| | - Loading videos/images from various file formats (CSV, JSON, JSONL) |
| | - Resizing, cropping, and transforming media |
| | - MediaDataset for video-only preprocessing workflows |
| | - BucketSampler for grouping videos by resolution |
| | |
| | Can be used as a standalone script: |
| | python scripts/process_videos.py dataset.csv --resolution-buckets 768x768x25 \ |
| | --output-dir /path/to/output --model-source /path/to/ltx2.safetensors |
| | """ |
| |
|
| | import json |
| | import math |
| | from pathlib import Path |
| | from typing import Any |
| |
|
| | import numpy as np |
| | import pandas as pd |
| | import torch |
| | import torchaudio |
| | import typer |
| | from pillow_heif import register_heif_opener |
| | from rich.console import Console |
| | from rich.progress import ( |
| | BarColumn, |
| | MofNCompleteColumn, |
| | Progress, |
| | SpinnerColumn, |
| | TaskProgressColumn, |
| | TextColumn, |
| | TimeElapsedColumn, |
| | TimeRemainingColumn, |
| | ) |
| | from torch.utils.data import DataLoader, Dataset |
| | from torchvision import transforms |
| | from torchvision.transforms import InterpolationMode |
| | from torchvision.transforms.functional import crop, resize, to_tensor |
| | from transformers.utils.logging import disable_progress_bar |
| |
|
| | from ltx_core.model.audio_vae.ops import AudioProcessor |
| | from ltx_trainer import logger |
| | from ltx_trainer.model_loader import load_audio_vae_encoder, load_video_vae_encoder |
| | from ltx_trainer.utils import open_image_as_srgb |
| | from ltx_trainer.video_utils import get_video_frame_count, read_video |
| |
|
| | disable_progress_bar() |
| |
|
| | |
| | register_heif_opener() |
| |
|
| | |
| | VAE_SPATIAL_FACTOR = 32 |
| | VAE_TEMPORAL_FACTOR = 8 |
| |
|
| | |
| | AUDIO_LATENT_CHANNELS = 8 |
| | AUDIO_FREQUENCY_BINS = 16 |
| |
|
| | app = typer.Typer( |
| | pretty_exceptions_enable=False, |
| | no_args_is_help=True, |
| | help="Process videos/images and save latent representations for video generation training.", |
| | ) |
| |
|
| |
|
| | class MediaDataset(Dataset): |
| | """ |
| | Dataset for processing video and image files. |
| | |
| | This dataset is designed for media preprocessing workflows where you need to: |
| | - Load and preprocess videos/images |
| | - Apply resizing and cropping transformations |
| | - Handle different resolution buckets |
| | - Filter out invalid media files |
| | - Optionally extract audio from video files |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | dataset_file: str | Path, |
| | main_media_column: str, |
| | video_column: str, |
| | resolution_buckets: list[tuple[int, int, int]], |
| | reshape_mode: str = "center", |
| | with_audio: bool = False, |
| | ) -> None: |
| | """ |
| | Initialize the media dataset. |
| | |
| | Args: |
| | dataset_file: Path to CSV/JSON/JSONL metadata file |
| | video_column: Column name for video paths in the metadata file |
| | resolution_buckets: List of (frames, height, width) tuples |
| | reshape_mode: How to crop videos ("center", "random") |
| | with_audio: Whether to extract audio from video files |
| | """ |
| | super().__init__() |
| |
|
| | self.dataset_file = Path(dataset_file) |
| | self.main_media_column = main_media_column |
| | self.resolution_buckets = resolution_buckets |
| | self.reshape_mode = reshape_mode |
| | self.with_audio = with_audio |
| |
|
| | |
| | self.main_media_paths = self._load_video_paths(main_media_column) |
| |
|
| | |
| | self.video_paths = self._load_video_paths(video_column) |
| |
|
| | |
| | self._filter_valid_videos() |
| |
|
| | self.max_target_frames = max(self.resolution_buckets, key=lambda x: x[0])[0] |
| |
|
| | |
| | self.transforms = transforms.Compose( |
| | [ |
| | transforms.Lambda(lambda x: x.clamp_(0, 1)), |
| | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
| | ] |
| | ) |
| |
|
| | def __len__(self) -> int: |
| | return len(self.video_paths) |
| |
|
| | def __getitem__(self, index: int) -> dict[str, Any]: |
| | """Get a single video/image with metadata, and optionally audio.""" |
| | if isinstance(index, list): |
| | |
| | return index |
| |
|
| | video_path: Path = self.video_paths[index] |
| |
|
| | |
| | data_root = self.dataset_file.parent |
| | relative_path = str(video_path.relative_to(data_root)) |
| | media_relative_path = str(self.main_media_paths[index].relative_to(data_root)) |
| |
|
| | if video_path.suffix.lower() in [".png", ".jpg", ".jpeg"]: |
| | media_tensor = self._preprocess_image(video_path) |
| | fps = 1.0 |
| | audio_data = None |
| | else: |
| | media_tensor, fps = self._preprocess_video(video_path) |
| |
|
| | |
| | if self.with_audio: |
| | |
| | |
| | |
| | target_duration = media_tensor.shape[1] / fps |
| | audio_data = self._extract_audio(video_path, target_duration) |
| | else: |
| | audio_data = None |
| |
|
| | |
| | _, num_frames, height, width = media_tensor.shape |
| |
|
| | result = { |
| | "video": media_tensor, |
| | "relative_path": relative_path, |
| | "main_media_relative_path": media_relative_path, |
| | "video_metadata": { |
| | "num_frames": num_frames, |
| | "height": height, |
| | "width": width, |
| | "fps": fps, |
| | }, |
| | } |
| |
|
| | |
| | if audio_data is not None: |
| | result["audio"] = audio_data |
| |
|
| | return result |
| |
|
| | @staticmethod |
| | def _extract_audio(video_path: Path, target_duration: float) -> dict[str, torch.Tensor | int] | None: |
| | """Extract audio track from a video file, trimmed to match video duration.""" |
| | try: |
| | |
| | |
| | waveform, sample_rate = torchaudio.load(str(video_path)) |
| |
|
| | |
| | target_samples = int(target_duration * sample_rate) |
| | current_samples = waveform.shape[-1] |
| |
|
| | if current_samples > target_samples: |
| | |
| | waveform = waveform[..., :target_samples] |
| | elif current_samples < target_samples: |
| | |
| | padding = target_samples - current_samples |
| | waveform = torch.nn.functional.pad(waveform, (0, padding)) |
| | logger.warning(f"Padded audio to {target_duration:.2f} seconds for {video_path}") |
| |
|
| | return {"waveform": waveform, "sample_rate": sample_rate} |
| |
|
| | except Exception as e: |
| | logger.debug(f"Could not extract audio from {video_path}: {e}") |
| | return None |
| |
|
| | def _load_video_paths(self, column: str) -> list[Path]: |
| | """Load video paths from the specified data source.""" |
| | if self.dataset_file.suffix == ".csv": |
| | return self._load_video_paths_from_csv(column) |
| | elif self.dataset_file.suffix == ".json": |
| | return self._load_video_paths_from_json(column) |
| | elif self.dataset_file.suffix == ".jsonl": |
| | return self._load_video_paths_from_jsonl(column) |
| | else: |
| | raise ValueError("Expected `dataset_file` to be a path to a CSV, JSON, or JSONL file.") |
| |
|
| | def _load_video_paths_from_csv(self, column: str) -> list[Path]: |
| | """Load video paths from a CSV file.""" |
| | df = pd.read_csv(self.dataset_file) |
| | if column not in df.columns: |
| | raise ValueError(f"Column '{column}' not found in CSV file") |
| |
|
| | data_root = self.dataset_file.parent |
| | video_paths = [data_root / Path(line.strip()) for line in df[column].tolist()] |
| |
|
| | |
| | invalid_paths = [path for path in video_paths if not path.is_file()] |
| | if invalid_paths: |
| | raise ValueError(f"Found {len(invalid_paths)} invalid video paths. First few: {invalid_paths[:5]}") |
| |
|
| | return video_paths |
| |
|
| | def _load_video_paths_from_json(self, column: str) -> list[Path]: |
| | """Load video paths from a JSON file.""" |
| | with open(self.dataset_file, "r", encoding="utf-8") as file: |
| | data = json.load(file) |
| |
|
| | if not isinstance(data, list): |
| | raise ValueError("JSON file must contain a list of objects") |
| |
|
| | data_root = self.dataset_file.parent |
| | video_paths = [] |
| | for entry in data: |
| | if column not in entry: |
| | raise ValueError(f"Key '{column}' not found in JSON entry") |
| | video_paths.append(data_root / Path(entry[column].strip())) |
| |
|
| | |
| | invalid_paths = [path for path in video_paths if not path.is_file()] |
| | if invalid_paths: |
| | raise ValueError(f"Found {len(invalid_paths)} invalid video paths. First few: {invalid_paths[:5]}") |
| |
|
| | return video_paths |
| |
|
| | def _load_video_paths_from_jsonl(self, column: str) -> list[Path]: |
| | """Load video paths from a JSONL file.""" |
| | data_root = self.dataset_file.parent |
| | video_paths = [] |
| | with open(self.dataset_file, "r", encoding="utf-8") as file: |
| | for line in file: |
| | entry = json.loads(line) |
| | if column not in entry: |
| | raise ValueError(f"Key '{column}' not found in JSONL entry") |
| | video_paths.append(data_root / Path(entry[column].strip())) |
| |
|
| | |
| | invalid_paths = [path for path in video_paths if not path.is_file()] |
| | if invalid_paths: |
| | raise ValueError(f"Found {len(invalid_paths)} invalid video paths. First few: {invalid_paths[:5]}") |
| |
|
| | return video_paths |
| |
|
| | def _filter_valid_videos(self) -> None: |
| | """Filter out videos with insufficient frames.""" |
| | original_length = len(self.video_paths) |
| | valid_video_paths = [] |
| | valid_main_media_paths = [] |
| | min_frames_required = min(self.resolution_buckets, key=lambda x: x[0])[0] |
| |
|
| | for i, video_path in enumerate(self.video_paths): |
| | if video_path.suffix.lower() in [".png", ".jpg", ".jpeg"]: |
| | valid_video_paths.append(video_path) |
| | valid_main_media_paths.append(self.main_media_paths[i]) |
| | continue |
| |
|
| | try: |
| | frame_count = get_video_frame_count(video_path) |
| |
|
| | if frame_count >= min_frames_required: |
| | valid_video_paths.append(video_path) |
| | valid_main_media_paths.append(self.main_media_paths[i]) |
| | else: |
| | logger.warning( |
| | f"Skipping video at {video_path} - has {frame_count} frames, " |
| | f"which is less than the minimum required frames ({min_frames_required})" |
| | ) |
| | except Exception as e: |
| | logger.warning(f"Failed to read video at {video_path}: {e!s}") |
| |
|
| | |
| | self.video_paths = valid_video_paths |
| | self.main_media_paths = valid_main_media_paths |
| |
|
| | if len(self.video_paths) < original_length: |
| | logger.warning( |
| | f"Filtered out {original_length - len(self.video_paths)} videos with insufficient frames. " |
| | f"Proceeding with {len(self.video_paths)} valid videos." |
| | ) |
| |
|
| | def _preprocess_image(self, path: Path) -> torch.Tensor: |
| | """Preprocess a single image by resizing and applying transforms.""" |
| | image = open_image_as_srgb(path) |
| | image = to_tensor(image) |
| | image = image.unsqueeze(0) |
| |
|
| | |
| | nearest_bucket = self._get_resolution_bucket_for_item(image) |
| | _, target_height, target_width = nearest_bucket |
| | image_resized = self._resize_and_crop(image, target_height, target_width) |
| | |
| |
|
| | |
| | image = self.transforms(image_resized) |
| |
|
| | |
| | image = image.unsqueeze(1) |
| | return image |
| |
|
| | def _preprocess_video(self, path: Path) -> tuple[torch.Tensor, float]: |
| | """Preprocess a video by loading, resizing, and applying transforms. |
| | |
| | Returns: |
| | Tuple of (video tensor in [C, F, H, W] format, fps) |
| | """ |
| | |
| | video, fps = read_video(path, max_frames=self.max_target_frames) |
| |
|
| | nearest_bucket = self._get_resolution_bucket_for_item(video) |
| | target_num_frames, target_height, target_width = nearest_bucket |
| | frames_resized = self._resize_and_crop(video, target_height, target_width) |
| |
|
| | |
| | frames_resized = frames_resized[:target_num_frames] |
| |
|
| | |
| | video = torch.stack([self.transforms(frame) for frame in frames_resized], dim=0) |
| |
|
| | |
| | |
| | video = video.permute(1, 0, 2, 3).contiguous() |
| |
|
| | return video, fps |
| |
|
| | def _get_resolution_bucket_for_item(self, media_tensor: torch.Tensor) -> tuple[int, int, int]: |
| | """Get the nearest resolution bucket for the given media tensor.""" |
| | num_frames, _, height, width = media_tensor.shape |
| |
|
| | def distance(bucket: tuple[int, int, int]) -> tuple: |
| | bucket_num_frames, bucket_height, bucket_width = bucket |
| | |
| | |
| | |
| | |
| | return ( |
| | abs(math.log(width / height) - math.log(bucket_width / bucket_height)), |
| | -bucket_num_frames, |
| | -(bucket_height * bucket_width), |
| | ) |
| |
|
| | |
| | relevant_buckets = [b for b in self.resolution_buckets if b[0] <= num_frames] |
| | if not relevant_buckets: |
| | raise ValueError(f"No resolution buckets have <= {num_frames} frames. Available: {self.resolution_buckets}") |
| |
|
| | |
| | nearest_bucket = min(relevant_buckets, key=distance) |
| |
|
| | return nearest_bucket |
| |
|
| | def _resize_and_crop(self, media_tensor: torch.Tensor, target_height: int, target_width: int) -> torch.Tensor: |
| | """Resize and crop tensor to target size.""" |
| | |
| | current_height, current_width = media_tensor.shape[2], media_tensor.shape[3] |
| |
|
| | |
| | current_aspect = current_width / current_height |
| | target_aspect = target_width / target_height |
| |
|
| | |
| | if current_aspect > target_aspect: |
| | |
| | new_width = int(current_width * target_height / current_height) |
| | media_tensor = resize( |
| | media_tensor, |
| | size=[target_height, new_width], |
| | interpolation=InterpolationMode.BICUBIC, |
| | ) |
| | else: |
| | |
| | new_height = int(current_height * target_width / current_width) |
| | media_tensor = resize( |
| | media_tensor, |
| | size=[new_height, target_width], |
| | interpolation=InterpolationMode.BICUBIC, |
| | ) |
| |
|
| | |
| | current_height, current_width = media_tensor.shape[2], media_tensor.shape[3] |
| | media_tensor = media_tensor.squeeze(0) |
| |
|
| | |
| | delta_h = current_height - target_height |
| | delta_w = current_width - target_width |
| |
|
| | |
| | if self.reshape_mode == "random": |
| | |
| | top = np.random.randint(0, delta_h + 1) |
| | left = np.random.randint(0, delta_w + 1) |
| | elif self.reshape_mode == "center": |
| | |
| | top, left = delta_h // 2, delta_w // 2 |
| | else: |
| | raise ValueError(f"Unsupported reshape mode: {self.reshape_mode}") |
| |
|
| | |
| | media_tensor = crop(media_tensor, top=top, left=left, height=target_height, width=target_width) |
| | return media_tensor |
| |
|
| |
|
| | def compute_latents( |
| | dataset_file: str | Path, |
| | video_column: str, |
| | resolution_buckets: list[tuple[int, int, int]], |
| | output_dir: str, |
| | model_path: str, |
| | main_media_column: str | None = None, |
| | reshape_mode: str = "center", |
| | batch_size: int = 1, |
| | device: str = "cuda", |
| | vae_tiling: bool = False, |
| | with_audio: bool = False, |
| | audio_output_dir: str | None = None, |
| | ) -> None: |
| | """ |
| | Process videos and save latent representations. |
| | |
| | Args: |
| | dataset_file: Path to metadata file (CSV/JSON/JSONL) containing video paths |
| | video_column: Column name for video paths in the metadata file |
| | resolution_buckets: List of (frames, height, width) tuples |
| | output_dir: Directory to save video latents |
| | model_path: Path to LTX-2 checkpoint (.safetensors) |
| | reshape_mode: How to crop videos ("center", "random") |
| | main_media_column: Column name for main media paths (if different from video_column) |
| | batch_size: Batch size for processing |
| | device: Device to use for computation |
| | vae_tiling: Whether to enable VAE tiling |
| | with_audio: Whether to extract and encode audio from videos |
| | audio_output_dir: Directory to save audio latents (required if with_audio=True) |
| | """ |
| | |
| | if with_audio and audio_output_dir is None: |
| | raise ValueError("audio_output_dir must be provided when with_audio=True") |
| |
|
| | console = Console() |
| | torch_device = torch.device(device) |
| |
|
| | |
| | dataset = MediaDataset( |
| | dataset_file=dataset_file, |
| | main_media_column=main_media_column or video_column, |
| | video_column=video_column, |
| | resolution_buckets=resolution_buckets, |
| | reshape_mode=reshape_mode, |
| | with_audio=with_audio, |
| | ) |
| | logger.info(f"Loaded {len(dataset)} valid media files") |
| |
|
| | output_path = Path(output_dir) |
| | output_path.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | audio_output_path = None |
| | if with_audio: |
| | audio_output_path = Path(audio_output_dir) |
| | audio_output_path.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | with console.status(f"[bold]Loading video VAE encoder from [cyan]{model_path}[/]...", spinner="dots"): |
| | vae = load_video_vae_encoder(model_path, device=torch_device, dtype=torch.bfloat16) |
| |
|
| | if vae_tiling: |
| | vae.enable_tiling() |
| |
|
| | |
| | audio_vae_encoder = None |
| | audio_processor = None |
| | if with_audio: |
| | with console.status(f"[bold]Loading audio VAE encoder from [cyan]{model_path}[/]...", spinner="dots"): |
| | audio_vae_encoder = load_audio_vae_encoder( |
| | checkpoint_path=model_path, |
| | device=torch_device, |
| | dtype=torch.float32, |
| | ) |
| | |
| | audio_processor = AudioProcessor( |
| | sample_rate=audio_vae_encoder.sample_rate, |
| | mel_bins=audio_vae_encoder.mel_bins, |
| | mel_hop_length=audio_vae_encoder.mel_hop_length, |
| | n_fft=audio_vae_encoder.n_fft, |
| | ).to(torch_device) |
| |
|
| | |
| | |
| | |
| | if with_audio and batch_size > 1: |
| | logger.warning("Audio processing requires batch_size=1. Overriding batch_size to 1.") |
| | batch_size = 1 |
| | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4) |
| |
|
| | |
| | audio_success_count = 0 |
| | audio_skip_count = 0 |
| |
|
| | |
| | with Progress( |
| | SpinnerColumn(), |
| | TextColumn("[progress.description]{task.description}"), |
| | BarColumn(), |
| | TaskProgressColumn(), |
| | MofNCompleteColumn(), |
| | TimeElapsedColumn(), |
| | TimeRemainingColumn(), |
| | console=console, |
| | ) as progress: |
| | task = progress.add_task("Processing videos", total=len(dataloader)) |
| |
|
| | for batch in dataloader: |
| | |
| | video = batch["video"] |
| |
|
| | |
| | with torch.inference_mode(): |
| | video_latent_data = encode_video(vae=vae, video=video) |
| |
|
| | |
| | for i in range(len(batch["relative_path"])): |
| | output_rel_path = Path(batch["main_media_relative_path"][i]).with_suffix(".pt") |
| | output_file = output_path / output_rel_path |
| |
|
| | |
| | output_file.parent.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | latent_data = { |
| | "latents": video_latent_data["latents"][i].cpu().contiguous(), |
| | "num_frames": video_latent_data["num_frames"], |
| | "height": video_latent_data["height"], |
| | "width": video_latent_data["width"], |
| | "fps": batch["video_metadata"]["fps"][i].item(), |
| | } |
| |
|
| | torch.save(latent_data, output_file) |
| |
|
| | |
| | if with_audio: |
| | audio_batch = batch.get("audio") |
| | if audio_batch is not None: |
| | |
| | |
| | audio_data = { |
| | "waveform": audio_batch["waveform"][i], |
| | "sample_rate": audio_batch["sample_rate"][i].item(), |
| | } |
| |
|
| | |
| | with torch.inference_mode(): |
| | audio_latents = encode_audio(audio_vae_encoder, audio_processor, audio_data) |
| |
|
| | |
| | audio_output_file = audio_output_path / output_rel_path |
| | audio_output_file.parent.mkdir(parents=True, exist_ok=True) |
| |
|
| | audio_save_data = { |
| | "latents": audio_latents["latents"].cpu().contiguous(), |
| | "num_time_steps": audio_latents["num_time_steps"], |
| | "frequency_bins": audio_latents["frequency_bins"], |
| | "duration": audio_latents["duration"], |
| | } |
| |
|
| | torch.save(audio_save_data, audio_output_file) |
| | audio_success_count += 1 |
| | else: |
| | |
| | audio_skip_count += 1 |
| |
|
| | progress.advance(task) |
| |
|
| | |
| | logger.info(f"Processed {len(dataset)} videos. Latents saved to {output_path}") |
| | if with_audio: |
| | logger.info( |
| | f"Audio processing: {audio_success_count} videos with audio, " |
| | f"{audio_skip_count} videos without audio (skipped)" |
| | ) |
| |
|
| |
|
| | def encode_video( |
| | vae: torch.nn.Module, |
| | video: torch.Tensor, |
| | dtype: torch.dtype | None = None, |
| | ) -> dict[str, torch.Tensor | int]: |
| | """Encode video into non-patchified latent representation. |
| | |
| | Args: |
| | vae: Video VAE encoder model |
| | video: Input tensor of shape [B, C, F, H, W] (batch, channels, frames, height, width) |
| | This is the format expected by the VAE encoder. |
| | dtype: Target dtype for output latents |
| | |
| | Returns: |
| | Dict containing non-patchified latents and shape information: |
| | { |
| | "latents": Tensor[B, C, F', H', W'], # Non-patchified format with batch dim |
| | "num_frames": int, # Latent frame count |
| | "height": int, # Latent height |
| | "width": int, # Latent width |
| | } |
| | """ |
| | device = next(vae.parameters()).device |
| | vae_dtype = next(vae.parameters()).dtype |
| |
|
| | |
| | if video.ndim == 4: |
| | video = video.unsqueeze(0) |
| |
|
| | video = video.to(device=device, dtype=vae_dtype) |
| |
|
| | |
| | latents = vae(video) |
| |
|
| | if dtype is not None: |
| | latents = latents.to(dtype=dtype) |
| |
|
| | _, _, num_frames, height, width = latents.shape |
| |
|
| | return { |
| | "latents": latents, |
| | "num_frames": num_frames, |
| | "height": height, |
| | "width": width, |
| | } |
| |
|
| |
|
| | def encode_audio( |
| | audio_vae_encoder: torch.nn.Module, |
| | audio_processor: torch.nn.Module, |
| | audio_data: dict[str, torch.Tensor | int], |
| | ) -> dict[str, torch.Tensor | int | float]: |
| | """Encode audio waveform into latent representation. |
| | |
| | Args: |
| | audio_vae_encoder: Audio VAE encoder model from ltx-core |
| | audio_processor: AudioProcessor for waveform-to-spectrogram conversion |
| | audio_data: Dict with {"waveform": Tensor[channels, samples], "sample_rate": int} |
| | |
| | Returns: |
| | Dict containing audio latents and shape information: |
| | { |
| | "latents": Tensor[C, T, F], # Non-patchified format |
| | "num_time_steps": int, |
| | "frequency_bins": int, |
| | "duration": float, |
| | } |
| | """ |
| | device = next(audio_vae_encoder.parameters()).device |
| | dtype = next(audio_vae_encoder.parameters()).dtype |
| |
|
| | waveform = audio_data["waveform"].to(device=device, dtype=dtype) |
| | sample_rate = audio_data["sample_rate"] |
| |
|
| | |
| | if waveform.dim() == 2: |
| | waveform = waveform.unsqueeze(0) |
| |
|
| | |
| | duration = waveform.shape[-1] / sample_rate |
| |
|
| | |
| | mel_spectrogram = audio_processor.waveform_to_mel(waveform, waveform_sample_rate=sample_rate) |
| | mel_spectrogram = mel_spectrogram.to(dtype=dtype) |
| |
|
| | |
| | latents = audio_vae_encoder(mel_spectrogram) |
| |
|
| | |
| | _, _channels, time_steps, freq_bins = latents.shape |
| |
|
| | return { |
| | "latents": latents.squeeze(0), |
| | "num_time_steps": time_steps, |
| | "frequency_bins": freq_bins, |
| | "duration": duration, |
| | } |
| |
|
| |
|
| | def parse_resolution_buckets(resolution_buckets_str: str) -> list[tuple[int, int, int]]: |
| | """Parse resolution buckets from string format to list of tuples (frames, height, width)""" |
| | resolution_buckets = [] |
| | for bucket_str in resolution_buckets_str.split(";"): |
| | w, h, f = map(int, bucket_str.split("x")) |
| |
|
| | if w % VAE_SPATIAL_FACTOR != 0 or h % VAE_SPATIAL_FACTOR != 0: |
| | raise typer.BadParameter( |
| | f"Width and height must be multiples of {VAE_SPATIAL_FACTOR}, got {w}x{h}", |
| | param_hint="resolution-buckets", |
| | ) |
| |
|
| | if f % VAE_TEMPORAL_FACTOR != 1: |
| | raise typer.BadParameter( |
| | f"Number of frames must be a multiple of {VAE_TEMPORAL_FACTOR} plus 1, got {f}", |
| | param_hint="resolution-buckets", |
| | ) |
| |
|
| | resolution_buckets.append((f, h, w)) |
| | return resolution_buckets |
| |
|
| |
|
| | @app.command() |
| | def main( |
| | dataset_file: str = typer.Argument( |
| | ..., |
| | help="Path to metadata file (CSV/JSON/JSONL) containing video paths", |
| | ), |
| | resolution_buckets: str = typer.Option( |
| | ..., |
| | help='Resolution buckets in format "WxHxF;WxHxF;..." (e.g. "768x768x25;512x512x49")', |
| | ), |
| | output_dir: str = typer.Option( |
| | ..., |
| | help="Output directory to save video latents", |
| | ), |
| | model_path: str = typer.Option( |
| | ..., |
| | help="Path to LTX-2 checkpoint (.safetensors file)", |
| | ), |
| | video_column: str = typer.Option( |
| | default="media_path", |
| | help="Column name in the dataset JSON/JSONL/CSV file containing video paths", |
| | ), |
| | batch_size: int = typer.Option( |
| | default=1, |
| | help="Batch size for processing", |
| | ), |
| | device: str = typer.Option( |
| | default="cuda", |
| | help="Device to use for computation", |
| | ), |
| | vae_tiling: bool = typer.Option( |
| | default=False, |
| | help="Enable VAE tiling for larger video resolutions", |
| | ), |
| | reshape_mode: str = typer.Option( |
| | default="center", |
| | help="How to crop videos: 'center' or 'random'", |
| | ), |
| | with_audio: bool = typer.Option( |
| | default=False, |
| | help="Extract and encode audio from video files", |
| | ), |
| | audio_output_dir: str | None = typer.Option( |
| | default=None, |
| | help="Output directory for audio latents (required if --with-audio is set)", |
| | ), |
| | ) -> None: |
| | """Process videos/images and save latent representations for video generation training. |
| | |
| | This script processes videos and images from metadata files and saves latent representations |
| | that can be used for training video generation models. The output latents will maintain |
| | the same folder structure and naming as the corresponding media files. |
| | |
| | Examples: |
| | # Process videos from a CSV file |
| | python scripts/process_videos.py dataset.csv --resolution-buckets 768x768x25 \\ |
| | --output-dir ./latents --model-path /path/to/ltx2.safetensors |
| | |
| | # Process videos from a JSON file with custom video column |
| | python scripts/process_videos.py dataset.json --resolution-buckets 768x768x25 \\ |
| | --output-dir ./latents --model-path /path/to/ltx2.safetensors --video-column "video_path" |
| | |
| | # Enable VAE tiling to save GPU VRAM |
| | python scripts/process_videos.py dataset.csv --resolution-buckets 1024x1024x25 \\ |
| | --output-dir ./latents --model-path /path/to/ltx2.safetensors --vae-tiling |
| | |
| | # Process videos with audio |
| | python scripts/process_videos.py dataset.csv --resolution-buckets 768x768x25 \\ |
| | --output-dir ./latents --model-path /path/to/ltx2.safetensors \\ |
| | --with-audio --audio-output-dir ./audio_latents |
| | """ |
| |
|
| | |
| | if not Path(dataset_file).is_file(): |
| | raise typer.BadParameter(f"Dataset file not found: {dataset_file}") |
| |
|
| | |
| | if with_audio and audio_output_dir is None: |
| | raise typer.BadParameter("--audio-output-dir is required when --with-audio is set") |
| |
|
| | |
| | parsed_resolution_buckets = parse_resolution_buckets(resolution_buckets) |
| |
|
| | if len(parsed_resolution_buckets) > 1: |
| | logger.warning( |
| | "Using multiple resolution buckets. " |
| | "When training with multiple resolution buckets, you must use a batch size of 1." |
| | ) |
| |
|
| | |
| | compute_latents( |
| | dataset_file=dataset_file, |
| | video_column=video_column, |
| | resolution_buckets=parsed_resolution_buckets, |
| | output_dir=output_dir, |
| | model_path=model_path, |
| | reshape_mode=reshape_mode, |
| | batch_size=batch_size, |
| | device=device, |
| | vae_tiling=vae_tiling, |
| | with_audio=with_audio, |
| | audio_output_dir=audio_output_dir, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | app() |
| |
|