File size: 5,075 Bytes
80cacd4
cd64594
82551bb
 
 
 
 
 
 
 
80cacd4
82551bb
cd64594
33708c6
80cacd4
be400de
 
82551bb
be400de
cd64594
80cacd4
 
 
e500a2a
80cacd4
2455309
d5c03fb
1319df4
be400de
82551bb
b915b77
 
1319df4
2455309
82551bb
 
2455309
33708c6
80cacd4
82551bb
cd64594
100dbc1
82551bb
 
 
 
cd64594
82551bb
cd64594
82551bb
 
80cacd4
 
 
 
cd64594
80cacd4
cd64594
80cacd4
cd64594
80cacd4
 
 
 
 
cd64594
80cacd4
 
 
 
 
 
 
 
 
cd64594
33708c6
 
d5c03fb
33708c6
 
 
80cacd4
 
 
 
 
 
 
cd64594
80cacd4
 
 
 
 
 
 
 
 
 
 
 
cd64594
80cacd4
 
 
 
 
 
 
cd64594
80cacd4
 
 
 
 
cd64594
80cacd4
 
 
82551bb
 
80cacd4
 
 
 
 
 
 
 
cd64594
80cacd4
 
 
 
 
82551bb
cd64594
80cacd4
 
 
 
82551bb
 
80cacd4
 
33708c6
80cacd4
 
 
 
cd64594
82551bb
 
cd64594
 
 
 
 
 
80cacd4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
""" Gradio app: D-FINE + SigLIP Classify. """

import os
import gradio as gr
from pathlib import Path

from dfine_jina_pipeline import run_single_image

BASE_DIR = os.path.dirname(os.path.abspath(__file__))

DEFAULT_LABELS = "gun, knife, cigarette, phone"


def run_dfine_classify(image, dfine_threshold, dfine_model_choice, classifier_choice, siglip_threshold, labels_text):
    """D-FINE first, then classify crops with SigLIP.
    Returns (group_crop_gallery, known_crop_gallery, status_message).
    """
    if image is None:
        return [], [], "Upload an image."

    labels = [l.strip() for l in labels_text.split(",") if l.strip()]
    if not labels:
        return [], [], "Enter at least one label."

    dfine_model = dfine_model_choice.strip().lower() if dfine_model_choice else "medium-obj2coco"
    conf_thresh = float(siglip_threshold)
    classifier = classifier_choice.strip() if classifier_choice else "siglip-256"

    group_crops, known_crops, status = run_single_image(
        image,
        dfine_model=dfine_model,
        det_threshold=float(dfine_threshold),
        conf_threshold=conf_thresh,
        gap_threshold=0.0,
        min_side=24,
        crop_dedup_iou=0.4,
        min_display_conf=conf_thresh,
        classifier=classifier,
        labels=labels,
    )

    return [(g, None) for g in (group_crops or [])], [(k, None) for k in (known_crops or [])], status or ""


IMG_HEIGHT = 400


with gr.Blocks(title="Small Object Detection") as app:

    gr.Markdown("# Small Object Detection")

    gr.Markdown(
        "**D-FINE** detects persons/cars, then small-object crops are classified with **SigLIP** (zero-shot). "
        "Choose a D-FINE model and enter comma-separated class labels for SigLIP."
    )

    with gr.Row():

        with gr.Column(scale=1):

            inp_dfine = gr.Image(
                type="pil",
                label="Input image",
                height=IMG_HEIGHT
            )

            dfine_model_radio = gr.Dropdown(
                choices=[
                    "small-obj365", "medium-obj365", "large-obj365",
                    "small-coco", "medium-coco", "large-coco",
                    "small-obj2coco", "medium-obj2coco", "large-obj2coco",
                ],
                value="medium-obj2coco",
                label="D-FINE model",
            )

            classifier_dropdown = gr.Dropdown(
                choices=["siglip-224", "siglip-256", "siglip-384"],
                value="siglip-256",
                label="Classifier model",
            )

            dfine_threshold_slider = gr.Slider(
                minimum=0.05,
                maximum=0.5,
                value=0.15,
                step=0.05,
                label="D-FINE detection threshold",
            )

            def update_dfine_threshold_default(choice):
                if not choice:
                    return gr.update(value=0.15)
                size = choice.strip().lower().split("-")[0]
                defaults = {"large": 0.2, "medium": 0.15, "small": 0.1}
                return gr.update(value=defaults.get(size, 0.15))

            dfine_model_radio.change(
                fn=update_dfine_threshold_default,
                inputs=[dfine_model_radio],
                outputs=[dfine_threshold_slider],
            )

            siglip_threshold_slider = gr.Slider(
                minimum=0.001,
                maximum=0.1,
                value=0.005,
                step=0.001,
                label="SigLIP: min confidence threshold",
            )

            labels_input = gr.Textbox(
                label="Labels (comma-separated)",
                value=DEFAULT_LABELS,
                placeholder="e.g. gun, knife, cigarette, phone",
            )

            btn_dfine = gr.Button(
                "Run D-FINE + Classify",
                variant="primary"
            )

        with gr.Column(scale=1):

            out_gallery_dfine = gr.Gallery(
                label="Person/car crops (all D-FINE objects inside drawn with label + score)",
                height=IMG_HEIGHT,
                columns=2,
                object_fit="contain",
            )

            out_gallery_known = gr.Gallery(
                label="Known objects (class + score above each crop)",
                height=IMG_HEIGHT,
                columns=4,
                object_fit="contain",
            )

            out_status_dfine = gr.Textbox(
                label="Classification details",
                lines=8,
                interactive=False,
            )

    btn_dfine.click(
        fn=run_dfine_classify,
        inputs=[inp_dfine, dfine_threshold_slider, dfine_model_radio, classifier_dropdown, siglip_threshold_slider, labels_input],
        outputs=[out_gallery_dfine, out_gallery_known, out_status_dfine],
        concurrency_limit=1,
    )


app.launch(
    server_name=os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0"),
    server_port=int(
        os.environ.get(
            "PORT",
            os.environ.get("GRADIO_SERVER_PORT", 7860)
        )
    ),
)