Grounded-SAM / app.py
XtewaldX's picture
Update app.py
b1d4f19 verified
raw
history blame
4.89 kB
import os
import torch
import gradio as gr
import numpy as np
from PIL import Image
from transformers import (
AutoProcessor,
AutoModelForZeroShotObjectDetection,
BlipProcessor,
BlipForConditionalGeneration
)
from segment_anything import sam_model_registry, SamPredictor
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# --------------------------------------------------
# MODELS
# --------------------------------------------------
DINO_MODEL = "IDEA-Research/grounding-dino-base"
BLIP_MODEL = "Salesforce/blip-image-captioning-base"
SAM_TYPE = "vit_b"
SAM_CHECKPOINT = "sam_vit_b.pth"
SAM_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
BOX_THRESHOLD = 0.3
# --------------------------------------------------
# DOWNLOAD SAM
# --------------------------------------------------
if not os.path.exists(SAM_CHECKPOINT):
import urllib.request
print("Downloading SAM checkpoint...")
urllib.request.urlretrieve(SAM_URL, SAM_CHECKPOINT)
# --------------------------------------------------
# LOAD MODELS
# --------------------------------------------------
print("Loading GroundingDINO...")
processor = AutoProcessor.from_pretrained(DINO_MODEL)
dino = AutoModelForZeroShotObjectDetection.from_pretrained(DINO_MODEL).to(DEVICE)
print("Loading SAM...")
sam = sam_model_registry[SAM_TYPE](checkpoint=SAM_CHECKPOINT)
sam.to(device=DEVICE)
predictor = SamPredictor(sam)
print("Loading BLIP...")
blip_processor = BlipProcessor.from_pretrained(BLIP_MODEL)
blip_model = BlipForConditionalGeneration.from_pretrained(BLIP_MODEL).to(DEVICE)
# --------------------------------------------------
# BLIP CAPTION
# --------------------------------------------------
def generate_caption(image):
inputs = blip_processor(image, return_tensors="pt").to(DEVICE)
with torch.no_grad():
out = blip_model.generate(**inputs)
caption = blip_processor.decode(out[0], skip_special_tokens=True)
return caption
# --------------------------------------------------
# DETECT OBJECTS
# --------------------------------------------------
def detect(image, prompt):
inputs = processor(images=image, text=prompt, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = dino(**inputs)
results = processor.post_process_grounded_object_detection(
outputs,
target_sizes=[image.size[::-1]],
)[0]
boxes = results["boxes"]
scores = results["scores"]
keep = scores > BOX_THRESHOLD
return boxes[keep]
# --------------------------------------------------
# DRAW BOXES
# --------------------------------------------------
def draw_boxes(image, boxes):
image_np = np.array(image)
result = image_np.copy()
for box in boxes:
x1, y1, x2, y2 = box.cpu().numpy().astype(int)
result[y1:y1+3, x1:x2] = [255, 0, 0]
result[y2:y2+3, x1:x2] = [255, 0, 0]
result[y1:y2, x1:x1+3] = [255, 0, 0]
result[y1:y2, x2:x2+3] = [255, 0, 0]
return Image.fromarray(result)
# --------------------------------------------------
# SEGMENT
# --------------------------------------------------
def segment(image, prompt):
image = image.convert("RGB")
image_np = np.array(image)
boxes = detect(image, prompt)
if len(boxes) == 0:
return image
predictor.set_image(image_np)
boxes = boxes.to(DEVICE)
transformed = predictor.transform.apply_boxes_torch(
boxes, image_np.shape[:2]
)
masks, _, _ = predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed,
multimask_output=False,
)
result = image_np.copy()
for mask in masks:
m = mask[0].cpu().numpy()
result[m > 0] = (
result[m > 0] * 0.5 + np.array([0, 255, 0]) * 0.5
).astype(np.uint8)
return Image.fromarray(result)
# --------------------------------------------------
# PIPELINE
# --------------------------------------------------
def run_pipeline(image, prompt, mode):
if mode == "seg":
return segment(image, prompt)
if mode == "det":
boxes = detect(image, prompt)
return draw_boxes(image, boxes)
if mode == "automatic":
caption = generate_caption(image)
print("BLIP caption:", caption)
return segment(image, caption)
# --------------------------------------------------
# UI
# --------------------------------------------------
demo = gr.Interface(
fn=run_pipeline,
inputs=[
gr.Image(type="pil"),
gr.Textbox(label="Prompt", value="person"),
gr.Dropdown(
["seg", "det", "automatic"],
value="seg",
label="Mode"
),
],
outputs=gr.Image(),
title="GroundingDINO + SAM + BLIP (CPU version)",
)
if __name__ == "__main__":
demo.launch()