File size: 4,478 Bytes
721a039
 
 
9b17214
 
af13d88
721a039
9b17214
9c1c3e8
32d6a70
9b17214
721a039
9b17214
 
 
 
af13d88
9b17214
 
 
 
 
 
af13d88
9b17214
 
 
 
 
af13d88
9b17214
 
 
 
 
 
 
 
5045d1c
9b17214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
721a039
 
9b17214
 
 
 
 
 
 
 
 
 
 
 
 
 
721a039
 
9b17214
 
 
 
721a039
9b17214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
721a039
 
9b17214
 
 
 
 
 
 
 
 
 
 
721a039
 
 
9b17214
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import gradio as gr
import supervision as sv
from PIL import Image
from rfdetr import RFDETRMedium, RFDETRSegPreview
from rfdetr.detr import RFDETR
from rfdetr.util.coco_classes import COCO_CLASSES

MARKDOWN = """
# (RF-DETR Demo-Alliance Bioversity-CIAT)
RF-DETR object detection and segmentation
"""

COLOR = sv.ColorPalette.from_hex([
    "#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff",
    "#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00"
])

def load_model(resolution: int, checkpoint: str) -> RFDETR:
    if checkpoint == "medium (object detection)":
        return RFDETRMedium(resolution=resolution)
    if checkpoint == "segmentation preview":
        return RFDETRSegPreview(resolution=resolution)
    raise TypeError("Checkpoint must be medium (object detection) or segmentation preview.")

def adjust_resolution(checkpoint: str, resolution: int) -> int:
    if checkpoint == "segmentation preview":
        divisor = 24
    elif checkpoint == "medium (object detection)":
        divisor = 32
    else:
        raise ValueError(f"Unknown checkpoint: {checkpoint}")
    remainder = resolution % divisor
    if remainder == 0:
        return resolution
    lower = resolution - remainder
    upper = lower + divisor
    if resolution - lower < upper - resolution:
        return lower
    else:
        return upper

def image_processing_inference(
    input_image: Image.Image,
    confidence: float,
    resolution: int,
    checkpoint: str
):
    if input_image is None:
        return None
    resolution = adjust_resolution(checkpoint=checkpoint, resolution=resolution)
    model = load_model(resolution=resolution, checkpoint=checkpoint)
    detections = model.predict(input_image, threshold=confidence)

    resolution_wh = (input_image.width, input_image.height)
    text_scale = sv.calculate_optimal_text_scale(resolution_wh=resolution_wh) - 0.2
    thickness = sv.calculate_optimal_line_thickness(resolution_wh=resolution_wh)

    mask_annotator = sv.MaskAnnotator(color=COLOR)
    bbox_annotator = sv.BoxAnnotator(color=COLOR, thickness=thickness)
    label_annotator = sv.LabelAnnotator(
        color=COLOR,
        text_color=sv.Color.BLACK,
        text_scale=text_scale
    )

    labels = [
        f"{COCO_CLASSES[class_id]} {confidence:.2f}"
        for class_id, confidence
        in zip(detections.class_id, detections.confidence)
    ]

    annotated_image = input_image.copy()
    annotated_image = bbox_annotator.annotate(annotated_image, detections)
    annotated_image = label_annotator.annotate(annotated_image, detections, labels)
    if checkpoint == "segmentation preview":
        annotated_image = mask_annotator.annotate(annotated_image, detections)
    return annotated_image

with gr.Blocks() as demo:
    gr.Markdown(MARKDOWN)
    with gr.Row():
        image_processing_input_image = gr.Image(
            label="Upload image",
            image_mode="RGB",
            type="pil",
            height=600
        )
        image_processing_output_image = gr.Image(
            label="Output image",
            image_mode="RGB",
            type="pil",
            height=600
        )
    with gr.Row():
        with gr.Column():
            image_processing_confidence_slider = gr.Slider(
                label="Confidence",
                minimum=0.0,
                maximum=1.0,
                step=0.05,
                value=0.5,
            )
            image_processing_resolution_slider = gr.Slider(
                label="Inference resolution",
                minimum=224,
                maximum=2240,
                step=1,
                value=896,
            )
            image_processing_checkpoint_dropdown = gr.Dropdown(
                label="Checkpoint",
                choices=[
                    "medium (object detection)",
                    "segmentation preview"
                ],
                value="medium (object detection)"
            )
        with gr.Column():
            image_processing_submit_button = gr.Button("Submit")

    image_processing_submit_button.click(
        image_processing_inference,
        inputs=[
            image_processing_input_image,
            image_processing_confidence_slider,
            image_processing_resolution_slider,
            image_processing_checkpoint_dropdown
        ],
        outputs=image_processing_output_image,
    )

if __name__ == "__main__":
    demo.launch(debug=False, show_error=True)