piky's picture
Update app.py
c1df300 verified
from io import BytesIO
from collections import Counter
import requests
import torch
import re
import gradio as gr
from PIL import Image, ImageDraw
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
# Load model once
model_id = "IDEA-Research/grounding-dino-tiny"
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
# prompt normalization function
def normalize_prompt(text_prompt):
# lowercase
text_prompt = text_prompt.lower().strip()
# replace common separators with "."
text_prompt = re.sub(r"\s*(and|,|&)\s*", ".", text_prompt)
# split words and remove empties
parts = [p.strip() for p in text_prompt.split(".") if p.strip()]
# rebuild as "cat. dog."
return ". ".join(parts) + "."
def detect_objects(image_url, uploaded_image, text_prompt):
try:
# Load image
if uploaded_image is not None:
image = uploaded_image.convert("RGB")
elif image_url:
headers = {"User-Agent": "Mozilla/5.0"}
response = requests.get(image_url, headers=headers, timeout=10)
response.raise_for_status()
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
return None, "Please provide an image URL or upload an image."
# Default prompt fallback
display_prompt = text_prompt.strip() if text_prompt and text_prompt.strip() else "capsule"
model_prompt = normalize_prompt(display_prompt)
# Inference
inputs = processor(images=image, text=model_prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
threshold=0.4,
target_sizes=[image.size[::-1]]
)
# Draw results
draw = ImageDraw.Draw(image)
detected_labels = []
for result in results:
boxes = result["boxes"]
scores = result["scores"]
labels = result["text_labels"]
# Skip empty detections
if len(boxes) == 0:
continue
for box, score, label in zip(boxes, scores, labels):
box = box.tolist()
detected_labels.append(label)
x1, y1, x2, y2 = box
# Draw bounding box
draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
# Draw label
draw.text((x1, max(0, y1 - 15)), f"{label} {score:.2f}", fill="red")
if not detected_labels:
searched_object = text_prompt.replace(".", ", ").strip(", ").strip()
return f"""
## No {searched_object} found in the image
Try:
- lowering the threshold
- using a clearer image
- changing the detection prompt
""", image
counts = Counter(detected_labels)
summary_rows = []
for label, count in counts.items():
summary_rows.append(
f"<tr><td style='padding:4px 12px'>{label}</td>"
f"<td style='padding:4px 12px'><b>{count}</b></td></tr>"
)
total_types = len(counts)
summary = f"""
<h3>Detected {total_types} object type(s) for: {display_prompt}</h3>
<table style='border-collapse: collapse; width: 100%;'>
<tr>
<th style='text-align:left; padding:4px 12px;'>Object</th>
<th style='text-align:left; padding:4px 12px;'>Count</th>
</tr>
{''.join(summary_rows)}
</table>
"""
return summary, image
except Exception as e:
return None, f"Error: {str(e)}"
app = gr.Interface(
fn=detect_objects,
inputs=[
gr.Textbox(label="Image URL"),
gr.Image(type="pil", label="Upload JPG/PNG"),
gr.Textbox(label="Detection Prompt", placeholder="e.g. a cat")
],
outputs=[
gr.Markdown(label="Detection Summary"),
gr.Image(label="Annotated Image")
],
title="Grounding DINO Object Detection",
description="Upload an image or provide an image URL, then enter objects to detect."
)
app.launch(server_name="0.0.0.0", server_port=7860)