import gradio as gr import torch import numpy as np from PIL import Image import cv2 from segment_anything import sam_model_registry, SamPredictor from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection import supervision as sv import os import urllib.request # Download SAM checkpoint if not exists SAM_CHECKPOINT = "sam_vit_h_4b8939.pth" SAM_CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" if not os.path.exists(SAM_CHECKPOINT): print(f"Downloading SAM checkpoint...") urllib.request.urlretrieve(SAM_CHECKPOINT_URL, SAM_CHECKPOINT) print(f"SAM checkpoint downloaded!") # Initialize models device = "cuda" if torch.cuda.is_available() else "cpu" # Load Grounding DINO from Hugging Face grounding_dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny") grounding_dino_model = AutoModelForZeroShotObjectDetection.from_pretrained( "IDEA-Research/grounding-dino-tiny" ).to(device) # Load SAM sam = sam_model_registry["vit_h"](checkpoint=SAM_CHECKPOINT) sam.to(device=device) sam_predictor = SamPredictor(sam) def process_image(image, text_prompt, box_threshold, text_threshold, quality): """ Process image with Grounded SAM """ try: # Resize based on quality setting if quality == "Low": max_size = 800 elif quality == "Medium": max_size = 1024 else: # High max_size = 1920 # Resize image if needed h, w = image.shape[:2] if max(h, w) > max_size: scale = max_size / max(h, w) new_h, new_w = int(h * scale), int(w * scale) image = cv2.resize(image, (new_w, new_h)) # Convert to PIL Image for Grounding DINO pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) # Grounding DINO inference inputs = grounding_dino_processor(images=pil_image, text=text_prompt, return_tensors="pt").to(device) with torch.no_grad(): outputs = grounding_dino_model(**inputs) # Post-process results results = grounding_dino_processor.post_process_grounded_object_detection( outputs, inputs.input_ids, box_threshold=box_threshold, text_threshold=text_threshold, target_sizes=[pil_image.size[::-1]] )[0] # Extract boxes and labels boxes = results["boxes"].cpu().numpy() labels = results["labels"] if len(boxes) == 0: return image, "No objects detected. Try adjusting the thresholds or text prompt." # Convert boxes to xyxy format for SAM boxes_xyxy = boxes # SAM inference sam_predictor.set_image(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) masks = [] for box in boxes_xyxy: mask, _, _ = sam_predictor.predict( box=box, multimask_output=False ) masks.append(mask[0]) # Visualize results result_image = image.copy() # Draw masks for i, mask in enumerate(masks): color = np.random.randint(0, 255, 3).tolist() result_image[mask] = result_image[mask] * 0.5 + np.array(color) * 0.5 # Draw boxes and labels for i, (box, label) in enumerate(zip(boxes_xyxy, labels)): x1, y1, x2, y2 = map(int, box) color = np.random.randint(0, 255, 3).tolist() cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 2) cv2.putText(result_image, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) metadata = f"✅ Detected {len(boxes)} objects: {', '.join(labels)}" return result_image, metadata except Exception as e: return image, f"❌ Error: {str(e)}" # Gradio Interface with gr.Blocks(title="Grounded SAM") as demo: gr.Markdown("# 🎯 Grounded SAM - Object Detection & Segmentation") gr.Markdown("Upload an image and describe what you want to detect (e.g., 'fish', 'all fish', 'person').") with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input Image", type="numpy") text_prompt = gr.Textbox( label="Text Prompt", placeholder="e.g., 'fish', 'person', 'car'", value="fish" ) with gr.Accordion("Advanced Settings", open=False): box_threshold = gr.Slider( minimum=0.0, maximum=1.0, value=0.35, step=0.05, label="Box Threshold (detection confidence)" ) text_threshold = gr.Slider( minimum=0.0, maximum=1.0, value=0.25, step=0.05, label="Text Threshold (text matching confidence)" ) quality = gr.Radio( choices=["Low", "Medium", "High"], value="Medium", label="Processing Quality" ) submit_btn = gr.Button("🚀 Process Image", variant="primary") with gr.Column(): output_image = gr.Image(label="Output with Masks & Boxes", type="numpy") output_metadata = gr.Textbox(label="Detection Metadata", lines=3) submit_btn.click( fn=process_image, inputs=[input_image, text_prompt, box_threshold, text_threshold, quality], outputs=[output_image, output_metadata] ) gr.Examples( examples=[ ["examples/fish1.jpg", "fish", 0.35, 0.25, "Medium"], ["examples/fish2.jpg", "all fish", 0.35, 0.25, "Medium"], ], inputs=[input_image, text_prompt, box_threshold, text_threshold, quality], ) demo.launch()