|
|
|
|
|
|
|
|
""" |
|
|
Preprocess a video dataset by computing video clips latents and text captions embeddings. |
|
|
|
|
|
This script provides a command-line interface for preprocessing video datasets by computing |
|
|
latent representations of video clips and text embeddings of their captions. The preprocessed |
|
|
data can be used to accelerate training of video generation models and to save GPU memory. |
|
|
|
|
|
Basic usage: |
|
|
python scripts/process_dataset.py /path/to/dataset.json --resolution-buckets 768x768x49 \ |
|
|
--model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma |
|
|
|
|
|
The dataset must be a CSV, JSON, or JSONL file with columns for captions and video paths. |
|
|
""" |
|
|
|
|
|
from pathlib import Path |
|
|
|
|
|
import typer |
|
|
from decode_latents import LatentsDecoder |
|
|
from process_captions import compute_captions_embeddings |
|
|
from process_videos import compute_latents, parse_resolution_buckets |
|
|
from rich.console import Console |
|
|
|
|
|
from ltx_trainer import logger |
|
|
|
|
|
console = Console() |
|
|
app = typer.Typer( |
|
|
pretty_exceptions_enable=False, |
|
|
no_args_is_help=True, |
|
|
help="Preprocess a video dataset by computing video clips latents and text captions embeddings. " |
|
|
"The dataset must be a CSV, JSON, or JSONL file with columns for captions and video paths.", |
|
|
) |
|
|
|
|
|
|
|
|
def preprocess_dataset( |
|
|
dataset_file: str, |
|
|
caption_column: str, |
|
|
video_column: str, |
|
|
resolution_buckets: list[tuple[int, int, int]], |
|
|
batch_size: int, |
|
|
output_dir: str | None, |
|
|
lora_trigger: str | None, |
|
|
vae_tiling: bool, |
|
|
decode: bool, |
|
|
model_path: str, |
|
|
text_encoder_path: str, |
|
|
device: str, |
|
|
remove_llm_prefixes: bool = False, |
|
|
reference_column: str | None = None, |
|
|
with_audio: bool = False, |
|
|
) -> None: |
|
|
"""Run the preprocessing pipeline with the given arguments.""" |
|
|
|
|
|
_validate_dataset_file(dataset_file) |
|
|
|
|
|
|
|
|
output_base = Path(output_dir) if output_dir else Path(dataset_file).parent / ".precomputed" |
|
|
conditions_dir = output_base / "conditions" |
|
|
latents_dir = output_base / "latents" |
|
|
|
|
|
if lora_trigger: |
|
|
logger.info(f'LoRA trigger word "{lora_trigger}" will be prepended to all captions') |
|
|
|
|
|
|
|
|
compute_captions_embeddings( |
|
|
dataset_file=dataset_file, |
|
|
output_dir=str(conditions_dir), |
|
|
model_path=model_path, |
|
|
text_encoder_path=text_encoder_path, |
|
|
caption_column=caption_column, |
|
|
media_column=video_column, |
|
|
lora_trigger=lora_trigger, |
|
|
remove_llm_prefixes=remove_llm_prefixes, |
|
|
batch_size=batch_size, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
|
|
|
audio_latents_dir = None |
|
|
if with_audio: |
|
|
logger.info("Audio preprocessing enabled - will extract and encode audio from videos") |
|
|
audio_latents_dir = output_base / "audio_latents" |
|
|
|
|
|
compute_latents( |
|
|
dataset_file=dataset_file, |
|
|
video_column=video_column, |
|
|
resolution_buckets=resolution_buckets, |
|
|
output_dir=str(latents_dir), |
|
|
model_path=model_path, |
|
|
batch_size=batch_size, |
|
|
device=device, |
|
|
vae_tiling=vae_tiling, |
|
|
with_audio=with_audio, |
|
|
audio_output_dir=str(audio_latents_dir) if audio_latents_dir else None, |
|
|
) |
|
|
|
|
|
|
|
|
if reference_column: |
|
|
logger.info("Processing reference videos for IC-LoRA training...") |
|
|
reference_latents_dir = output_base / "reference_latents" |
|
|
|
|
|
compute_latents( |
|
|
dataset_file=dataset_file, |
|
|
main_media_column=video_column, |
|
|
video_column=reference_column, |
|
|
resolution_buckets=resolution_buckets, |
|
|
output_dir=str(reference_latents_dir), |
|
|
model_path=model_path, |
|
|
batch_size=batch_size, |
|
|
device=device, |
|
|
vae_tiling=vae_tiling, |
|
|
) |
|
|
|
|
|
|
|
|
if decode: |
|
|
logger.info("Decoding latents for verification...") |
|
|
|
|
|
decoder = LatentsDecoder( |
|
|
model_path=model_path, |
|
|
device=device, |
|
|
vae_tiling=vae_tiling, |
|
|
with_audio=with_audio, |
|
|
) |
|
|
decoder.decode(latents_dir, output_base / "decoded_videos") |
|
|
|
|
|
|
|
|
if reference_column: |
|
|
reference_latents_dir = output_base / "reference_latents" |
|
|
if reference_latents_dir.exists(): |
|
|
logger.info("Decoding reference videos...") |
|
|
decoder.decode(reference_latents_dir, output_base / "decoded_reference_videos") |
|
|
|
|
|
|
|
|
if with_audio and audio_latents_dir and audio_latents_dir.exists(): |
|
|
logger.info("Decoding audio latents...") |
|
|
decoder.decode_audio(audio_latents_dir, output_base / "decoded_audio") |
|
|
|
|
|
|
|
|
logger.info(f"Dataset preprocessing complete! Results saved to {output_base}") |
|
|
if reference_column: |
|
|
logger.info("Reference videos processed and saved to reference_latents/ directory for IC-LoRA training") |
|
|
if with_audio: |
|
|
logger.info("Audio latents saved to audio_latents/ directory for audio-video training") |
|
|
|
|
|
|
|
|
def _validate_dataset_file(dataset_path: str) -> None: |
|
|
"""Validate that the dataset file exists and has the correct format.""" |
|
|
dataset_file = Path(dataset_path) |
|
|
|
|
|
if not dataset_file.exists(): |
|
|
raise FileNotFoundError(f"Dataset file does not exist: {dataset_file}") |
|
|
|
|
|
if not dataset_file.is_file(): |
|
|
raise ValueError(f"Dataset path must be a file, not a directory: {dataset_file}") |
|
|
|
|
|
if dataset_file.suffix.lower() not in [".csv", ".json", ".jsonl"]: |
|
|
raise ValueError(f"Dataset file must be CSV, JSON, or JSONL format: {dataset_file}") |
|
|
|
|
|
|
|
|
@app.command() |
|
|
def main( |
|
|
dataset_path: str = typer.Argument( |
|
|
..., |
|
|
help="Path to metadata file (CSV/JSON/JSONL) containing captions and video paths", |
|
|
), |
|
|
resolution_buckets: str = typer.Option( |
|
|
..., |
|
|
help='Resolution buckets in format "WxHxF;WxHxF;..." (e.g. "768x768x25;512x512x49")', |
|
|
), |
|
|
model_path: str = typer.Option( |
|
|
..., |
|
|
help="Path to LTX-2 checkpoint (.safetensors file)", |
|
|
), |
|
|
text_encoder_path: str = typer.Option( |
|
|
..., |
|
|
help="Path to Gemma text encoder directory", |
|
|
), |
|
|
caption_column: str = typer.Option( |
|
|
default="caption", |
|
|
help="Column name containing captions in the dataset JSON/JSONL/CSV file", |
|
|
), |
|
|
video_column: str = typer.Option( |
|
|
default="media_path", |
|
|
help="Column name containing video paths in the dataset JSON/JSONL/CSV file", |
|
|
), |
|
|
batch_size: int = typer.Option( |
|
|
default=1, |
|
|
help="Batch size for preprocessing", |
|
|
), |
|
|
device: str = typer.Option( |
|
|
default="cuda", |
|
|
help="Device to use for computation", |
|
|
), |
|
|
vae_tiling: bool = typer.Option( |
|
|
default=False, |
|
|
help="Enable VAE tiling for larger video resolutions", |
|
|
), |
|
|
output_dir: str | None = typer.Option( |
|
|
default=None, |
|
|
help="Output directory (defaults to .precomputed in dataset directory)", |
|
|
), |
|
|
lora_trigger: str | None = typer.Option( |
|
|
default=None, |
|
|
help="Optional trigger word to prepend to each caption (activates the LoRA during inference)", |
|
|
), |
|
|
decode: bool = typer.Option( |
|
|
default=False, |
|
|
help="Decode and save latents after encoding (videos and audio) for verification", |
|
|
), |
|
|
remove_llm_prefixes: bool = typer.Option( |
|
|
default=False, |
|
|
help="Remove LLM prefixes from captions", |
|
|
), |
|
|
reference_column: str | None = typer.Option( |
|
|
default=None, |
|
|
help="Column name containing reference video paths (for video-to-video training)", |
|
|
), |
|
|
with_audio: bool = typer.Option( |
|
|
default=False, |
|
|
help="Extract and encode audio from video files", |
|
|
), |
|
|
) -> None: |
|
|
"""Preprocess a video dataset by computing and saving latents and text embeddings. |
|
|
|
|
|
The dataset must be a CSV, JSON, or JSONL file with columns for captions and video paths. |
|
|
This script is designed for LTX-2 models which use the Gemma text encoder. |
|
|
|
|
|
Examples: |
|
|
# Process a dataset with LTX-2 model |
|
|
python scripts/process_dataset.py dataset.json --resolution-buckets 768x768x25 \\ |
|
|
--model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma |
|
|
|
|
|
# Process dataset with custom column names |
|
|
python scripts/process_dataset.py dataset.json --resolution-buckets 768x768x25 \\ |
|
|
--model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma \\ |
|
|
--caption-column "text" --video-column "video_path" |
|
|
|
|
|
# Process dataset with reference videos for IC-LoRA training |
|
|
python scripts/process_dataset.py dataset.json --resolution-buckets 768x768x25 \\ |
|
|
--model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma \\ |
|
|
--reference-column "reference_path" |
|
|
|
|
|
# Process dataset with audio for audio-video training |
|
|
python scripts/process_dataset.py dataset.json --resolution-buckets 768x512x97 \\ |
|
|
--model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma \\ |
|
|
--with-audio |
|
|
""" |
|
|
parsed_resolution_buckets = parse_resolution_buckets(resolution_buckets) |
|
|
|
|
|
if len(parsed_resolution_buckets) > 1: |
|
|
logger.warning( |
|
|
"Using multiple resolution buckets. " |
|
|
"When training with multiple resolution buckets, you must use a batch size of 1." |
|
|
) |
|
|
|
|
|
preprocess_dataset( |
|
|
dataset_file=dataset_path, |
|
|
caption_column=caption_column, |
|
|
video_column=video_column, |
|
|
resolution_buckets=parsed_resolution_buckets, |
|
|
batch_size=batch_size, |
|
|
output_dir=output_dir, |
|
|
lora_trigger=lora_trigger, |
|
|
vae_tiling=vae_tiling, |
|
|
decode=decode, |
|
|
model_path=model_path, |
|
|
text_encoder_path=text_encoder_path, |
|
|
device=device, |
|
|
remove_llm_prefixes=remove_llm_prefixes, |
|
|
reference_column=reference_column, |
|
|
with_audio=with_audio, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
app() |
|
|
|