|
|
|
|
|
|
|
|
""" |
|
|
Auto-caption videos with audio using multimodal models. |
|
|
|
|
|
This script provides a command-line interface for generating captions for videos |
|
|
(including audio) using multimodal models. It supports: |
|
|
|
|
|
- Qwen2.5-Omni: Local model for audio-visual captioning (default) |
|
|
- Gemini Flash: Cloud-based API for audio-visual captioning |
|
|
|
|
|
The paths to videos in the generated dataset/captions file will be RELATIVE to the |
|
|
directory where the output file is stored. This makes the dataset more portable and |
|
|
easier to use in different environments. |
|
|
|
|
|
Basic usage: |
|
|
# Caption a single video (includes audio by default) |
|
|
caption_videos.py video.mp4 --output captions.json |
|
|
|
|
|
# Caption all videos in a directory |
|
|
caption_videos.py videos_dir/ --output captions.csv |
|
|
|
|
|
# Caption with custom instruction |
|
|
caption_videos.py video.mp4 --instruction "Describe what happens in this video in detail." |
|
|
|
|
|
Advanced usage: |
|
|
# Use Gemini Flash API (requires GEMINI_API_KEY or GOOGLE_API_KEY env var) |
|
|
caption_videos.py videos_dir/ --captioner-type gemini_flash |
|
|
|
|
|
# Disable audio processing (video-only captions) |
|
|
caption_videos.py videos_dir/ --no-audio |
|
|
|
|
|
# Process videos with specific extensions and save as JSON |
|
|
caption_videos.py videos_dir/ --extensions mp4,mov,avi --output captions.json |
|
|
""" |
|
|
|
|
|
import csv |
|
|
import json |
|
|
from enum import Enum |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
import typer |
|
|
from rich.console import Console |
|
|
from rich.progress import ( |
|
|
BarColumn, |
|
|
MofNCompleteColumn, |
|
|
Progress, |
|
|
SpinnerColumn, |
|
|
TextColumn, |
|
|
TimeElapsedColumn, |
|
|
TimeRemainingColumn, |
|
|
) |
|
|
from transformers.utils.logging import disable_progress_bar |
|
|
|
|
|
from ltx_trainer.captioning import ( |
|
|
CaptionerType, |
|
|
MediaCaptioningModel, |
|
|
create_captioner, |
|
|
) |
|
|
|
|
|
VIDEO_EXTENSIONS = ["mp4", "avi", "mov", "mkv", "webm"] |
|
|
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png"] |
|
|
MEDIA_EXTENSIONS = VIDEO_EXTENSIONS + IMAGE_EXTENSIONS |
|
|
SAVE_INTERVAL = 5 |
|
|
|
|
|
console = Console() |
|
|
app = typer.Typer( |
|
|
pretty_exceptions_enable=False, |
|
|
no_args_is_help=True, |
|
|
help="Auto-caption videos with audio using multimodal models.", |
|
|
) |
|
|
|
|
|
disable_progress_bar() |
|
|
|
|
|
|
|
|
class OutputFormat(str, Enum): |
|
|
"""Available output formats for captions.""" |
|
|
|
|
|
TXT = "txt" |
|
|
CSV = "csv" |
|
|
JSON = "json" |
|
|
JSONL = "jsonl" |
|
|
|
|
|
|
|
|
def caption_media( |
|
|
input_path: Path, |
|
|
output_path: Path, |
|
|
captioner: MediaCaptioningModel, |
|
|
extensions: list[str], |
|
|
recursive: bool, |
|
|
fps: int, |
|
|
include_audio: bool, |
|
|
clean_caption: bool, |
|
|
output_format: OutputFormat, |
|
|
override: bool, |
|
|
) -> None: |
|
|
"""Caption videos and images using the provided captioning model. |
|
|
|
|
|
Args: |
|
|
input_path: Path to input video file or directory |
|
|
output_path: Path to output caption file |
|
|
captioner: Media captioning model |
|
|
extensions: List of media file extensions to include |
|
|
recursive: Whether to search subdirectories recursively |
|
|
fps: Frames per second to sample from videos (ignored for images) |
|
|
include_audio: Whether to include audio in captioning |
|
|
clean_caption: Whether to clean up captions |
|
|
output_format: Format to save the captions in |
|
|
override: Whether to override existing captions |
|
|
""" |
|
|
|
|
|
|
|
|
media_files = _get_media_files(input_path, extensions, recursive) |
|
|
|
|
|
if not media_files: |
|
|
console.print("[bold yellow]No media files found to process.[/]") |
|
|
return |
|
|
|
|
|
console.print(f"Found [bold]{len(media_files)}[/] media files to process.") |
|
|
|
|
|
|
|
|
base_dir = output_path.parent.resolve() |
|
|
existing_captions = _load_existing_captions(output_path, output_format) |
|
|
existing_abs_paths = {str((base_dir / p).resolve()) for p in existing_captions} |
|
|
|
|
|
if override: |
|
|
media_to_process = media_files |
|
|
else: |
|
|
media_to_process = [f for f in media_files if str(f.resolve()) not in existing_abs_paths] |
|
|
if skipped := len(media_files) - len(media_to_process): |
|
|
console.print(f"[bold yellow]Skipping {skipped} media that already have captions.[/]") |
|
|
|
|
|
if not media_to_process: |
|
|
console.print("[bold yellow]All media already have captions. Use --override to recaption.[/]") |
|
|
return |
|
|
|
|
|
|
|
|
captions = existing_captions.copy() |
|
|
successfully_captioned = 0 |
|
|
progress = Progress( |
|
|
SpinnerColumn(), |
|
|
TextColumn("{task.description}"), |
|
|
BarColumn(bar_width=40), |
|
|
MofNCompleteColumn(), |
|
|
TimeElapsedColumn(), |
|
|
TextColumn("•"), |
|
|
TimeRemainingColumn(), |
|
|
console=console, |
|
|
) |
|
|
|
|
|
with progress: |
|
|
task = progress.add_task("Captioning", total=len(media_to_process)) |
|
|
|
|
|
for i, media_file in enumerate(media_to_process): |
|
|
progress.update(task, description=f"Captioning [bold blue]{media_file.name}[/]") |
|
|
|
|
|
try: |
|
|
|
|
|
caption = captioner.caption( |
|
|
path=media_file, |
|
|
fps=fps, |
|
|
include_audio=include_audio, |
|
|
clean_caption=clean_caption, |
|
|
) |
|
|
|
|
|
|
|
|
rel_path = str(media_file.resolve().relative_to(base_dir)) |
|
|
|
|
|
captions[rel_path] = caption |
|
|
successfully_captioned += 1 |
|
|
except Exception as e: |
|
|
console.print(f"[bold red]Error captioning {media_file}: {e}[/]") |
|
|
|
|
|
if i % SAVE_INTERVAL == 0: |
|
|
_save_captions(captions, output_path, output_format) |
|
|
|
|
|
|
|
|
progress.advance(task) |
|
|
|
|
|
|
|
|
_save_captions(captions, output_path, output_format) |
|
|
|
|
|
|
|
|
console.print( |
|
|
f"[bold green]✓[/] Captioned [bold]{successfully_captioned}/{len(media_to_process)}[/] media successfully.", |
|
|
) |
|
|
|
|
|
|
|
|
def _get_media_files( |
|
|
input_path: Path, |
|
|
extensions: list[str] = MEDIA_EXTENSIONS, |
|
|
recursive: bool = False, |
|
|
) -> list[Path]: |
|
|
"""Get all media files from the input path.""" |
|
|
input_path = Path(input_path) |
|
|
|
|
|
extensions = [ext.lower().lstrip(".") for ext in extensions] |
|
|
|
|
|
if input_path.is_file(): |
|
|
|
|
|
if input_path.suffix.lstrip(".").lower() in extensions: |
|
|
return [input_path] |
|
|
else: |
|
|
typer.echo(f"Warning: {input_path} is not a recognized media file. Skipping.") |
|
|
return [] |
|
|
elif input_path.is_dir(): |
|
|
|
|
|
media_files = [] |
|
|
|
|
|
|
|
|
glob_pattern = "**/*" if recursive else "*" |
|
|
|
|
|
|
|
|
for ext in extensions: |
|
|
media_files.extend(input_path.glob(f"{glob_pattern}.{ext}")) |
|
|
|
|
|
return sorted(media_files) |
|
|
else: |
|
|
typer.echo(f"Error: {input_path} does not exist.") |
|
|
raise typer.Exit(code=1) |
|
|
|
|
|
|
|
|
def _save_captions( |
|
|
captions: dict[str, str], |
|
|
output_path: Path, |
|
|
format_type: OutputFormat, |
|
|
) -> None: |
|
|
"""Save captions to a file in the specified format. |
|
|
|
|
|
Args: |
|
|
captions: Dictionary mapping media paths to captions |
|
|
output_path: Path to save the output file |
|
|
format_type: Format to save the captions in |
|
|
""" |
|
|
|
|
|
output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
console.print("[bold blue]Saving captions...[/]") |
|
|
|
|
|
match format_type: |
|
|
case OutputFormat.TXT: |
|
|
|
|
|
captions_file = output_path.with_stem(f"{output_path.stem}_captions") |
|
|
paths_file = output_path.with_stem(f"{output_path.stem}_paths") |
|
|
|
|
|
with captions_file.open("w", encoding="utf-8") as f: |
|
|
for caption in captions.values(): |
|
|
f.write(f"{caption}\n") |
|
|
|
|
|
with paths_file.open("w", encoding="utf-8") as f: |
|
|
for media_path in captions: |
|
|
f.write(f"{media_path}\n") |
|
|
|
|
|
console.print(f"[bold green]✓[/] Captions saved to [cyan]{captions_file}[/]") |
|
|
console.print(f"[bold green]✓[/] Media paths saved to [cyan]{paths_file}[/]") |
|
|
|
|
|
case OutputFormat.CSV: |
|
|
with output_path.open("w", encoding="utf-8", newline="") as f: |
|
|
writer = csv.writer(f) |
|
|
writer.writerow(["caption", "media_path"]) |
|
|
for media_path, caption in captions.items(): |
|
|
writer.writerow([caption, media_path]) |
|
|
|
|
|
console.print(f"[bold green]✓[/] Captions saved to [cyan]{output_path}[/]") |
|
|
|
|
|
case OutputFormat.JSON: |
|
|
|
|
|
json_data = [{"caption": caption, "media_path": media_path} for media_path, caption in captions.items()] |
|
|
|
|
|
with output_path.open("w", encoding="utf-8") as f: |
|
|
json.dump(json_data, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
console.print(f"[bold green]✓[/] Captions saved to [cyan]{output_path}[/]") |
|
|
|
|
|
case OutputFormat.JSONL: |
|
|
with output_path.open("w", encoding="utf-8") as f: |
|
|
for media_path, caption in captions.items(): |
|
|
f.write(json.dumps({"caption": caption, "media_path": media_path}, ensure_ascii=False) + "\n") |
|
|
|
|
|
console.print(f"[bold green]✓[/] Captions saved to [cyan]{output_path}[/]") |
|
|
|
|
|
case _: |
|
|
raise ValueError(f"Unsupported output format: {format_type}") |
|
|
|
|
|
|
|
|
def _load_existing_captions( |
|
|
output_path: Path, |
|
|
format_type: OutputFormat, |
|
|
) -> dict[str, str]: |
|
|
"""Load existing captions from a file. |
|
|
|
|
|
Args: |
|
|
output_path: Path to the captions file |
|
|
format_type: Format of the captions file |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping media paths to captions, or empty dict if file doesn't exist |
|
|
""" |
|
|
if not output_path.exists(): |
|
|
return {} |
|
|
|
|
|
console.print(f"[bold blue]Loading existing captions from [cyan]{output_path}[/]...[/]") |
|
|
|
|
|
existing_captions = {} |
|
|
|
|
|
try: |
|
|
match format_type: |
|
|
case OutputFormat.TXT: |
|
|
|
|
|
captions_file = output_path.with_stem(f"{output_path.stem}_captions") |
|
|
paths_file = output_path.with_stem(f"{output_path.stem}_paths") |
|
|
|
|
|
if captions_file.exists() and paths_file.exists(): |
|
|
captions = captions_file.read_text(encoding="utf-8").splitlines() |
|
|
paths = paths_file.read_text(encoding="utf-8").splitlines() |
|
|
|
|
|
if len(captions) == len(paths): |
|
|
existing_captions = dict(zip(paths, captions, strict=False)) |
|
|
|
|
|
case OutputFormat.CSV: |
|
|
with output_path.open("r", encoding="utf-8", newline="") as f: |
|
|
reader = csv.reader(f) |
|
|
|
|
|
next(reader, None) |
|
|
for row in reader: |
|
|
if len(row) >= 2: |
|
|
caption, media_path = row[0], row[1] |
|
|
existing_captions[media_path] = caption |
|
|
|
|
|
case OutputFormat.JSON: |
|
|
with output_path.open("r", encoding="utf-8") as f: |
|
|
json_data = json.load(f) |
|
|
for item in json_data: |
|
|
if "caption" in item and "media_path" in item: |
|
|
existing_captions[item["media_path"]] = item["caption"] |
|
|
|
|
|
case OutputFormat.JSONL: |
|
|
with output_path.open("r", encoding="utf-8") as f: |
|
|
for line in f: |
|
|
item = json.loads(line) |
|
|
if "caption" in item and "media_path" in item: |
|
|
existing_captions[item["media_path"]] = item["caption"] |
|
|
|
|
|
case _: |
|
|
raise ValueError(f"Unsupported output format: {format_type}") |
|
|
|
|
|
console.print(f"[bold green]✓[/] Loaded [bold]{len(existing_captions)}[/] existing captions") |
|
|
return existing_captions |
|
|
|
|
|
except Exception as e: |
|
|
console.print(f"[bold yellow]Warning: Could not load existing captions: {e}[/]") |
|
|
return {} |
|
|
|
|
|
|
|
|
@app.command() |
|
|
def main( |
|
|
input_path: Path = typer.Argument( |
|
|
..., |
|
|
help="Path to input video/image file or directory containing media files", |
|
|
exists=True, |
|
|
), |
|
|
output: Path | None = typer.Option( |
|
|
None, |
|
|
"--output", |
|
|
"-o", |
|
|
help="Path to output file for captions. Format determined by file extension.", |
|
|
), |
|
|
captioner_type: CaptionerType = typer.Option( |
|
|
CaptionerType.QWEN_OMNI, |
|
|
"--captioner-type", |
|
|
"-c", |
|
|
help="Type of captioner to use. Valid values: 'qwen_omni' (local), 'gemini_flash' (API)", |
|
|
case_sensitive=False, |
|
|
), |
|
|
device: str | None = typer.Option( |
|
|
None, |
|
|
"--device", |
|
|
"-d", |
|
|
help="Device to use for inference (e.g., 'cuda', 'cuda:0', 'cpu'). Only for local models.", |
|
|
), |
|
|
use_8bit: bool = typer.Option( |
|
|
False, |
|
|
"--use-8bit", |
|
|
help="Whether to use 8-bit precision for the captioning model (reduces memory usage)", |
|
|
), |
|
|
instruction: str | None = typer.Option( |
|
|
None, |
|
|
"--instruction", |
|
|
"-i", |
|
|
help="Custom instruction for the captioning model. If not provided, uses an appropriate default.", |
|
|
), |
|
|
extensions: str = typer.Option( |
|
|
",".join(MEDIA_EXTENSIONS), |
|
|
"--extensions", |
|
|
"-e", |
|
|
help="Comma-separated list of media file extensions to process", |
|
|
), |
|
|
recursive: bool = typer.Option( |
|
|
False, |
|
|
"--recursive", |
|
|
"-r", |
|
|
help="Search for media files in subdirectories recursively", |
|
|
), |
|
|
fps: int = typer.Option( |
|
|
3, |
|
|
"--fps", |
|
|
"-f", |
|
|
help="Frames per second to sample from videos (ignored for images)", |
|
|
), |
|
|
include_audio: bool = typer.Option( |
|
|
True, |
|
|
"--audio/--no-audio", |
|
|
help="Whether to include audio in captioning (for videos with audio tracks)", |
|
|
), |
|
|
clean_caption: bool = typer.Option( |
|
|
True, |
|
|
"--clean-caption/--raw-caption", |
|
|
help="Whether to clean up captions by removing common VLM patterns", |
|
|
), |
|
|
override: bool = typer.Option( |
|
|
False, |
|
|
"--override", |
|
|
help="Whether to override existing captions for media", |
|
|
), |
|
|
api_key: str | None = typer.Option( |
|
|
None, |
|
|
"--api-key", |
|
|
envvar=["GOOGLE_API_KEY", "GEMINI_API_KEY"], |
|
|
help="API key for Gemini Flash (can also use GOOGLE_API_KEY or GEMINI_API_KEY env var)", |
|
|
), |
|
|
) -> None: |
|
|
"""Auto-caption videos with audio using multimodal models. |
|
|
|
|
|
This script supports audio-visual captioning using: |
|
|
- Qwen2.5-Omni: Local model (default) - processes both video and audio |
|
|
- Gemini Flash: Cloud API - requires GOOGLE_API_KEY environment variable |
|
|
|
|
|
The paths in the output file will be relative to the output file's directory. |
|
|
|
|
|
Examples: |
|
|
# Caption videos with audio using Qwen2.5-Omni (default) |
|
|
caption_videos.py videos_dir/ -o captions.json |
|
|
|
|
|
# Caption using Gemini Flash API |
|
|
caption_videos.py videos_dir/ -o captions.json -c gemini_flash |
|
|
|
|
|
# Caption without audio (video-only) |
|
|
caption_videos.py videos_dir/ -o captions.json --no-audio |
|
|
|
|
|
# Caption with custom instruction |
|
|
caption_videos.py video.mp4 -o captions.json -i "Describe this video in detail" |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
device_str = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
ext_list = [ext.strip() for ext in extensions.split(",")] |
|
|
|
|
|
|
|
|
if output is None: |
|
|
output_format = OutputFormat.JSON |
|
|
if input_path.is_file(): |
|
|
|
|
|
output = input_path.with_suffix(".dataset.json") |
|
|
else: |
|
|
|
|
|
output = input_path / "dataset.json" |
|
|
else: |
|
|
|
|
|
output_format = OutputFormat(Path(output).suffix.lstrip(".").lower()) |
|
|
|
|
|
|
|
|
output = Path(output).resolve() |
|
|
console.print(f"Output will be saved to [bold blue]{output}[/]") |
|
|
|
|
|
|
|
|
with console.status("Loading captioning model...", spinner="dots"): |
|
|
if captioner_type == CaptionerType.QWEN_OMNI: |
|
|
captioner = create_captioner( |
|
|
captioner_type=captioner_type, |
|
|
device=device_str, |
|
|
use_8bit=use_8bit, |
|
|
instruction=instruction, |
|
|
) |
|
|
elif captioner_type == CaptionerType.GEMINI_FLASH: |
|
|
captioner = create_captioner( |
|
|
captioner_type=captioner_type, |
|
|
api_key=api_key, |
|
|
instruction=instruction, |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Unsupported captioner type: {captioner_type}") |
|
|
|
|
|
console.print(f"[bold green]✓[/] {captioner_type.value} captioning model loaded successfully") |
|
|
|
|
|
|
|
|
caption_media( |
|
|
input_path=input_path, |
|
|
output_path=output, |
|
|
captioner=captioner, |
|
|
extensions=ext_list, |
|
|
recursive=recursive, |
|
|
fps=fps, |
|
|
include_audio=include_audio, |
|
|
clean_caption=clean_caption, |
|
|
output_format=output_format, |
|
|
override=override, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
app() |
|
|
|