File size: 4,604 Bytes
998fc2d
a54fa15
a04d092
 
bb2bd9b
070a65e
a54fa15
070a65e
a04d092
070a65e
a04d092
 
 
 
 
 
bb2bd9b
 
 
 
 
 
 
 
 
 
 
 
 
998fc2d
070a65e
 
a54fa15
070a65e
 
 
 
 
 
 
 
a54fa15
070a65e
577d286
3b1b323
bb2bd9b
070a65e
a54fa15
bb2bd9b
070a65e
 
 
 
 
 
 
b287adc
070a65e
 
 
a54fa15
 
 
343ac5d
070a65e
343ac5d
 
 
 
 
 
 
 
 
a54fa15
 
343ac5d
a54fa15
343ac5d
a54fa15
 
343ac5d
a54fa15
343ac5d
998fc2d
a54fa15
1ad2815
 
006c2ee
1ad2815
343ac5d
 
006c2ee
1ad2815
 
343ac5d
a54fa15
 
1ad2815
68243d8
 
1ad2815
68243d8
3b1b323
 
68243d8
343ac5d
3b1b323
 
 
 
 
 
 
 
 
 
 
 
 
a54fa15
02c397a
998fc2d
070a65e
a54fa15
a04d092
 
070a65e
 
 
 
 
dfac13c
070a65e
a54fa15
a63426f
 
a54fa15
070a65e
c1df300
a04d092
2e32330
343ac5d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from io import BytesIO
from collections import Counter
import requests
import torch
import re
import gradio as gr
from PIL import Image, ImageDraw
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection

# Load model once
model_id = "IDEA-Research/grounding-dino-tiny"
device = "cuda" if torch.cuda.is_available() else "cpu"

processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)

# prompt normalization function
def normalize_prompt(text_prompt):
    # lowercase
    text_prompt = text_prompt.lower().strip()

    # replace common separators with "."
    text_prompt = re.sub(r"\s*(and|,|&)\s*", ".", text_prompt)

    # split words and remove empties
    parts = [p.strip() for p in text_prompt.split(".") if p.strip()]

    # rebuild as "cat. dog."
    return ". ".join(parts) + "."

def detect_objects(image_url, uploaded_image, text_prompt):
    try:
        # Load image
        if uploaded_image is not None:
            image = uploaded_image.convert("RGB")
        elif image_url:
            headers = {"User-Agent": "Mozilla/5.0"}
            response = requests.get(image_url, headers=headers, timeout=10)
            response.raise_for_status()
            image = Image.open(BytesIO(response.content)).convert("RGB")
        else:
            return None, "Please provide an image URL or upload an image."

        # Default prompt fallback
        display_prompt = text_prompt.strip() if text_prompt and text_prompt.strip() else "capsule"
        model_prompt = normalize_prompt(display_prompt)

        # Inference
        inputs = processor(images=image, text=model_prompt, return_tensors="pt").to(device)

        with torch.no_grad():
            outputs = model(**inputs)

        results = processor.post_process_grounded_object_detection(
            outputs,
            inputs.input_ids,
            threshold=0.4,
            target_sizes=[image.size[::-1]]
        )

        # Draw results
        draw = ImageDraw.Draw(image)
        detected_labels = []
        
        for result in results:
            boxes = result["boxes"]
            scores = result["scores"]
            labels = result["text_labels"]
        
            # Skip empty detections
            if len(boxes) == 0:
                continue
        
            for box, score, label in zip(boxes, scores, labels):
                box = box.tolist()
                detected_labels.append(label)
        
                x1, y1, x2, y2 = box
        
                # Draw bounding box
                draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
        
                # Draw label
                draw.text((x1, max(0, y1 - 15)), f"{label} {score:.2f}", fill="red")

        if not detected_labels:
            searched_object = text_prompt.replace(".", ", ").strip(", ").strip()
        
            return f"""
                ## No {searched_object} found in the image
                
                Try:
                - lowering the threshold
                - using a clearer image
                - changing the detection prompt
                """, image

        counts = Counter(detected_labels)

        summary_rows = []
        
        for label, count in counts.items():
            summary_rows.append(
                f"<tr><td style='padding:4px 12px'>{label}</td>"
                f"<td style='padding:4px 12px'><b>{count}</b></td></tr>"
            )
        
        total_types = len(counts)
        
        summary = f"""
            <h3>Detected {total_types} object type(s) for: {display_prompt}</h3>
            
            <table style='border-collapse: collapse; width: 100%;'>
            <tr>
            <th style='text-align:left; padding:4px 12px;'>Object</th>
            <th style='text-align:left; padding:4px 12px;'>Count</th>
            </tr>
            {''.join(summary_rows)}
            </table>
            """

        return summary, image

    except Exception as e:
        return None, f"Error: {str(e)}"


app = gr.Interface(
    fn=detect_objects,
    inputs=[
        gr.Textbox(label="Image URL"),
        gr.Image(type="pil", label="Upload JPG/PNG"),
        gr.Textbox(label="Detection Prompt", placeholder="e.g. a cat")
    ],
    outputs=[
        gr.Markdown(label="Detection Summary"),
        gr.Image(label="Annotated Image")
    ],
    title="Grounding DINO Object Detection",
    description="Upload an image or provide an image URL, then enter objects to detect."
)

app.launch(server_name="0.0.0.0", server_port=7860)