import math import os import logging import time import warnings from dataclasses import dataclass from functools import lru_cache from typing import Any, Iterable, Mapping, Optional import requests from PIL import Image import torch # Prevent repeated warning spam: # FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated ... Use `HF_HOME` instead. if "TRANSFORMERS_CACHE" in os.environ and "HF_HOME" not in os.environ: os.environ["HF_HOME"] = os.environ["TRANSFORMERS_CACHE"] os.environ.pop("TRANSFORMERS_CACHE", None) warnings.filterwarnings( "ignore", category=FutureWarning, message=r"Using `TRANSFORMERS_CACHE` is deprecated.*", ) # --- BLIP offline mode --- # If True, BLIP model/processor will ONLY load from local cache and will never # attempt to contact huggingface.co (no timeouts/retries). If the model isn't # cached yet, you'll get a fast error telling you to run once online. # # For deployment, keep this OFF unless you have a pre-populated model cache. BLIP_OFFLINE = os.environ.get("BLIP_OFFLINE", "0").strip() in ("1", "true", "True") if BLIP_OFFLINE: os.environ["HF_HUB_OFFLINE"] = "1" os.environ["TRANSFORMERS_OFFLINE"] = "1" # If your network is slow, Hugging Face Hub's default 10s read timeout can cause # repeated retries when resolving files (HEAD/ETag). Bump timeouts to be more tolerant. def _bump_env_timeout(name: str, minimum_seconds: int) -> None: raw = os.environ.get(name) try: current = int(raw) if raw is not None else None except ValueError: current = None if current is None or current < minimum_seconds: os.environ[name] = str(minimum_seconds) _bump_env_timeout("HF_HUB_ETAG_TIMEOUT", 60) _bump_env_timeout("HF_HUB_DOWNLOAD_TIMEOUT", 300) # Reduce noisy hub retry logs (optional). Comment these out if you want detailed logs. logging.getLogger("huggingface_hub").setLevel(logging.ERROR) from transformers import BlipForConditionalGeneration, BlipProcessor MODEL_ID = "Salesforce/blip-image-captioning-base" IMAGE_PATH = "test3.jpg" # text_style examples: casual, formal, genz, funny, dry, educational, gen alpha, inspirational, mysterious, direct # caption_length examples: short, medium, long, or a number of words like "20" USER_OPTIONS = { "text_style": "funny", "platform": "LinkedIn", "keywords": "", "hashtags": False, "language": "English", "caption_length": "medium", } SHOW_PROGRESS = True def _progress_printer(enabled: bool = True): last_bucket = {"v": -1} def log(percent: int, message: str = "") -> None: if not enabled: return percent = max(0, min(int(percent), 100)) bucket = (percent // 10) * 10 if bucket != last_bucket["v"]: last_bucket["v"] = bucket if message: print(f"[{bucket:>3}%] {message}") else: print(f"[{bucket:>3}%]") return log @dataclass(frozen=True) class UserOptions: text_style: str = "casual" platform: str = "Instagram" keywords: Optional[str] = None description: Optional[str] = None hashtags: bool = True emojis: bool = False language: str = "English" caption_length: str = "short" def _normalize_user_options(user_options: Any) -> UserOptions: if user_options is None: return UserOptions() if isinstance(user_options, UserOptions): return user_options if isinstance(user_options, Mapping): style = user_options.get("text_style", user_options.get("tone_style", "casual")) keywords = user_options.get("keywords", user_options.get("description")) description = user_options.get("description") hashtags = user_options.get("hashtags", user_options.get("add_hashtags", True)) emojis = user_options.get("emojis", user_options.get("use_emojis", False)) caption_length = user_options.get("caption_length", user_options.get("length", "short")) return UserOptions( text_style=str(style), platform=str(user_options.get("platform", "Instagram")), keywords=( None if keywords in (None, "") else str(keywords) ), description=( None if description in (None, "") else str(description) ), hashtags=bool(hashtags), emojis=bool(emojis), language=str(user_options.get("language", "English")), caption_length=str(caption_length), ) raise TypeError("user_options must be a UserOptions, dict-like, or None") def _caption_length_instruction(caption_length: str) -> str: value = (caption_length or "").strip().lower() if not value: return "" if value.isdigit(): return f"Target length: about {int(value)} words." if value in {"small", "short", "brief"}: return "Target length: short (1–2 sentences)." if value in {"medium", "normal"}: return "Target length: medium (2–4 sentences)." if value in {"large", "long", "detailed"}: return "Target length: long (a short paragraph)." return f"Target length: {caption_length}." def _additional_context_instruction(user_options: UserOptions) -> str: """Format user-provided context so the LLM actually uses it.""" parts: list[str] = [] if user_options.keywords: parts.append(f"Keywords to incorporate: {user_options.keywords}") if user_options.description: parts.append(f"Additional description/context: {user_options.description}") if not parts: return "" return "\n".join(parts) def _get_device() -> str: return "cuda" if torch.cuda.is_available() else "cpu" @lru_cache(maxsize=1) def _load_blip(model_id: str = MODEL_ID): device = _get_device() local_only = bool(BLIP_OFFLINE) try: processor = BlipProcessor.from_pretrained( model_id, # Avoid torchvision dependency in minimal environments (e.g. HF Spaces) # by forcing the "slow" image processor. use_fast=False, local_files_only=local_only, ) model = BlipForConditionalGeneration.from_pretrained( model_id, local_files_only=local_only, ).to(device) except Exception as e: if local_only: raise RuntimeError( "BLIP_OFFLINE is enabled but the BLIP model isn't available in the local cache yet. " "Temporarily set BLIP_OFFLINE = False and run once online to download the model, " "then set BLIP_OFFLINE back to True." ) from e raise model.eval() return processor, model, device def caption_image(image: Image.Image) -> str: """Generate a short caption for a PIL image using BLIP. Keeps the existing generation settings intact. """ processor, model, device = _load_blip(MODEL_ID) image = image.convert("RGB") inputs = processor(images=image, return_tensors="pt").to(device) with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=20, # lower = faster num_beams=3, # lower = faster (5 -> 3 is a good tradeoff) ) return processor.decode(out[0], skip_special_tokens=True) def caption_image_path(image_path: str) -> str: image = Image.open(image_path).convert("RGB") return caption_image(image) _IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp", ".tif", ".tiff"} _VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm", ".wmv", ".m4v"} def _looks_like_video(path: str) -> bool: ext = os.path.splitext(path)[1].lower() if ext in _VIDEO_EXTS: return True if ext in _IMAGE_EXTS: return False # Unknown extension: try opening as image; if that fails, treat as video. try: Image.open(path) return False except Exception: return True def extract_frames(video_path: str, frames_per_minute: int = 8, min_frames: int = 8) -> list[Image.Image]: """Extract frames from a video as PIL images. Sampling rules: - Target 8 frames per minute. - Ensure a minimum of 8 frames total. - If video < 1 minute: pick 8 evenly spaced frames across whole duration. - If video >= 1 minute: pick 8 evenly spaced frames within each minute. """ if not os.path.exists(video_path): raise FileNotFoundError(f"Video not found: {video_path}") try: import cv2 # type: ignore except Exception as e: # pragma: no cover raise RuntimeError( "opencv-python is required for video support. Install with: pip install opencv-python" ) from e cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise RuntimeError(f"Failed to open video: {video_path}") fps = float(cap.get(cv2.CAP_PROP_FPS) or 0.0) frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0) if fps <= 0.0: fps = 30.0 duration_sec = frame_count / fps if frame_count > 0 else float(cap.get(cv2.CAP_PROP_POS_MSEC) or 0.0) / 1000.0 if duration_sec <= 0.0 and frame_count > 0: duration_sec = frame_count / fps if duration_sec <= 0.0: cap.release() raise RuntimeError("Could not determine video duration.") minutes = int(math.ceil(duration_sec / 60.0)) sample_times: list[float] = [] if duration_sec < 60.0: total = max(min_frames, frames_per_minute) # Evenly spaced across full duration. for i in range(total): t = (duration_sec * i) / total sample_times.append(t) else: # 8 frames per minute, evenly spaced within each minute. for m in range(minutes): start = 60.0 * m end = min(60.0 * (m + 1), duration_sec) if end <= start: continue for i in range(frames_per_minute): t = start + (end - start) * (i / frames_per_minute) sample_times.append(t) # Convert to frame indices and dedupe while preserving order. seen: set[int] = set() frame_indices: list[int] = [] for t in sample_times: idx = int(t * fps) if frame_count > 0: idx = max(0, min(idx, frame_count - 1)) if idx not in seen: seen.add(idx) frame_indices.append(idx) images: list[Image.Image] = [] for idx in frame_indices: cap.set(cv2.CAP_PROP_POS_FRAMES, idx) ok, frame = cap.read() if not ok or frame is None: continue # OpenCV gives BGR; convert to RGB. frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) images.append(Image.fromarray(frame_rgb)) cap.release() if not images: raise RuntimeError("No frames extracted from video.") # Enforce minimum frames if possible by re-sampling across entire duration. if len(images) < min_frames and frame_count > 0: cap = cv2.VideoCapture(video_path) extra_indices: list[int] = [] for i in range(min_frames): idx = int((frame_count * i) / min_frames) idx = max(0, min(idx, frame_count - 1)) if idx not in seen: extra_indices.append(idx) seen.add(idx) for idx in extra_indices: cap.set(cv2.CAP_PROP_POS_FRAMES, idx) ok, frame = cap.read() if not ok or frame is None: continue frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) images.append(Image.fromarray(frame_rgb)) cap.release() if len(images) < min_frames: raise RuntimeError(f"Extracted {len(images)} frames; expected at least {min_frames}.") return images def generate_frame_captions(images: Iterable[Image.Image]) -> list[str]: captions: list[str] = [] imgs = list(images) total = len(imgs) progress = _progress_printer(SHOW_PROGRESS) if total: progress(0, f"Captioning {total} frames...") for i, img in enumerate(imgs, start=1): try: captions.append(caption_image(img)) except Exception as e: captions.append(f"[caption_failed: {e}]") if total: pct = int((i / total) * 100) progress(pct, f"Captioned {i}/{total} frames") return captions def _openrouter_chat(messages: list[dict], model: str, timeout_s: int = 60) -> str: api_key = os.environ.get("OPENROUTER_API_KEY") if not api_key: raise RuntimeError("Missing OPENROUTER_API_KEY environment variable.") # HF Spaces (and some networks) can be bursty/slow. Allow tuning via env. try: timeout_s = int(os.environ.get("OPENROUTER_TIMEOUT_S", str(timeout_s))) except ValueError: pass try: max_retries = int(os.environ.get("OPENROUTER_MAX_RETRIES", "2")) except ValueError: max_retries = 2 url = "https://openrouter.ai/api/v1/chat/completions" headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", } payload = { "model": model, "messages": messages, # Keep responses snappy to reduce timeouts. "max_tokens": 350, } last_err: Exception | None = None for attempt in range(max_retries + 1): # Progressive backoff + longer timeout per attempt. attempt_timeout = timeout_s + (attempt * 45) try: resp = requests.post( url, headers=headers, json=payload, # (connect timeout, read timeout) timeout=(10, attempt_timeout), ) break except (requests.Timeout, requests.RequestException) as e: last_err = e if attempt >= max_retries: raise RuntimeError(f"OpenRouter request failed: {e}") from e time.sleep(0.8 * (attempt + 1)) if resp.status_code >= 400: raise RuntimeError(f"OpenRouter API error {resp.status_code}: {resp.text}") try: data = resp.json() except ValueError as e: raise RuntimeError(f"OpenRouter returned non-JSON response: {resp.text}") from e try: return data["choices"][0]["message"]["content"].strip() except Exception as e: raise RuntimeError(f"Unexpected OpenRouter response format: {data}") from e def _clean_caption_output(text: str) -> str: """Extract only the final caption text from an LLM response. Some reasoning-style models occasionally append analysis/metadata (e.g. bullet lists like "- Maintains casual French tone..."). We defensively strip those so the UI only shows the caption. """ t = (text or "").strip() if not t: return "" # Remove common reasoning blocks (DeepSeek R1 style). if "" in t and "" in t: t = t.split("")[-1].strip() # Unwrap fenced blocks if the model returns ```text ...```. if t.startswith("```"): parts = t.split("```") # parts[0] is empty; parts[1] contains optional language + content. if len(parts) >= 3: inner = parts[1] inner_lines = inner.splitlines() if inner_lines and inner_lines[0].strip().isalpha() and len(inner_lines[0].strip()) <= 12: inner = "\n".join(inner_lines[1:]) t = inner.strip() or t # Drop leading labels. for prefix in ( "final caption:", "final:", "caption:", "output:", "réponse finale:", "réponse:", "résultat:", ): if t.lower().startswith(prefix): t = t[len(prefix) :].strip() break # Hard cut at explicit meta sections. lower = t.lower() for marker in ( "with this caption:", "with this caption", "explanation:", "analysis:", "notes:", "justification:", "raison:", "pourquoi:", ): idx = lower.find(marker) if idx != -1: t = t[:idx].strip() lower = t.lower() break # If there's a meta bullet section appended, cut it off. meta_keywords = ( "maintain", "maintains", "tone", "style", "language", "hashtags", "emoji", "length", "platform", "casual", "formal", "explication", "analyse", ) lines = t.splitlines() for i, line in enumerate(lines): s = line.strip() if not s.startswith("-"): continue s_lower = s.lower() if any(k in s_lower for k in meta_keywords): t = "\n".join(lines[:i]).strip() break # Remove surrounding quotes (common in model outputs). if len(t) >= 2 and ((t[0] == t[-1] == '"') or (t[0] == t[-1] == "'")): t = t[1:-1].strip() return t def compose_video_caption(frame_captions: list[str], user_options: Any, use_openrouter: bool = True) -> str: if not frame_captions: raise ValueError("frame_captions is empty") user_options = _normalize_user_options(user_options) def _fallback_summary() -> str: # Cheap deterministic fallback: return a compact summary based on a few unique frame captions. uniq: list[str] = [] seen: set[str] = set() for c in frame_captions: c2 = (c or "").strip() if not c2 or c2.startswith("[caption_failed"): continue k = c2.lower() if k in seen: continue seen.add(k) uniq.append(c2) if len(uniq) >= 6: break if not uniq: return "A short video clip." return "Video shows: " + "; ".join(uniq) if not use_openrouter: return _fallback_summary() model = "deepseek/deepseek-r1-0528:free" captions_text = "\n".join(f"- {c}" for c in frame_captions) system = ( "You are a video caption summarizer. " "Given short captions of sampled frames from a video, infer a single coherent caption " "that describes the overall video. " "Do NOT list frames one-by-one. Output ONLY the final caption text, no quotes." ) extra_ctx = _additional_context_instruction(user_options) user = ( "Requirements:\n" f"- Output language: {user_options.language}\n" f"- Target platform: {user_options.platform}\n" f"- Text style: {user_options.text_style}\n" f"- Emojis: {'allowed' if user_options.emojis else 'do not use'}\n" f"- Hashtags: {'include' if user_options.hashtags else 'do not include'}\n" f"- {_caption_length_instruction(user_options.caption_length)}\n" + (f"\nAdditional context (use this to steer the caption; incorporate keywords if provided):\n{extra_ctx}\n" if extra_ctx else "\n") + f"\nFrame captions:\n{captions_text}" ) out = _openrouter_chat( messages=[ {"role": "system", "content": system}, {"role": "user", "content": user}, ], model=model, ) cleaned = _clean_caption_output(out) return cleaned if cleaned else _fallback_summary() def compose_image_caption(base_caption: str, user_options: Any, use_openrouter: bool = True) -> str: if not base_caption.strip(): raise ValueError("base_caption is empty") user_options = _normalize_user_options(user_options) if not use_openrouter: return base_caption.strip() model = "deepseek/deepseek-r1-0528:free" system = ( "You are a caption fuser/editor for social platforms. " "Given a base caption and preferences, output a polished final caption. " "If additional context/keywords are provided, incorporate them naturally and consistently. " "Output ONLY the final caption text, no quotes, no explanations." ) extra_ctx = _additional_context_instruction(user_options) user = ( f"Base caption: {base_caption}\n" "Requirements:\n" f"- Text style: {user_options.text_style}\n" f"- Target platform: {user_options.platform}\n" f"- Output language: {user_options.language}\n" f"- Emojis: {'allowed' if user_options.emojis else 'do not use'}\n" f"- Hashtags: {'include' if user_options.hashtags else 'do not include'}\n" f"- {_caption_length_instruction(user_options.caption_length)}\n" + (f"\nAdditional context (use this to steer the caption; incorporate keywords if provided):\n{extra_ctx}\n" if extra_ctx else "") ) out = _openrouter_chat( messages=[ {"role": "system", "content": system}, {"role": "user", "content": user}, ], model=model, ) cleaned = _clean_caption_output(out) # If the model returns an empty/fully-stripped response, fall back to the base caption. return cleaned if cleaned else base_caption.strip() def caption_video(video_path: str, user_options: Any = None, use_openrouter: bool = True) -> str: user_options = _normalize_user_options(user_options) progress = _progress_printer(SHOW_PROGRESS) progress(0, "Starting video captioning") frames = extract_frames(video_path) progress(10, f"Extracted {len(frames)} frames") # Map frame caption progress into 10%..80% so we still have room for compose steps. frame_captions: list[str] = [] total = len(frames) for i, img in enumerate(frames, start=1): try: frame_captions.append(caption_image(img)) except Exception as e: frame_captions.append(f"[caption_failed: {e}]") if total: pct = 10 + int((i / total) * 70) progress(pct, f"Captioned {i}/{total} frames") progress(80, "Summarizing video") base_caption = compose_video_caption(frame_captions, user_options, use_openrouter=use_openrouter) progress(90, "Polishing final caption") final_caption = compose_image_caption(base_caption, user_options, use_openrouter=use_openrouter) progress(100, "Done") return final_caption def main(): """Simple manual entrypoint. - No args: captions IMAGE_PATH (image or video auto-detected). - One arg: captions that path (image or video auto-detected). """ import sys media_path = IMAGE_PATH if len(sys.argv) == 1 else sys.argv[1] opts = _normalize_user_options(USER_OPTIONS) if _looks_like_video(media_path): try: final = caption_video(media_path, opts, use_openrouter=True) except RuntimeError as e: # If OpenRouter isn't configured, still return a deterministic fallback. if "OPENROUTER_API_KEY" in str(e): final = caption_video(media_path, opts, use_openrouter=False) else: raise print(final) else: progress = _progress_printer(SHOW_PROGRESS) progress(0, "Starting image captioning") base = caption_image_path(media_path) progress(50, "Polishing caption with AI") try: final = compose_image_caption(base, opts, use_openrouter=True) except RuntimeError as e: # If OpenRouter isn't configured, still return the raw BLIP caption. if "OPENROUTER_API_KEY" in str(e): final = base else: raise progress(100, "Done") print(final) if __name__ == "__main__": main()