|
|
|
|
|
|
|
|
""" |
|
|
Compute text embeddings for video generation training. |
|
|
|
|
|
This module provides functionality for processing text captions, including: |
|
|
- Loading captions from various file formats (CSV, JSON, JSONL) |
|
|
- Cleaning and preprocessing text (removing LLM prefixes, adding ID tokens) |
|
|
- CaptionsDataset for caption-only preprocessing workflows |
|
|
|
|
|
Can be used as a standalone script: |
|
|
python scripts/process_captions.py dataset.json --output-dir /path/to/output \ |
|
|
--model-source /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma |
|
|
""" |
|
|
|
|
|
import json |
|
|
import os |
|
|
from pathlib import Path |
|
|
from typing import Any |
|
|
|
|
|
import pandas as pd |
|
|
import torch |
|
|
import typer |
|
|
from rich.console import Console |
|
|
from rich.progress import ( |
|
|
BarColumn, |
|
|
MofNCompleteColumn, |
|
|
Progress, |
|
|
SpinnerColumn, |
|
|
TaskProgressColumn, |
|
|
TextColumn, |
|
|
TimeElapsedColumn, |
|
|
TimeRemainingColumn, |
|
|
) |
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
from transformers.utils.logging import disable_progress_bar |
|
|
|
|
|
from ltx_trainer import logger |
|
|
from ltx_trainer.model_loader import load_text_encoder |
|
|
|
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
disable_progress_bar() |
|
|
|
|
|
|
|
|
COMMON_BEGINNING_PHRASES: tuple[str, ...] = ( |
|
|
"This video", |
|
|
"The video", |
|
|
"This clip", |
|
|
"The clip", |
|
|
"The animation", |
|
|
"This image", |
|
|
"The image", |
|
|
"This picture", |
|
|
"The picture", |
|
|
) |
|
|
|
|
|
COMMON_CONTINUATION_WORDS: tuple[str, ...] = ( |
|
|
"shows", |
|
|
"depicts", |
|
|
"features", |
|
|
"captures", |
|
|
"highlights", |
|
|
"introduces", |
|
|
"presents", |
|
|
) |
|
|
|
|
|
COMMON_LLM_START_PHRASES: tuple[str, ...] = ( |
|
|
"In the video,", |
|
|
"In this video,", |
|
|
"In this video clip,", |
|
|
"In the clip,", |
|
|
"Caption:", |
|
|
*( |
|
|
f"{beginning} {continuation}" |
|
|
for beginning in COMMON_BEGINNING_PHRASES |
|
|
for continuation in COMMON_CONTINUATION_WORDS |
|
|
), |
|
|
) |
|
|
|
|
|
app = typer.Typer( |
|
|
pretty_exceptions_enable=False, |
|
|
no_args_is_help=True, |
|
|
help="Process text captions and save embeddings for video generation training.", |
|
|
) |
|
|
|
|
|
|
|
|
class CaptionsDataset(Dataset): |
|
|
""" |
|
|
Dataset for processing text captions only. |
|
|
|
|
|
This dataset is designed for caption preprocessing workflows where you only need |
|
|
to process text without loading videos. Useful for: |
|
|
- Precomputing text embeddings |
|
|
- Caption cleaning and preprocessing |
|
|
- Text-only preprocessing pipelines |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dataset_file: str | Path, |
|
|
caption_column: str, |
|
|
media_column: str = "media_path", |
|
|
lora_trigger: str | None = None, |
|
|
remove_llm_prefixes: bool = False, |
|
|
) -> None: |
|
|
""" |
|
|
Initialize the captions dataset. |
|
|
|
|
|
Args: |
|
|
dataset_file: Path to CSV/JSON/JSONL metadata file |
|
|
caption_column: Column name for captions in the metadata file |
|
|
media_column: Column name for media paths (used for output naming) |
|
|
lora_trigger: Optional trigger word to prepend to each caption |
|
|
remove_llm_prefixes: Whether to remove common LLM-generated prefixes |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.dataset_file = Path(dataset_file) |
|
|
self.caption_column = caption_column |
|
|
self.media_column = media_column |
|
|
self.lora_trigger = f"{lora_trigger.strip()} " if lora_trigger else "" |
|
|
|
|
|
|
|
|
self.caption_data = self._load_caption_data() |
|
|
|
|
|
|
|
|
self.output_paths = list(self.caption_data.keys()) |
|
|
self.prompts = list(self.caption_data.values()) |
|
|
|
|
|
|
|
|
if remove_llm_prefixes: |
|
|
self._clean_llm_prefixes() |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.prompts) |
|
|
|
|
|
def __getitem__(self, index: int) -> dict[str, Any]: |
|
|
"""Get a single caption with optional trigger word prepended and output path.""" |
|
|
prompt = self.lora_trigger + self.prompts[index] |
|
|
return { |
|
|
"prompt": prompt, |
|
|
"output_path": self.output_paths[index], |
|
|
"index": index, |
|
|
} |
|
|
|
|
|
def _load_caption_data(self) -> dict[str, str]: |
|
|
"""Load captions and compute their output embedding paths.""" |
|
|
if self.dataset_file.suffix == ".csv": |
|
|
return self._load_caption_data_from_csv() |
|
|
elif self.dataset_file.suffix == ".json": |
|
|
return self._load_caption_data_from_json() |
|
|
elif self.dataset_file.suffix == ".jsonl": |
|
|
return self._load_caption_data_from_jsonl() |
|
|
else: |
|
|
raise ValueError("Expected `dataset_file` to be a path to a CSV, JSON, or JSONL file.") |
|
|
|
|
|
def _load_caption_data_from_csv(self) -> dict[str, str]: |
|
|
"""Load captions from a CSV file and compute output embedding paths.""" |
|
|
df = pd.read_csv(self.dataset_file) |
|
|
|
|
|
if self.caption_column not in df.columns: |
|
|
raise ValueError(f"Column '{self.caption_column}' not found in CSV file") |
|
|
if self.media_column not in df.columns: |
|
|
raise ValueError(f"Column '{self.media_column}' not found in CSV file") |
|
|
|
|
|
caption_data = {} |
|
|
for _, row in df.iterrows(): |
|
|
media_path = Path(row[self.media_column].strip()) |
|
|
|
|
|
output_path = str(media_path.with_suffix(".pt")) |
|
|
caption_data[output_path] = row[self.caption_column] |
|
|
|
|
|
return caption_data |
|
|
|
|
|
def _load_caption_data_from_json(self) -> dict[str, str]: |
|
|
"""Load captions from a JSON file and compute output embedding paths.""" |
|
|
with open(self.dataset_file, "r", encoding="utf-8") as file: |
|
|
data = json.load(file) |
|
|
|
|
|
if not isinstance(data, list): |
|
|
raise ValueError("JSON file must contain a list of objects") |
|
|
|
|
|
caption_data = {} |
|
|
for entry in data: |
|
|
if self.caption_column not in entry: |
|
|
raise ValueError(f"Key '{self.caption_column}' not found in JSON entry: {entry}") |
|
|
if self.media_column not in entry: |
|
|
raise ValueError(f"Key '{self.media_column}' not found in JSON entry: {entry}") |
|
|
|
|
|
media_path = Path(entry[self.media_column].strip()) |
|
|
|
|
|
output_path = str(media_path.with_suffix(".pt")) |
|
|
caption_data[output_path] = entry[self.caption_column] |
|
|
|
|
|
return caption_data |
|
|
|
|
|
def _load_caption_data_from_jsonl(self) -> dict[str, str]: |
|
|
"""Load captions from a JSONL file and compute output embedding paths.""" |
|
|
caption_data = {} |
|
|
with open(self.dataset_file, "r", encoding="utf-8") as file: |
|
|
for line in file: |
|
|
entry = json.loads(line) |
|
|
if self.caption_column not in entry: |
|
|
raise ValueError(f"Key '{self.caption_column}' not found in JSONL entry: {entry}") |
|
|
if self.media_column not in entry: |
|
|
raise ValueError(f"Key '{self.media_column}' not found in JSONL entry: {entry}") |
|
|
|
|
|
media_path = Path(entry[self.media_column].strip()) |
|
|
|
|
|
output_path = str(media_path.with_suffix(".pt")) |
|
|
caption_data[output_path] = entry[self.caption_column] |
|
|
|
|
|
return caption_data |
|
|
|
|
|
def _clean_llm_prefixes(self) -> None: |
|
|
"""Remove common LLM-generated prefixes from captions.""" |
|
|
for i in range(len(self.prompts)): |
|
|
self.prompts[i] = self.prompts[i].strip() |
|
|
for phrase in COMMON_LLM_START_PHRASES: |
|
|
if self.prompts[i].startswith(phrase): |
|
|
self.prompts[i] = self.prompts[i].removeprefix(phrase).strip() |
|
|
break |
|
|
|
|
|
|
|
|
def compute_captions_embeddings( |
|
|
dataset_file: str | Path, |
|
|
output_dir: str, |
|
|
model_path: str, |
|
|
text_encoder_path: str, |
|
|
caption_column: str = "caption", |
|
|
media_column: str = "media_path", |
|
|
lora_trigger: str | None = None, |
|
|
remove_llm_prefixes: bool = False, |
|
|
batch_size: int = 8, |
|
|
device: str = "cuda", |
|
|
) -> None: |
|
|
""" |
|
|
Process captions and save text embeddings. |
|
|
|
|
|
Args: |
|
|
dataset_file: Path to metadata file (CSV/JSON/JSONL) containing captions and media paths |
|
|
output_dir: Directory to save embeddings |
|
|
model_path: Path to LTX-2 checkpoint (.safetensors) |
|
|
text_encoder_path: Path to Gemma text encoder directory |
|
|
caption_column: Column name containing captions in the metadata file |
|
|
media_column: Column name containing media paths (used for output naming) |
|
|
lora_trigger: Optional trigger word to prepend to each caption |
|
|
remove_llm_prefixes: Whether to remove common LLM-generated prefixes |
|
|
batch_size: Batch size for processing |
|
|
device: Device to use for computation |
|
|
""" |
|
|
|
|
|
console = Console() |
|
|
|
|
|
|
|
|
dataset = CaptionsDataset( |
|
|
dataset_file=dataset_file, |
|
|
caption_column=caption_column, |
|
|
media_column=media_column, |
|
|
lora_trigger=lora_trigger, |
|
|
remove_llm_prefixes=remove_llm_prefixes, |
|
|
) |
|
|
logger.info(f"Loaded {len(dataset):,} captions") |
|
|
|
|
|
output_path = Path(output_dir) |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
with console.status("[bold]Loading Gemma text encoder...", spinner="dots"): |
|
|
text_encoder = load_text_encoder(model_path, text_encoder_path, device=device, dtype=torch.bfloat16) |
|
|
|
|
|
logger.info("Text encoder loaded successfully") |
|
|
|
|
|
|
|
|
if batch_size > 1: |
|
|
logger.warning( |
|
|
"Batch size greater than 1 is not currently supported with the Gemma tokenizer. " |
|
|
"Overriding batch_size to 1. This will be fixed in a future update." |
|
|
) |
|
|
batch_size = 1 |
|
|
|
|
|
|
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2) |
|
|
|
|
|
|
|
|
total_batches = len(dataloader) |
|
|
logger.info(f"Processing captions in {total_batches:,} batches...") |
|
|
|
|
|
with Progress( |
|
|
SpinnerColumn(), |
|
|
TextColumn("[progress.description]{task.description}"), |
|
|
BarColumn(), |
|
|
TaskProgressColumn(), |
|
|
MofNCompleteColumn(), |
|
|
TimeElapsedColumn(), |
|
|
TimeRemainingColumn(), |
|
|
console=console, |
|
|
) as progress: |
|
|
task = progress.add_task("Processing captions", total=len(dataloader)) |
|
|
for batch in dataloader: |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
|
|
|
|
|
|
|
|
|
for i in range(len(batch["prompt"])): |
|
|
prompt_embeds, prompt_attention_mask = text_encoder._preprocess_text( |
|
|
batch["prompt"][i], padding_side="left" |
|
|
) |
|
|
|
|
|
output_rel_path = Path(batch["output_path"][i]) |
|
|
|
|
|
|
|
|
output_dir_path = output_path / output_rel_path.parent |
|
|
output_dir_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
embedding_data = { |
|
|
"prompt_embeds": prompt_embeds[0].cpu().contiguous(), |
|
|
"prompt_attention_mask": prompt_attention_mask[0].cpu().contiguous(), |
|
|
} |
|
|
|
|
|
output_file = output_path / output_rel_path |
|
|
torch.save(embedding_data, output_file) |
|
|
|
|
|
progress.advance(task) |
|
|
|
|
|
logger.info(f"Processed {len(dataset):,} captions. Embeddings saved to {output_path}") |
|
|
|
|
|
|
|
|
@app.command() |
|
|
def main( |
|
|
dataset_file: str = typer.Argument( |
|
|
..., |
|
|
help="Path to metadata file (CSV/JSON/JSONL) containing captions and media paths", |
|
|
), |
|
|
output_dir: str = typer.Option( |
|
|
..., |
|
|
help="Output directory to save text embeddings", |
|
|
), |
|
|
model_path: str = typer.Option( |
|
|
..., |
|
|
help="Path to LTX-2 checkpoint (.safetensors file)", |
|
|
), |
|
|
text_encoder_path: str = typer.Option( |
|
|
..., |
|
|
help="Path to Gemma text encoder directory", |
|
|
), |
|
|
caption_column: str = typer.Option( |
|
|
default="caption", |
|
|
help="Column name containing captions in the dataset JSON/JSONL/CSV file", |
|
|
), |
|
|
media_column: str = typer.Option( |
|
|
default="media_path", |
|
|
help="Column name in the dataset JSON/JSONL/CSV file containing media paths " |
|
|
"(used for output file naming and folder structure)", |
|
|
), |
|
|
batch_size: int = typer.Option( |
|
|
default=8, |
|
|
help="Batch size for processing", |
|
|
), |
|
|
device: str = typer.Option( |
|
|
default="cuda", |
|
|
help="Device to use for computation", |
|
|
), |
|
|
lora_trigger: str | None = typer.Option( |
|
|
default=None, |
|
|
help="Optional trigger word to prepend to each caption (activates the LoRA during inference)", |
|
|
), |
|
|
remove_llm_prefixes: bool = typer.Option( |
|
|
default=False, |
|
|
help="Remove common LLM-generated prefixes from captions", |
|
|
), |
|
|
) -> None: |
|
|
"""Process text captions and save embeddings for video generation training. |
|
|
|
|
|
This script processes captions from metadata files and saves text embeddings |
|
|
that can be used for training video generation models. The output embeddings |
|
|
will maintain the same folder structure and naming as the corresponding media files. |
|
|
|
|
|
Note: This script is designed for LTX-2 models which use the Gemma text encoder. |
|
|
|
|
|
Examples: |
|
|
# Process captions with LTX-2 model |
|
|
python scripts/process_captions.py dataset.json --output-dir ./embeddings \\ |
|
|
--model-path /path/to/ltx2_checkpoint.safetensors \\ |
|
|
--text-encoder-path /path/to/gemma |
|
|
|
|
|
# Add a trigger word for LoRA training |
|
|
python scripts/process_captions.py dataset.json --output-dir ./embeddings \\ |
|
|
--model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma \\ |
|
|
--lora-trigger "mytoken" |
|
|
|
|
|
# Remove LLM-generated prefixes from captions |
|
|
python scripts/process_captions.py dataset.json --output-dir ./embeddings \\ |
|
|
--model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma \\ |
|
|
--remove-llm-prefixes |
|
|
""" |
|
|
|
|
|
|
|
|
if not Path(dataset_file).is_file(): |
|
|
raise typer.BadParameter(f"Dataset file not found: {dataset_file}") |
|
|
|
|
|
if lora_trigger: |
|
|
logger.info(f'LoRA trigger word "{lora_trigger}" will be prepended to all captions') |
|
|
|
|
|
|
|
|
compute_captions_embeddings( |
|
|
dataset_file=dataset_file, |
|
|
output_dir=output_dir, |
|
|
model_path=model_path, |
|
|
text_encoder_path=text_encoder_path, |
|
|
caption_column=caption_column, |
|
|
media_column=media_column, |
|
|
lora_trigger=lora_trigger, |
|
|
remove_llm_prefixes=remove_llm_prefixes, |
|
|
batch_size=batch_size, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
app() |
|
|
|