File size: 2,195 Bytes
add2ebd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import time
from typing import List
import torch
from fastapi import FastAPI, UploadFile, File, Form
from PIL import Image
import io
import numpy as np
import supervision as sv

from groundingdino.util.inference import load_model, predict, annotate
import groundingdino.datasets.transforms as T

app = FastAPI()

# Load Grounding DINO model - Using Swin-B as Swin-L is not available
CONFIG_PATH = "weights/GroundingDINO_SwinB_cfg.py"
CHECKPOINT_PATH = "weights/groundingdino_swinb_cogcoor.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {DEVICE}")
model = load_model(CONFIG_PATH, CHECKPOINT_PATH, device=DEVICE)

def load_image_from_pil(image_pil):
    transform = T.Compose(
        [
            T.RandomResize([800], max_size=1333),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
    image_source = np.asarray(image_pil)
    image_transformed, _ = transform(image_pil, None)
    return image_transformed

@app.post("/predict")
async def predict_api(
    image: UploadFile = File(...), 
    text_prompt: str = Form(...), 
    box_threshold: float = Form(0.35), 
    text_threshold: float = Form(0.25)
):
    start_time = time.time()

    image_data = await image.read()
    image_pil = Image.open(io.BytesIO(image_data)).convert("RGB")
    
    # Load image for Grounding DINO
    transformed_image = load_image_from_pil(image_pil)

    # Predict
    boxes, logits, phrases = predict(
        model=model,
        image=transformed_image,
        caption=text_prompt,
        box_threshold=box_threshold,
        text_threshold=text_threshold,
        device=DEVICE
    )

    inference_time = time.time() - start_time

    return {
        "boxes": boxes.tolist(),
        "logits": logits.tolist(),
        "phrases": phrases,
        "inference_time": inference_time
    }

@app.get("/")
async def root():
    return {"message": "Grounding DINO FastAPI is running!"}

# Programmatically set HF_TOKEN
hf_token = os.getenv("HF_TOKEN")
if hf_token:
    print(f"Hugging Face Token found: {hf_token[:5]}...")
else:
    print("Hugging Face Token not found in environment variables.")