EnginDev's picture
Update app.py
6f96e62 verified
import gradio as gr
import torch
import numpy as np
from PIL import Image
import cv2
from segment_anything import sam_model_registry, SamPredictor
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
import supervision as sv
import os
import urllib.request
# Download SAM checkpoint if not exists
SAM_CHECKPOINT = "sam_vit_h_4b8939.pth"
SAM_CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
if not os.path.exists(SAM_CHECKPOINT):
print(f"Downloading SAM checkpoint...")
urllib.request.urlretrieve(SAM_CHECKPOINT_URL, SAM_CHECKPOINT)
print(f"SAM checkpoint downloaded!")
# Initialize models
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Grounding DINO from Hugging Face
grounding_dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
grounding_dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(
"IDEA-Research/grounding-dino-tiny"
).to(device)
# Load SAM
sam = sam_model_registry["vit_h"](checkpoint=SAM_CHECKPOINT)
sam.to(device=device)
sam_predictor = SamPredictor(sam)
def process_image(image, text_prompt, box_threshold, text_threshold, quality):
"""
Process image with Grounded SAM
"""
try:
# Resize based on quality setting
if quality == "Low":
max_size = 800
elif quality == "Medium":
max_size = 1024
else: # High
max_size = 1920
# Resize image if needed
h, w = image.shape[:2]
if max(h, w) > max_size:
scale = max_size / max(h, w)
new_h, new_w = int(h * scale), int(w * scale)
image = cv2.resize(image, (new_w, new_h))
# Convert to PIL Image for Grounding DINO
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
# Grounding DINO inference
inputs = grounding_dino_processor(images=pil_image, text=text_prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = grounding_dino_model(**inputs)
# Post-process results
results = grounding_dino_processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=box_threshold,
text_threshold=text_threshold,
target_sizes=[pil_image.size[::-1]]
)[0]
# Extract boxes and labels
boxes = results["boxes"].cpu().numpy()
labels = results["labels"]
if len(boxes) == 0:
return image, "No objects detected. Try adjusting the thresholds or text prompt."
# Convert boxes to xyxy format for SAM
boxes_xyxy = boxes
# SAM inference
sam_predictor.set_image(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
masks = []
for box in boxes_xyxy:
mask, _, _ = sam_predictor.predict(
box=box,
multimask_output=False
)
masks.append(mask[0])
# Visualize results
result_image = image.copy()
# Draw masks
for i, mask in enumerate(masks):
color = np.random.randint(0, 255, 3).tolist()
result_image[mask] = result_image[mask] * 0.5 + np.array(color) * 0.5
# Draw boxes and labels
for i, (box, label) in enumerate(zip(boxes_xyxy, labels)):
x1, y1, x2, y2 = map(int, box)
color = np.random.randint(0, 255, 3).tolist()
cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 2)
cv2.putText(result_image, label, (x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
metadata = f"βœ… Detected {len(boxes)} objects: {', '.join(labels)}"
return result_image, metadata
except Exception as e:
return image, f"❌ Error: {str(e)}"
# Gradio Interface
with gr.Blocks(title="Grounded SAM") as demo:
gr.Markdown("# 🎯 Grounded SAM - Object Detection & Segmentation")
gr.Markdown("Upload an image and describe what you want to detect (e.g., 'fish', 'all fish', 'person').")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="numpy")
text_prompt = gr.Textbox(
label="Text Prompt",
placeholder="e.g., 'fish', 'person', 'car'",
value="fish"
)
with gr.Accordion("Advanced Settings", open=False):
box_threshold = gr.Slider(
minimum=0.0, maximum=1.0, value=0.35, step=0.05,
label="Box Threshold (detection confidence)"
)
text_threshold = gr.Slider(
minimum=0.0, maximum=1.0, value=0.25, step=0.05,
label="Text Threshold (text matching confidence)"
)
quality = gr.Radio(
choices=["Low", "Medium", "High"],
value="Medium",
label="Processing Quality"
)
submit_btn = gr.Button("πŸš€ Process Image", variant="primary")
with gr.Column():
output_image = gr.Image(label="Output with Masks & Boxes", type="numpy")
output_metadata = gr.Textbox(label="Detection Metadata", lines=3)
submit_btn.click(
fn=process_image,
inputs=[input_image, text_prompt, box_threshold, text_threshold, quality],
outputs=[output_image, output_metadata]
)
gr.Examples(
examples=[
["examples/fish1.jpg", "fish", 0.35, 0.25, "Medium"],
["examples/fish2.jpg", "all fish", 0.35, 0.25, "Medium"],
],
inputs=[input_image, text_prompt, box_threshold, text_threshold, quality],
)
demo.launch()