Spaces:
Sleeping
Sleeping
feat:grounding dino
Browse files
app.py
CHANGED
|
@@ -1,17 +1,21 @@
|
|
| 1 |
# app.py
|
| 2 |
from time import perf_counter
|
| 3 |
-
from
|
|
|
|
| 4 |
|
| 5 |
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
| 6 |
from pydantic import BaseModel, Field, HttpUrl
|
|
|
|
| 7 |
import uvicorn
|
| 8 |
|
| 9 |
from util import get_runner, SmolVLMRunner
|
| 10 |
|
| 11 |
-
app = FastAPI(title="SmolVLM Inference API", version="1.
|
| 12 |
_runner: Optional[SmolVLMRunner] = None
|
| 13 |
|
| 14 |
|
|
|
|
|
|
|
| 15 |
class URLRequest(BaseModel):
|
| 16 |
prompt: str = Field(..., description="Text prompt to accompany the images.")
|
| 17 |
image_urls: List[HttpUrl] = Field(..., description="List of image URLs.")
|
|
@@ -19,18 +23,32 @@ class URLRequest(BaseModel):
|
|
| 19 |
temperature: Optional[float] = Field(None, ge=0.0, le=2.0)
|
| 20 |
top_p: Optional[float] = Field(None, gt=0.0, le=1.0)
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
@app.on_event("startup")
|
| 24 |
async def _load_model_on_startup():
|
| 25 |
global _runner
|
| 26 |
_runner = get_runner()
|
| 27 |
|
| 28 |
-
|
| 29 |
@app.get("/")
|
| 30 |
def health():
|
| 31 |
return {"status": "ok", "model": _runner.model_id if _runner else None}
|
| 32 |
|
| 33 |
|
|
|
|
|
|
|
| 34 |
@app.post("/generate")
|
| 35 |
async def generate_from_files(
|
| 36 |
prompt: str = Form(...),
|
|
@@ -105,6 +123,66 @@ async def generate_from_urls(req: URLRequest):
|
|
| 105 |
return {"text": text, "metrics": metrics}
|
| 106 |
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
if __name__ == "__main__":
|
| 109 |
# Run with: python app.py (or: uvicorn app:app --host 0.0.0.0 --port 8000)
|
| 110 |
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=False)
|
|
|
|
| 1 |
# app.py
|
| 2 |
from time import perf_counter
|
| 3 |
+
from io import BytesIO
|
| 4 |
+
from typing import List, Optional, Union
|
| 5 |
|
| 6 |
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
| 7 |
from pydantic import BaseModel, Field, HttpUrl
|
| 8 |
+
from PIL import Image
|
| 9 |
import uvicorn
|
| 10 |
|
| 11 |
from util import get_runner, SmolVLMRunner
|
| 12 |
|
| 13 |
+
app = FastAPI(title="SmolVLM Inference API", version="1.2.0")
|
| 14 |
_runner: Optional[SmolVLMRunner] = None
|
| 15 |
|
| 16 |
|
| 17 |
+
# ----------------------- Pydantic models -----------------------
|
| 18 |
+
|
| 19 |
class URLRequest(BaseModel):
|
| 20 |
prompt: str = Field(..., description="Text prompt to accompany the images.")
|
| 21 |
image_urls: List[HttpUrl] = Field(..., description="List of image URLs.")
|
|
|
|
| 23 |
temperature: Optional[float] = Field(None, ge=0.0, le=2.0)
|
| 24 |
top_p: Optional[float] = Field(None, gt=0.0, le=1.0)
|
| 25 |
|
| 26 |
+
class DetectDescribeURLRequest(BaseModel):
|
| 27 |
+
image_url: HttpUrl
|
| 28 |
+
labels: Union[str, List[str]]
|
| 29 |
+
box_threshold: float = 0.40
|
| 30 |
+
text_threshold: float = 0.30
|
| 31 |
+
pad_frac: float = 0.06
|
| 32 |
+
max_new_tokens: int = 160
|
| 33 |
+
return_overlay: bool = True
|
| 34 |
+
temperature: Optional[float] = None
|
| 35 |
+
top_p: Optional[float] = None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ----------------------- Startup / health -----------------------
|
| 39 |
|
| 40 |
@app.on_event("startup")
|
| 41 |
async def _load_model_on_startup():
|
| 42 |
global _runner
|
| 43 |
_runner = get_runner()
|
| 44 |
|
|
|
|
| 45 |
@app.get("/")
|
| 46 |
def health():
|
| 47 |
return {"status": "ok", "model": _runner.model_id if _runner else None}
|
| 48 |
|
| 49 |
|
| 50 |
+
# ----------------------- Core VLM endpoints -----------------------
|
| 51 |
+
|
| 52 |
@app.post("/generate")
|
| 53 |
async def generate_from_files(
|
| 54 |
prompt: str = Form(...),
|
|
|
|
| 123 |
return {"text": text, "metrics": metrics}
|
| 124 |
|
| 125 |
|
| 126 |
+
# ----------------------- Detect & Describe endpoints -----------------------
|
| 127 |
+
|
| 128 |
+
@app.post("/detect_describe")
|
| 129 |
+
async def detect_describe(
|
| 130 |
+
image: UploadFile = File(..., description="One image file (image/*)"),
|
| 131 |
+
labels: str = Form(..., description='Comma-separated phrases, e.g. "a man,a dog"'),
|
| 132 |
+
box_threshold: float = Form(0.40),
|
| 133 |
+
text_threshold: float = Form(0.30),
|
| 134 |
+
pad_frac: float = Form(0.06),
|
| 135 |
+
max_new_tokens: int = Form(160),
|
| 136 |
+
temperature: Optional[float] = Form(None),
|
| 137 |
+
top_p: Optional[float] = Form(None),
|
| 138 |
+
return_overlay: bool = Form(True),
|
| 139 |
+
):
|
| 140 |
+
if not image.content_type or not image.content_type.startswith("image/"):
|
| 141 |
+
raise HTTPException(status_code=415, detail=f"Unsupported file type: {image.content_type}")
|
| 142 |
+
|
| 143 |
+
try:
|
| 144 |
+
raw = await image.read()
|
| 145 |
+
pil = Image.open(BytesIO(raw)).convert("RGB")
|
| 146 |
+
except Exception as e:
|
| 147 |
+
raise HTTPException(status_code=400, detail=f"Failed to read image: {e}")
|
| 148 |
+
|
| 149 |
+
out = _runner.detect_and_describe(
|
| 150 |
+
image=pil,
|
| 151 |
+
labels=labels, # comma-separated string OK
|
| 152 |
+
box_threshold=box_threshold,
|
| 153 |
+
text_threshold=text_threshold,
|
| 154 |
+
pad_frac=pad_frac,
|
| 155 |
+
max_new_tokens=max_new_tokens,
|
| 156 |
+
temperature=temperature,
|
| 157 |
+
top_p=top_p,
|
| 158 |
+
return_overlay=return_overlay,
|
| 159 |
+
)
|
| 160 |
+
return out
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@app.post("/detect_describe_url")
|
| 164 |
+
async def detect_describe_url(req: DetectDescribeURLRequest):
|
| 165 |
+
try:
|
| 166 |
+
pil = _runner.load_pil_from_urls([str(req.image_url)])[0]
|
| 167 |
+
except Exception as e:
|
| 168 |
+
raise HTTPException(status_code=400, detail=f"Failed to fetch image: {e}")
|
| 169 |
+
|
| 170 |
+
out = _runner.detect_and_describe(
|
| 171 |
+
image=pil,
|
| 172 |
+
labels=req.labels,
|
| 173 |
+
box_threshold=req.box_threshold,
|
| 174 |
+
text_threshold=req.text_threshold,
|
| 175 |
+
pad_frac=req.pad_frac,
|
| 176 |
+
max_new_tokens=req.max_new_tokens,
|
| 177 |
+
temperature=req.temperature,
|
| 178 |
+
top_p=req.top_p,
|
| 179 |
+
return_overlay=req.return_overlay,
|
| 180 |
+
)
|
| 181 |
+
return out
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# ----------------------- Entrypoint -----------------------
|
| 185 |
+
|
| 186 |
if __name__ == "__main__":
|
| 187 |
# Run with: python app.py (or: uvicorn app:app --host 0.0.0.0 --port 8000)
|
| 188 |
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=False)
|
grounding_dino2.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# grounding_dino_runner.py
|
| 2 |
+
# Lightweight Grounding DINO wrapper for box detection + cropping.
|
| 3 |
+
# Works on CPU or GPU; safe on T4 (no flash-attn).
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import threading
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import List, Dict, Any, Tuple, Optional
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 13 |
+
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
| 14 |
+
|
| 15 |
+
def visualize_detections(
|
| 16 |
+
image: Image.Image,
|
| 17 |
+
detections: list[dict],
|
| 18 |
+
*,
|
| 19 |
+
box_color: tuple[int, int, int] = (0, 255, 0),
|
| 20 |
+
text_color: tuple[int, int, int] = (0, 0, 0),
|
| 21 |
+
box_width: int = 3,
|
| 22 |
+
) -> Image.Image:
|
| 23 |
+
"""
|
| 24 |
+
Draw boxes + labels on a copy of `image`.
|
| 25 |
+
Each detection item expects: {'label': str, 'score': float, 'box_xyxy': (x0,y0,x1,y1)}
|
| 26 |
+
"""
|
| 27 |
+
vis = image.copy()
|
| 28 |
+
draw = ImageDraw.Draw(vis)
|
| 29 |
+
try:
|
| 30 |
+
font = ImageFont.truetype("DejaVuSans.ttf", 16)
|
| 31 |
+
except Exception:
|
| 32 |
+
font = None
|
| 33 |
+
|
| 34 |
+
for det in detections:
|
| 35 |
+
x0, y0, x1, y1 = det["box_xyxy"]
|
| 36 |
+
lab = det.get("label", "")
|
| 37 |
+
sc = det.get("score", 0.0)
|
| 38 |
+
draw.rectangle((x0, y0, x1, y1), outline=box_color, width=box_width)
|
| 39 |
+
text = f"{lab} {sc:.2f}"
|
| 40 |
+
text_w = draw.textlength(text, font=font) if font else len(text) * 8
|
| 41 |
+
pad = 4
|
| 42 |
+
draw.rectangle((x0, y0 - 20, x0 + int(text_w) + pad * 2, y0), fill=box_color)
|
| 43 |
+
draw.text((x0 + pad, y0 - 18), text, fill=text_color, font=font)
|
| 44 |
+
return vis
|
| 45 |
+
|
| 46 |
+
def _clamp_xyxy(box: List[float], w: int, h: int) -> Tuple[int, int, int, int]:
|
| 47 |
+
x0, y0, x1, y1 = box
|
| 48 |
+
x0 = max(0, min(int(round(x0)), w - 1))
|
| 49 |
+
y0 = max(0, min(int(round(y0)), h - 1))
|
| 50 |
+
x1 = max(0, min(int(round(x1)), w - 1))
|
| 51 |
+
y1 = max(0, min(int(round(y1)), h - 1))
|
| 52 |
+
if x1 < x0:
|
| 53 |
+
x0, x1 = x1, x0
|
| 54 |
+
if y1 < y0:
|
| 55 |
+
y0, y1 = y1, y0
|
| 56 |
+
return x0, y0, x1, y1
|
| 57 |
+
|
| 58 |
+
def _pad_box(box: Tuple[int, int, int, int], w: int, h: int, frac: float = 0.06) -> Tuple[int, int, int, int]:
|
| 59 |
+
x0, y0, x1, y1 = box
|
| 60 |
+
bw, bh = x1 - x0, y1 - y0
|
| 61 |
+
dx, dy = int(bw * frac), int(bh * frac)
|
| 62 |
+
return max(0, x0 - dx), max(0, y0 - dy), min(w - 1, x1 + dx), min(h - 1, y1 + dy)
|
| 63 |
+
|
| 64 |
+
def crop_from_box(img: Image.Image, box_xyxy: Tuple[int, int, int, int]) -> Image.Image:
|
| 65 |
+
return img.crop(box_xyxy)
|
| 66 |
+
|
| 67 |
+
class GroundingDINORunner:
|
| 68 |
+
"""
|
| 69 |
+
Minimal singleton-style wrapper for Grounding DINO zero-shot detector.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None):
|
| 73 |
+
self.model_id = model_id or os.getenv("GDINO_MODEL_ID", "IDEA-Research/grounding-dino-tiny")
|
| 74 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 75 |
+
self._lock = threading.Lock()
|
| 76 |
+
|
| 77 |
+
self.processor = AutoProcessor.from_pretrained(self.model_id, cache_dir=CACHE_DIR)
|
| 78 |
+
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(
|
| 79 |
+
self.model_id, cache_dir=CACHE_DIR
|
| 80 |
+
).to(self.device)
|
| 81 |
+
self.model.eval()
|
| 82 |
+
|
| 83 |
+
@staticmethod
|
| 84 |
+
def _normalize_labels(labels: List[str] | str) -> List[List[str]]:
|
| 85 |
+
if isinstance(labels, str):
|
| 86 |
+
items = [x.strip() for x in labels.split(",") if x.strip()]
|
| 87 |
+
else:
|
| 88 |
+
items = [x.strip() for x in labels if x and x.strip()]
|
| 89 |
+
if not items:
|
| 90 |
+
raise ValueError("No labels provided.")
|
| 91 |
+
# Grounding DINO expects nested list of phrases: [["a cat", "a remote control"]]
|
| 92 |
+
return [items]
|
| 93 |
+
|
| 94 |
+
def detect(
|
| 95 |
+
self,
|
| 96 |
+
image: Image.Image,
|
| 97 |
+
labels: List[str] | str,
|
| 98 |
+
box_threshold: float = 0.4,
|
| 99 |
+
text_threshold: float = 0.3,
|
| 100 |
+
pad_frac: float = 0.06,
|
| 101 |
+
) -> List[Dict[str, Any]]:
|
| 102 |
+
"""
|
| 103 |
+
Runs zero-shot detection and returns a list of dicts:
|
| 104 |
+
{ 'label': str, 'score': float, 'box_xyxy': (x0,y0,x1,y1), 'crop': PIL.Image }
|
| 105 |
+
"""
|
| 106 |
+
w, h = image.size
|
| 107 |
+
phrases = self._normalize_labels(labels)
|
| 108 |
+
inputs = self.processor(images=image, text=phrases, return_tensors="pt").to(self.device)
|
| 109 |
+
|
| 110 |
+
with self._lock, torch.no_grad():
|
| 111 |
+
outputs = self.model(**inputs)
|
| 112 |
+
|
| 113 |
+
# transformers>=4.51 renamed box_threshold -> threshold
|
| 114 |
+
try:
|
| 115 |
+
post = self.processor.post_process_grounded_object_detection(
|
| 116 |
+
outputs=outputs,
|
| 117 |
+
input_ids=inputs.input_ids,
|
| 118 |
+
threshold=float(box_threshold),
|
| 119 |
+
text_threshold=float(text_threshold),
|
| 120 |
+
target_sizes=[(h, w)],
|
| 121 |
+
)
|
| 122 |
+
except TypeError:
|
| 123 |
+
post = self.processor.post_process_grounded_object_detection(
|
| 124 |
+
outputs=outputs,
|
| 125 |
+
input_ids=inputs.input_ids,
|
| 126 |
+
box_threshold=float(box_threshold),
|
| 127 |
+
text_threshold=float(text_threshold),
|
| 128 |
+
target_sizes=[(h, w)],
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
det = post[0]
|
| 132 |
+
boxes = det.get("boxes", [])
|
| 133 |
+
scores = det.get("scores", [])
|
| 134 |
+
labels_out = det.get("text_labels", det.get("labels", []))
|
| 135 |
+
|
| 136 |
+
results: List[Dict[str, Any]] = []
|
| 137 |
+
for b, s, lab in zip(boxes, scores, labels_out):
|
| 138 |
+
b = b.tolist() if hasattr(b, "tolist") else list(b)
|
| 139 |
+
bx = _clamp_xyxy(b, w, h)
|
| 140 |
+
bx = _pad_box(bx, w, h, pad_frac)
|
| 141 |
+
crop = crop_from_box(image, bx)
|
| 142 |
+
score = float(s.item()) if torch.is_tensor(s) else float(s)
|
| 143 |
+
results.append({"label": lab, "score": score, "box_xyxy": bx, "crop": crop})
|
| 144 |
+
|
| 145 |
+
return results
|
| 146 |
+
|
| 147 |
+
# convenience singleton
|
| 148 |
+
_runner_singleton: GroundingDINORunner | None = None
|
| 149 |
+
|
| 150 |
+
def get_runner() -> GroundingDINORunner:
|
| 151 |
+
global _runner_singleton
|
| 152 |
+
if _runner_singleton is None:
|
| 153 |
+
_runner_singleton = GroundingDINORunner()
|
| 154 |
+
return _runner_singleton
|
| 155 |
+
|
ui.py
CHANGED
|
@@ -1,13 +1,11 @@
|
|
| 1 |
# ui.py
|
| 2 |
import os
|
| 3 |
-
|
| 4 |
import io
|
| 5 |
import json
|
| 6 |
import requests
|
| 7 |
import streamlit as st
|
| 8 |
from PIL import Image
|
| 9 |
|
| 10 |
-
|
| 11 |
st.set_page_config(page_title="SmolVLM UI", layout="wide")
|
| 12 |
st.title("SmolVLM")
|
| 13 |
|
|
@@ -22,9 +20,6 @@ with st.sidebar:
|
|
| 22 |
top_p = st.slider("top_p", 0.05, 1.0, 0.95, step=0.05) if topp_on else None
|
| 23 |
st.caption("API base: " + API_BASE)
|
| 24 |
|
| 25 |
-
tabs = st.tabs(["Upload images", "Image URLs"])
|
| 26 |
-
prompt = st.text_area("Prompt", "Can you describe the image(s)?", height=80)
|
| 27 |
-
|
| 28 |
def show_metrics(metrics: dict):
|
| 29 |
if not metrics:
|
| 30 |
return
|
|
@@ -40,9 +35,13 @@ def show_metrics(metrics: dict):
|
|
| 40 |
cols[3].metric("GPU reserved (MB)", f"{vram:.0f}" if vram is not None else "—")
|
| 41 |
st.expander("All metrics").json(info)
|
| 42 |
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
| 44 |
st.subheader("Upload one or more images")
|
| 45 |
files = st.file_uploader("Images", type=["png", "jpg", "jpeg", "webp"], accept_multiple_files=True)
|
|
|
|
| 46 |
run = st.button("Generate from uploads", type="primary", use_container_width=True, key="run_files")
|
| 47 |
|
| 48 |
if run:
|
|
@@ -87,19 +86,22 @@ with tabs[0]:
|
|
| 87 |
except Exception:
|
| 88 |
st.write(e.response.text)
|
| 89 |
|
| 90 |
-
|
|
|
|
| 91 |
st.subheader("Use remote image URLs")
|
| 92 |
-
|
|
|
|
|
|
|
| 93 |
run2 = st.button("Generate from URLs", type="primary", use_container_width=True, key="run_urls")
|
| 94 |
|
| 95 |
if run2:
|
| 96 |
urls = [u.strip() for u in urls_raw.splitlines() if u.strip()]
|
| 97 |
-
if not urls or not
|
| 98 |
st.error("Please add at least one URL and a prompt.")
|
| 99 |
else:
|
| 100 |
with st.spinner("Calling FastAPI…"):
|
| 101 |
body = {
|
| 102 |
-
"prompt":
|
| 103 |
"image_urls": urls,
|
| 104 |
"max_new_tokens": max_new_tokens,
|
| 105 |
"temperature": temperature, # FastAPI model allows null
|
|
@@ -123,3 +125,53 @@ with tabs[1]:
|
|
| 123 |
except Exception:
|
| 124 |
st.write(e.response.text)
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# ui.py
|
| 2 |
import os
|
|
|
|
| 3 |
import io
|
| 4 |
import json
|
| 5 |
import requests
|
| 6 |
import streamlit as st
|
| 7 |
from PIL import Image
|
| 8 |
|
|
|
|
| 9 |
st.set_page_config(page_title="SmolVLM UI", layout="wide")
|
| 10 |
st.title("SmolVLM")
|
| 11 |
|
|
|
|
| 20 |
top_p = st.slider("top_p", 0.05, 1.0, 0.95, step=0.05) if topp_on else None
|
| 21 |
st.caption("API base: " + API_BASE)
|
| 22 |
|
|
|
|
|
|
|
|
|
|
| 23 |
def show_metrics(metrics: dict):
|
| 24 |
if not metrics:
|
| 25 |
return
|
|
|
|
| 35 |
cols[3].metric("GPU reserved (MB)", f"{vram:.0f}" if vram is not None else "—")
|
| 36 |
st.expander("All metrics").json(info)
|
| 37 |
|
| 38 |
+
tab_upload, tab_urls, tab_detect = st.tabs(["Upload images", "Image URLs", "Detect & Describe"])
|
| 39 |
+
|
| 40 |
+
# -------------------- Tab 1: uploads -> /generate --------------------
|
| 41 |
+
with tab_upload:
|
| 42 |
st.subheader("Upload one or more images")
|
| 43 |
files = st.file_uploader("Images", type=["png", "jpg", "jpeg", "webp"], accept_multiple_files=True)
|
| 44 |
+
prompt = st.text_area("Prompt", "Can you describe the image(s)?", height=80)
|
| 45 |
run = st.button("Generate from uploads", type="primary", use_container_width=True, key="run_files")
|
| 46 |
|
| 47 |
if run:
|
|
|
|
| 86 |
except Exception:
|
| 87 |
st.write(e.response.text)
|
| 88 |
|
| 89 |
+
# -------------------- Tab 2: URLs -> /generate_urls --------------------
|
| 90 |
+
with tab_urls:
|
| 91 |
st.subheader("Use remote image URLs")
|
| 92 |
+
prompt2 = st.text_area("Prompt", "Can you describe the image(s)?", height=80, key="prompt_urls")
|
| 93 |
+
urls_raw = st.text_area("One URL per line", "", height=120,
|
| 94 |
+
placeholder="https://example.com/a.jpg\nhttps://example.com/b.png")
|
| 95 |
run2 = st.button("Generate from URLs", type="primary", use_container_width=True, key="run_urls")
|
| 96 |
|
| 97 |
if run2:
|
| 98 |
urls = [u.strip() for u in urls_raw.splitlines() if u.strip()]
|
| 99 |
+
if not urls or not prompt2.strip():
|
| 100 |
st.error("Please add at least one URL and a prompt.")
|
| 101 |
else:
|
| 102 |
with st.spinner("Calling FastAPI…"):
|
| 103 |
body = {
|
| 104 |
+
"prompt": prompt2,
|
| 105 |
"image_urls": urls,
|
| 106 |
"max_new_tokens": max_new_tokens,
|
| 107 |
"temperature": temperature, # FastAPI model allows null
|
|
|
|
| 125 |
except Exception:
|
| 126 |
st.write(e.response.text)
|
| 127 |
|
| 128 |
+
# -------------------- Tab 3: Detect & Describe -> /detect_describe --------------------
|
| 129 |
+
with tab_detect:
|
| 130 |
+
st.subheader("Grounding DINO + SmolVLM")
|
| 131 |
+
det_image = st.file_uploader("Image", type=["jpg", "jpeg", "png", "webp"], accept_multiple_files=False)
|
| 132 |
+
det_labels = st.text_input("Labels (comma-separated)", "a man,a dog")
|
| 133 |
+
det_box_thr = st.slider("box_threshold", 0.05, 0.95, 0.40, 0.01)
|
| 134 |
+
det_text_thr = st.slider("text_threshold", 0.05, 0.95, 0.30, 0.01)
|
| 135 |
+
det_pad = st.slider("crop padding (fraction)", 0.0, 0.2, 0.06, 0.01)
|
| 136 |
+
det_max_new = st.slider("max_new_tokens", 1, 512, 160, 1)
|
| 137 |
+
|
| 138 |
+
run_det = st.button("Detect & Describe", type="primary", use_container_width=True)
|
| 139 |
+
if run_det:
|
| 140 |
+
if not det_image or not det_labels.strip():
|
| 141 |
+
st.error("Please provide an image and at least one label.")
|
| 142 |
+
else:
|
| 143 |
+
with st.spinner("Calling FastAPI…"):
|
| 144 |
+
data = {
|
| 145 |
+
"labels": det_labels,
|
| 146 |
+
"box_threshold": str(det_box_thr),
|
| 147 |
+
"text_threshold": str(det_text_thr),
|
| 148 |
+
"pad_frac": str(det_pad),
|
| 149 |
+
"max_new_tokens": str(det_max_new),
|
| 150 |
+
"return_overlay": "true",
|
| 151 |
+
}
|
| 152 |
+
files = [("image", (det_image.name, det_image.read(), det_image.type or "application/octet-stream"))]
|
| 153 |
+
try:
|
| 154 |
+
r = requests.post(f"{API_BASE}/detect_describe", data=data, files=files, timeout=300)
|
| 155 |
+
r.raise_for_status()
|
| 156 |
+
out = r.json()
|
| 157 |
+
|
| 158 |
+
# Show overlay
|
| 159 |
+
b64 = out.get("overlay_png_b64")
|
| 160 |
+
if b64:
|
| 161 |
+
st.image(f"data:image/png;base64,{b64}", caption="Detections", use_column_width=True)
|
| 162 |
+
|
| 163 |
+
# List detections
|
| 164 |
+
dets = out.get("detections", [])
|
| 165 |
+
if not dets:
|
| 166 |
+
st.info("No detections at current thresholds.")
|
| 167 |
+
for i, d in enumerate(dets, 1):
|
| 168 |
+
st.markdown(f"**{i}. {d['label']}** (score={d['score']:.2f}, box={d['box_xyxy']})")
|
| 169 |
+
st.write(d["description"])
|
| 170 |
+
except requests.RequestException as e:
|
| 171 |
+
st.error(f"Request failed: {e}")
|
| 172 |
+
if hasattr(e, "response") and e.response is not None:
|
| 173 |
+
try:
|
| 174 |
+
st.code(e.response.text, language="json")
|
| 175 |
+
except Exception:
|
| 176 |
+
st.write(e.response.text)
|
| 177 |
+
|
util.py
CHANGED
|
@@ -27,7 +27,7 @@ from PIL import Image
|
|
| 27 |
from transformers import AutoProcessor, AutoModelForVision2Seq
|
| 28 |
from transformers.image_utils import load_image as hf_load_image
|
| 29 |
|
| 30 |
-
|
| 31 |
|
| 32 |
|
| 33 |
def _has_flash_attn() -> bool:
|
|
@@ -102,6 +102,78 @@ class SmolVLMRunner:
|
|
| 102 |
return [cls._ensure_rgb(Image.open(BytesIO(b))) for b in blobs]
|
| 103 |
|
| 104 |
# ---------- Inference ----------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
def generate(
|
| 106 |
self,
|
| 107 |
prompt: str,
|
|
|
|
| 27 |
from transformers import AutoProcessor, AutoModelForVision2Seq
|
| 28 |
from transformers.image_utils import load_image as hf_load_image
|
| 29 |
|
| 30 |
+
from grounding_dino2 import get_runner as get_gdino_runner, visualize_detections
|
| 31 |
|
| 32 |
|
| 33 |
def _has_flash_attn() -> bool:
|
|
|
|
| 102 |
return [cls._ensure_rgb(Image.open(BytesIO(b))) for b in blobs]
|
| 103 |
|
| 104 |
# ---------- Inference ----------
|
| 105 |
+
def detect_and_describe(
|
| 106 |
+
self,
|
| 107 |
+
image: Image.Image,
|
| 108 |
+
labels: list[str] | str,
|
| 109 |
+
*,
|
| 110 |
+
box_threshold: float = 0.4,
|
| 111 |
+
text_threshold: float = 0.3,
|
| 112 |
+
pad_frac: float = 0.06,
|
| 113 |
+
max_new_tokens: int = 160,
|
| 114 |
+
temperature: float | None = None,
|
| 115 |
+
top_p: float | None = None,
|
| 116 |
+
return_overlay: bool = False,
|
| 117 |
+
) -> list[dict] | dict:
|
| 118 |
+
"""
|
| 119 |
+
Uses Grounding DINO to detect boxes for `labels`, then asks SmolVLM to
|
| 120 |
+
describe each cropped box.
|
| 121 |
+
|
| 122 |
+
If return_overlay=False (default): returns a list of dicts:
|
| 123 |
+
[{ 'label','score','box_xyxy','description' }, ...]
|
| 124 |
+
If return_overlay=True: returns a dict:
|
| 125 |
+
{ 'detections': [...], 'overlay_png_b64': '<base64 PNG>' }
|
| 126 |
+
"""
|
| 127 |
+
gdino = get_gdino_runner()
|
| 128 |
+
detections = gdino.detect(
|
| 129 |
+
image=image,
|
| 130 |
+
labels=labels,
|
| 131 |
+
box_threshold=box_threshold,
|
| 132 |
+
text_threshold=text_threshold,
|
| 133 |
+
pad_frac=pad_frac,
|
| 134 |
+
)
|
| 135 |
+
if not detections:
|
| 136 |
+
return [] if not return_overlay else {"detections": [], "overlay_png_b64": None}
|
| 137 |
+
|
| 138 |
+
results: list[dict] = []
|
| 139 |
+
for det in detections:
|
| 140 |
+
crop = det["crop"]
|
| 141 |
+
prompt_txt = f"Describe the object inside this crop in detail. It was detected with the phrase: '{det['label']}'."
|
| 142 |
+
content = [{"type": "image"}, {"type": "text", "text": prompt_txt}]
|
| 143 |
+
messages = [{"role": "user", "content": content}]
|
| 144 |
+
chat_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
|
| 145 |
+
|
| 146 |
+
inputs = self.processor(text=chat_prompt, images=[crop], return_tensors="pt")
|
| 147 |
+
inputs = {k: (v.to(self.device) if hasattr(v, "to") else v) for k, v in inputs.items()}
|
| 148 |
+
|
| 149 |
+
gen_kwargs = dict(max_new_tokens=max_new_tokens)
|
| 150 |
+
if temperature is not None:
|
| 151 |
+
gen_kwargs["temperature"] = float(temperature)
|
| 152 |
+
if top_p is not None:
|
| 153 |
+
gen_kwargs["top_p"] = float(top_p)
|
| 154 |
+
|
| 155 |
+
with self._lock, torch.inference_mode():
|
| 156 |
+
out_ids = self.model.generate(**inputs, **gen_kwargs)
|
| 157 |
+
text = self.processor.batch_decode(out_ids, skip_special_tokens=True)[0].strip()
|
| 158 |
+
if text.startswith("Assistant:"):
|
| 159 |
+
text = text[len("Assistant:"):].strip()
|
| 160 |
+
|
| 161 |
+
results.append({
|
| 162 |
+
"label": det["label"],
|
| 163 |
+
"score": det["score"],
|
| 164 |
+
"box_xyxy": det["box_xyxy"],
|
| 165 |
+
"description": text,
|
| 166 |
+
})
|
| 167 |
+
|
| 168 |
+
if not return_overlay:
|
| 169 |
+
return results
|
| 170 |
+
|
| 171 |
+
# Build overlay image (PNG -> base64 string)
|
| 172 |
+
overlay = visualize_detections(image, detections)
|
| 173 |
+
buf = io.BytesIO()
|
| 174 |
+
overlay.save(buf, format="PNG")
|
| 175 |
+
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
| 176 |
+
return {"detections": results, "overlay_png_b64": b64}
|
| 177 |
def generate(
|
| 178 |
self,
|
| 179 |
prompt: str,
|