from io import BytesIO from collections import Counter import requests import torch import re import gradio as gr from PIL import Image, ImageDraw from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection # Load model once model_id = "IDEA-Research/grounding-dino-tiny" device = "cuda" if torch.cuda.is_available() else "cpu" processor = AutoProcessor.from_pretrained(model_id) model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device) # prompt normalization function def normalize_prompt(text_prompt): # lowercase text_prompt = text_prompt.lower().strip() # replace common separators with "." text_prompt = re.sub(r"\s*(and|,|&)\s*", ".", text_prompt) # split words and remove empties parts = [p.strip() for p in text_prompt.split(".") if p.strip()] # rebuild as "cat. dog." return ". ".join(parts) + "." def detect_objects(image_url, uploaded_image, text_prompt): try: # Load image if uploaded_image is not None: image = uploaded_image.convert("RGB") elif image_url: headers = {"User-Agent": "Mozilla/5.0"} response = requests.get(image_url, headers=headers, timeout=10) response.raise_for_status() image = Image.open(BytesIO(response.content)).convert("RGB") else: return None, "Please provide an image URL or upload an image." # Default prompt fallback display_prompt = text_prompt.strip() if text_prompt and text_prompt.strip() else "capsule" model_prompt = normalize_prompt(display_prompt) # Inference inputs = processor(images=image, text=model_prompt, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) results = processor.post_process_grounded_object_detection( outputs, inputs.input_ids, threshold=0.4, target_sizes=[image.size[::-1]] ) # Draw results draw = ImageDraw.Draw(image) detected_labels = [] for result in results: boxes = result["boxes"] scores = result["scores"] labels = result["text_labels"] # Skip empty detections if len(boxes) == 0: continue for box, score, label in zip(boxes, scores, labels): box = box.tolist() detected_labels.append(label) x1, y1, x2, y2 = box # Draw bounding box draw.rectangle([x1, y1, x2, y2], outline="red", width=3) # Draw label draw.text((x1, max(0, y1 - 15)), f"{label} {score:.2f}", fill="red") if not detected_labels: searched_object = text_prompt.replace(".", ", ").strip(", ").strip() return f""" ## No {searched_object} found in the image Try: - lowering the threshold - using a clearer image - changing the detection prompt """, image counts = Counter(detected_labels) summary_rows = [] for label, count in counts.items(): summary_rows.append( f"{label}" f"{count}" ) total_types = len(counts) summary = f"""

Detected {total_types} object type(s) for: {display_prompt}

{''.join(summary_rows)}
Object Count
""" return summary, image except Exception as e: return None, f"Error: {str(e)}" app = gr.Interface( fn=detect_objects, inputs=[ gr.Textbox(label="Image URL"), gr.Image(type="pil", label="Upload JPG/PNG"), gr.Textbox(label="Detection Prompt", placeholder="e.g. a cat") ], outputs=[ gr.Markdown(label="Detection Summary"), gr.Image(label="Annotated Image") ], title="Grounding DINO Object Detection", description="Upload an image or provide an image URL, then enter objects to detect." ) app.launch(server_name="0.0.0.0", server_port=7860)