File size: 10,486 Bytes
0a0d999
dafe938
0a0d999
dafe938
 
 
 
 
 
 
 
 
 
 
0a0d999
 
dafe938
 
 
 
 
 
 
0a0d999
 
7d07e42
 
0a0d999
 
 
7d07e42
dafe938
0a0d999
 
 
7d07e42
 
 
 
 
 
dafe938
 
 
0a0d999
 
7d07e42
 
 
 
 
 
 
 
 
dafe938
 
 
 
 
 
 
 
7d07e42
 
 
0a0d999
dafe938
 
 
 
 
 
 
0a0d999
 
dafe938
 
0a0d999
dafe938
 
0a0d999
7d07e42
dafe938
 
 
 
 
0a0d999
dafe938
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a0d999
dafe938
 
 
0a0d999
 
dafe938
0a0d999
 
dafe938
 
0a0d999
 
dafe938
0a0d999
dafe938
0a0d999
dafe938
0a0d999
dafe938
0a0d999
 
dafe938
 
 
 
0a0d999
 
 
 
dafe938
 
 
 
 
 
 
 
 
 
0a0d999
 
 
 
 
dafe938
0a0d999
dafe938
 
 
7d07e42
dafe938
7d07e42
0a0d999
 
 
 
dafe938
 
0a0d999
dafe938
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a0d999
 
dafe938
 
0a0d999
7d07e42
dafe938
7d07e42
 
dafe938
0a0d999
 
 
dafe938
 
 
 
0a0d999
dafe938
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
"""
Image captioning β€” dua strategi tergantung mode:

MODE 1 (default, FAST): YOLO detections β†’ Groq text LLM β†’ caption
  - Image TIDAK dikirim ke API
  - Latency: YOLO ~0.1s + Groq text ~0.5-1s = total ~1s
  - Tidak bergantung pada Groq Vision quota/latency

MODE 2 (vision, opt-in): kirim gambar ke Groq Vision API
  - Lebih akurat untuk gambar tanpa objek COCO jelas
  - Latency: 3-15s (tergantung Groq server load)
  - Aktifkan dengan env GROQ_CAPTION_MODE=vision

Default mode=fast karena Groq Vision sering timeout dari HF US server.

Env vars:
  GROQ_API_KEY           - wajib
  GROQ_CAPTION_MODE      - "fast" (default) atau "vision"
  GROQ_TEXT_MODEL        - default "llama-3.3-70b-versatile"
  GROQ_VISION_MODEL      - default "meta-llama/llama-4-scout-17b-16e-instruct"
  GROQ_TEXT_TIMEOUT      - default 15
  GROQ_VISION_TIMEOUT    - default 30
  GROQ_VISION_MAX_SIDE   - default 1024
"""

from __future__ import annotations

import os
import io
import base64
from dataclasses import dataclass
from typing import Optional, List

import httpx
from PIL import Image
from loguru import logger

from ..config import get_cv_settings
from ..processors.image_preprocessor import ImageInput


_GROQ_API_URL         = "https://api.groq.com/openai/v1/chat/completions"
_DEFAULT_TEXT_MODEL   = "llama-3.3-70b-versatile"
_DEFAULT_VISION_MODEL = "meta-llama/llama-4-scout-17b-16e-instruct"


@dataclass
class CaptionResult:
    caption: str
    model: str
    confidence: float = 1.0


class ImageCaptioner:
    """
    Smart image captioner.

    Mode FAST (default):
      Gunakan YOLO detections + Groq text LLM β†’ caption natural.
      Image TIDAK dikirim ke API. Latency ~0.5-1s.

    Mode VISION (GROQ_CAPTION_MODE=vision):
      Encode gambar β†’ Groq Vision API. Latency 3-15s.
    """

    def __init__(self):
        _ = get_cv_settings()
        self.api_key          = os.environ.get("GROQ_API_KEY", "").strip()
        self.mode             = os.environ.get("GROQ_CAPTION_MODE", "fast").lower()
        self.text_model       = os.environ.get("GROQ_TEXT_MODEL", _DEFAULT_TEXT_MODEL).strip()
        self.vision_model     = os.environ.get("GROQ_VISION_MODEL", _DEFAULT_VISION_MODEL).strip()
        self._text_timeout    = float(os.environ.get("GROQ_TEXT_TIMEOUT", "15"))
        self._vision_timeout  = float(os.environ.get("GROQ_VISION_TIMEOUT", "30"))
        self._max_side        = int(os.environ.get("GROQ_VISION_MAX_SIDE", "1024"))

        if not self.api_key:
            logger.warning("GROQ_API_KEY tidak di-set.")

        logger.info(
            f"ImageCaptioner ready. mode={self.mode} | "
            f"text={self.text_model} | API key: {'SET' if self.api_key else 'NOT SET'}"
        )

    def _groq_headers(self):
        return {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json",
        }

    def _call_groq_text(self, prompt: str, system: str, max_tokens: int = 80) -> str:
        if not self.api_key:
            raise RuntimeError("GROQ_API_KEY belum di-set.")
        payload = {
            "model": self.text_model,
            "messages": [
                {"role": "system", "content": system},
                {"role": "user",   "content": prompt},
            ],
            "max_tokens": max_tokens,
            "temperature": 0.3,
        }
        with httpx.Client(timeout=self._text_timeout) as client:
            try:
                resp = client.post(_GROQ_API_URL, json=payload, headers=self._groq_headers())
            except httpx.TimeoutException as e:
                raise RuntimeError(f"Groq text timeout ({self._text_timeout}s): {e}")
            except httpx.HTTPError as e:
                raise RuntimeError(f"Groq text network error: {e}")
        if resp.status_code >= 400:
            try:
                err = resp.json().get("error", {}).get("message", resp.text)
            except Exception:
                err = resp.text[:200]
            raise RuntimeError(f"Groq text error {resp.status_code}: {err}")
        try:
            return resp.json()["choices"][0]["message"]["content"].strip()
        except (KeyError, IndexError) as e:
            raise RuntimeError(f"Groq text response unexpected: {e}")

    def _image_to_data_url(self, pil_image: Image.Image) -> str:
        img = pil_image.convert("RGB")
        w, h = img.size
        if max(w, h) > self._max_side:
            scale = self._max_side / max(w, h)
            img = img.resize((max(1, int(w * scale)), max(1, int(h * scale))), Image.LANCZOS)
        buf = io.BytesIO()
        img.save(buf, format="JPEG", quality=85, optimize=True)
        return "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode()

    def _call_groq_vision(self, image: ImageInput, user_prompt: str, system_prompt: str, max_tokens: int) -> str:
        if not self.api_key:
            raise RuntimeError("GROQ_API_KEY belum di-set.")
        payload = {
            "model": self.vision_model,
            "messages": [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": [
                    {"type": "text",      "text": user_prompt},
                    {"type": "image_url", "image_url": {"url": self._image_to_data_url(image.pil_image)}},
                ]},
            ],
            "max_tokens": max_tokens,
            "temperature": 0.2,
        }
        with httpx.Client(timeout=self._vision_timeout) as client:
            try:
                resp = client.post(_GROQ_API_URL, json=payload, headers=self._groq_headers())
            except httpx.TimeoutException:
                raise RuntimeError(
                    f"Groq Vision timeout ({self._vision_timeout}s). "
                    "Set GROQ_CAPTION_MODE=fast di HF Space Settings untuk mode cepat."
                )
            except httpx.HTTPError as e:
                raise RuntimeError(f"Groq Vision network error: {e}")
        if resp.status_code >= 400:
            try:
                err = resp.json().get("error", {}).get("message", resp.text)
            except Exception:
                err = resp.text[:300]
            raise RuntimeError(f"Groq Vision error {resp.status_code}: {err}")
        try:
            return resp.json()["choices"][0]["message"]["content"].strip()
        except (KeyError, IndexError) as e:
            raise RuntimeError(f"Groq Vision response unexpected: {e}")

    # ── Public API ────────────────────────────────────────────────────────

    def caption(
        self,
        image: ImageInput,
        prompt: Optional[str] = None,
        detections=None,
        max_new_tokens: int = 80,
    ) -> CaptionResult:
        """Generate caption. Mode fast = YOLO+text, mode vision = Groq Vision."""
        if self.mode == "fast":
            return self._caption_fast(image, detections, prompt)
        return self._caption_vision(image, prompt, max_new_tokens)

    def _caption_fast(self, image: ImageInput, detections=None, custom_prompt: Optional[str] = None) -> CaptionResult:
        """Caption dari YOLO detections via Groq text LLM β€” tidak kirim gambar ke API."""
        img_info = f"{image.width}x{image.height}px"

        if detections and len(detections) > 0:
            summary: dict = {}
            for d in detections:
                summary[d.label] = summary.get(d.label, 0) + 1
            det_str = ", ".join(f"{c} {l}{'s' if c > 1 else ''}" for l, c in summary.items())
            context = f"Image {img_info}. Detected: {det_str}."
        else:
            context = f"Image {img_info}. No COCO objects detected (may be a scene, document, or abstract)."

        user_msg = (
            f"{context}\n\n"
            + (custom_prompt if custom_prompt else
               "Write a short natural caption (max 20 words) for this image based on what was detected. "
               "Be specific. Do not start with 'The image shows' or 'This is'.")
        )
        system = (
            "You are a concise image captioning assistant. "
            "Write natural, specific captions based on object detection results. "
            "Never start with 'The image shows', 'This is', or similar filler phrases."
        )

        try:
            text = self._call_groq_text(user_msg, system, max_tokens=60)
            logger.debug(f"Fast caption: {text}")
            return CaptionResult(caption=text, model=f"{self.text_model}(fast)")
        except Exception as e:
            logger.warning(f"Fast caption Groq call failed, pure fallback: {e}")
            # Fallback 100% offline β€” no API
            if detections and len(detections) > 0:
                summary = {}
                for d in detections:
                    summary[d.label] = summary.get(d.label, 0) + 1
                parts = [f"{c} {l}" for l, c in summary.items()]
                caption = "Scene with: " + ", ".join(parts)
            else:
                caption = f"Image ({image.width}x{image.height})"
            return CaptionResult(caption=caption, model="offline-fallback")

    def _caption_vision(self, image: ImageInput, prompt: Optional[str] = None, max_new_tokens: int = 80) -> CaptionResult:
        """Caption via Groq Vision API."""
        system = (
            "You are a precise image captioning assistant. "
            "Describe the image in one short sentence (under 25 words). "
            "Be factual. Do NOT start with 'The image shows' or 'This is a picture of'."
        )
        text = self._call_groq_vision(
            image,
            user_prompt=(prompt.strip() if prompt else "Describe this image."),
            system_prompt=system,
            max_tokens=max_new_tokens,
        )
        return CaptionResult(caption=text, model=self.vision_model)

    def visual_qa(self, image: ImageInput, question: str) -> CaptionResult:
        """Visual QA β€” selalu pakai Vision API."""
        question = (question or "").strip()
        if not question:
            raise ValueError("Question tidak boleh kosong.")
        system = (
            "You are a visual QA assistant. "
            "Answer briefly and factually (under 20 words). "
            "If not visible in image, say so."
        )
        text = self._call_groq_vision(image, question, system, max_tokens=80)
        return CaptionResult(caption=text, model=self.vision_model)