Spaces:
Runtime error
Runtime error
File size: 2,966 Bytes
463eb87 f104dc9 463eb87 1fe2e3f 463eb87 1fe2e3f 463eb87 f104dc9 463eb87 468afca 463eb87 468afca 463eb87 468afca 463eb87 468afca 463eb87 468afca 463eb87 468afca f8e92b8 463eb87 f8e92b8 463eb87 468afca f8e92b8 463eb87 f104dc9 468afca 463eb87 f8e92b8 463eb87 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import torch
from transformers import pipeline
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from random import choice
import io
detector50 = pipeline(model="facebook/detr-resnet-50")
detector101 = pipeline(model="facebook/detr-resnet-101")
import gradio as gr
COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
"#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
"#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
fdic = {
"family" : "Impact",
"style" : "italic",
"size" : 1,
"color" : "yellow",
"weight" : "bold"
}
def get_figure(in_pil_img, in_results):
plt.figure(figsize=(16, 10))
plt.imshow(in_pil_img)
#pyplot.gcf()
ax = plt.gca()
for prediction in in_results:
selected_color = choice(COLORS)
x, y = prediction['box']['xmin'], prediction['box']['ymin'],
w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']
ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)
plt.axis("off")
return plt.gcf()
def infer(in_pil_img):
results = None
results = detector50(in_pil_img)
# if model == "detr-resnet-101":
# results = detector101(in_pil_img)
# else:
# results = detector50(in_pil_img)
figure = get_figure(in_pil_img, results)
buf = io.BytesIO()
figure.savefig(buf, bbox_inches='tight')
buf.seek(0)
output_pil_img = Image.open(buf)
return output_pil_img
with gr.Blocks(title="Object Detection",
css="footer {visibility: hidden}"
) as demo:
#sample_index = gr.State([])
# gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">DETR Object Detection</div>""")
# gr.HTML("""<h4 style="color:navy;">1. Select a model.</h4>""")
# model = gr.Radio(["detr-resnet-50", "detr-resnet-101"], value="detr-resnet-50", label="Model name")
# gr.HTML("""<br/>""")
# gr.HTML("""<h4>Select an example by clicking a thumbnail below.</h4>""")
# gr.HTML("""<h4>Or upload an image by clicking on the canvas.</h4>""")
with gr.Row():
input_image = gr.Image(label="Input image", type="pil")
output_image = gr.Image(label="Output image with object detection", type="pil")
gr.Examples(['samples/cats.jpg', 'samples/detectron2.png', 'samples/cat.jpg', 'samples/hotdog.jpg'], inputs=input_image)
# gr.HTML("""<br/>""")
gr.HTML("""<h4>Click "Infer" button to predict object instances. It will take about 10-15 seconds</h4>""")
send_btn = gr.Button("Infer")
send_btn.click(fn=infer, inputs=[input_image], outputs=[output_image])
#demo.queue()
demo.launch()
### EOF ###
|