File size: 3,337 Bytes
e82f421
 
3c98920
e82f421
3c98920
e82f421
3c98920
 
a6b4885
3c98920
e82f421
 
 
3c98920
 
 
 
 
e82f421
3c98920
 
e82f421
 
3c98920
 
 
 
 
e82f421
 
3c98920
e82f421
 
 
3c98920
e82f421
 
 
 
3c98920
 
 
e82f421
 
 
 
 
3c98920
e82f421
 
3c98920
e82f421
3c98920
 
e82f421
 
 
 
 
 
3c98920
e82f421
 
 
 
3c98920
e82f421
 
3c98920
e82f421
3c98920
 
 
e82f421
3c98920
 
 
e82f421
 
 
3c98920
 
 
e82f421
 
 
 
 
 
 
3c98920
 
 
e82f421
 
 
3c98920
 
 
 
 
 
e82f421
3c98920
e82f421
3c98920
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
from transformers import pipeline

from PIL import Image

import matplotlib.pyplot as plt
import matplotlib.patches as patches

from random import choice
import io

detector50 = pipeline(model="facebook/detr-resnet-50")

detector101 = pipeline(model="facebook/detr-resnet-101")


import gradio as gr

COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
            "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
            "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]

fdic = {
    "family" : "Impact",
    "style" : "italic",
    "size" : 15,
    "color" : "yellow",
    "weight" : "bold"
}


def get_figure(in_pil_img, in_results):
    plt.figure(figsize=(16, 10))
    plt.imshow(in_pil_img)
    #pyplot.gcf()
    ax = plt.gca()

    for prediction in in_results:
        selected_color = choice(COLORS)

        x, y = prediction['box']['xmin'], prediction['box']['ymin'],
        w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']

        ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
        ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)

    plt.axis("off")

    return plt.gcf()


def infer(model, in_pil_img):

    results = None
    if model == "detr-resnet-101":
        results = detector101(in_pil_img)
    else:
        results = detector50(in_pil_img)

    figure = get_figure(in_pil_img, results)

    buf = io.BytesIO()
    figure.savefig(buf, bbox_inches='tight')
    buf.seek(0)
    output_pil_img = Image.open(buf)

    return output_pil_img


with gr.Blocks(title="DETR Object Detection - ClassCat",
                    css=".gradio-container {background:lightyellow;}"
               ) as demo:
    #sample_index = gr.State([])

    gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">DETR Object Detection</div>""")

    gr.HTML("""<h4 style="color:navy;">1. Select a model.</h4>""")

    model = gr.Radio(["detr-resnet-50", "detr-resnet-101"], value="detr-resnet-50", label="Model name")

    gr.HTML("""<br/>""")
    gr.HTML("""<h4 style="color:navy;">2-a. Select an example by clicking a thumbnail below.</h4>""")
    gr.HTML("""<h4 style="color:navy;">2-b. Or upload an image by clicking on the canvas.</h4>""")

    with gr.Row():
        input_image = gr.Image(label="Input image", type="pil")
        output_image = gr.Image(label="Output image with predicted instances", type="pil")

    gr.Examples(['samples/cats.jpg', 'samples/detectron2.png', 'samples/cat.jpg', 'samples/hotdog.jpg'], inputs=input_image)

    gr.HTML("""<br/>""")
    gr.HTML("""<h4 style="color:navy;">3. Then, click "Infer" button to predict object instances. It will take about 10 seconds (on cpu)</h4>""")

    send_btn = gr.Button("Infer")
    send_btn.click(fn=infer, inputs=[model, input_image], outputs=[output_image])

    gr.HTML("""<br/>""")
    gr.HTML("""<h4 style="color:navy;">Reference</h4>""")
    gr.HTML("""<ul>""")
    gr.HTML("""<li><a href="https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/detr_attention.ipynb" target="_blank">Hands-on tutorial for DETR</a>""")
    gr.HTML("""</ul>""")


#demo.queue()
demo.launch(debug=True)