| from io import BytesIO |
| from collections import Counter |
| import requests |
| import torch |
| import re |
| import gradio as gr |
| from PIL import Image, ImageDraw |
| from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection |
|
|
| |
| model_id = "IDEA-Research/grounding-dino-tiny" |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| processor = AutoProcessor.from_pretrained(model_id) |
| model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device) |
|
|
| |
| def normalize_prompt(text_prompt): |
| |
| text_prompt = text_prompt.lower().strip() |
|
|
| |
| text_prompt = re.sub(r"\s*(and|,|&)\s*", ".", text_prompt) |
|
|
| |
| parts = [p.strip() for p in text_prompt.split(".") if p.strip()] |
|
|
| |
| return ". ".join(parts) + "." |
|
|
| def detect_objects(image_url, uploaded_image, text_prompt): |
| try: |
| |
| if uploaded_image is not None: |
| image = uploaded_image.convert("RGB") |
| elif image_url: |
| headers = {"User-Agent": "Mozilla/5.0"} |
| response = requests.get(image_url, headers=headers, timeout=10) |
| response.raise_for_status() |
| image = Image.open(BytesIO(response.content)).convert("RGB") |
| else: |
| return None, "Please provide an image URL or upload an image." |
|
|
| |
| display_prompt = text_prompt.strip() if text_prompt and text_prompt.strip() else "capsule" |
| model_prompt = normalize_prompt(display_prompt) |
|
|
| |
| inputs = processor(images=image, text=model_prompt, return_tensors="pt").to(device) |
|
|
| with torch.no_grad(): |
| outputs = model(**inputs) |
|
|
| results = processor.post_process_grounded_object_detection( |
| outputs, |
| inputs.input_ids, |
| threshold=0.4, |
| target_sizes=[image.size[::-1]] |
| ) |
|
|
| |
| draw = ImageDraw.Draw(image) |
| detected_labels = [] |
| |
| for result in results: |
| boxes = result["boxes"] |
| scores = result["scores"] |
| labels = result["text_labels"] |
| |
| |
| if len(boxes) == 0: |
| continue |
| |
| for box, score, label in zip(boxes, scores, labels): |
| box = box.tolist() |
| detected_labels.append(label) |
| |
| x1, y1, x2, y2 = box |
| |
| |
| draw.rectangle([x1, y1, x2, y2], outline="red", width=3) |
| |
| |
| draw.text((x1, max(0, y1 - 15)), f"{label} {score:.2f}", fill="red") |
|
|
| if not detected_labels: |
| searched_object = text_prompt.replace(".", ", ").strip(", ").strip() |
| |
| return f""" |
| ## No {searched_object} found in the image |
| |
| Try: |
| - lowering the threshold |
| - using a clearer image |
| - changing the detection prompt |
| """, image |
|
|
| counts = Counter(detected_labels) |
|
|
| summary_rows = [] |
| |
| for label, count in counts.items(): |
| summary_rows.append( |
| f"<tr><td style='padding:4px 12px'>{label}</td>" |
| f"<td style='padding:4px 12px'><b>{count}</b></td></tr>" |
| ) |
| |
| total_types = len(counts) |
| |
| summary = f""" |
| <h3>Detected {total_types} object type(s) for: {display_prompt}</h3> |
| |
| <table style='border-collapse: collapse; width: 100%;'> |
| <tr> |
| <th style='text-align:left; padding:4px 12px;'>Object</th> |
| <th style='text-align:left; padding:4px 12px;'>Count</th> |
| </tr> |
| {''.join(summary_rows)} |
| </table> |
| """ |
|
|
| return summary, image |
|
|
| except Exception as e: |
| return None, f"Error: {str(e)}" |
|
|
|
|
| app = gr.Interface( |
| fn=detect_objects, |
| inputs=[ |
| gr.Textbox(label="Image URL"), |
| gr.Image(type="pil", label="Upload JPG/PNG"), |
| gr.Textbox(label="Detection Prompt", placeholder="e.g. a cat") |
| ], |
| outputs=[ |
| gr.Markdown(label="Detection Summary"), |
| gr.Image(label="Annotated Image") |
| ], |
| title="Grounding DINO Object Detection", |
| description="Upload an image or provide an image URL, then enter objects to detect." |
| ) |
|
|
| app.launch(server_name="0.0.0.0", server_port=7860) |