ltx-2 / packages /ltx-trainer /scripts /caption_videos.py
linoy
inital commit
ebfc6b3
#!/usr/bin/env python3
"""
Auto-caption videos with audio using multimodal models.
This script provides a command-line interface for generating captions for videos
(including audio) using multimodal models. It supports:
- Qwen2.5-Omni: Local model for audio-visual captioning (default)
- Gemini Flash: Cloud-based API for audio-visual captioning
The paths to videos in the generated dataset/captions file will be RELATIVE to the
directory where the output file is stored. This makes the dataset more portable and
easier to use in different environments.
Basic usage:
# Caption a single video (includes audio by default)
caption_videos.py video.mp4 --output captions.json
# Caption all videos in a directory
caption_videos.py videos_dir/ --output captions.csv
# Caption with custom instruction
caption_videos.py video.mp4 --instruction "Describe what happens in this video in detail."
Advanced usage:
# Use Gemini Flash API (requires GEMINI_API_KEY or GOOGLE_API_KEY env var)
caption_videos.py videos_dir/ --captioner-type gemini_flash
# Disable audio processing (video-only captions)
caption_videos.py videos_dir/ --no-audio
# Process videos with specific extensions and save as JSON
caption_videos.py videos_dir/ --extensions mp4,mov,avi --output captions.json
"""
import csv
import json
from enum import Enum
from pathlib import Path
import torch
import typer
from rich.console import Console
from rich.progress import (
BarColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
from transformers.utils.logging import disable_progress_bar
from ltx_trainer.captioning import (
CaptionerType,
MediaCaptioningModel,
create_captioner,
)
VIDEO_EXTENSIONS = ["mp4", "avi", "mov", "mkv", "webm"]
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png"]
MEDIA_EXTENSIONS = VIDEO_EXTENSIONS + IMAGE_EXTENSIONS
SAVE_INTERVAL = 5
console = Console()
app = typer.Typer(
pretty_exceptions_enable=False,
no_args_is_help=True,
help="Auto-caption videos with audio using multimodal models.",
)
disable_progress_bar()
class OutputFormat(str, Enum):
"""Available output formats for captions."""
TXT = "txt" # Separate files for captions and video paths, one caption / video path per line
CSV = "csv" # CSV file with video path and caption columns
JSON = "json" # JSON file with video paths as keys and captions as values
JSONL = "jsonl" # JSON Lines file with one JSON object per line
def caption_media(
input_path: Path,
output_path: Path,
captioner: MediaCaptioningModel,
extensions: list[str],
recursive: bool,
fps: int,
include_audio: bool,
clean_caption: bool,
output_format: OutputFormat,
override: bool,
) -> None:
"""Caption videos and images using the provided captioning model.
Args:
input_path: Path to input video file or directory
output_path: Path to output caption file
captioner: Media captioning model
extensions: List of media file extensions to include
recursive: Whether to search subdirectories recursively
fps: Frames per second to sample from videos (ignored for images)
include_audio: Whether to include audio in captioning
clean_caption: Whether to clean up captions
output_format: Format to save the captions in
override: Whether to override existing captions
"""
# Get list of media files to process
media_files = _get_media_files(input_path, extensions, recursive)
if not media_files:
console.print("[bold yellow]No media files found to process.[/]")
return
console.print(f"Found [bold]{len(media_files)}[/] media files to process.")
# Load existing captions and determine which files need processing
base_dir = output_path.parent.resolve()
existing_captions = _load_existing_captions(output_path, output_format)
existing_abs_paths = {str((base_dir / p).resolve()) for p in existing_captions}
if override:
media_to_process = media_files
else:
media_to_process = [f for f in media_files if str(f.resolve()) not in existing_abs_paths]
if skipped := len(media_files) - len(media_to_process):
console.print(f"[bold yellow]Skipping {skipped} media that already have captions.[/]")
if not media_to_process:
console.print("[bold yellow]All media already have captions. Use --override to recaption.[/]")
return
# Process media files
captions = existing_captions.copy()
successfully_captioned = 0
progress = Progress(
SpinnerColumn(),
TextColumn("{task.description}"),
BarColumn(bar_width=40),
MofNCompleteColumn(),
TimeElapsedColumn(),
TextColumn("•"),
TimeRemainingColumn(),
console=console,
)
with progress:
task = progress.add_task("Captioning", total=len(media_to_process))
for i, media_file in enumerate(media_to_process):
progress.update(task, description=f"Captioning [bold blue]{media_file.name}[/]")
try:
# Generate caption for the media
caption = captioner.caption(
path=media_file,
fps=fps,
include_audio=include_audio,
clean_caption=clean_caption,
)
# Convert absolute path to relative path (relative to the output file's directory)
rel_path = str(media_file.resolve().relative_to(base_dir))
# Store the caption with the relative path as key
captions[rel_path] = caption
successfully_captioned += 1
except Exception as e:
console.print(f"[bold red]Error captioning {media_file}: {e}[/]")
if i % SAVE_INTERVAL == 0:
_save_captions(captions, output_path, output_format)
# Advance progress bar
progress.advance(task)
# Save captions to file
_save_captions(captions, output_path, output_format)
# Print summary
console.print(
f"[bold green]✓[/] Captioned [bold]{successfully_captioned}/{len(media_to_process)}[/] media successfully.",
)
def _get_media_files(
input_path: Path,
extensions: list[str] = MEDIA_EXTENSIONS,
recursive: bool = False,
) -> list[Path]:
"""Get all media files from the input path."""
input_path = Path(input_path)
# Normalize extensions to lowercase without dots
extensions = [ext.lower().lstrip(".") for ext in extensions]
if input_path.is_file():
# If input is a file, check if it has a valid extension
if input_path.suffix.lstrip(".").lower() in extensions:
return [input_path]
else:
typer.echo(f"Warning: {input_path} is not a recognized media file. Skipping.")
return []
elif input_path.is_dir():
# If input is a directory, find all media files
media_files = []
# Define the glob pattern based on whether we're searching recursively
glob_pattern = "**/*" if recursive else "*"
# Find all files with the specified extensions
for ext in extensions:
media_files.extend(input_path.glob(f"{glob_pattern}.{ext}"))
return sorted(media_files)
else:
typer.echo(f"Error: {input_path} does not exist.")
raise typer.Exit(code=1)
def _save_captions(
captions: dict[str, str],
output_path: Path,
format_type: OutputFormat,
) -> None:
"""Save captions to a file in the specified format.
Args:
captions: Dictionary mapping media paths to captions
output_path: Path to save the output file
format_type: Format to save the captions in
"""
# Create parent directories if they don't exist
output_path.parent.mkdir(parents=True, exist_ok=True)
console.print("[bold blue]Saving captions...[/]")
match format_type:
case OutputFormat.TXT:
# Create two separate files for captions and media paths
captions_file = output_path.with_stem(f"{output_path.stem}_captions")
paths_file = output_path.with_stem(f"{output_path.stem}_paths")
with captions_file.open("w", encoding="utf-8") as f:
for caption in captions.values():
f.write(f"{caption}\n")
with paths_file.open("w", encoding="utf-8") as f:
for media_path in captions:
f.write(f"{media_path}\n")
console.print(f"[bold green]✓[/] Captions saved to [cyan]{captions_file}[/]")
console.print(f"[bold green]✓[/] Media paths saved to [cyan]{paths_file}[/]")
case OutputFormat.CSV:
with output_path.open("w", encoding="utf-8", newline="") as f:
writer = csv.writer(f)
writer.writerow(["caption", "media_path"])
for media_path, caption in captions.items():
writer.writerow([caption, media_path])
console.print(f"[bold green]✓[/] Captions saved to [cyan]{output_path}[/]")
case OutputFormat.JSON:
# Format as list of dictionaries with caption and media_path keys
json_data = [{"caption": caption, "media_path": media_path} for media_path, caption in captions.items()]
with output_path.open("w", encoding="utf-8") as f:
json.dump(json_data, f, indent=2, ensure_ascii=False)
console.print(f"[bold green]✓[/] Captions saved to [cyan]{output_path}[/]")
case OutputFormat.JSONL:
with output_path.open("w", encoding="utf-8") as f:
for media_path, caption in captions.items():
f.write(json.dumps({"caption": caption, "media_path": media_path}, ensure_ascii=False) + "\n")
console.print(f"[bold green]✓[/] Captions saved to [cyan]{output_path}[/]")
case _:
raise ValueError(f"Unsupported output format: {format_type}")
def _load_existing_captions( # noqa: PLR0912
output_path: Path,
format_type: OutputFormat,
) -> dict[str, str]:
"""Load existing captions from a file.
Args:
output_path: Path to the captions file
format_type: Format of the captions file
Returns:
Dictionary mapping media paths to captions, or empty dict if file doesn't exist
"""
if not output_path.exists():
return {}
console.print(f"[bold blue]Loading existing captions from [cyan]{output_path}[/]...[/]")
existing_captions = {}
try:
match format_type:
case OutputFormat.TXT:
# For TXT format, we have two separate files
captions_file = output_path.with_stem(f"{output_path.stem}_captions")
paths_file = output_path.with_stem(f"{output_path.stem}_paths")
if captions_file.exists() and paths_file.exists():
captions = captions_file.read_text(encoding="utf-8").splitlines()
paths = paths_file.read_text(encoding="utf-8").splitlines()
if len(captions) == len(paths):
existing_captions = dict(zip(paths, captions, strict=False))
case OutputFormat.CSV:
with output_path.open("r", encoding="utf-8", newline="") as f:
reader = csv.reader(f)
# Skip header
next(reader, None)
for row in reader:
if len(row) >= 2:
caption, media_path = row[0], row[1]
existing_captions[media_path] = caption
case OutputFormat.JSON:
with output_path.open("r", encoding="utf-8") as f:
json_data = json.load(f)
for item in json_data:
if "caption" in item and "media_path" in item:
existing_captions[item["media_path"]] = item["caption"]
case OutputFormat.JSONL:
with output_path.open("r", encoding="utf-8") as f:
for line in f:
item = json.loads(line)
if "caption" in item and "media_path" in item:
existing_captions[item["media_path"]] = item["caption"]
case _:
raise ValueError(f"Unsupported output format: {format_type}")
console.print(f"[bold green]✓[/] Loaded [bold]{len(existing_captions)}[/] existing captions")
return existing_captions
except Exception as e:
console.print(f"[bold yellow]Warning: Could not load existing captions: {e}[/]")
return {}
@app.command()
def main( # noqa: PLR0913
input_path: Path = typer.Argument( # noqa: B008
...,
help="Path to input video/image file or directory containing media files",
exists=True,
),
output: Path | None = typer.Option( # noqa: B008
None,
"--output",
"-o",
help="Path to output file for captions. Format determined by file extension.",
),
captioner_type: CaptionerType = typer.Option( # noqa: B008
CaptionerType.QWEN_OMNI,
"--captioner-type",
"-c",
help="Type of captioner to use. Valid values: 'qwen_omni' (local), 'gemini_flash' (API)",
case_sensitive=False,
),
device: str | None = typer.Option(
None,
"--device",
"-d",
help="Device to use for inference (e.g., 'cuda', 'cuda:0', 'cpu'). Only for local models.",
),
use_8bit: bool = typer.Option(
False,
"--use-8bit",
help="Whether to use 8-bit precision for the captioning model (reduces memory usage)",
),
instruction: str | None = typer.Option(
None,
"--instruction",
"-i",
help="Custom instruction for the captioning model. If not provided, uses an appropriate default.",
),
extensions: str = typer.Option(
",".join(MEDIA_EXTENSIONS),
"--extensions",
"-e",
help="Comma-separated list of media file extensions to process",
),
recursive: bool = typer.Option(
False,
"--recursive",
"-r",
help="Search for media files in subdirectories recursively",
),
fps: int = typer.Option(
3,
"--fps",
"-f",
help="Frames per second to sample from videos (ignored for images)",
),
include_audio: bool = typer.Option(
True,
"--audio/--no-audio",
help="Whether to include audio in captioning (for videos with audio tracks)",
),
clean_caption: bool = typer.Option(
True,
"--clean-caption/--raw-caption",
help="Whether to clean up captions by removing common VLM patterns",
),
override: bool = typer.Option(
False,
"--override",
help="Whether to override existing captions for media",
),
api_key: str | None = typer.Option(
None,
"--api-key",
envvar=["GOOGLE_API_KEY", "GEMINI_API_KEY"],
help="API key for Gemini Flash (can also use GOOGLE_API_KEY or GEMINI_API_KEY env var)",
),
) -> None:
"""Auto-caption videos with audio using multimodal models.
This script supports audio-visual captioning using:
- Qwen2.5-Omni: Local model (default) - processes both video and audio
- Gemini Flash: Cloud API - requires GOOGLE_API_KEY environment variable
The paths in the output file will be relative to the output file's directory.
Examples:
# Caption videos with audio using Qwen2.5-Omni (default)
caption_videos.py videos_dir/ -o captions.json
# Caption using Gemini Flash API
caption_videos.py videos_dir/ -o captions.json -c gemini_flash
# Caption without audio (video-only)
caption_videos.py videos_dir/ -o captions.json --no-audio
# Caption with custom instruction
caption_videos.py video.mp4 -o captions.json -i "Describe this video in detail"
"""
# Determine device for local models
device_str = device or ("cuda" if torch.cuda.is_available() else "cpu")
# Parse extensions
ext_list = [ext.strip() for ext in extensions.split(",")]
# Determine output path and format
if output is None:
output_format = OutputFormat.JSON
if input_path.is_file(): # noqa: SIM108
# Default to a JSON file with the same name as the input media
output = input_path.with_suffix(".dataset.json")
else:
# Default to a JSON file in the input directory
output = input_path / "dataset.json"
else:
# Determine format from file extension
output_format = OutputFormat(Path(output).suffix.lstrip(".").lower())
# Ensure output path is absolute
output = Path(output).resolve()
console.print(f"Output will be saved to [bold blue]{output}[/]")
# Initialize captioning model
with console.status("Loading captioning model...", spinner="dots"):
if captioner_type == CaptionerType.QWEN_OMNI:
captioner = create_captioner(
captioner_type=captioner_type,
device=device_str,
use_8bit=use_8bit,
instruction=instruction,
)
elif captioner_type == CaptionerType.GEMINI_FLASH:
captioner = create_captioner(
captioner_type=captioner_type,
api_key=api_key,
instruction=instruction,
)
else:
raise ValueError(f"Unsupported captioner type: {captioner_type}")
console.print(f"[bold green]✓[/] {captioner_type.value} captioning model loaded successfully")
# Caption media files
caption_media(
input_path=input_path,
output_path=output,
captioner=captioner,
extensions=ext_list,
recursive=recursive,
fps=fps,
include_audio=include_audio,
clean_caption=clean_caption,
output_format=output_format,
override=override,
)
if __name__ == "__main__":
app()