caption_generator / caption_blip.py
3v324v23's picture
Refactor caption composition: implement fallback summary logic for empty outputs
3a03c0c
import math
import os
import logging
import time
import warnings
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, Iterable, Mapping, Optional
import requests
from PIL import Image
import torch
# Prevent repeated warning spam:
# FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated ... Use `HF_HOME` instead.
if "TRANSFORMERS_CACHE" in os.environ and "HF_HOME" not in os.environ:
os.environ["HF_HOME"] = os.environ["TRANSFORMERS_CACHE"]
os.environ.pop("TRANSFORMERS_CACHE", None)
warnings.filterwarnings(
"ignore",
category=FutureWarning,
message=r"Using `TRANSFORMERS_CACHE` is deprecated.*",
)
# --- BLIP offline mode ---
# If True, BLIP model/processor will ONLY load from local cache and will never
# attempt to contact huggingface.co (no timeouts/retries). If the model isn't
# cached yet, you'll get a fast error telling you to run once online.
#
# For deployment, keep this OFF unless you have a pre-populated model cache.
BLIP_OFFLINE = os.environ.get("BLIP_OFFLINE", "0").strip() in ("1", "true", "True")
if BLIP_OFFLINE:
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
# If your network is slow, Hugging Face Hub's default 10s read timeout can cause
# repeated retries when resolving files (HEAD/ETag). Bump timeouts to be more tolerant.
def _bump_env_timeout(name: str, minimum_seconds: int) -> None:
raw = os.environ.get(name)
try:
current = int(raw) if raw is not None else None
except ValueError:
current = None
if current is None or current < minimum_seconds:
os.environ[name] = str(minimum_seconds)
_bump_env_timeout("HF_HUB_ETAG_TIMEOUT", 60)
_bump_env_timeout("HF_HUB_DOWNLOAD_TIMEOUT", 300)
# Reduce noisy hub retry logs (optional). Comment these out if you want detailed logs.
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
from transformers import BlipForConditionalGeneration, BlipProcessor
MODEL_ID = "Salesforce/blip-image-captioning-base"
IMAGE_PATH = "test3.jpg"
# text_style examples: casual, formal, genz, funny, dry, educational, gen alpha, inspirational, mysterious, direct
# caption_length examples: short, medium, long, or a number of words like "20"
USER_OPTIONS = {
"text_style": "funny",
"platform": "LinkedIn",
"keywords": "",
"hashtags": False,
"language": "English",
"caption_length": "medium",
}
SHOW_PROGRESS = True
def _progress_printer(enabled: bool = True):
last_bucket = {"v": -1}
def log(percent: int, message: str = "") -> None:
if not enabled:
return
percent = max(0, min(int(percent), 100))
bucket = (percent // 10) * 10
if bucket != last_bucket["v"]:
last_bucket["v"] = bucket
if message:
print(f"[{bucket:>3}%] {message}")
else:
print(f"[{bucket:>3}%]")
return log
@dataclass(frozen=True)
class UserOptions:
text_style: str = "casual"
platform: str = "Instagram"
keywords: Optional[str] = None
description: Optional[str] = None
hashtags: bool = True
emojis: bool = False
language: str = "English"
caption_length: str = "short"
def _normalize_user_options(user_options: Any) -> UserOptions:
if user_options is None:
return UserOptions()
if isinstance(user_options, UserOptions):
return user_options
if isinstance(user_options, Mapping):
style = user_options.get("text_style", user_options.get("tone_style", "casual"))
keywords = user_options.get("keywords", user_options.get("description"))
description = user_options.get("description")
hashtags = user_options.get("hashtags", user_options.get("add_hashtags", True))
emojis = user_options.get("emojis", user_options.get("use_emojis", False))
caption_length = user_options.get("caption_length", user_options.get("length", "short"))
return UserOptions(
text_style=str(style),
platform=str(user_options.get("platform", "Instagram")),
keywords=(
None
if keywords in (None, "")
else str(keywords)
),
description=(
None
if description in (None, "")
else str(description)
),
hashtags=bool(hashtags),
emojis=bool(emojis),
language=str(user_options.get("language", "English")),
caption_length=str(caption_length),
)
raise TypeError("user_options must be a UserOptions, dict-like, or None")
def _caption_length_instruction(caption_length: str) -> str:
value = (caption_length or "").strip().lower()
if not value:
return ""
if value.isdigit():
return f"Target length: about {int(value)} words."
if value in {"small", "short", "brief"}:
return "Target length: short (1–2 sentences)."
if value in {"medium", "normal"}:
return "Target length: medium (2–4 sentences)."
if value in {"large", "long", "detailed"}:
return "Target length: long (a short paragraph)."
return f"Target length: {caption_length}."
def _additional_context_instruction(user_options: UserOptions) -> str:
"""Format user-provided context so the LLM actually uses it."""
parts: list[str] = []
if user_options.keywords:
parts.append(f"Keywords to incorporate: {user_options.keywords}")
if user_options.description:
parts.append(f"Additional description/context: {user_options.description}")
if not parts:
return ""
return "\n".join(parts)
def _get_device() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"
@lru_cache(maxsize=1)
def _load_blip(model_id: str = MODEL_ID):
device = _get_device()
local_only = bool(BLIP_OFFLINE)
try:
processor = BlipProcessor.from_pretrained(
model_id,
# Avoid torchvision dependency in minimal environments (e.g. HF Spaces)
# by forcing the "slow" image processor.
use_fast=False,
local_files_only=local_only,
)
model = BlipForConditionalGeneration.from_pretrained(
model_id,
local_files_only=local_only,
).to(device)
except Exception as e:
if local_only:
raise RuntimeError(
"BLIP_OFFLINE is enabled but the BLIP model isn't available in the local cache yet. "
"Temporarily set BLIP_OFFLINE = False and run once online to download the model, "
"then set BLIP_OFFLINE back to True."
) from e
raise
model.eval()
return processor, model, device
def caption_image(image: Image.Image) -> str:
"""Generate a short caption for a PIL image using BLIP.
Keeps the existing generation settings intact.
"""
processor, model, device = _load_blip(MODEL_ID)
image = image.convert("RGB")
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=20, # lower = faster
num_beams=3, # lower = faster (5 -> 3 is a good tradeoff)
)
return processor.decode(out[0], skip_special_tokens=True)
def caption_image_path(image_path: str) -> str:
image = Image.open(image_path).convert("RGB")
return caption_image(image)
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp", ".tif", ".tiff"}
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm", ".wmv", ".m4v"}
def _looks_like_video(path: str) -> bool:
ext = os.path.splitext(path)[1].lower()
if ext in _VIDEO_EXTS:
return True
if ext in _IMAGE_EXTS:
return False
# Unknown extension: try opening as image; if that fails, treat as video.
try:
Image.open(path)
return False
except Exception:
return True
def extract_frames(video_path: str, frames_per_minute: int = 8, min_frames: int = 8) -> list[Image.Image]:
"""Extract frames from a video as PIL images.
Sampling rules:
- Target 8 frames per minute.
- Ensure a minimum of 8 frames total.
- If video < 1 minute: pick 8 evenly spaced frames across whole duration.
- If video >= 1 minute: pick 8 evenly spaced frames within each minute.
"""
if not os.path.exists(video_path):
raise FileNotFoundError(f"Video not found: {video_path}")
try:
import cv2 # type: ignore
except Exception as e: # pragma: no cover
raise RuntimeError(
"opencv-python is required for video support. Install with: pip install opencv-python"
) from e
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise RuntimeError(f"Failed to open video: {video_path}")
fps = float(cap.get(cv2.CAP_PROP_FPS) or 0.0)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
if fps <= 0.0:
fps = 30.0
duration_sec = frame_count / fps if frame_count > 0 else float(cap.get(cv2.CAP_PROP_POS_MSEC) or 0.0) / 1000.0
if duration_sec <= 0.0 and frame_count > 0:
duration_sec = frame_count / fps
if duration_sec <= 0.0:
cap.release()
raise RuntimeError("Could not determine video duration.")
minutes = int(math.ceil(duration_sec / 60.0))
sample_times: list[float] = []
if duration_sec < 60.0:
total = max(min_frames, frames_per_minute)
# Evenly spaced across full duration.
for i in range(total):
t = (duration_sec * i) / total
sample_times.append(t)
else:
# 8 frames per minute, evenly spaced within each minute.
for m in range(minutes):
start = 60.0 * m
end = min(60.0 * (m + 1), duration_sec)
if end <= start:
continue
for i in range(frames_per_minute):
t = start + (end - start) * (i / frames_per_minute)
sample_times.append(t)
# Convert to frame indices and dedupe while preserving order.
seen: set[int] = set()
frame_indices: list[int] = []
for t in sample_times:
idx = int(t * fps)
if frame_count > 0:
idx = max(0, min(idx, frame_count - 1))
if idx not in seen:
seen.add(idx)
frame_indices.append(idx)
images: list[Image.Image] = []
for idx in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ok, frame = cap.read()
if not ok or frame is None:
continue
# OpenCV gives BGR; convert to RGB.
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
images.append(Image.fromarray(frame_rgb))
cap.release()
if not images:
raise RuntimeError("No frames extracted from video.")
# Enforce minimum frames if possible by re-sampling across entire duration.
if len(images) < min_frames and frame_count > 0:
cap = cv2.VideoCapture(video_path)
extra_indices: list[int] = []
for i in range(min_frames):
idx = int((frame_count * i) / min_frames)
idx = max(0, min(idx, frame_count - 1))
if idx not in seen:
extra_indices.append(idx)
seen.add(idx)
for idx in extra_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ok, frame = cap.read()
if not ok or frame is None:
continue
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
images.append(Image.fromarray(frame_rgb))
cap.release()
if len(images) < min_frames:
raise RuntimeError(f"Extracted {len(images)} frames; expected at least {min_frames}.")
return images
def generate_frame_captions(images: Iterable[Image.Image]) -> list[str]:
captions: list[str] = []
imgs = list(images)
total = len(imgs)
progress = _progress_printer(SHOW_PROGRESS)
if total:
progress(0, f"Captioning {total} frames...")
for i, img in enumerate(imgs, start=1):
try:
captions.append(caption_image(img))
except Exception as e:
captions.append(f"[caption_failed: {e}]")
if total:
pct = int((i / total) * 100)
progress(pct, f"Captioned {i}/{total} frames")
return captions
def _openrouter_chat(messages: list[dict], model: str, timeout_s: int = 60) -> str:
api_key = os.environ.get("OPENROUTER_API_KEY")
if not api_key:
raise RuntimeError("Missing OPENROUTER_API_KEY environment variable.")
# HF Spaces (and some networks) can be bursty/slow. Allow tuning via env.
try:
timeout_s = int(os.environ.get("OPENROUTER_TIMEOUT_S", str(timeout_s)))
except ValueError:
pass
try:
max_retries = int(os.environ.get("OPENROUTER_MAX_RETRIES", "2"))
except ValueError:
max_retries = 2
url = "https://openrouter.ai/api/v1/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
payload = {
"model": model,
"messages": messages,
# Keep responses snappy to reduce timeouts.
"max_tokens": 350,
}
last_err: Exception | None = None
for attempt in range(max_retries + 1):
# Progressive backoff + longer timeout per attempt.
attempt_timeout = timeout_s + (attempt * 45)
try:
resp = requests.post(
url,
headers=headers,
json=payload,
# (connect timeout, read timeout)
timeout=(10, attempt_timeout),
)
break
except (requests.Timeout, requests.RequestException) as e:
last_err = e
if attempt >= max_retries:
raise RuntimeError(f"OpenRouter request failed: {e}") from e
time.sleep(0.8 * (attempt + 1))
if resp.status_code >= 400:
raise RuntimeError(f"OpenRouter API error {resp.status_code}: {resp.text}")
try:
data = resp.json()
except ValueError as e:
raise RuntimeError(f"OpenRouter returned non-JSON response: {resp.text}") from e
try:
return data["choices"][0]["message"]["content"].strip()
except Exception as e:
raise RuntimeError(f"Unexpected OpenRouter response format: {data}") from e
def _clean_caption_output(text: str) -> str:
"""Extract only the final caption text from an LLM response.
Some reasoning-style models occasionally append analysis/metadata (e.g. bullet lists like
"- Maintains casual French tone..."). We defensively strip those so the UI only shows the caption.
"""
t = (text or "").strip()
if not t:
return ""
# Remove common reasoning blocks (DeepSeek R1 style).
if "<think>" in t and "</think>" in t:
t = t.split("</think>")[-1].strip()
# Unwrap fenced blocks if the model returns ```text ...```.
if t.startswith("```"):
parts = t.split("```")
# parts[0] is empty; parts[1] contains optional language + content.
if len(parts) >= 3:
inner = parts[1]
inner_lines = inner.splitlines()
if inner_lines and inner_lines[0].strip().isalpha() and len(inner_lines[0].strip()) <= 12:
inner = "\n".join(inner_lines[1:])
t = inner.strip() or t
# Drop leading labels.
for prefix in (
"final caption:",
"final:",
"caption:",
"output:",
"réponse finale:",
"réponse:",
"résultat:",
):
if t.lower().startswith(prefix):
t = t[len(prefix) :].strip()
break
# Hard cut at explicit meta sections.
lower = t.lower()
for marker in (
"with this caption:",
"with this caption",
"explanation:",
"analysis:",
"notes:",
"justification:",
"raison:",
"pourquoi:",
):
idx = lower.find(marker)
if idx != -1:
t = t[:idx].strip()
lower = t.lower()
break
# If there's a meta bullet section appended, cut it off.
meta_keywords = (
"maintain",
"maintains",
"tone",
"style",
"language",
"hashtags",
"emoji",
"length",
"platform",
"casual",
"formal",
"explication",
"analyse",
)
lines = t.splitlines()
for i, line in enumerate(lines):
s = line.strip()
if not s.startswith("-"):
continue
s_lower = s.lower()
if any(k in s_lower for k in meta_keywords):
t = "\n".join(lines[:i]).strip()
break
# Remove surrounding quotes (common in model outputs).
if len(t) >= 2 and ((t[0] == t[-1] == '"') or (t[0] == t[-1] == "'")):
t = t[1:-1].strip()
return t
def compose_video_caption(frame_captions: list[str], user_options: Any, use_openrouter: bool = True) -> str:
if not frame_captions:
raise ValueError("frame_captions is empty")
user_options = _normalize_user_options(user_options)
def _fallback_summary() -> str:
# Cheap deterministic fallback: return a compact summary based on a few unique frame captions.
uniq: list[str] = []
seen: set[str] = set()
for c in frame_captions:
c2 = (c or "").strip()
if not c2 or c2.startswith("[caption_failed"):
continue
k = c2.lower()
if k in seen:
continue
seen.add(k)
uniq.append(c2)
if len(uniq) >= 6:
break
if not uniq:
return "A short video clip."
return "Video shows: " + "; ".join(uniq)
if not use_openrouter:
return _fallback_summary()
model = "deepseek/deepseek-r1-0528:free"
captions_text = "\n".join(f"- {c}" for c in frame_captions)
system = (
"You are a video caption summarizer. "
"Given short captions of sampled frames from a video, infer a single coherent caption "
"that describes the overall video. "
"Do NOT list frames one-by-one. Output ONLY the final caption text, no quotes."
)
extra_ctx = _additional_context_instruction(user_options)
user = (
"Requirements:\n"
f"- Output language: {user_options.language}\n"
f"- Target platform: {user_options.platform}\n"
f"- Text style: {user_options.text_style}\n"
f"- Emojis: {'allowed' if user_options.emojis else 'do not use'}\n"
f"- Hashtags: {'include' if user_options.hashtags else 'do not include'}\n"
f"- {_caption_length_instruction(user_options.caption_length)}\n"
+ (f"\nAdditional context (use this to steer the caption; incorporate keywords if provided):\n{extra_ctx}\n" if extra_ctx else "\n")
+ f"\nFrame captions:\n{captions_text}"
)
out = _openrouter_chat(
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user},
],
model=model,
)
cleaned = _clean_caption_output(out)
return cleaned if cleaned else _fallback_summary()
def compose_image_caption(base_caption: str, user_options: Any, use_openrouter: bool = True) -> str:
if not base_caption.strip():
raise ValueError("base_caption is empty")
user_options = _normalize_user_options(user_options)
if not use_openrouter:
return base_caption.strip()
model = "deepseek/deepseek-r1-0528:free"
system = (
"You are a caption fuser/editor for social platforms. "
"Given a base caption and preferences, output a polished final caption. "
"If additional context/keywords are provided, incorporate them naturally and consistently. "
"Output ONLY the final caption text, no quotes, no explanations."
)
extra_ctx = _additional_context_instruction(user_options)
user = (
f"Base caption: {base_caption}\n"
"Requirements:\n"
f"- Text style: {user_options.text_style}\n"
f"- Target platform: {user_options.platform}\n"
f"- Output language: {user_options.language}\n"
f"- Emojis: {'allowed' if user_options.emojis else 'do not use'}\n"
f"- Hashtags: {'include' if user_options.hashtags else 'do not include'}\n"
f"- {_caption_length_instruction(user_options.caption_length)}\n"
+ (f"\nAdditional context (use this to steer the caption; incorporate keywords if provided):\n{extra_ctx}\n" if extra_ctx else "")
)
out = _openrouter_chat(
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user},
],
model=model,
)
cleaned = _clean_caption_output(out)
# If the model returns an empty/fully-stripped response, fall back to the base caption.
return cleaned if cleaned else base_caption.strip()
def caption_video(video_path: str, user_options: Any = None, use_openrouter: bool = True) -> str:
user_options = _normalize_user_options(user_options)
progress = _progress_printer(SHOW_PROGRESS)
progress(0, "Starting video captioning")
frames = extract_frames(video_path)
progress(10, f"Extracted {len(frames)} frames")
# Map frame caption progress into 10%..80% so we still have room for compose steps.
frame_captions: list[str] = []
total = len(frames)
for i, img in enumerate(frames, start=1):
try:
frame_captions.append(caption_image(img))
except Exception as e:
frame_captions.append(f"[caption_failed: {e}]")
if total:
pct = 10 + int((i / total) * 70)
progress(pct, f"Captioned {i}/{total} frames")
progress(80, "Summarizing video")
base_caption = compose_video_caption(frame_captions, user_options, use_openrouter=use_openrouter)
progress(90, "Polishing final caption")
final_caption = compose_image_caption(base_caption, user_options, use_openrouter=use_openrouter)
progress(100, "Done")
return final_caption
def main():
"""Simple manual entrypoint.
- No args: captions IMAGE_PATH (image or video auto-detected).
- One arg: captions that path (image or video auto-detected).
"""
import sys
media_path = IMAGE_PATH if len(sys.argv) == 1 else sys.argv[1]
opts = _normalize_user_options(USER_OPTIONS)
if _looks_like_video(media_path):
try:
final = caption_video(media_path, opts, use_openrouter=True)
except RuntimeError as e:
# If OpenRouter isn't configured, still return a deterministic fallback.
if "OPENROUTER_API_KEY" in str(e):
final = caption_video(media_path, opts, use_openrouter=False)
else:
raise
print(final)
else:
progress = _progress_printer(SHOW_PROGRESS)
progress(0, "Starting image captioning")
base = caption_image_path(media_path)
progress(50, "Polishing caption with AI")
try:
final = compose_image_caption(base, opts, use_openrouter=True)
except RuntimeError as e:
# If OpenRouter isn't configured, still return the raw BLIP caption.
if "OPENROUTER_API_KEY" in str(e):
final = base
else:
raise
progress(100, "Done")
print(final)
if __name__ == "__main__":
main()