Spaces:
Runtime error
Runtime error
File size: 5,962 Bytes
7a701fd 6f96e62 7a701fd 6f96e62 7a701fd 6f96e62 7a701fd 851e42c 6f96e62 7a701fd 6f96e62 7a701fd 6f96e62 7a701fd 6f96e62 7a701fd 6f96e62 851e42c 6f96e62 851e42c 6f96e62 7a701fd 6f96e62 7a701fd 6f96e62 7a701fd 851e42c 6f96e62 7a701fd 851e42c 6f96e62 851e42c 6f96e62 851e42c 6f96e62 851e42c 6f96e62 851e42c 6f96e62 851e42c 6f96e62 851e42c 7a701fd 6f96e62 7a701fd 851e42c 6f96e62 7a701fd 6f96e62 851e42c 6f96e62 7a701fd 6f96e62 7a701fd 6f96e62 7a701fd 6f96e62 7a701fd 6f96e62 7a701fd 6f96e62 7a701fd 6f96e62 7a701fd 851e42c 6f96e62 851e42c 7a701fd 851e42c 6f96e62 851e42c 6f96e62 851e42c 7a701fd 6f96e62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import gradio as gr
import torch
import numpy as np
from PIL import Image
import cv2
from segment_anything import sam_model_registry, SamPredictor
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
import supervision as sv
import os
import urllib.request
# Download SAM checkpoint if not exists
SAM_CHECKPOINT = "sam_vit_h_4b8939.pth"
SAM_CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
if not os.path.exists(SAM_CHECKPOINT):
print(f"Downloading SAM checkpoint...")
urllib.request.urlretrieve(SAM_CHECKPOINT_URL, SAM_CHECKPOINT)
print(f"SAM checkpoint downloaded!")
# Initialize models
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Grounding DINO from Hugging Face
grounding_dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
grounding_dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(
"IDEA-Research/grounding-dino-tiny"
).to(device)
# Load SAM
sam = sam_model_registry["vit_h"](checkpoint=SAM_CHECKPOINT)
sam.to(device=device)
sam_predictor = SamPredictor(sam)
def process_image(image, text_prompt, box_threshold, text_threshold, quality):
"""
Process image with Grounded SAM
"""
try:
# Resize based on quality setting
if quality == "Low":
max_size = 800
elif quality == "Medium":
max_size = 1024
else: # High
max_size = 1920
# Resize image if needed
h, w = image.shape[:2]
if max(h, w) > max_size:
scale = max_size / max(h, w)
new_h, new_w = int(h * scale), int(w * scale)
image = cv2.resize(image, (new_w, new_h))
# Convert to PIL Image for Grounding DINO
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
# Grounding DINO inference
inputs = grounding_dino_processor(images=pil_image, text=text_prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = grounding_dino_model(**inputs)
# Post-process results
results = grounding_dino_processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=box_threshold,
text_threshold=text_threshold,
target_sizes=[pil_image.size[::-1]]
)[0]
# Extract boxes and labels
boxes = results["boxes"].cpu().numpy()
labels = results["labels"]
if len(boxes) == 0:
return image, "No objects detected. Try adjusting the thresholds or text prompt."
# Convert boxes to xyxy format for SAM
boxes_xyxy = boxes
# SAM inference
sam_predictor.set_image(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
masks = []
for box in boxes_xyxy:
mask, _, _ = sam_predictor.predict(
box=box,
multimask_output=False
)
masks.append(mask[0])
# Visualize results
result_image = image.copy()
# Draw masks
for i, mask in enumerate(masks):
color = np.random.randint(0, 255, 3).tolist()
result_image[mask] = result_image[mask] * 0.5 + np.array(color) * 0.5
# Draw boxes and labels
for i, (box, label) in enumerate(zip(boxes_xyxy, labels)):
x1, y1, x2, y2 = map(int, box)
color = np.random.randint(0, 255, 3).tolist()
cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 2)
cv2.putText(result_image, label, (x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
metadata = f"β
Detected {len(boxes)} objects: {', '.join(labels)}"
return result_image, metadata
except Exception as e:
return image, f"β Error: {str(e)}"
# Gradio Interface
with gr.Blocks(title="Grounded SAM") as demo:
gr.Markdown("# π― Grounded SAM - Object Detection & Segmentation")
gr.Markdown("Upload an image and describe what you want to detect (e.g., 'fish', 'all fish', 'person').")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="numpy")
text_prompt = gr.Textbox(
label="Text Prompt",
placeholder="e.g., 'fish', 'person', 'car'",
value="fish"
)
with gr.Accordion("Advanced Settings", open=False):
box_threshold = gr.Slider(
minimum=0.0, maximum=1.0, value=0.35, step=0.05,
label="Box Threshold (detection confidence)"
)
text_threshold = gr.Slider(
minimum=0.0, maximum=1.0, value=0.25, step=0.05,
label="Text Threshold (text matching confidence)"
)
quality = gr.Radio(
choices=["Low", "Medium", "High"],
value="Medium",
label="Processing Quality"
)
submit_btn = gr.Button("π Process Image", variant="primary")
with gr.Column():
output_image = gr.Image(label="Output with Masks & Boxes", type="numpy")
output_metadata = gr.Textbox(label="Detection Metadata", lines=3)
submit_btn.click(
fn=process_image,
inputs=[input_image, text_prompt, box_threshold, text_threshold, quality],
outputs=[output_image, output_metadata]
)
gr.Examples(
examples=[
["examples/fish1.jpg", "fish", 0.35, 0.25, "Medium"],
["examples/fish2.jpg", "all fish", 0.35, 0.25, "Medium"],
],
inputs=[input_image, text_prompt, box_threshold, text_threshold, quality],
)
demo.launch()
|