ai-rag / cv_module /src /models /captioner.py
robrtt's picture
Clean rebuild: all features fixed
dafe938
"""
Image captioning β€” dua strategi tergantung mode:
MODE 1 (default, FAST): YOLO detections β†’ Groq text LLM β†’ caption
- Image TIDAK dikirim ke API
- Latency: YOLO ~0.1s + Groq text ~0.5-1s = total ~1s
- Tidak bergantung pada Groq Vision quota/latency
MODE 2 (vision, opt-in): kirim gambar ke Groq Vision API
- Lebih akurat untuk gambar tanpa objek COCO jelas
- Latency: 3-15s (tergantung Groq server load)
- Aktifkan dengan env GROQ_CAPTION_MODE=vision
Default mode=fast karena Groq Vision sering timeout dari HF US server.
Env vars:
GROQ_API_KEY - wajib
GROQ_CAPTION_MODE - "fast" (default) atau "vision"
GROQ_TEXT_MODEL - default "llama-3.3-70b-versatile"
GROQ_VISION_MODEL - default "meta-llama/llama-4-scout-17b-16e-instruct"
GROQ_TEXT_TIMEOUT - default 15
GROQ_VISION_TIMEOUT - default 30
GROQ_VISION_MAX_SIDE - default 1024
"""
from __future__ import annotations
import os
import io
import base64
from dataclasses import dataclass
from typing import Optional, List
import httpx
from PIL import Image
from loguru import logger
from ..config import get_cv_settings
from ..processors.image_preprocessor import ImageInput
_GROQ_API_URL = "https://api.groq.com/openai/v1/chat/completions"
_DEFAULT_TEXT_MODEL = "llama-3.3-70b-versatile"
_DEFAULT_VISION_MODEL = "meta-llama/llama-4-scout-17b-16e-instruct"
@dataclass
class CaptionResult:
caption: str
model: str
confidence: float = 1.0
class ImageCaptioner:
"""
Smart image captioner.
Mode FAST (default):
Gunakan YOLO detections + Groq text LLM β†’ caption natural.
Image TIDAK dikirim ke API. Latency ~0.5-1s.
Mode VISION (GROQ_CAPTION_MODE=vision):
Encode gambar β†’ Groq Vision API. Latency 3-15s.
"""
def __init__(self):
_ = get_cv_settings()
self.api_key = os.environ.get("GROQ_API_KEY", "").strip()
self.mode = os.environ.get("GROQ_CAPTION_MODE", "fast").lower()
self.text_model = os.environ.get("GROQ_TEXT_MODEL", _DEFAULT_TEXT_MODEL).strip()
self.vision_model = os.environ.get("GROQ_VISION_MODEL", _DEFAULT_VISION_MODEL).strip()
self._text_timeout = float(os.environ.get("GROQ_TEXT_TIMEOUT", "15"))
self._vision_timeout = float(os.environ.get("GROQ_VISION_TIMEOUT", "30"))
self._max_side = int(os.environ.get("GROQ_VISION_MAX_SIDE", "1024"))
if not self.api_key:
logger.warning("GROQ_API_KEY tidak di-set.")
logger.info(
f"ImageCaptioner ready. mode={self.mode} | "
f"text={self.text_model} | API key: {'SET' if self.api_key else 'NOT SET'}"
)
def _groq_headers(self):
return {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
def _call_groq_text(self, prompt: str, system: str, max_tokens: int = 80) -> str:
if not self.api_key:
raise RuntimeError("GROQ_API_KEY belum di-set.")
payload = {
"model": self.text_model,
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": prompt},
],
"max_tokens": max_tokens,
"temperature": 0.3,
}
with httpx.Client(timeout=self._text_timeout) as client:
try:
resp = client.post(_GROQ_API_URL, json=payload, headers=self._groq_headers())
except httpx.TimeoutException as e:
raise RuntimeError(f"Groq text timeout ({self._text_timeout}s): {e}")
except httpx.HTTPError as e:
raise RuntimeError(f"Groq text network error: {e}")
if resp.status_code >= 400:
try:
err = resp.json().get("error", {}).get("message", resp.text)
except Exception:
err = resp.text[:200]
raise RuntimeError(f"Groq text error {resp.status_code}: {err}")
try:
return resp.json()["choices"][0]["message"]["content"].strip()
except (KeyError, IndexError) as e:
raise RuntimeError(f"Groq text response unexpected: {e}")
def _image_to_data_url(self, pil_image: Image.Image) -> str:
img = pil_image.convert("RGB")
w, h = img.size
if max(w, h) > self._max_side:
scale = self._max_side / max(w, h)
img = img.resize((max(1, int(w * scale)), max(1, int(h * scale))), Image.LANCZOS)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=85, optimize=True)
return "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode()
def _call_groq_vision(self, image: ImageInput, user_prompt: str, system_prompt: str, max_tokens: int) -> str:
if not self.api_key:
raise RuntimeError("GROQ_API_KEY belum di-set.")
payload = {
"model": self.vision_model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": [
{"type": "text", "text": user_prompt},
{"type": "image_url", "image_url": {"url": self._image_to_data_url(image.pil_image)}},
]},
],
"max_tokens": max_tokens,
"temperature": 0.2,
}
with httpx.Client(timeout=self._vision_timeout) as client:
try:
resp = client.post(_GROQ_API_URL, json=payload, headers=self._groq_headers())
except httpx.TimeoutException:
raise RuntimeError(
f"Groq Vision timeout ({self._vision_timeout}s). "
"Set GROQ_CAPTION_MODE=fast di HF Space Settings untuk mode cepat."
)
except httpx.HTTPError as e:
raise RuntimeError(f"Groq Vision network error: {e}")
if resp.status_code >= 400:
try:
err = resp.json().get("error", {}).get("message", resp.text)
except Exception:
err = resp.text[:300]
raise RuntimeError(f"Groq Vision error {resp.status_code}: {err}")
try:
return resp.json()["choices"][0]["message"]["content"].strip()
except (KeyError, IndexError) as e:
raise RuntimeError(f"Groq Vision response unexpected: {e}")
# ── Public API ────────────────────────────────────────────────────────
def caption(
self,
image: ImageInput,
prompt: Optional[str] = None,
detections=None,
max_new_tokens: int = 80,
) -> CaptionResult:
"""Generate caption. Mode fast = YOLO+text, mode vision = Groq Vision."""
if self.mode == "fast":
return self._caption_fast(image, detections, prompt)
return self._caption_vision(image, prompt, max_new_tokens)
def _caption_fast(self, image: ImageInput, detections=None, custom_prompt: Optional[str] = None) -> CaptionResult:
"""Caption dari YOLO detections via Groq text LLM β€” tidak kirim gambar ke API."""
img_info = f"{image.width}x{image.height}px"
if detections and len(detections) > 0:
summary: dict = {}
for d in detections:
summary[d.label] = summary.get(d.label, 0) + 1
det_str = ", ".join(f"{c} {l}{'s' if c > 1 else ''}" for l, c in summary.items())
context = f"Image {img_info}. Detected: {det_str}."
else:
context = f"Image {img_info}. No COCO objects detected (may be a scene, document, or abstract)."
user_msg = (
f"{context}\n\n"
+ (custom_prompt if custom_prompt else
"Write a short natural caption (max 20 words) for this image based on what was detected. "
"Be specific. Do not start with 'The image shows' or 'This is'.")
)
system = (
"You are a concise image captioning assistant. "
"Write natural, specific captions based on object detection results. "
"Never start with 'The image shows', 'This is', or similar filler phrases."
)
try:
text = self._call_groq_text(user_msg, system, max_tokens=60)
logger.debug(f"Fast caption: {text}")
return CaptionResult(caption=text, model=f"{self.text_model}(fast)")
except Exception as e:
logger.warning(f"Fast caption Groq call failed, pure fallback: {e}")
# Fallback 100% offline β€” no API
if detections and len(detections) > 0:
summary = {}
for d in detections:
summary[d.label] = summary.get(d.label, 0) + 1
parts = [f"{c} {l}" for l, c in summary.items()]
caption = "Scene with: " + ", ".join(parts)
else:
caption = f"Image ({image.width}x{image.height})"
return CaptionResult(caption=caption, model="offline-fallback")
def _caption_vision(self, image: ImageInput, prompt: Optional[str] = None, max_new_tokens: int = 80) -> CaptionResult:
"""Caption via Groq Vision API."""
system = (
"You are a precise image captioning assistant. "
"Describe the image in one short sentence (under 25 words). "
"Be factual. Do NOT start with 'The image shows' or 'This is a picture of'."
)
text = self._call_groq_vision(
image,
user_prompt=(prompt.strip() if prompt else "Describe this image."),
system_prompt=system,
max_tokens=max_new_tokens,
)
return CaptionResult(caption=text, model=self.vision_model)
def visual_qa(self, image: ImageInput, question: str) -> CaptionResult:
"""Visual QA β€” selalu pakai Vision API."""
question = (question or "").strip()
if not question:
raise ValueError("Question tidak boleh kosong.")
system = (
"You are a visual QA assistant. "
"Answer briefly and factually (under 20 words). "
"If not visible in image, say so."
)
text = self._call_groq_vision(image, question, system, max_tokens=80)
return CaptionResult(caption=text, model=self.vision_model)