reygml commited on
Commit
b7bc425
·
1 Parent(s): 029cf9d

feat:grounding dino

Browse files
Files changed (4) hide show
  1. app.py +81 -3
  2. grounding_dino2.py +155 -0
  3. ui.py +62 -10
  4. util.py +73 -1
app.py CHANGED
@@ -1,17 +1,21 @@
1
  # app.py
2
  from time import perf_counter
3
- from typing import List, Optional
 
4
 
5
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
6
  from pydantic import BaseModel, Field, HttpUrl
 
7
  import uvicorn
8
 
9
  from util import get_runner, SmolVLMRunner
10
 
11
- app = FastAPI(title="SmolVLM Inference API", version="1.1.0")
12
  _runner: Optional[SmolVLMRunner] = None
13
 
14
 
 
 
15
  class URLRequest(BaseModel):
16
  prompt: str = Field(..., description="Text prompt to accompany the images.")
17
  image_urls: List[HttpUrl] = Field(..., description="List of image URLs.")
@@ -19,18 +23,32 @@ class URLRequest(BaseModel):
19
  temperature: Optional[float] = Field(None, ge=0.0, le=2.0)
20
  top_p: Optional[float] = Field(None, gt=0.0, le=1.0)
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  @app.on_event("startup")
24
  async def _load_model_on_startup():
25
  global _runner
26
  _runner = get_runner()
27
 
28
-
29
  @app.get("/")
30
  def health():
31
  return {"status": "ok", "model": _runner.model_id if _runner else None}
32
 
33
 
 
 
34
  @app.post("/generate")
35
  async def generate_from_files(
36
  prompt: str = Form(...),
@@ -105,6 +123,66 @@ async def generate_from_urls(req: URLRequest):
105
  return {"text": text, "metrics": metrics}
106
 
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  if __name__ == "__main__":
109
  # Run with: python app.py (or: uvicorn app:app --host 0.0.0.0 --port 8000)
110
  uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=False)
 
1
  # app.py
2
  from time import perf_counter
3
+ from io import BytesIO
4
+ from typing import List, Optional, Union
5
 
6
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
7
  from pydantic import BaseModel, Field, HttpUrl
8
+ from PIL import Image
9
  import uvicorn
10
 
11
  from util import get_runner, SmolVLMRunner
12
 
13
+ app = FastAPI(title="SmolVLM Inference API", version="1.2.0")
14
  _runner: Optional[SmolVLMRunner] = None
15
 
16
 
17
+ # ----------------------- Pydantic models -----------------------
18
+
19
  class URLRequest(BaseModel):
20
  prompt: str = Field(..., description="Text prompt to accompany the images.")
21
  image_urls: List[HttpUrl] = Field(..., description="List of image URLs.")
 
23
  temperature: Optional[float] = Field(None, ge=0.0, le=2.0)
24
  top_p: Optional[float] = Field(None, gt=0.0, le=1.0)
25
 
26
+ class DetectDescribeURLRequest(BaseModel):
27
+ image_url: HttpUrl
28
+ labels: Union[str, List[str]]
29
+ box_threshold: float = 0.40
30
+ text_threshold: float = 0.30
31
+ pad_frac: float = 0.06
32
+ max_new_tokens: int = 160
33
+ return_overlay: bool = True
34
+ temperature: Optional[float] = None
35
+ top_p: Optional[float] = None
36
+
37
+
38
+ # ----------------------- Startup / health -----------------------
39
 
40
  @app.on_event("startup")
41
  async def _load_model_on_startup():
42
  global _runner
43
  _runner = get_runner()
44
 
 
45
  @app.get("/")
46
  def health():
47
  return {"status": "ok", "model": _runner.model_id if _runner else None}
48
 
49
 
50
+ # ----------------------- Core VLM endpoints -----------------------
51
+
52
  @app.post("/generate")
53
  async def generate_from_files(
54
  prompt: str = Form(...),
 
123
  return {"text": text, "metrics": metrics}
124
 
125
 
126
+ # ----------------------- Detect & Describe endpoints -----------------------
127
+
128
+ @app.post("/detect_describe")
129
+ async def detect_describe(
130
+ image: UploadFile = File(..., description="One image file (image/*)"),
131
+ labels: str = Form(..., description='Comma-separated phrases, e.g. "a man,a dog"'),
132
+ box_threshold: float = Form(0.40),
133
+ text_threshold: float = Form(0.30),
134
+ pad_frac: float = Form(0.06),
135
+ max_new_tokens: int = Form(160),
136
+ temperature: Optional[float] = Form(None),
137
+ top_p: Optional[float] = Form(None),
138
+ return_overlay: bool = Form(True),
139
+ ):
140
+ if not image.content_type or not image.content_type.startswith("image/"):
141
+ raise HTTPException(status_code=415, detail=f"Unsupported file type: {image.content_type}")
142
+
143
+ try:
144
+ raw = await image.read()
145
+ pil = Image.open(BytesIO(raw)).convert("RGB")
146
+ except Exception as e:
147
+ raise HTTPException(status_code=400, detail=f"Failed to read image: {e}")
148
+
149
+ out = _runner.detect_and_describe(
150
+ image=pil,
151
+ labels=labels, # comma-separated string OK
152
+ box_threshold=box_threshold,
153
+ text_threshold=text_threshold,
154
+ pad_frac=pad_frac,
155
+ max_new_tokens=max_new_tokens,
156
+ temperature=temperature,
157
+ top_p=top_p,
158
+ return_overlay=return_overlay,
159
+ )
160
+ return out
161
+
162
+
163
+ @app.post("/detect_describe_url")
164
+ async def detect_describe_url(req: DetectDescribeURLRequest):
165
+ try:
166
+ pil = _runner.load_pil_from_urls([str(req.image_url)])[0]
167
+ except Exception as e:
168
+ raise HTTPException(status_code=400, detail=f"Failed to fetch image: {e}")
169
+
170
+ out = _runner.detect_and_describe(
171
+ image=pil,
172
+ labels=req.labels,
173
+ box_threshold=req.box_threshold,
174
+ text_threshold=req.text_threshold,
175
+ pad_frac=req.pad_frac,
176
+ max_new_tokens=req.max_new_tokens,
177
+ temperature=req.temperature,
178
+ top_p=req.top_p,
179
+ return_overlay=req.return_overlay,
180
+ )
181
+ return out
182
+
183
+
184
+ # ----------------------- Entrypoint -----------------------
185
+
186
  if __name__ == "__main__":
187
  # Run with: python app.py (or: uvicorn app:app --host 0.0.0.0 --port 8000)
188
  uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=False)
grounding_dino2.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # grounding_dino_runner.py
2
+ # Lightweight Grounding DINO wrapper for box detection + cropping.
3
+ # Works on CPU or GPU; safe on T4 (no flash-attn).
4
+ from __future__ import annotations
5
+
6
+ import os
7
+ import threading
8
+ from pathlib import Path
9
+ from typing import List, Dict, Any, Tuple, Optional
10
+
11
+ import torch
12
+ from PIL import Image, ImageDraw, ImageFont
13
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
14
+
15
+ def visualize_detections(
16
+ image: Image.Image,
17
+ detections: list[dict],
18
+ *,
19
+ box_color: tuple[int, int, int] = (0, 255, 0),
20
+ text_color: tuple[int, int, int] = (0, 0, 0),
21
+ box_width: int = 3,
22
+ ) -> Image.Image:
23
+ """
24
+ Draw boxes + labels on a copy of `image`.
25
+ Each detection item expects: {'label': str, 'score': float, 'box_xyxy': (x0,y0,x1,y1)}
26
+ """
27
+ vis = image.copy()
28
+ draw = ImageDraw.Draw(vis)
29
+ try:
30
+ font = ImageFont.truetype("DejaVuSans.ttf", 16)
31
+ except Exception:
32
+ font = None
33
+
34
+ for det in detections:
35
+ x0, y0, x1, y1 = det["box_xyxy"]
36
+ lab = det.get("label", "")
37
+ sc = det.get("score", 0.0)
38
+ draw.rectangle((x0, y0, x1, y1), outline=box_color, width=box_width)
39
+ text = f"{lab} {sc:.2f}"
40
+ text_w = draw.textlength(text, font=font) if font else len(text) * 8
41
+ pad = 4
42
+ draw.rectangle((x0, y0 - 20, x0 + int(text_w) + pad * 2, y0), fill=box_color)
43
+ draw.text((x0 + pad, y0 - 18), text, fill=text_color, font=font)
44
+ return vis
45
+
46
+ def _clamp_xyxy(box: List[float], w: int, h: int) -> Tuple[int, int, int, int]:
47
+ x0, y0, x1, y1 = box
48
+ x0 = max(0, min(int(round(x0)), w - 1))
49
+ y0 = max(0, min(int(round(y0)), h - 1))
50
+ x1 = max(0, min(int(round(x1)), w - 1))
51
+ y1 = max(0, min(int(round(y1)), h - 1))
52
+ if x1 < x0:
53
+ x0, x1 = x1, x0
54
+ if y1 < y0:
55
+ y0, y1 = y1, y0
56
+ return x0, y0, x1, y1
57
+
58
+ def _pad_box(box: Tuple[int, int, int, int], w: int, h: int, frac: float = 0.06) -> Tuple[int, int, int, int]:
59
+ x0, y0, x1, y1 = box
60
+ bw, bh = x1 - x0, y1 - y0
61
+ dx, dy = int(bw * frac), int(bh * frac)
62
+ return max(0, x0 - dx), max(0, y0 - dy), min(w - 1, x1 + dx), min(h - 1, y1 + dy)
63
+
64
+ def crop_from_box(img: Image.Image, box_xyxy: Tuple[int, int, int, int]) -> Image.Image:
65
+ return img.crop(box_xyxy)
66
+
67
+ class GroundingDINORunner:
68
+ """
69
+ Minimal singleton-style wrapper for Grounding DINO zero-shot detector.
70
+ """
71
+
72
+ def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None):
73
+ self.model_id = model_id or os.getenv("GDINO_MODEL_ID", "IDEA-Research/grounding-dino-tiny")
74
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
75
+ self._lock = threading.Lock()
76
+
77
+ self.processor = AutoProcessor.from_pretrained(self.model_id, cache_dir=CACHE_DIR)
78
+ self.model = AutoModelForZeroShotObjectDetection.from_pretrained(
79
+ self.model_id, cache_dir=CACHE_DIR
80
+ ).to(self.device)
81
+ self.model.eval()
82
+
83
+ @staticmethod
84
+ def _normalize_labels(labels: List[str] | str) -> List[List[str]]:
85
+ if isinstance(labels, str):
86
+ items = [x.strip() for x in labels.split(",") if x.strip()]
87
+ else:
88
+ items = [x.strip() for x in labels if x and x.strip()]
89
+ if not items:
90
+ raise ValueError("No labels provided.")
91
+ # Grounding DINO expects nested list of phrases: [["a cat", "a remote control"]]
92
+ return [items]
93
+
94
+ def detect(
95
+ self,
96
+ image: Image.Image,
97
+ labels: List[str] | str,
98
+ box_threshold: float = 0.4,
99
+ text_threshold: float = 0.3,
100
+ pad_frac: float = 0.06,
101
+ ) -> List[Dict[str, Any]]:
102
+ """
103
+ Runs zero-shot detection and returns a list of dicts:
104
+ { 'label': str, 'score': float, 'box_xyxy': (x0,y0,x1,y1), 'crop': PIL.Image }
105
+ """
106
+ w, h = image.size
107
+ phrases = self._normalize_labels(labels)
108
+ inputs = self.processor(images=image, text=phrases, return_tensors="pt").to(self.device)
109
+
110
+ with self._lock, torch.no_grad():
111
+ outputs = self.model(**inputs)
112
+
113
+ # transformers>=4.51 renamed box_threshold -> threshold
114
+ try:
115
+ post = self.processor.post_process_grounded_object_detection(
116
+ outputs=outputs,
117
+ input_ids=inputs.input_ids,
118
+ threshold=float(box_threshold),
119
+ text_threshold=float(text_threshold),
120
+ target_sizes=[(h, w)],
121
+ )
122
+ except TypeError:
123
+ post = self.processor.post_process_grounded_object_detection(
124
+ outputs=outputs,
125
+ input_ids=inputs.input_ids,
126
+ box_threshold=float(box_threshold),
127
+ text_threshold=float(text_threshold),
128
+ target_sizes=[(h, w)],
129
+ )
130
+
131
+ det = post[0]
132
+ boxes = det.get("boxes", [])
133
+ scores = det.get("scores", [])
134
+ labels_out = det.get("text_labels", det.get("labels", []))
135
+
136
+ results: List[Dict[str, Any]] = []
137
+ for b, s, lab in zip(boxes, scores, labels_out):
138
+ b = b.tolist() if hasattr(b, "tolist") else list(b)
139
+ bx = _clamp_xyxy(b, w, h)
140
+ bx = _pad_box(bx, w, h, pad_frac)
141
+ crop = crop_from_box(image, bx)
142
+ score = float(s.item()) if torch.is_tensor(s) else float(s)
143
+ results.append({"label": lab, "score": score, "box_xyxy": bx, "crop": crop})
144
+
145
+ return results
146
+
147
+ # convenience singleton
148
+ _runner_singleton: GroundingDINORunner | None = None
149
+
150
+ def get_runner() -> GroundingDINORunner:
151
+ global _runner_singleton
152
+ if _runner_singleton is None:
153
+ _runner_singleton = GroundingDINORunner()
154
+ return _runner_singleton
155
+
ui.py CHANGED
@@ -1,13 +1,11 @@
1
  # ui.py
2
  import os
3
-
4
  import io
5
  import json
6
  import requests
7
  import streamlit as st
8
  from PIL import Image
9
 
10
-
11
  st.set_page_config(page_title="SmolVLM UI", layout="wide")
12
  st.title("SmolVLM")
13
 
@@ -22,9 +20,6 @@ with st.sidebar:
22
  top_p = st.slider("top_p", 0.05, 1.0, 0.95, step=0.05) if topp_on else None
23
  st.caption("API base: " + API_BASE)
24
 
25
- tabs = st.tabs(["Upload images", "Image URLs"])
26
- prompt = st.text_area("Prompt", "Can you describe the image(s)?", height=80)
27
-
28
  def show_metrics(metrics: dict):
29
  if not metrics:
30
  return
@@ -40,9 +35,13 @@ def show_metrics(metrics: dict):
40
  cols[3].metric("GPU reserved (MB)", f"{vram:.0f}" if vram is not None else "—")
41
  st.expander("All metrics").json(info)
42
 
43
- with tabs[0]:
 
 
 
44
  st.subheader("Upload one or more images")
45
  files = st.file_uploader("Images", type=["png", "jpg", "jpeg", "webp"], accept_multiple_files=True)
 
46
  run = st.button("Generate from uploads", type="primary", use_container_width=True, key="run_files")
47
 
48
  if run:
@@ -87,19 +86,22 @@ with tabs[0]:
87
  except Exception:
88
  st.write(e.response.text)
89
 
90
- with tabs[1]:
 
91
  st.subheader("Use remote image URLs")
92
- urls_raw = st.text_area("One URL per line", "", height=120, placeholder="https://example.com/a.jpg\nhttps://example.com/b.png")
 
 
93
  run2 = st.button("Generate from URLs", type="primary", use_container_width=True, key="run_urls")
94
 
95
  if run2:
96
  urls = [u.strip() for u in urls_raw.splitlines() if u.strip()]
97
- if not urls or not prompt.strip():
98
  st.error("Please add at least one URL and a prompt.")
99
  else:
100
  with st.spinner("Calling FastAPI…"):
101
  body = {
102
- "prompt": prompt,
103
  "image_urls": urls,
104
  "max_new_tokens": max_new_tokens,
105
  "temperature": temperature, # FastAPI model allows null
@@ -123,3 +125,53 @@ with tabs[1]:
123
  except Exception:
124
  st.write(e.response.text)
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # ui.py
2
  import os
 
3
  import io
4
  import json
5
  import requests
6
  import streamlit as st
7
  from PIL import Image
8
 
 
9
  st.set_page_config(page_title="SmolVLM UI", layout="wide")
10
  st.title("SmolVLM")
11
 
 
20
  top_p = st.slider("top_p", 0.05, 1.0, 0.95, step=0.05) if topp_on else None
21
  st.caption("API base: " + API_BASE)
22
 
 
 
 
23
  def show_metrics(metrics: dict):
24
  if not metrics:
25
  return
 
35
  cols[3].metric("GPU reserved (MB)", f"{vram:.0f}" if vram is not None else "—")
36
  st.expander("All metrics").json(info)
37
 
38
+ tab_upload, tab_urls, tab_detect = st.tabs(["Upload images", "Image URLs", "Detect & Describe"])
39
+
40
+ # -------------------- Tab 1: uploads -> /generate --------------------
41
+ with tab_upload:
42
  st.subheader("Upload one or more images")
43
  files = st.file_uploader("Images", type=["png", "jpg", "jpeg", "webp"], accept_multiple_files=True)
44
+ prompt = st.text_area("Prompt", "Can you describe the image(s)?", height=80)
45
  run = st.button("Generate from uploads", type="primary", use_container_width=True, key="run_files")
46
 
47
  if run:
 
86
  except Exception:
87
  st.write(e.response.text)
88
 
89
+ # -------------------- Tab 2: URLs -> /generate_urls --------------------
90
+ with tab_urls:
91
  st.subheader("Use remote image URLs")
92
+ prompt2 = st.text_area("Prompt", "Can you describe the image(s)?", height=80, key="prompt_urls")
93
+ urls_raw = st.text_area("One URL per line", "", height=120,
94
+ placeholder="https://example.com/a.jpg\nhttps://example.com/b.png")
95
  run2 = st.button("Generate from URLs", type="primary", use_container_width=True, key="run_urls")
96
 
97
  if run2:
98
  urls = [u.strip() for u in urls_raw.splitlines() if u.strip()]
99
+ if not urls or not prompt2.strip():
100
  st.error("Please add at least one URL and a prompt.")
101
  else:
102
  with st.spinner("Calling FastAPI…"):
103
  body = {
104
+ "prompt": prompt2,
105
  "image_urls": urls,
106
  "max_new_tokens": max_new_tokens,
107
  "temperature": temperature, # FastAPI model allows null
 
125
  except Exception:
126
  st.write(e.response.text)
127
 
128
+ # -------------------- Tab 3: Detect & Describe -> /detect_describe --------------------
129
+ with tab_detect:
130
+ st.subheader("Grounding DINO + SmolVLM")
131
+ det_image = st.file_uploader("Image", type=["jpg", "jpeg", "png", "webp"], accept_multiple_files=False)
132
+ det_labels = st.text_input("Labels (comma-separated)", "a man,a dog")
133
+ det_box_thr = st.slider("box_threshold", 0.05, 0.95, 0.40, 0.01)
134
+ det_text_thr = st.slider("text_threshold", 0.05, 0.95, 0.30, 0.01)
135
+ det_pad = st.slider("crop padding (fraction)", 0.0, 0.2, 0.06, 0.01)
136
+ det_max_new = st.slider("max_new_tokens", 1, 512, 160, 1)
137
+
138
+ run_det = st.button("Detect & Describe", type="primary", use_container_width=True)
139
+ if run_det:
140
+ if not det_image or not det_labels.strip():
141
+ st.error("Please provide an image and at least one label.")
142
+ else:
143
+ with st.spinner("Calling FastAPI…"):
144
+ data = {
145
+ "labels": det_labels,
146
+ "box_threshold": str(det_box_thr),
147
+ "text_threshold": str(det_text_thr),
148
+ "pad_frac": str(det_pad),
149
+ "max_new_tokens": str(det_max_new),
150
+ "return_overlay": "true",
151
+ }
152
+ files = [("image", (det_image.name, det_image.read(), det_image.type or "application/octet-stream"))]
153
+ try:
154
+ r = requests.post(f"{API_BASE}/detect_describe", data=data, files=files, timeout=300)
155
+ r.raise_for_status()
156
+ out = r.json()
157
+
158
+ # Show overlay
159
+ b64 = out.get("overlay_png_b64")
160
+ if b64:
161
+ st.image(f"data:image/png;base64,{b64}", caption="Detections", use_column_width=True)
162
+
163
+ # List detections
164
+ dets = out.get("detections", [])
165
+ if not dets:
166
+ st.info("No detections at current thresholds.")
167
+ for i, d in enumerate(dets, 1):
168
+ st.markdown(f"**{i}. {d['label']}** (score={d['score']:.2f}, box={d['box_xyxy']})")
169
+ st.write(d["description"])
170
+ except requests.RequestException as e:
171
+ st.error(f"Request failed: {e}")
172
+ if hasattr(e, "response") and e.response is not None:
173
+ try:
174
+ st.code(e.response.text, language="json")
175
+ except Exception:
176
+ st.write(e.response.text)
177
+
util.py CHANGED
@@ -27,7 +27,7 @@ from PIL import Image
27
  from transformers import AutoProcessor, AutoModelForVision2Seq
28
  from transformers.image_utils import load_image as hf_load_image
29
 
30
-
31
 
32
 
33
  def _has_flash_attn() -> bool:
@@ -102,6 +102,78 @@ class SmolVLMRunner:
102
  return [cls._ensure_rgb(Image.open(BytesIO(b))) for b in blobs]
103
 
104
  # ---------- Inference ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def generate(
106
  self,
107
  prompt: str,
 
27
  from transformers import AutoProcessor, AutoModelForVision2Seq
28
  from transformers.image_utils import load_image as hf_load_image
29
 
30
+ from grounding_dino2 import get_runner as get_gdino_runner, visualize_detections
31
 
32
 
33
  def _has_flash_attn() -> bool:
 
102
  return [cls._ensure_rgb(Image.open(BytesIO(b))) for b in blobs]
103
 
104
  # ---------- Inference ----------
105
+ def detect_and_describe(
106
+ self,
107
+ image: Image.Image,
108
+ labels: list[str] | str,
109
+ *,
110
+ box_threshold: float = 0.4,
111
+ text_threshold: float = 0.3,
112
+ pad_frac: float = 0.06,
113
+ max_new_tokens: int = 160,
114
+ temperature: float | None = None,
115
+ top_p: float | None = None,
116
+ return_overlay: bool = False,
117
+ ) -> list[dict] | dict:
118
+ """
119
+ Uses Grounding DINO to detect boxes for `labels`, then asks SmolVLM to
120
+ describe each cropped box.
121
+
122
+ If return_overlay=False (default): returns a list of dicts:
123
+ [{ 'label','score','box_xyxy','description' }, ...]
124
+ If return_overlay=True: returns a dict:
125
+ { 'detections': [...], 'overlay_png_b64': '<base64 PNG>' }
126
+ """
127
+ gdino = get_gdino_runner()
128
+ detections = gdino.detect(
129
+ image=image,
130
+ labels=labels,
131
+ box_threshold=box_threshold,
132
+ text_threshold=text_threshold,
133
+ pad_frac=pad_frac,
134
+ )
135
+ if not detections:
136
+ return [] if not return_overlay else {"detections": [], "overlay_png_b64": None}
137
+
138
+ results: list[dict] = []
139
+ for det in detections:
140
+ crop = det["crop"]
141
+ prompt_txt = f"Describe the object inside this crop in detail. It was detected with the phrase: '{det['label']}'."
142
+ content = [{"type": "image"}, {"type": "text", "text": prompt_txt}]
143
+ messages = [{"role": "user", "content": content}]
144
+ chat_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
145
+
146
+ inputs = self.processor(text=chat_prompt, images=[crop], return_tensors="pt")
147
+ inputs = {k: (v.to(self.device) if hasattr(v, "to") else v) for k, v in inputs.items()}
148
+
149
+ gen_kwargs = dict(max_new_tokens=max_new_tokens)
150
+ if temperature is not None:
151
+ gen_kwargs["temperature"] = float(temperature)
152
+ if top_p is not None:
153
+ gen_kwargs["top_p"] = float(top_p)
154
+
155
+ with self._lock, torch.inference_mode():
156
+ out_ids = self.model.generate(**inputs, **gen_kwargs)
157
+ text = self.processor.batch_decode(out_ids, skip_special_tokens=True)[0].strip()
158
+ if text.startswith("Assistant:"):
159
+ text = text[len("Assistant:"):].strip()
160
+
161
+ results.append({
162
+ "label": det["label"],
163
+ "score": det["score"],
164
+ "box_xyxy": det["box_xyxy"],
165
+ "description": text,
166
+ })
167
+
168
+ if not return_overlay:
169
+ return results
170
+
171
+ # Build overlay image (PNG -> base64 string)
172
+ overlay = visualize_detections(image, detections)
173
+ buf = io.BytesIO()
174
+ overlay.save(buf, format="PNG")
175
+ b64 = base64.b64encode(buf.getvalue()).decode("ascii")
176
+ return {"detections": results, "overlay_png_b64": b64}
177
  def generate(
178
  self,
179
  prompt: str,