| import os |
| import time |
| from typing import List |
| import torch |
| from fastapi import FastAPI, UploadFile, File, Form |
| from PIL import Image |
| import io |
| import numpy as np |
| import supervision as sv |
|
|
| from groundingdino.util.inference import load_model, predict, annotate |
| import groundingdino.datasets.transforms as T |
|
|
| app = FastAPI() |
|
|
| |
| CONFIG_PATH = "weights/GroundingDINO_SwinB_cfg.py" |
| CHECKPOINT_PATH = "weights/groundingdino_swinb_cogcoor.pth" |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| print(f"Using device: {DEVICE}") |
| model = load_model(CONFIG_PATH, CHECKPOINT_PATH, device=DEVICE) |
|
|
| def load_image_from_pil(image_pil): |
| transform = T.Compose( |
| [ |
| T.RandomResize([800], max_size=1333), |
| T.ToTensor(), |
| T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| ] |
| ) |
| image_source = np.asarray(image_pil) |
| image_transformed, _ = transform(image_pil, None) |
| return image_transformed |
|
|
| @app.post("/predict") |
| async def predict_api( |
| image: UploadFile = File(...), |
| text_prompt: str = Form(...), |
| box_threshold: float = Form(0.35), |
| text_threshold: float = Form(0.25) |
| ): |
| start_time = time.time() |
|
|
| image_data = await image.read() |
| image_pil = Image.open(io.BytesIO(image_data)).convert("RGB") |
| |
| |
| transformed_image = load_image_from_pil(image_pil) |
|
|
| |
| boxes, logits, phrases = predict( |
| model=model, |
| image=transformed_image, |
| caption=text_prompt, |
| box_threshold=box_threshold, |
| text_threshold=text_threshold, |
| device=DEVICE |
| ) |
|
|
| inference_time = time.time() - start_time |
|
|
| return { |
| "boxes": boxes.tolist(), |
| "logits": logits.tolist(), |
| "phrases": phrases, |
| "inference_time": inference_time |
| } |
|
|
| @app.get("/") |
| async def root(): |
| return {"message": "Grounding DINO FastAPI is running!"} |
|
|
| |
| hf_token = os.getenv("HF_TOKEN") |
| if hf_token: |
| print(f"Hugging Face Token found: {hf_token[:5]}...") |
| else: |
| print("Hugging Face Token not found in environment variables.") |
|
|