#!/usr/bin/env python3 """ Auto-caption videos with audio using multimodal models. This script provides a command-line interface for generating captions for videos (including audio) using multimodal models. It supports: - Qwen2.5-Omni: Local model for audio-visual captioning (default) - Gemini Flash: Cloud-based API for audio-visual captioning The paths to videos in the generated dataset/captions file will be RELATIVE to the directory where the output file is stored. This makes the dataset more portable and easier to use in different environments. Basic usage: # Caption a single video (includes audio by default) caption_videos.py video.mp4 --output captions.json # Caption all videos in a directory caption_videos.py videos_dir/ --output captions.csv # Caption with custom instruction caption_videos.py video.mp4 --instruction "Describe what happens in this video in detail." Advanced usage: # Use Gemini Flash API (requires GEMINI_API_KEY or GOOGLE_API_KEY env var) caption_videos.py videos_dir/ --captioner-type gemini_flash # Disable audio processing (video-only captions) caption_videos.py videos_dir/ --no-audio # Process videos with specific extensions and save as JSON caption_videos.py videos_dir/ --extensions mp4,mov,avi --output captions.json """ import csv import json from enum import Enum from pathlib import Path import torch import typer from rich.console import Console from rich.progress import ( BarColumn, MofNCompleteColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn, ) from transformers.utils.logging import disable_progress_bar from ltx_trainer.captioning import ( CaptionerType, MediaCaptioningModel, create_captioner, ) VIDEO_EXTENSIONS = ["mp4", "avi", "mov", "mkv", "webm"] IMAGE_EXTENSIONS = ["jpg", "jpeg", "png"] MEDIA_EXTENSIONS = VIDEO_EXTENSIONS + IMAGE_EXTENSIONS SAVE_INTERVAL = 5 console = Console() app = typer.Typer( pretty_exceptions_enable=False, no_args_is_help=True, help="Auto-caption videos with audio using multimodal models.", ) disable_progress_bar() class OutputFormat(str, Enum): """Available output formats for captions.""" TXT = "txt" # Separate files for captions and video paths, one caption / video path per line CSV = "csv" # CSV file with video path and caption columns JSON = "json" # JSON file with video paths as keys and captions as values JSONL = "jsonl" # JSON Lines file with one JSON object per line def caption_media( input_path: Path, output_path: Path, captioner: MediaCaptioningModel, extensions: list[str], recursive: bool, fps: int, include_audio: bool, clean_caption: bool, output_format: OutputFormat, override: bool, ) -> None: """Caption videos and images using the provided captioning model. Args: input_path: Path to input video file or directory output_path: Path to output caption file captioner: Media captioning model extensions: List of media file extensions to include recursive: Whether to search subdirectories recursively fps: Frames per second to sample from videos (ignored for images) include_audio: Whether to include audio in captioning clean_caption: Whether to clean up captions output_format: Format to save the captions in override: Whether to override existing captions """ # Get list of media files to process media_files = _get_media_files(input_path, extensions, recursive) if not media_files: console.print("[bold yellow]No media files found to process.[/]") return console.print(f"Found [bold]{len(media_files)}[/] media files to process.") # Load existing captions and determine which files need processing base_dir = output_path.parent.resolve() existing_captions = _load_existing_captions(output_path, output_format) existing_abs_paths = {str((base_dir / p).resolve()) for p in existing_captions} if override: media_to_process = media_files else: media_to_process = [f for f in media_files if str(f.resolve()) not in existing_abs_paths] if skipped := len(media_files) - len(media_to_process): console.print(f"[bold yellow]Skipping {skipped} media that already have captions.[/]") if not media_to_process: console.print("[bold yellow]All media already have captions. Use --override to recaption.[/]") return # Process media files captions = existing_captions.copy() successfully_captioned = 0 progress = Progress( SpinnerColumn(), TextColumn("{task.description}"), BarColumn(bar_width=40), MofNCompleteColumn(), TimeElapsedColumn(), TextColumn("•"), TimeRemainingColumn(), console=console, ) with progress: task = progress.add_task("Captioning", total=len(media_to_process)) for i, media_file in enumerate(media_to_process): progress.update(task, description=f"Captioning [bold blue]{media_file.name}[/]") try: # Generate caption for the media caption = captioner.caption( path=media_file, fps=fps, include_audio=include_audio, clean_caption=clean_caption, ) # Convert absolute path to relative path (relative to the output file's directory) rel_path = str(media_file.resolve().relative_to(base_dir)) # Store the caption with the relative path as key captions[rel_path] = caption successfully_captioned += 1 except Exception as e: console.print(f"[bold red]Error captioning {media_file}: {e}[/]") if i % SAVE_INTERVAL == 0: _save_captions(captions, output_path, output_format) # Advance progress bar progress.advance(task) # Save captions to file _save_captions(captions, output_path, output_format) # Print summary console.print( f"[bold green]✓[/] Captioned [bold]{successfully_captioned}/{len(media_to_process)}[/] media successfully.", ) def _get_media_files( input_path: Path, extensions: list[str] = MEDIA_EXTENSIONS, recursive: bool = False, ) -> list[Path]: """Get all media files from the input path.""" input_path = Path(input_path) # Normalize extensions to lowercase without dots extensions = [ext.lower().lstrip(".") for ext in extensions] if input_path.is_file(): # If input is a file, check if it has a valid extension if input_path.suffix.lstrip(".").lower() in extensions: return [input_path] else: typer.echo(f"Warning: {input_path} is not a recognized media file. Skipping.") return [] elif input_path.is_dir(): # If input is a directory, find all media files media_files = [] # Define the glob pattern based on whether we're searching recursively glob_pattern = "**/*" if recursive else "*" # Find all files with the specified extensions for ext in extensions: media_files.extend(input_path.glob(f"{glob_pattern}.{ext}")) return sorted(media_files) else: typer.echo(f"Error: {input_path} does not exist.") raise typer.Exit(code=1) def _save_captions( captions: dict[str, str], output_path: Path, format_type: OutputFormat, ) -> None: """Save captions to a file in the specified format. Args: captions: Dictionary mapping media paths to captions output_path: Path to save the output file format_type: Format to save the captions in """ # Create parent directories if they don't exist output_path.parent.mkdir(parents=True, exist_ok=True) console.print("[bold blue]Saving captions...[/]") match format_type: case OutputFormat.TXT: # Create two separate files for captions and media paths captions_file = output_path.with_stem(f"{output_path.stem}_captions") paths_file = output_path.with_stem(f"{output_path.stem}_paths") with captions_file.open("w", encoding="utf-8") as f: for caption in captions.values(): f.write(f"{caption}\n") with paths_file.open("w", encoding="utf-8") as f: for media_path in captions: f.write(f"{media_path}\n") console.print(f"[bold green]✓[/] Captions saved to [cyan]{captions_file}[/]") console.print(f"[bold green]✓[/] Media paths saved to [cyan]{paths_file}[/]") case OutputFormat.CSV: with output_path.open("w", encoding="utf-8", newline="") as f: writer = csv.writer(f) writer.writerow(["caption", "media_path"]) for media_path, caption in captions.items(): writer.writerow([caption, media_path]) console.print(f"[bold green]✓[/] Captions saved to [cyan]{output_path}[/]") case OutputFormat.JSON: # Format as list of dictionaries with caption and media_path keys json_data = [{"caption": caption, "media_path": media_path} for media_path, caption in captions.items()] with output_path.open("w", encoding="utf-8") as f: json.dump(json_data, f, indent=2, ensure_ascii=False) console.print(f"[bold green]✓[/] Captions saved to [cyan]{output_path}[/]") case OutputFormat.JSONL: with output_path.open("w", encoding="utf-8") as f: for media_path, caption in captions.items(): f.write(json.dumps({"caption": caption, "media_path": media_path}, ensure_ascii=False) + "\n") console.print(f"[bold green]✓[/] Captions saved to [cyan]{output_path}[/]") case _: raise ValueError(f"Unsupported output format: {format_type}") def _load_existing_captions( # noqa: PLR0912 output_path: Path, format_type: OutputFormat, ) -> dict[str, str]: """Load existing captions from a file. Args: output_path: Path to the captions file format_type: Format of the captions file Returns: Dictionary mapping media paths to captions, or empty dict if file doesn't exist """ if not output_path.exists(): return {} console.print(f"[bold blue]Loading existing captions from [cyan]{output_path}[/]...[/]") existing_captions = {} try: match format_type: case OutputFormat.TXT: # For TXT format, we have two separate files captions_file = output_path.with_stem(f"{output_path.stem}_captions") paths_file = output_path.with_stem(f"{output_path.stem}_paths") if captions_file.exists() and paths_file.exists(): captions = captions_file.read_text(encoding="utf-8").splitlines() paths = paths_file.read_text(encoding="utf-8").splitlines() if len(captions) == len(paths): existing_captions = dict(zip(paths, captions, strict=False)) case OutputFormat.CSV: with output_path.open("r", encoding="utf-8", newline="") as f: reader = csv.reader(f) # Skip header next(reader, None) for row in reader: if len(row) >= 2: caption, media_path = row[0], row[1] existing_captions[media_path] = caption case OutputFormat.JSON: with output_path.open("r", encoding="utf-8") as f: json_data = json.load(f) for item in json_data: if "caption" in item and "media_path" in item: existing_captions[item["media_path"]] = item["caption"] case OutputFormat.JSONL: with output_path.open("r", encoding="utf-8") as f: for line in f: item = json.loads(line) if "caption" in item and "media_path" in item: existing_captions[item["media_path"]] = item["caption"] case _: raise ValueError(f"Unsupported output format: {format_type}") console.print(f"[bold green]✓[/] Loaded [bold]{len(existing_captions)}[/] existing captions") return existing_captions except Exception as e: console.print(f"[bold yellow]Warning: Could not load existing captions: {e}[/]") return {} @app.command() def main( # noqa: PLR0913 input_path: Path = typer.Argument( # noqa: B008 ..., help="Path to input video/image file or directory containing media files", exists=True, ), output: Path | None = typer.Option( # noqa: B008 None, "--output", "-o", help="Path to output file for captions. Format determined by file extension.", ), captioner_type: CaptionerType = typer.Option( # noqa: B008 CaptionerType.QWEN_OMNI, "--captioner-type", "-c", help="Type of captioner to use. Valid values: 'qwen_omni' (local), 'gemini_flash' (API)", case_sensitive=False, ), device: str | None = typer.Option( None, "--device", "-d", help="Device to use for inference (e.g., 'cuda', 'cuda:0', 'cpu'). Only for local models.", ), use_8bit: bool = typer.Option( False, "--use-8bit", help="Whether to use 8-bit precision for the captioning model (reduces memory usage)", ), instruction: str | None = typer.Option( None, "--instruction", "-i", help="Custom instruction for the captioning model. If not provided, uses an appropriate default.", ), extensions: str = typer.Option( ",".join(MEDIA_EXTENSIONS), "--extensions", "-e", help="Comma-separated list of media file extensions to process", ), recursive: bool = typer.Option( False, "--recursive", "-r", help="Search for media files in subdirectories recursively", ), fps: int = typer.Option( 3, "--fps", "-f", help="Frames per second to sample from videos (ignored for images)", ), include_audio: bool = typer.Option( True, "--audio/--no-audio", help="Whether to include audio in captioning (for videos with audio tracks)", ), clean_caption: bool = typer.Option( True, "--clean-caption/--raw-caption", help="Whether to clean up captions by removing common VLM patterns", ), override: bool = typer.Option( False, "--override", help="Whether to override existing captions for media", ), api_key: str | None = typer.Option( None, "--api-key", envvar=["GOOGLE_API_KEY", "GEMINI_API_KEY"], help="API key for Gemini Flash (can also use GOOGLE_API_KEY or GEMINI_API_KEY env var)", ), ) -> None: """Auto-caption videos with audio using multimodal models. This script supports audio-visual captioning using: - Qwen2.5-Omni: Local model (default) - processes both video and audio - Gemini Flash: Cloud API - requires GOOGLE_API_KEY environment variable The paths in the output file will be relative to the output file's directory. Examples: # Caption videos with audio using Qwen2.5-Omni (default) caption_videos.py videos_dir/ -o captions.json # Caption using Gemini Flash API caption_videos.py videos_dir/ -o captions.json -c gemini_flash # Caption without audio (video-only) caption_videos.py videos_dir/ -o captions.json --no-audio # Caption with custom instruction caption_videos.py video.mp4 -o captions.json -i "Describe this video in detail" """ # Determine device for local models device_str = device or ("cuda" if torch.cuda.is_available() else "cpu") # Parse extensions ext_list = [ext.strip() for ext in extensions.split(",")] # Determine output path and format if output is None: output_format = OutputFormat.JSON if input_path.is_file(): # noqa: SIM108 # Default to a JSON file with the same name as the input media output = input_path.with_suffix(".dataset.json") else: # Default to a JSON file in the input directory output = input_path / "dataset.json" else: # Determine format from file extension output_format = OutputFormat(Path(output).suffix.lstrip(".").lower()) # Ensure output path is absolute output = Path(output).resolve() console.print(f"Output will be saved to [bold blue]{output}[/]") # Initialize captioning model with console.status("Loading captioning model...", spinner="dots"): if captioner_type == CaptionerType.QWEN_OMNI: captioner = create_captioner( captioner_type=captioner_type, device=device_str, use_8bit=use_8bit, instruction=instruction, ) elif captioner_type == CaptionerType.GEMINI_FLASH: captioner = create_captioner( captioner_type=captioner_type, api_key=api_key, instruction=instruction, ) else: raise ValueError(f"Unsupported captioner type: {captioner_type}") console.print(f"[bold green]✓[/] {captioner_type.value} captioning model loaded successfully") # Caption media files caption_media( input_path=input_path, output_path=output, captioner=captioner, extensions=ext_list, recursive=recursive, fps=fps, include_audio=include_audio, clean_caption=clean_caption, output_format=output_format, override=override, ) if __name__ == "__main__": app()