|
|
|
|
|
|
|
|
""" |
|
|
Compute latent representations for video generation training. |
|
|
|
|
|
This module provides functionality for processing video and image files, including: |
|
|
- Loading videos/images from various file formats (CSV, JSON, JSONL) |
|
|
- Resizing, cropping, and transforming media |
|
|
- MediaDataset for video-only preprocessing workflows |
|
|
- BucketSampler for grouping videos by resolution |
|
|
|
|
|
Can be used as a standalone script: |
|
|
python scripts/process_videos.py dataset.csv --resolution-buckets 768x768x25 \ |
|
|
--output-dir /path/to/output --model-source /path/to/ltx2.safetensors |
|
|
""" |
|
|
|
|
|
import json |
|
|
import math |
|
|
from pathlib import Path |
|
|
from typing import Any |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
import torchaudio |
|
|
import typer |
|
|
from pillow_heif import register_heif_opener |
|
|
from rich.console import Console |
|
|
from rich.progress import ( |
|
|
BarColumn, |
|
|
MofNCompleteColumn, |
|
|
Progress, |
|
|
SpinnerColumn, |
|
|
TaskProgressColumn, |
|
|
TextColumn, |
|
|
TimeElapsedColumn, |
|
|
TimeRemainingColumn, |
|
|
) |
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
from torchvision import transforms |
|
|
from torchvision.transforms import InterpolationMode |
|
|
from torchvision.transforms.functional import crop, resize, to_tensor |
|
|
from transformers.utils.logging import disable_progress_bar |
|
|
|
|
|
from ltx_core.model.audio_vae.ops import AudioProcessor |
|
|
from ltx_trainer import logger |
|
|
from ltx_trainer.model_loader import load_audio_vae_encoder, load_video_vae_encoder |
|
|
from ltx_trainer.utils import open_image_as_srgb |
|
|
from ltx_trainer.video_utils import get_video_frame_count, read_video |
|
|
|
|
|
disable_progress_bar() |
|
|
|
|
|
|
|
|
register_heif_opener() |
|
|
|
|
|
|
|
|
VAE_SPATIAL_FACTOR = 32 |
|
|
VAE_TEMPORAL_FACTOR = 8 |
|
|
|
|
|
|
|
|
AUDIO_LATENT_CHANNELS = 8 |
|
|
AUDIO_FREQUENCY_BINS = 16 |
|
|
|
|
|
app = typer.Typer( |
|
|
pretty_exceptions_enable=False, |
|
|
no_args_is_help=True, |
|
|
help="Process videos/images and save latent representations for video generation training.", |
|
|
) |
|
|
|
|
|
|
|
|
class MediaDataset(Dataset): |
|
|
""" |
|
|
Dataset for processing video and image files. |
|
|
|
|
|
This dataset is designed for media preprocessing workflows where you need to: |
|
|
- Load and preprocess videos/images |
|
|
- Apply resizing and cropping transformations |
|
|
- Handle different resolution buckets |
|
|
- Filter out invalid media files |
|
|
- Optionally extract audio from video files |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dataset_file: str | Path, |
|
|
main_media_column: str, |
|
|
video_column: str, |
|
|
resolution_buckets: list[tuple[int, int, int]], |
|
|
reshape_mode: str = "center", |
|
|
with_audio: bool = False, |
|
|
) -> None: |
|
|
""" |
|
|
Initialize the media dataset. |
|
|
|
|
|
Args: |
|
|
dataset_file: Path to CSV/JSON/JSONL metadata file |
|
|
video_column: Column name for video paths in the metadata file |
|
|
resolution_buckets: List of (frames, height, width) tuples |
|
|
reshape_mode: How to crop videos ("center", "random") |
|
|
with_audio: Whether to extract audio from video files |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.dataset_file = Path(dataset_file) |
|
|
self.main_media_column = main_media_column |
|
|
self.resolution_buckets = resolution_buckets |
|
|
self.reshape_mode = reshape_mode |
|
|
self.with_audio = with_audio |
|
|
|
|
|
|
|
|
self.main_media_paths = self._load_video_paths(main_media_column) |
|
|
|
|
|
|
|
|
self.video_paths = self._load_video_paths(video_column) |
|
|
|
|
|
|
|
|
self._filter_valid_videos() |
|
|
|
|
|
self.max_target_frames = max(self.resolution_buckets, key=lambda x: x[0])[0] |
|
|
|
|
|
|
|
|
self.transforms = transforms.Compose( |
|
|
[ |
|
|
transforms.Lambda(lambda x: x.clamp_(0, 1)), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
|
] |
|
|
) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.video_paths) |
|
|
|
|
|
def __getitem__(self, index: int) -> dict[str, Any]: |
|
|
"""Get a single video/image with metadata, and optionally audio.""" |
|
|
if isinstance(index, list): |
|
|
|
|
|
return index |
|
|
|
|
|
video_path: Path = self.video_paths[index] |
|
|
|
|
|
|
|
|
data_root = self.dataset_file.parent |
|
|
relative_path = str(video_path.relative_to(data_root)) |
|
|
media_relative_path = str(self.main_media_paths[index].relative_to(data_root)) |
|
|
|
|
|
if video_path.suffix.lower() in [".png", ".jpg", ".jpeg"]: |
|
|
media_tensor = self._preprocess_image(video_path) |
|
|
fps = 1.0 |
|
|
audio_data = None |
|
|
else: |
|
|
media_tensor, fps = self._preprocess_video(video_path) |
|
|
|
|
|
|
|
|
if self.with_audio: |
|
|
|
|
|
|
|
|
|
|
|
target_duration = media_tensor.shape[1] / fps |
|
|
audio_data = self._extract_audio(video_path, target_duration) |
|
|
else: |
|
|
audio_data = None |
|
|
|
|
|
|
|
|
_, num_frames, height, width = media_tensor.shape |
|
|
|
|
|
result = { |
|
|
"video": media_tensor, |
|
|
"relative_path": relative_path, |
|
|
"main_media_relative_path": media_relative_path, |
|
|
"video_metadata": { |
|
|
"num_frames": num_frames, |
|
|
"height": height, |
|
|
"width": width, |
|
|
"fps": fps, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
if audio_data is not None: |
|
|
result["audio"] = audio_data |
|
|
|
|
|
return result |
|
|
|
|
|
@staticmethod |
|
|
def _extract_audio(video_path: Path, target_duration: float) -> dict[str, torch.Tensor | int] | None: |
|
|
"""Extract audio track from a video file, trimmed to match video duration.""" |
|
|
try: |
|
|
|
|
|
|
|
|
waveform, sample_rate = torchaudio.load(str(video_path)) |
|
|
|
|
|
|
|
|
target_samples = int(target_duration * sample_rate) |
|
|
current_samples = waveform.shape[-1] |
|
|
|
|
|
if current_samples > target_samples: |
|
|
|
|
|
waveform = waveform[..., :target_samples] |
|
|
elif current_samples < target_samples: |
|
|
|
|
|
padding = target_samples - current_samples |
|
|
waveform = torch.nn.functional.pad(waveform, (0, padding)) |
|
|
logger.warning(f"Padded audio to {target_duration:.2f} seconds for {video_path}") |
|
|
|
|
|
return {"waveform": waveform, "sample_rate": sample_rate} |
|
|
|
|
|
except Exception as e: |
|
|
logger.debug(f"Could not extract audio from {video_path}: {e}") |
|
|
return None |
|
|
|
|
|
def _load_video_paths(self, column: str) -> list[Path]: |
|
|
"""Load video paths from the specified data source.""" |
|
|
if self.dataset_file.suffix == ".csv": |
|
|
return self._load_video_paths_from_csv(column) |
|
|
elif self.dataset_file.suffix == ".json": |
|
|
return self._load_video_paths_from_json(column) |
|
|
elif self.dataset_file.suffix == ".jsonl": |
|
|
return self._load_video_paths_from_jsonl(column) |
|
|
else: |
|
|
raise ValueError("Expected `dataset_file` to be a path to a CSV, JSON, or JSONL file.") |
|
|
|
|
|
def _load_video_paths_from_csv(self, column: str) -> list[Path]: |
|
|
"""Load video paths from a CSV file.""" |
|
|
df = pd.read_csv(self.dataset_file) |
|
|
if column not in df.columns: |
|
|
raise ValueError(f"Column '{column}' not found in CSV file") |
|
|
|
|
|
data_root = self.dataset_file.parent |
|
|
video_paths = [data_root / Path(line.strip()) for line in df[column].tolist()] |
|
|
|
|
|
|
|
|
invalid_paths = [path for path in video_paths if not path.is_file()] |
|
|
if invalid_paths: |
|
|
raise ValueError(f"Found {len(invalid_paths)} invalid video paths. First few: {invalid_paths[:5]}") |
|
|
|
|
|
return video_paths |
|
|
|
|
|
def _load_video_paths_from_json(self, column: str) -> list[Path]: |
|
|
"""Load video paths from a JSON file.""" |
|
|
with open(self.dataset_file, "r", encoding="utf-8") as file: |
|
|
data = json.load(file) |
|
|
|
|
|
if not isinstance(data, list): |
|
|
raise ValueError("JSON file must contain a list of objects") |
|
|
|
|
|
data_root = self.dataset_file.parent |
|
|
video_paths = [] |
|
|
for entry in data: |
|
|
if column not in entry: |
|
|
raise ValueError(f"Key '{column}' not found in JSON entry") |
|
|
video_paths.append(data_root / Path(entry[column].strip())) |
|
|
|
|
|
|
|
|
invalid_paths = [path for path in video_paths if not path.is_file()] |
|
|
if invalid_paths: |
|
|
raise ValueError(f"Found {len(invalid_paths)} invalid video paths. First few: {invalid_paths[:5]}") |
|
|
|
|
|
return video_paths |
|
|
|
|
|
def _load_video_paths_from_jsonl(self, column: str) -> list[Path]: |
|
|
"""Load video paths from a JSONL file.""" |
|
|
data_root = self.dataset_file.parent |
|
|
video_paths = [] |
|
|
with open(self.dataset_file, "r", encoding="utf-8") as file: |
|
|
for line in file: |
|
|
entry = json.loads(line) |
|
|
if column not in entry: |
|
|
raise ValueError(f"Key '{column}' not found in JSONL entry") |
|
|
video_paths.append(data_root / Path(entry[column].strip())) |
|
|
|
|
|
|
|
|
invalid_paths = [path for path in video_paths if not path.is_file()] |
|
|
if invalid_paths: |
|
|
raise ValueError(f"Found {len(invalid_paths)} invalid video paths. First few: {invalid_paths[:5]}") |
|
|
|
|
|
return video_paths |
|
|
|
|
|
def _filter_valid_videos(self) -> None: |
|
|
"""Filter out videos with insufficient frames.""" |
|
|
original_length = len(self.video_paths) |
|
|
valid_video_paths = [] |
|
|
valid_main_media_paths = [] |
|
|
min_frames_required = min(self.resolution_buckets, key=lambda x: x[0])[0] |
|
|
|
|
|
for i, video_path in enumerate(self.video_paths): |
|
|
if video_path.suffix.lower() in [".png", ".jpg", ".jpeg"]: |
|
|
valid_video_paths.append(video_path) |
|
|
valid_main_media_paths.append(self.main_media_paths[i]) |
|
|
continue |
|
|
|
|
|
try: |
|
|
frame_count = get_video_frame_count(video_path) |
|
|
|
|
|
if frame_count >= min_frames_required: |
|
|
valid_video_paths.append(video_path) |
|
|
valid_main_media_paths.append(self.main_media_paths[i]) |
|
|
else: |
|
|
logger.warning( |
|
|
f"Skipping video at {video_path} - has {frame_count} frames, " |
|
|
f"which is less than the minimum required frames ({min_frames_required})" |
|
|
) |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to read video at {video_path}: {e!s}") |
|
|
|
|
|
|
|
|
self.video_paths = valid_video_paths |
|
|
self.main_media_paths = valid_main_media_paths |
|
|
|
|
|
if len(self.video_paths) < original_length: |
|
|
logger.warning( |
|
|
f"Filtered out {original_length - len(self.video_paths)} videos with insufficient frames. " |
|
|
f"Proceeding with {len(self.video_paths)} valid videos." |
|
|
) |
|
|
|
|
|
def _preprocess_image(self, path: Path) -> torch.Tensor: |
|
|
"""Preprocess a single image by resizing and applying transforms.""" |
|
|
image = open_image_as_srgb(path) |
|
|
image = to_tensor(image) |
|
|
image = image.unsqueeze(0) |
|
|
|
|
|
|
|
|
nearest_bucket = self._get_resolution_bucket_for_item(image) |
|
|
_, target_height, target_width = nearest_bucket |
|
|
image_resized = self._resize_and_crop(image, target_height, target_width) |
|
|
|
|
|
|
|
|
|
|
|
image = self.transforms(image_resized) |
|
|
|
|
|
|
|
|
image = image.unsqueeze(1) |
|
|
return image |
|
|
|
|
|
def _preprocess_video(self, path: Path) -> tuple[torch.Tensor, float]: |
|
|
"""Preprocess a video by loading, resizing, and applying transforms. |
|
|
|
|
|
Returns: |
|
|
Tuple of (video tensor in [C, F, H, W] format, fps) |
|
|
""" |
|
|
|
|
|
video, fps = read_video(path, max_frames=self.max_target_frames) |
|
|
|
|
|
nearest_bucket = self._get_resolution_bucket_for_item(video) |
|
|
target_num_frames, target_height, target_width = nearest_bucket |
|
|
frames_resized = self._resize_and_crop(video, target_height, target_width) |
|
|
|
|
|
|
|
|
frames_resized = frames_resized[:target_num_frames] |
|
|
|
|
|
|
|
|
video = torch.stack([self.transforms(frame) for frame in frames_resized], dim=0) |
|
|
|
|
|
|
|
|
|
|
|
video = video.permute(1, 0, 2, 3).contiguous() |
|
|
|
|
|
return video, fps |
|
|
|
|
|
def _get_resolution_bucket_for_item(self, media_tensor: torch.Tensor) -> tuple[int, int, int]: |
|
|
"""Get the nearest resolution bucket for the given media tensor.""" |
|
|
num_frames, _, height, width = media_tensor.shape |
|
|
|
|
|
def distance(bucket: tuple[int, int, int]) -> tuple: |
|
|
bucket_num_frames, bucket_height, bucket_width = bucket |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return ( |
|
|
abs(math.log(width / height) - math.log(bucket_width / bucket_height)), |
|
|
-bucket_num_frames, |
|
|
-(bucket_height * bucket_width), |
|
|
) |
|
|
|
|
|
|
|
|
relevant_buckets = [b for b in self.resolution_buckets if b[0] <= num_frames] |
|
|
if not relevant_buckets: |
|
|
raise ValueError(f"No resolution buckets have <= {num_frames} frames. Available: {self.resolution_buckets}") |
|
|
|
|
|
|
|
|
nearest_bucket = min(relevant_buckets, key=distance) |
|
|
|
|
|
return nearest_bucket |
|
|
|
|
|
def _resize_and_crop(self, media_tensor: torch.Tensor, target_height: int, target_width: int) -> torch.Tensor: |
|
|
"""Resize and crop tensor to target size.""" |
|
|
|
|
|
current_height, current_width = media_tensor.shape[2], media_tensor.shape[3] |
|
|
|
|
|
|
|
|
current_aspect = current_width / current_height |
|
|
target_aspect = target_width / target_height |
|
|
|
|
|
|
|
|
if current_aspect > target_aspect: |
|
|
|
|
|
new_width = int(current_width * target_height / current_height) |
|
|
media_tensor = resize( |
|
|
media_tensor, |
|
|
size=[target_height, new_width], |
|
|
interpolation=InterpolationMode.BICUBIC, |
|
|
) |
|
|
else: |
|
|
|
|
|
new_height = int(current_height * target_width / current_width) |
|
|
media_tensor = resize( |
|
|
media_tensor, |
|
|
size=[new_height, target_width], |
|
|
interpolation=InterpolationMode.BICUBIC, |
|
|
) |
|
|
|
|
|
|
|
|
current_height, current_width = media_tensor.shape[2], media_tensor.shape[3] |
|
|
media_tensor = media_tensor.squeeze(0) |
|
|
|
|
|
|
|
|
delta_h = current_height - target_height |
|
|
delta_w = current_width - target_width |
|
|
|
|
|
|
|
|
if self.reshape_mode == "random": |
|
|
|
|
|
top = np.random.randint(0, delta_h + 1) |
|
|
left = np.random.randint(0, delta_w + 1) |
|
|
elif self.reshape_mode == "center": |
|
|
|
|
|
top, left = delta_h // 2, delta_w // 2 |
|
|
else: |
|
|
raise ValueError(f"Unsupported reshape mode: {self.reshape_mode}") |
|
|
|
|
|
|
|
|
media_tensor = crop(media_tensor, top=top, left=left, height=target_height, width=target_width) |
|
|
return media_tensor |
|
|
|
|
|
|
|
|
def compute_latents( |
|
|
dataset_file: str | Path, |
|
|
video_column: str, |
|
|
resolution_buckets: list[tuple[int, int, int]], |
|
|
output_dir: str, |
|
|
model_path: str, |
|
|
main_media_column: str | None = None, |
|
|
reshape_mode: str = "center", |
|
|
batch_size: int = 1, |
|
|
device: str = "cuda", |
|
|
vae_tiling: bool = False, |
|
|
with_audio: bool = False, |
|
|
audio_output_dir: str | None = None, |
|
|
) -> None: |
|
|
""" |
|
|
Process videos and save latent representations. |
|
|
|
|
|
Args: |
|
|
dataset_file: Path to metadata file (CSV/JSON/JSONL) containing video paths |
|
|
video_column: Column name for video paths in the metadata file |
|
|
resolution_buckets: List of (frames, height, width) tuples |
|
|
output_dir: Directory to save video latents |
|
|
model_path: Path to LTX-2 checkpoint (.safetensors) |
|
|
reshape_mode: How to crop videos ("center", "random") |
|
|
main_media_column: Column name for main media paths (if different from video_column) |
|
|
batch_size: Batch size for processing |
|
|
device: Device to use for computation |
|
|
vae_tiling: Whether to enable VAE tiling |
|
|
with_audio: Whether to extract and encode audio from videos |
|
|
audio_output_dir: Directory to save audio latents (required if with_audio=True) |
|
|
""" |
|
|
|
|
|
if with_audio and audio_output_dir is None: |
|
|
raise ValueError("audio_output_dir must be provided when with_audio=True") |
|
|
|
|
|
console = Console() |
|
|
torch_device = torch.device(device) |
|
|
|
|
|
|
|
|
dataset = MediaDataset( |
|
|
dataset_file=dataset_file, |
|
|
main_media_column=main_media_column or video_column, |
|
|
video_column=video_column, |
|
|
resolution_buckets=resolution_buckets, |
|
|
reshape_mode=reshape_mode, |
|
|
with_audio=with_audio, |
|
|
) |
|
|
logger.info(f"Loaded {len(dataset)} valid media files") |
|
|
|
|
|
output_path = Path(output_dir) |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
audio_output_path = None |
|
|
if with_audio: |
|
|
audio_output_path = Path(audio_output_dir) |
|
|
audio_output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
with console.status(f"[bold]Loading video VAE encoder from [cyan]{model_path}[/]...", spinner="dots"): |
|
|
vae = load_video_vae_encoder(model_path, device=torch_device, dtype=torch.bfloat16) |
|
|
|
|
|
if vae_tiling: |
|
|
vae.enable_tiling() |
|
|
|
|
|
|
|
|
audio_vae_encoder = None |
|
|
audio_processor = None |
|
|
if with_audio: |
|
|
with console.status(f"[bold]Loading audio VAE encoder from [cyan]{model_path}[/]...", spinner="dots"): |
|
|
audio_vae_encoder = load_audio_vae_encoder( |
|
|
checkpoint_path=model_path, |
|
|
device=torch_device, |
|
|
dtype=torch.float32, |
|
|
) |
|
|
|
|
|
audio_processor = AudioProcessor( |
|
|
sample_rate=audio_vae_encoder.sample_rate, |
|
|
mel_bins=audio_vae_encoder.mel_bins, |
|
|
mel_hop_length=audio_vae_encoder.mel_hop_length, |
|
|
n_fft=audio_vae_encoder.n_fft, |
|
|
).to(torch_device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if with_audio and batch_size > 1: |
|
|
logger.warning("Audio processing requires batch_size=1. Overriding batch_size to 1.") |
|
|
batch_size = 1 |
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4) |
|
|
|
|
|
|
|
|
audio_success_count = 0 |
|
|
audio_skip_count = 0 |
|
|
|
|
|
|
|
|
with Progress( |
|
|
SpinnerColumn(), |
|
|
TextColumn("[progress.description]{task.description}"), |
|
|
BarColumn(), |
|
|
TaskProgressColumn(), |
|
|
MofNCompleteColumn(), |
|
|
TimeElapsedColumn(), |
|
|
TimeRemainingColumn(), |
|
|
console=console, |
|
|
) as progress: |
|
|
task = progress.add_task("Processing videos", total=len(dataloader)) |
|
|
|
|
|
for batch in dataloader: |
|
|
|
|
|
video = batch["video"] |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
video_latent_data = encode_video(vae=vae, video=video) |
|
|
|
|
|
|
|
|
for i in range(len(batch["relative_path"])): |
|
|
output_rel_path = Path(batch["main_media_relative_path"][i]).with_suffix(".pt") |
|
|
output_file = output_path / output_rel_path |
|
|
|
|
|
|
|
|
output_file.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
latent_data = { |
|
|
"latents": video_latent_data["latents"][i].cpu().contiguous(), |
|
|
"num_frames": video_latent_data["num_frames"], |
|
|
"height": video_latent_data["height"], |
|
|
"width": video_latent_data["width"], |
|
|
"fps": batch["video_metadata"]["fps"][i].item(), |
|
|
} |
|
|
|
|
|
torch.save(latent_data, output_file) |
|
|
|
|
|
|
|
|
if with_audio: |
|
|
audio_batch = batch.get("audio") |
|
|
if audio_batch is not None: |
|
|
|
|
|
|
|
|
audio_data = { |
|
|
"waveform": audio_batch["waveform"][i], |
|
|
"sample_rate": audio_batch["sample_rate"][i].item(), |
|
|
} |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
audio_latents = encode_audio(audio_vae_encoder, audio_processor, audio_data) |
|
|
|
|
|
|
|
|
audio_output_file = audio_output_path / output_rel_path |
|
|
audio_output_file.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
audio_save_data = { |
|
|
"latents": audio_latents["latents"].cpu().contiguous(), |
|
|
"num_time_steps": audio_latents["num_time_steps"], |
|
|
"frequency_bins": audio_latents["frequency_bins"], |
|
|
"duration": audio_latents["duration"], |
|
|
} |
|
|
|
|
|
torch.save(audio_save_data, audio_output_file) |
|
|
audio_success_count += 1 |
|
|
else: |
|
|
|
|
|
audio_skip_count += 1 |
|
|
|
|
|
progress.advance(task) |
|
|
|
|
|
|
|
|
logger.info(f"Processed {len(dataset)} videos. Latents saved to {output_path}") |
|
|
if with_audio: |
|
|
logger.info( |
|
|
f"Audio processing: {audio_success_count} videos with audio, " |
|
|
f"{audio_skip_count} videos without audio (skipped)" |
|
|
) |
|
|
|
|
|
|
|
|
def encode_video( |
|
|
vae: torch.nn.Module, |
|
|
video: torch.Tensor, |
|
|
dtype: torch.dtype | None = None, |
|
|
) -> dict[str, torch.Tensor | int]: |
|
|
"""Encode video into non-patchified latent representation. |
|
|
|
|
|
Args: |
|
|
vae: Video VAE encoder model |
|
|
video: Input tensor of shape [B, C, F, H, W] (batch, channels, frames, height, width) |
|
|
This is the format expected by the VAE encoder. |
|
|
dtype: Target dtype for output latents |
|
|
|
|
|
Returns: |
|
|
Dict containing non-patchified latents and shape information: |
|
|
{ |
|
|
"latents": Tensor[B, C, F', H', W'], # Non-patchified format with batch dim |
|
|
"num_frames": int, # Latent frame count |
|
|
"height": int, # Latent height |
|
|
"width": int, # Latent width |
|
|
} |
|
|
""" |
|
|
device = next(vae.parameters()).device |
|
|
vae_dtype = next(vae.parameters()).dtype |
|
|
|
|
|
|
|
|
if video.ndim == 4: |
|
|
video = video.unsqueeze(0) |
|
|
|
|
|
video = video.to(device=device, dtype=vae_dtype) |
|
|
|
|
|
|
|
|
latents = vae(video) |
|
|
|
|
|
if dtype is not None: |
|
|
latents = latents.to(dtype=dtype) |
|
|
|
|
|
_, _, num_frames, height, width = latents.shape |
|
|
|
|
|
return { |
|
|
"latents": latents, |
|
|
"num_frames": num_frames, |
|
|
"height": height, |
|
|
"width": width, |
|
|
} |
|
|
|
|
|
|
|
|
def encode_audio( |
|
|
audio_vae_encoder: torch.nn.Module, |
|
|
audio_processor: torch.nn.Module, |
|
|
audio_data: dict[str, torch.Tensor | int], |
|
|
) -> dict[str, torch.Tensor | int | float]: |
|
|
"""Encode audio waveform into latent representation. |
|
|
|
|
|
Args: |
|
|
audio_vae_encoder: Audio VAE encoder model from ltx-core |
|
|
audio_processor: AudioProcessor for waveform-to-spectrogram conversion |
|
|
audio_data: Dict with {"waveform": Tensor[channels, samples], "sample_rate": int} |
|
|
|
|
|
Returns: |
|
|
Dict containing audio latents and shape information: |
|
|
{ |
|
|
"latents": Tensor[C, T, F], # Non-patchified format |
|
|
"num_time_steps": int, |
|
|
"frequency_bins": int, |
|
|
"duration": float, |
|
|
} |
|
|
""" |
|
|
device = next(audio_vae_encoder.parameters()).device |
|
|
dtype = next(audio_vae_encoder.parameters()).dtype |
|
|
|
|
|
waveform = audio_data["waveform"].to(device=device, dtype=dtype) |
|
|
sample_rate = audio_data["sample_rate"] |
|
|
|
|
|
|
|
|
if waveform.dim() == 2: |
|
|
waveform = waveform.unsqueeze(0) |
|
|
|
|
|
|
|
|
duration = waveform.shape[-1] / sample_rate |
|
|
|
|
|
|
|
|
mel_spectrogram = audio_processor.waveform_to_mel(waveform, waveform_sample_rate=sample_rate) |
|
|
mel_spectrogram = mel_spectrogram.to(dtype=dtype) |
|
|
|
|
|
|
|
|
latents = audio_vae_encoder(mel_spectrogram) |
|
|
|
|
|
|
|
|
_, _channels, time_steps, freq_bins = latents.shape |
|
|
|
|
|
return { |
|
|
"latents": latents.squeeze(0), |
|
|
"num_time_steps": time_steps, |
|
|
"frequency_bins": freq_bins, |
|
|
"duration": duration, |
|
|
} |
|
|
|
|
|
|
|
|
def parse_resolution_buckets(resolution_buckets_str: str) -> list[tuple[int, int, int]]: |
|
|
"""Parse resolution buckets from string format to list of tuples (frames, height, width)""" |
|
|
resolution_buckets = [] |
|
|
for bucket_str in resolution_buckets_str.split(";"): |
|
|
w, h, f = map(int, bucket_str.split("x")) |
|
|
|
|
|
if w % VAE_SPATIAL_FACTOR != 0 or h % VAE_SPATIAL_FACTOR != 0: |
|
|
raise typer.BadParameter( |
|
|
f"Width and height must be multiples of {VAE_SPATIAL_FACTOR}, got {w}x{h}", |
|
|
param_hint="resolution-buckets", |
|
|
) |
|
|
|
|
|
if f % VAE_TEMPORAL_FACTOR != 1: |
|
|
raise typer.BadParameter( |
|
|
f"Number of frames must be a multiple of {VAE_TEMPORAL_FACTOR} plus 1, got {f}", |
|
|
param_hint="resolution-buckets", |
|
|
) |
|
|
|
|
|
resolution_buckets.append((f, h, w)) |
|
|
return resolution_buckets |
|
|
|
|
|
|
|
|
@app.command() |
|
|
def main( |
|
|
dataset_file: str = typer.Argument( |
|
|
..., |
|
|
help="Path to metadata file (CSV/JSON/JSONL) containing video paths", |
|
|
), |
|
|
resolution_buckets: str = typer.Option( |
|
|
..., |
|
|
help='Resolution buckets in format "WxHxF;WxHxF;..." (e.g. "768x768x25;512x512x49")', |
|
|
), |
|
|
output_dir: str = typer.Option( |
|
|
..., |
|
|
help="Output directory to save video latents", |
|
|
), |
|
|
model_path: str = typer.Option( |
|
|
..., |
|
|
help="Path to LTX-2 checkpoint (.safetensors file)", |
|
|
), |
|
|
video_column: str = typer.Option( |
|
|
default="media_path", |
|
|
help="Column name in the dataset JSON/JSONL/CSV file containing video paths", |
|
|
), |
|
|
batch_size: int = typer.Option( |
|
|
default=1, |
|
|
help="Batch size for processing", |
|
|
), |
|
|
device: str = typer.Option( |
|
|
default="cuda", |
|
|
help="Device to use for computation", |
|
|
), |
|
|
vae_tiling: bool = typer.Option( |
|
|
default=False, |
|
|
help="Enable VAE tiling for larger video resolutions", |
|
|
), |
|
|
reshape_mode: str = typer.Option( |
|
|
default="center", |
|
|
help="How to crop videos: 'center' or 'random'", |
|
|
), |
|
|
with_audio: bool = typer.Option( |
|
|
default=False, |
|
|
help="Extract and encode audio from video files", |
|
|
), |
|
|
audio_output_dir: str | None = typer.Option( |
|
|
default=None, |
|
|
help="Output directory for audio latents (required if --with-audio is set)", |
|
|
), |
|
|
) -> None: |
|
|
"""Process videos/images and save latent representations for video generation training. |
|
|
|
|
|
This script processes videos and images from metadata files and saves latent representations |
|
|
that can be used for training video generation models. The output latents will maintain |
|
|
the same folder structure and naming as the corresponding media files. |
|
|
|
|
|
Examples: |
|
|
# Process videos from a CSV file |
|
|
python scripts/process_videos.py dataset.csv --resolution-buckets 768x768x25 \\ |
|
|
--output-dir ./latents --model-path /path/to/ltx2.safetensors |
|
|
|
|
|
# Process videos from a JSON file with custom video column |
|
|
python scripts/process_videos.py dataset.json --resolution-buckets 768x768x25 \\ |
|
|
--output-dir ./latents --model-path /path/to/ltx2.safetensors --video-column "video_path" |
|
|
|
|
|
# Enable VAE tiling to save GPU VRAM |
|
|
python scripts/process_videos.py dataset.csv --resolution-buckets 1024x1024x25 \\ |
|
|
--output-dir ./latents --model-path /path/to/ltx2.safetensors --vae-tiling |
|
|
|
|
|
# Process videos with audio |
|
|
python scripts/process_videos.py dataset.csv --resolution-buckets 768x768x25 \\ |
|
|
--output-dir ./latents --model-path /path/to/ltx2.safetensors \\ |
|
|
--with-audio --audio-output-dir ./audio_latents |
|
|
""" |
|
|
|
|
|
|
|
|
if not Path(dataset_file).is_file(): |
|
|
raise typer.BadParameter(f"Dataset file not found: {dataset_file}") |
|
|
|
|
|
|
|
|
if with_audio and audio_output_dir is None: |
|
|
raise typer.BadParameter("--audio-output-dir is required when --with-audio is set") |
|
|
|
|
|
|
|
|
parsed_resolution_buckets = parse_resolution_buckets(resolution_buckets) |
|
|
|
|
|
if len(parsed_resolution_buckets) > 1: |
|
|
logger.warning( |
|
|
"Using multiple resolution buckets. " |
|
|
"When training with multiple resolution buckets, you must use a batch size of 1." |
|
|
) |
|
|
|
|
|
|
|
|
compute_latents( |
|
|
dataset_file=dataset_file, |
|
|
video_column=video_column, |
|
|
resolution_buckets=parsed_resolution_buckets, |
|
|
output_dir=output_dir, |
|
|
model_path=model_path, |
|
|
reshape_mode=reshape_mode, |
|
|
batch_size=batch_size, |
|
|
device=device, |
|
|
vae_tiling=vae_tiling, |
|
|
with_audio=with_audio, |
|
|
audio_output_dir=audio_output_dir, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
app() |
|
|
|