File size: 2,966 Bytes
463eb87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f104dc9
 
 
463eb87
 
 
 
1fe2e3f
463eb87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fe2e3f
463eb87
 
 
 
 
 
f104dc9
463eb87
 
468afca
 
 
 
 
463eb87
 
 
 
 
 
 
 
 
 
 
468afca
 
463eb87
 
 
468afca
463eb87
468afca
463eb87
468afca
463eb87
468afca
f8e92b8
 
463eb87
 
 
f8e92b8
463eb87
 
 
468afca
f8e92b8
463eb87
 
f104dc9
468afca
463eb87
f8e92b8
463eb87
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104

import torch
from transformers import pipeline

from PIL import Image

import matplotlib.pyplot as plt
import matplotlib.patches as patches

from random import choice
import io

detector50 = pipeline(model="facebook/detr-resnet-50")

detector101 = pipeline(model="facebook/detr-resnet-101")


import gradio as gr

COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
            "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
            "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]

fdic = {
    "family" : "Impact",
    "style" : "italic",
    "size" : 1,
    "color" : "yellow",
    "weight" : "bold"
}


def get_figure(in_pil_img, in_results):
    plt.figure(figsize=(16, 10))
    plt.imshow(in_pil_img)
    #pyplot.gcf()
    ax = plt.gca()

    for prediction in in_results:
        selected_color = choice(COLORS)

        x, y = prediction['box']['xmin'], prediction['box']['ymin'],
        w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']

        ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
        ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)

    plt.axis("off")

    return plt.gcf()


def infer(in_pil_img):

    results = None
    results = detector50(in_pil_img)
    # if model == "detr-resnet-101":
    #     results = detector101(in_pil_img)
    # else:
    #     results = detector50(in_pil_img)

    figure = get_figure(in_pil_img, results)

    buf = io.BytesIO()
    figure.savefig(buf, bbox_inches='tight')
    buf.seek(0)
    output_pil_img = Image.open(buf)

    return output_pil_img


with gr.Blocks(title="Object Detection",                    
                   css="footer {visibility: hidden}"
               ) as demo:
    #sample_index = gr.State([])

    # gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">DETR Object Detection</div>""")

    # gr.HTML("""<h4 style="color:navy;">1. Select a model.</h4>""")

    # model = gr.Radio(["detr-resnet-50", "detr-resnet-101"], value="detr-resnet-50", label="Model name")

    # gr.HTML("""<br/>""")
    # gr.HTML("""<h4>Select an example by clicking a thumbnail below.</h4>""")
    # gr.HTML("""<h4>Or upload an image by clicking on the canvas.</h4>""")

    with gr.Row():
        input_image = gr.Image(label="Input image", type="pil")
        output_image = gr.Image(label="Output image with object detection", type="pil")

    gr.Examples(['samples/cats.jpg', 'samples/detectron2.png', 'samples/cat.jpg', 'samples/hotdog.jpg'], inputs=input_image)

    # gr.HTML("""<br/>""")
    gr.HTML("""<h4>Click "Infer" button to predict object instances. It will take about 10-15 seconds</h4>""")

    send_btn = gr.Button("Infer")
    send_btn.click(fn=infer, inputs=[input_image], outputs=[output_image])
 
#demo.queue()
demo.launch()


### EOF ###