linoy
inital commit
ebfc6b3
"""
Audio-visual media captioning using multimodal models.
This module provides captioning capabilities for videos with audio using:
- Qwen2.5-Omni: Local model supporting text, audio, image, and video inputs (default)
- Gemini Flash: Cloud-based API for audio-visual captioning
Requirements:
- Qwen2.5-Omni: transformers>=4.50, torch
- Gemini Flash: google-generativeai (pip install google-generativeai)
Set GEMINI_API_KEY or GOOGLE_API_KEY environment variable
"""
import itertools
import re
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
import torch
# Instruction for audio-visual captioning (default) - includes speech transcription and sounds
DEFAULT_CAPTION_INSTRUCTION = """\
Analyze this media and provide a detailed caption in the following EXACT format. Fill in ALL sections:
[VISUAL]: <Detailed description of people, objects, actions, settings, colors, and movements>
[SPEECH]: <Word-for-word transcription of everything spoken.
Listen carefully and transcribe the exact words. If no speech, write "None">
[SOUNDS]: <Description of music, ambient sounds, sound effects. If none, write "None">
[TEXT]: <Any on-screen text visible. If none, write "None">
You MUST fill in all four sections. For [SPEECH], transcribe the actual words spoken, not a summary."""
# Instruction for video-only captioning (no audio processing)
VIDEO_ONLY_CAPTION_INSTRUCTION = """\
Analyze this media and provide a detailed caption in the following EXACT format. Fill in ALL sections:
[VISUAL]: <Detailed description of people, objects, actions, settings, colors, and movements>
[TEXT]: <Any on-screen text visible. If none, write "None">
You MUST fill in both sections."""
class CaptionerType(str, Enum):
"""Enum for different types of media captioners."""
QWEN_OMNI = "qwen_omni" # Local Qwen2.5-Omni model (audio + video)
GEMINI_FLASH = "gemini_flash" # Gemini Flash API (audio + video)
def create_captioner(captioner_type: CaptionerType, **kwargs) -> "MediaCaptioningModel":
"""Factory function to create a media captioner.
Args:
captioner_type: The type of captioner to create
**kwargs: Additional arguments to pass to the captioner constructor
Returns:
An instance of a MediaCaptioningModel
"""
match captioner_type:
case CaptionerType.QWEN_OMNI:
return QwenOmniCaptioner(**kwargs)
case CaptionerType.GEMINI_FLASH:
return GeminiFlashCaptioner(**kwargs)
case _:
raise ValueError(f"Unsupported captioner type: {captioner_type}")
class MediaCaptioningModel(ABC):
"""Abstract base class for audio-visual media captioning models."""
@abstractmethod
def caption(self, path: str | Path, **kwargs) -> str:
"""Generate a caption for the given video or image.
Args:
path: Path to the video/image file to caption
Returns:
A string containing the generated caption
"""
@property
@abstractmethod
def supports_audio(self) -> bool:
"""Whether this captioner supports audio input."""
@staticmethod
def _is_image_file(path: str | Path) -> bool:
"""Check if the file is an image based on extension."""
return str(path).lower().endswith((".png", ".jpg", ".jpeg", ".heic", ".heif", ".webp"))
@staticmethod
def _is_video_file(path: str | Path) -> bool:
"""Check if the file is a video based on extension."""
return str(path).lower().endswith((".mp4", ".avi", ".mov", ".mkv", ".webm"))
@staticmethod
def _clean_raw_caption(caption: str) -> str:
"""Clean up the raw caption by removing common VLM patterns."""
start = ["The", "This"]
kind = ["video", "image", "scene", "animated sequence", "clip", "footage"]
act = ["displays", "shows", "features", "depicts", "presents", "showcases", "captures", "contains"]
for x, y, z in itertools.product(start, kind, act):
caption = caption.replace(f"{x} {y} {z} ", "", 1)
return caption
class QwenOmniCaptioner(MediaCaptioningModel):
"""Audio-visual captioning using Alibaba's Qwen2.5-Omni model.
Qwen2.5-Omni is an end-to-end multimodal model that can perceive text, images, audio, and video.
It uses a Thinker-Talker architecture where the Thinker generates text and the Talker can
generate speech. For captioning, we use only the Thinker component for text generation.
Key features:
- Block-wise processing for streaming multimodal inputs
- TMRoPE (Time-aligned Multimodal RoPE) for synchronizing video and audio timestamps
- Can extract and process audio directly from video files
See: https://huggingface.co/docs/transformers/en/model_doc/qwen2_5_omni
Model: Qwen/Qwen2.5-Omni-7B (7B parameters)
"""
MODEL_ID = "Qwen/Qwen2.5-Omni-7B"
# Default system prompt required by Qwen2.5-Omni for proper audio processing
DEFAULT_SYSTEM_PROMPT = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, "
"capable of perceiving auditory and visual inputs, as well as generating text and speech."
)
def __init__(
self,
device: str | torch.device | None = None,
use_8bit: bool = False,
instruction: str | None = None,
):
"""
Initialize the Qwen2.5-Omni captioner.
Args:
device: Device to use for inference (e.g., 'cuda', 'cuda:0', 'cpu')
use_8bit: Whether to use 8-bit quantization for reduced memory usage
instruction: Custom instruction prompt. If None, uses the default instruction
"""
self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
self.instruction = instruction
self._load_model(use_8bit=use_8bit)
@property
def supports_audio(self) -> bool:
return True
def caption(
self,
path: str | Path,
fps: int = 1,
include_audio: bool = True,
clean_caption: bool = True,
) -> str:
"""Generate a caption for the given video or image.
Args:
path: Path to the video/image file to caption
fps: Frames per second to sample from videos
include_audio: Whether to include audio in the captioning (for videos)
clean_caption: Whether to clean up the raw caption by removing common VLM patterns
Returns:
A string containing the generated caption
"""
path = Path(path)
is_image = self._is_image_file(path)
is_video = self._is_video_file(path)
# Determine if we should process audio
use_audio = include_audio and is_video
# Use custom instruction if provided, otherwise pick appropriate default
if self.instruction is not None:
instruction = self.instruction
else:
instruction = DEFAULT_CAPTION_INSTRUCTION if use_audio else VIDEO_ONLY_CAPTION_INSTRUCTION
# Build the user content based on media type
# Based on HuggingFace docs: https://huggingface.co/docs/transformers/en/model_doc/qwen2_5_omni
user_content = []
if is_image:
user_content.append({"type": "image", "image": str(path)})
elif is_video:
user_content.append({"type": "video", "video": str(path)})
# Add the instruction text
user_content.append({"type": "text", "text": instruction})
# Build conversation - use the default system prompt required by Qwen2.5-Omni
# Using a custom system prompt causes warnings and may affect audio processing
messages = [
{
"role": "system",
"content": [{"type": "text", "text": self.DEFAULT_SYSTEM_PROMPT}],
},
{"role": "user", "content": user_content},
]
# Process inputs using the processor's apply_chat_template
# For videos with audio, use load_audio_from_video=True and use_audio_in_video=True
inputs = self.processor.apply_chat_template(
messages,
load_audio_from_video=use_audio,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
fps=fps,
padding=True,
use_audio_in_video=use_audio,
).to(self.model.device)
# Generate caption (text only, using Thinker-only model)
# Note: For Qwen2_5OmniThinkerForConditionalGeneration, use standard generate params
# (not thinker_ prefixed ones, those are for the full Qwen2_5OmniForConditionalGeneration)
input_len = inputs["input_ids"].shape[1]
output_tokens = self.model.generate(
**inputs,
use_audio_in_video=use_audio,
do_sample=False,
max_new_tokens=1024,
)
# Extract only the generated tokens (exclude the input/prompt tokens)
generated_tokens = output_tokens[:, input_len:]
# Decode only the generated response
caption_raw = self.processor.batch_decode(
generated_tokens,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
# Remove hallucinated conversation turns (e.g., "Human\nHuman\n..." or "Human: ...")
# This is a known issue with chat models continuing to generate fake turns
# We look for patterns that are clearly hallucinated chat turns, not legitimate uses of "human"
# Match "\nHuman" followed by ":", "\n", or end of string (chat turn patterns)
# This won't match "A human walks..." or "...the human body..."
caption_raw = re.split(r"\nHuman(?::|(?:\s*\n)|$)", caption_raw, maxsplit=1)[0]
caption_raw = caption_raw.strip()
# Clean up caption if requested
return self._clean_raw_caption(caption_raw) if clean_caption else caption_raw
def _load_model(self, use_8bit: bool) -> None:
"""Load the Qwen2.5-Omni model and processor.
Uses the Thinker-only model (Qwen2_5OmniThinkerForConditionalGeneration) for text generation
to save compute by not loading the audio generation components.
"""
from transformers import ( # noqa: PLC0415
BitsAndBytesConfig,
Qwen2_5OmniProcessor,
Qwen2_5OmniThinkerForConditionalGeneration,
)
quantization_config = BitsAndBytesConfig(load_in_8bit=True) if use_8bit else None
# Use Thinker-only model for text generation (saves memory by not loading Talker)
self.model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
self.MODEL_ID,
dtype=torch.bfloat16,
low_cpu_mem_usage=True,
quantization_config=quantization_config,
device_map="auto",
)
self.processor = Qwen2_5OmniProcessor.from_pretrained(self.MODEL_ID)
class GeminiFlashCaptioner(MediaCaptioningModel):
"""Audio-visual captioning using Google's Gemini Flash API.
Gemini Flash is a cloud-based multimodal model that natively supports
audio and video understanding. Requires a Google API key.
Note: This captioner requires the `google-generativeai` package and a valid API key.
Set the GEMINI_API_KEY or GOOGLE_API_KEY environment variable, or pass the key directly.
"""
MODEL_ID = "gemini-flash-lite-latest"
def __init__(
self,
api_key: str | None = None,
instruction: str | None = None,
):
"""Initialize the Gemini Flash captioner.
Args:
api_key: Google API key. If not provided, will look for
GEMINI_API_KEY or GOOGLE_API_KEY environment variable.
instruction: Custom instruction prompt. If None, uses the default instruction
"""
self.instruction = instruction
self._init_client(api_key)
@property
def supports_audio(self) -> bool:
return True
def caption(
self,
path: str | Path,
fps: int = 3, # noqa: ARG002 - kept for API compatibility
include_audio: bool = True,
clean_caption: bool = True,
) -> str:
"""Generate a caption for the given video or image.
Args:
path: Path to the video/image file to caption
fps: Frames per second (not used for Gemini, kept for API compatibility)
include_audio: Whether to include audio content in the caption
clean_caption: Whether to clean up the raw caption
Returns:
A string containing the generated caption
"""
import time # noqa: PLC0415
path = Path(path)
is_video = self._is_video_file(path)
use_audio = include_audio and is_video
# Use custom instruction if provided, otherwise pick appropriate default
if self.instruction is not None:
instruction = self.instruction
else:
instruction = DEFAULT_CAPTION_INSTRUCTION if use_audio else VIDEO_ONLY_CAPTION_INSTRUCTION
# Upload the file to Gemini
uploaded_file = self._genai.upload_file(path)
# Wait for processing to complete (videos need time to process)
while uploaded_file.state.name == "PROCESSING":
time.sleep(1)
uploaded_file = self._genai.get_file(uploaded_file.name)
if uploaded_file.state.name == "FAILED":
raise RuntimeError(f"File processing failed: {uploaded_file.state.name}")
# Generate caption
response = self._model.generate_content([uploaded_file, instruction])
caption_raw = response.text
# Clean up the uploaded file
self._genai.delete_file(uploaded_file.name)
# Clean up caption if requested
return self._clean_raw_caption(caption_raw) if clean_caption else caption_raw
def _init_client(self, api_key: str | None) -> None:
"""Initialize the Gemini API client."""
import os # noqa: PLC0415
try:
import google.generativeai as genai # noqa: PLC0415
except ImportError as e:
raise ImportError(
"The `google-generativeai` package is required for Gemini Flash captioning. "
"Install it with: `uv pip install google-generativeai`"
) from e
# Get API key from argument or environment
# GEMINI_API_KEY is the recommended variable, GOOGLE_API_KEY also works
resolved_api_key = api_key or os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
if not resolved_api_key:
raise ValueError(
"Gemini API key is required. Provide it via the `api_key` argument "
"or set the GEMINI_API_KEY or GOOGLE_API_KEY environment variable."
)
# Configure the genai library with the API key
genai.configure(api_key=resolved_api_key)
# Store reference to genai module for file operations
self._genai = genai
# Initialize the model
self._model = genai.GenerativeModel(self.MODEL_ID)
def example() -> None:
"""Example usage of the captioning module."""
import sys # noqa: PLC0415
if len(sys.argv) < 2:
print(f"Usage: python {sys.argv[0]} <video_path> [captioner_type]") # noqa: T201
print(" captioner_type: qwen_omni (default) or gemini_flash") # noqa: T201
sys.exit(1)
video_path = sys.argv[1]
captioner_type = CaptionerType(sys.argv[2]) if len(sys.argv) > 2 else CaptionerType.QWEN_OMNI
print(f"Using {captioner_type.value} captioner:") # noqa: T201
captioner = create_captioner(captioner_type)
caption = captioner.caption(video_path)
print(f"CAPTION: {caption}") # noqa: T201
if __name__ == "__main__":
example()