File size: 4,604 Bytes
998fc2d a54fa15 a04d092 bb2bd9b 070a65e a54fa15 070a65e a04d092 070a65e a04d092 bb2bd9b 998fc2d 070a65e a54fa15 070a65e a54fa15 070a65e 577d286 3b1b323 bb2bd9b 070a65e a54fa15 bb2bd9b 070a65e b287adc 070a65e a54fa15 343ac5d 070a65e 343ac5d a54fa15 343ac5d a54fa15 343ac5d a54fa15 343ac5d a54fa15 343ac5d 998fc2d a54fa15 1ad2815 006c2ee 1ad2815 343ac5d 006c2ee 1ad2815 343ac5d a54fa15 1ad2815 68243d8 1ad2815 68243d8 3b1b323 68243d8 343ac5d 3b1b323 a54fa15 02c397a 998fc2d 070a65e a54fa15 a04d092 070a65e dfac13c 070a65e a54fa15 a63426f a54fa15 070a65e c1df300 a04d092 2e32330 343ac5d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | from io import BytesIO
from collections import Counter
import requests
import torch
import re
import gradio as gr
from PIL import Image, ImageDraw
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
# Load model once
model_id = "IDEA-Research/grounding-dino-tiny"
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
# prompt normalization function
def normalize_prompt(text_prompt):
# lowercase
text_prompt = text_prompt.lower().strip()
# replace common separators with "."
text_prompt = re.sub(r"\s*(and|,|&)\s*", ".", text_prompt)
# split words and remove empties
parts = [p.strip() for p in text_prompt.split(".") if p.strip()]
# rebuild as "cat. dog."
return ". ".join(parts) + "."
def detect_objects(image_url, uploaded_image, text_prompt):
try:
# Load image
if uploaded_image is not None:
image = uploaded_image.convert("RGB")
elif image_url:
headers = {"User-Agent": "Mozilla/5.0"}
response = requests.get(image_url, headers=headers, timeout=10)
response.raise_for_status()
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
return None, "Please provide an image URL or upload an image."
# Default prompt fallback
display_prompt = text_prompt.strip() if text_prompt and text_prompt.strip() else "capsule"
model_prompt = normalize_prompt(display_prompt)
# Inference
inputs = processor(images=image, text=model_prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
threshold=0.4,
target_sizes=[image.size[::-1]]
)
# Draw results
draw = ImageDraw.Draw(image)
detected_labels = []
for result in results:
boxes = result["boxes"]
scores = result["scores"]
labels = result["text_labels"]
# Skip empty detections
if len(boxes) == 0:
continue
for box, score, label in zip(boxes, scores, labels):
box = box.tolist()
detected_labels.append(label)
x1, y1, x2, y2 = box
# Draw bounding box
draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
# Draw label
draw.text((x1, max(0, y1 - 15)), f"{label} {score:.2f}", fill="red")
if not detected_labels:
searched_object = text_prompt.replace(".", ", ").strip(", ").strip()
return f"""
## No {searched_object} found in the image
Try:
- lowering the threshold
- using a clearer image
- changing the detection prompt
""", image
counts = Counter(detected_labels)
summary_rows = []
for label, count in counts.items():
summary_rows.append(
f"<tr><td style='padding:4px 12px'>{label}</td>"
f"<td style='padding:4px 12px'><b>{count}</b></td></tr>"
)
total_types = len(counts)
summary = f"""
<h3>Detected {total_types} object type(s) for: {display_prompt}</h3>
<table style='border-collapse: collapse; width: 100%;'>
<tr>
<th style='text-align:left; padding:4px 12px;'>Object</th>
<th style='text-align:left; padding:4px 12px;'>Count</th>
</tr>
{''.join(summary_rows)}
</table>
"""
return summary, image
except Exception as e:
return None, f"Error: {str(e)}"
app = gr.Interface(
fn=detect_objects,
inputs=[
gr.Textbox(label="Image URL"),
gr.Image(type="pil", label="Upload JPG/PNG"),
gr.Textbox(label="Detection Prompt", placeholder="e.g. a cat")
],
outputs=[
gr.Markdown(label="Detection Summary"),
gr.Image(label="Annotated Image")
],
title="Grounding DINO Object Detection",
description="Upload an image or provide an image URL, then enter objects to detect."
)
app.launch(server_name="0.0.0.0", server_port=7860) |