| |
|
|
| """ |
| Auto-caption videos with audio using multimodal models. |
| This script provides a command-line interface for generating captions for videos |
| (including audio) using multimodal models. It supports: |
| - Qwen2.5-Omni: Local model for audio-visual captioning (default) |
| - Gemini Flash: Cloud-based API for audio-visual captioning |
| The paths to videos in the generated dataset/captions file will be RELATIVE to the |
| directory where the output file is stored. This makes the dataset more portable and |
| easier to use in different environments. |
| Basic usage: |
| # Caption a single video (includes audio by default) |
| caption_videos.py video.mp4 --output captions.json |
| # Caption all videos in a directory |
| caption_videos.py videos_dir/ --output captions.csv |
| # Caption with custom instruction |
| caption_videos.py video.mp4 --instruction "Describe what happens in this video in detail." |
| Advanced usage: |
| # Use Gemini Flash API (requires GEMINI_API_KEY or GOOGLE_API_KEY env var) |
| caption_videos.py videos_dir/ --captioner-type gemini_flash |
| # Disable audio processing (video-only captions) |
| caption_videos.py videos_dir/ --no-audio |
| # Process videos with specific extensions and save as JSON |
| caption_videos.py videos_dir/ --extensions mp4,mov,avi --output captions.json |
| """ |
|
|
| import csv |
| import json |
| from enum import Enum |
| from pathlib import Path |
|
|
| import torch |
| import typer |
| from rich.console import Console |
| from rich.progress import ( |
| BarColumn, |
| MofNCompleteColumn, |
| Progress, |
| SpinnerColumn, |
| TextColumn, |
| TimeElapsedColumn, |
| TimeRemainingColumn, |
| ) |
| from transformers.utils.logging import disable_progress_bar |
|
|
| from ltx_trainer.captioning import CaptionerType, MediaCaptioningModel, create_captioner |
|
|
| VIDEO_EXTENSIONS = ["mp4", "avi", "mov", "mkv", "webm"] |
| IMAGE_EXTENSIONS = ["jpg", "jpeg", "png"] |
| MEDIA_EXTENSIONS = VIDEO_EXTENSIONS + IMAGE_EXTENSIONS |
| SAVE_INTERVAL = 5 |
|
|
| console = Console() |
| app = typer.Typer( |
| pretty_exceptions_enable=False, |
| no_args_is_help=True, |
| help="Auto-caption videos with audio using multimodal models.", |
| ) |
|
|
| disable_progress_bar() |
|
|
|
|
| class OutputFormat(str, Enum): |
| """Available output formats for captions.""" |
|
|
| TXT = "txt" |
| CSV = "csv" |
| JSON = "json" |
| JSONL = "jsonl" |
|
|
|
|
| def caption_media( |
| input_path: Path, |
| output_path: Path, |
| captioner: MediaCaptioningModel, |
| extensions: list[str], |
| recursive: bool, |
| fps: int, |
| include_audio: bool, |
| clean_caption: bool, |
| output_format: OutputFormat, |
| override: bool, |
| ) -> None: |
| """Caption videos and images using the provided captioning model. |
| Args: |
| input_path: Path to input video file or directory |
| output_path: Path to output caption file |
| captioner: Media captioning model |
| extensions: List of media file extensions to include |
| recursive: Whether to search subdirectories recursively |
| fps: Frames per second to sample from videos (ignored for images) |
| include_audio: Whether to include audio in captioning |
| clean_caption: Whether to clean up captions |
| output_format: Format to save the captions in |
| override: Whether to override existing captions |
| """ |
|
|
| |
| media_files = _get_media_files(input_path, extensions, recursive) |
|
|
| if not media_files: |
| console.print("[bold yellow]No media files found to process.[/]") |
| return |
|
|
| console.print(f"Found [bold]{len(media_files)}[/] media files to process.") |
|
|
| |
| base_dir = output_path.parent.resolve() |
| existing_captions = _load_existing_captions(output_path, output_format) |
| existing_abs_paths = {str((base_dir / p).resolve()) for p in existing_captions} |
|
|
| if override: |
| media_to_process = media_files |
| else: |
| media_to_process = [f for f in media_files if str(f.resolve()) not in existing_abs_paths] |
| if skipped := len(media_files) - len(media_to_process): |
| console.print(f"[bold yellow]Skipping {skipped} media that already have captions.[/]") |
|
|
| if not media_to_process: |
| console.print("[bold yellow]All media already have captions. Use --override to recaption.[/]") |
| return |
|
|
| |
| captions = existing_captions.copy() |
| successfully_captioned = 0 |
| progress = Progress( |
| SpinnerColumn(), |
| TextColumn("{task.description}"), |
| BarColumn(bar_width=40), |
| MofNCompleteColumn(), |
| TimeElapsedColumn(), |
| TextColumn("•"), |
| TimeRemainingColumn(), |
| console=console, |
| ) |
|
|
| with progress: |
| task = progress.add_task("Captioning", total=len(media_to_process)) |
|
|
| for i, media_file in enumerate(media_to_process): |
| progress.update(task, description=f"Captioning [bold blue]{media_file.name}[/]") |
|
|
| try: |
| |
| caption = captioner.caption( |
| path=media_file, |
| fps=fps, |
| include_audio=include_audio, |
| clean_caption=clean_caption, |
| ) |
|
|
| |
| rel_path = str(media_file.resolve().relative_to(base_dir)) |
| |
| captions[rel_path] = caption |
| successfully_captioned += 1 |
| except Exception as e: |
| console.print(f"[bold red]Error captioning {media_file}: {e}[/]") |
|
|
| if i % SAVE_INTERVAL == 0: |
| _save_captions(captions, output_path, output_format) |
|
|
| |
| progress.advance(task) |
|
|
| |
| _save_captions(captions, output_path, output_format) |
|
|
| |
| console.print( |
| f"[bold green]✓[/] Captioned [bold]{successfully_captioned}/{len(media_to_process)}[/] media successfully.", |
| ) |
|
|
|
|
| def _get_media_files( |
| input_path: Path, |
| extensions: list[str] = MEDIA_EXTENSIONS, |
| recursive: bool = False, |
| ) -> list[Path]: |
| """Get all media files from the input path.""" |
| input_path = Path(input_path) |
| |
| extensions_set = {ext.lower().lstrip(".") for ext in extensions} |
|
|
| if input_path.is_file(): |
| |
| if input_path.suffix.lstrip(".").lower() in extensions_set: |
| return [input_path] |
| else: |
| typer.echo(f"Warning: {input_path} is not a recognized media file. Skipping.") |
| return [] |
| elif input_path.is_dir(): |
| |
| glob_pattern = "**/*" if recursive else "*" |
| media_files = [ |
| f for f in input_path.glob(glob_pattern) if f.is_file() and f.suffix.lstrip(".").lower() in extensions_set |
| ] |
| return sorted(media_files) |
| else: |
| typer.echo(f"Error: {input_path} does not exist.") |
| raise typer.Exit(code=1) |
|
|
|
|
| def _save_captions( |
| captions: dict[str, str], |
| output_path: Path, |
| format_type: OutputFormat, |
| ) -> None: |
| """Save captions to a file in the specified format. |
| Args: |
| captions: Dictionary mapping media paths to captions |
| output_path: Path to save the output file |
| format_type: Format to save the captions in |
| """ |
| |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| console.print("[bold blue]Saving captions...[/]") |
|
|
| match format_type: |
| case OutputFormat.TXT: |
| |
| captions_file = output_path.with_stem(f"{output_path.stem}_captions") |
| paths_file = output_path.with_stem(f"{output_path.stem}_paths") |
|
|
| with captions_file.open("w", encoding="utf-8") as f: |
| for caption in captions.values(): |
| f.write(f"{caption}\n") |
|
|
| with paths_file.open("w", encoding="utf-8") as f: |
| for media_path in captions: |
| f.write(f"{media_path}\n") |
|
|
| console.print(f"[bold green]✓[/] Captions saved to [cyan]{captions_file}[/]") |
| console.print(f"[bold green]✓[/] Media paths saved to [cyan]{paths_file}[/]") |
|
|
| case OutputFormat.CSV: |
| with output_path.open("w", encoding="utf-8", newline="") as f: |
| writer = csv.writer(f) |
| writer.writerow(["caption", "media_path"]) |
| for media_path, caption in captions.items(): |
| writer.writerow([caption, media_path]) |
|
|
| console.print(f"[bold green]✓[/] Captions saved to [cyan]{output_path}[/]") |
|
|
| case OutputFormat.JSON: |
| |
| json_data = [{"caption": caption, "media_path": media_path} for media_path, caption in captions.items()] |
|
|
| with output_path.open("w", encoding="utf-8") as f: |
| json.dump(json_data, f, indent=2, ensure_ascii=False) |
|
|
| console.print(f"[bold green]✓[/] Captions saved to [cyan]{output_path}[/]") |
|
|
| case OutputFormat.JSONL: |
| with output_path.open("w", encoding="utf-8") as f: |
| for media_path, caption in captions.items(): |
| f.write(json.dumps({"caption": caption, "media_path": media_path}, ensure_ascii=False) + "\n") |
|
|
| console.print(f"[bold green]✓[/] Captions saved to [cyan]{output_path}[/]") |
|
|
| case _: |
| raise ValueError(f"Unsupported output format: {format_type}") |
|
|
|
|
| def _load_existing_captions( |
| output_path: Path, |
| format_type: OutputFormat, |
| ) -> dict[str, str]: |
| """Load existing captions from a file. |
| Args: |
| output_path: Path to the captions file |
| format_type: Format of the captions file |
| Returns: |
| Dictionary mapping media paths to captions, or empty dict if file doesn't exist |
| """ |
| if not output_path.exists(): |
| return {} |
|
|
| console.print(f"[bold blue]Loading existing captions from [cyan]{output_path}[/]...[/]") |
|
|
| existing_captions = {} |
|
|
| try: |
| match format_type: |
| case OutputFormat.TXT: |
| |
| captions_file = output_path.with_stem(f"{output_path.stem}_captions") |
| paths_file = output_path.with_stem(f"{output_path.stem}_paths") |
|
|
| if captions_file.exists() and paths_file.exists(): |
| captions = captions_file.read_text(encoding="utf-8").splitlines() |
| paths = paths_file.read_text(encoding="utf-8").splitlines() |
|
|
| if len(captions) == len(paths): |
| existing_captions = dict(zip(paths, captions, strict=False)) |
|
|
| case OutputFormat.CSV: |
| with output_path.open("r", encoding="utf-8", newline="") as f: |
| reader = csv.reader(f) |
| |
| next(reader, None) |
| for row in reader: |
| if len(row) >= 2: |
| caption, media_path = row[0], row[1] |
| existing_captions[media_path] = caption |
|
|
| case OutputFormat.JSON: |
| with output_path.open("r", encoding="utf-8") as f: |
| json_data = json.load(f) |
| for item in json_data: |
| if "caption" in item and "media_path" in item: |
| existing_captions[item["media_path"]] = item["caption"] |
|
|
| case OutputFormat.JSONL: |
| with output_path.open("r", encoding="utf-8") as f: |
| for line in f: |
| item = json.loads(line) |
| if "caption" in item and "media_path" in item: |
| existing_captions[item["media_path"]] = item["caption"] |
|
|
| case _: |
| raise ValueError(f"Unsupported output format: {format_type}") |
|
|
| console.print(f"[bold green]✓[/] Loaded [bold]{len(existing_captions)}[/] existing captions") |
| return existing_captions |
|
|
| except Exception as e: |
| console.print(f"[bold yellow]Warning: Could not load existing captions: {e}[/]") |
| return {} |
|
|
|
|
| @app.command() |
| def main( |
| input_path: Path = typer.Argument( |
| ..., |
| help="Path to input video/image file or directory containing media files", |
| exists=True, |
| ), |
| output: Path | None = typer.Option( |
| None, |
| "--output", |
| "-o", |
| help="Path to output file for captions. Format determined by file extension.", |
| ), |
| captioner_type: CaptionerType = typer.Option( |
| CaptionerType.QWEN_OMNI, |
| "--captioner-type", |
| "-c", |
| help="Type of captioner to use. Valid values: 'qwen_omni' (local), 'gemini_flash' (API)", |
| case_sensitive=False, |
| ), |
| device: str | None = typer.Option( |
| None, |
| "--device", |
| "-d", |
| help="Device to use for inference (e.g., 'cuda', 'cuda:0', 'cpu'). Only for local models.", |
| ), |
| use_8bit: bool = typer.Option( |
| False, |
| "--use-8bit", |
| help="Whether to use 8-bit precision for the captioning model (reduces memory usage)", |
| ), |
| instruction: str | None = typer.Option( |
| None, |
| "--instruction", |
| "-i", |
| help="Custom instruction for the captioning model. If not provided, uses an appropriate default.", |
| ), |
| extensions: str = typer.Option( |
| ",".join(MEDIA_EXTENSIONS), |
| "--extensions", |
| "-e", |
| help="Comma-separated list of media file extensions to process", |
| ), |
| recursive: bool = typer.Option( |
| False, |
| "--recursive", |
| "-r", |
| help="Search for media files in subdirectories recursively", |
| ), |
| fps: int = typer.Option( |
| 3, |
| "--fps", |
| "-f", |
| help="Frames per second to sample from videos (ignored for images)", |
| ), |
| include_audio: bool = typer.Option( |
| True, |
| "--audio/--no-audio", |
| help="Whether to include audio in captioning (for videos with audio tracks)", |
| ), |
| clean_caption: bool = typer.Option( |
| True, |
| "--clean-caption/--raw-caption", |
| help="Whether to clean up captions by removing common VLM patterns", |
| ), |
| override: bool = typer.Option( |
| False, |
| "--override", |
| help="Whether to override existing captions for media", |
| ), |
| api_key: str | None = typer.Option( |
| None, |
| "--api-key", |
| envvar=["GOOGLE_API_KEY", "GEMINI_API_KEY"], |
| help="API key for Gemini Flash (can also use GOOGLE_API_KEY or GEMINI_API_KEY env var)", |
| ), |
| ) -> None: |
| """Auto-caption videos with audio using multimodal models. |
| This script supports audio-visual captioning using: |
| - Qwen2.5-Omni: Local model (default) - processes both video and audio |
| - Gemini Flash: Cloud API - requires GOOGLE_API_KEY environment variable |
| The paths in the output file will be relative to the output file's directory. |
| Examples: |
| # Caption videos with audio using Qwen2.5-Omni (default) |
| caption_videos.py videos_dir/ -o captions.json |
| # Caption using Gemini Flash API |
| caption_videos.py videos_dir/ -o captions.json -c gemini_flash |
| # Caption without audio (video-only) |
| caption_videos.py videos_dir/ -o captions.json --no-audio |
| # Caption with custom instruction |
| caption_videos.py video.mp4 -o captions.json -i "Describe this video in detail" |
| """ |
|
|
| |
| device_str = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| ext_list = [ext.strip() for ext in extensions.split(",")] |
|
|
| |
| if output is None: |
| output_format = OutputFormat.JSON |
| if input_path.is_file(): |
| |
| output = input_path.with_suffix(".dataset.json") |
| else: |
| |
| output = input_path / "dataset.json" |
| else: |
| |
| output_format = OutputFormat(Path(output).suffix.lstrip(".").lower()) |
|
|
| |
| output = Path(output).resolve() |
| console.print(f"Output will be saved to [bold blue]{output}[/]") |
|
|
| |
| with console.status("Loading captioning model...", spinner="dots"): |
| if captioner_type == CaptionerType.QWEN_OMNI: |
| captioner = create_captioner( |
| captioner_type=captioner_type, |
| device=device_str, |
| use_8bit=use_8bit, |
| instruction=instruction, |
| ) |
| elif captioner_type == CaptionerType.GEMINI_FLASH: |
| captioner = create_captioner( |
| captioner_type=captioner_type, |
| api_key=api_key, |
| instruction=instruction, |
| ) |
| else: |
| raise ValueError(f"Unsupported captioner type: {captioner_type}") |
|
|
| console.print(f"[bold green]✓[/] {captioner_type.value} captioning model loaded successfully") |
|
|
| |
| caption_media( |
| input_path=input_path, |
| output_path=output, |
| captioner=captioner, |
| extensions=ext_list, |
| recursive=recursive, |
| fps=fps, |
| include_audio=include_audio, |
| clean_caption=clean_caption, |
| output_format=output_format, |
| override=override, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| app() |
|
|