Spaces:
Configuration error
Configuration error
File size: 3,337 Bytes
e82f421 3c98920 e82f421 3c98920 e82f421 3c98920 a6b4885 3c98920 e82f421 3c98920 e82f421 3c98920 e82f421 3c98920 e82f421 3c98920 e82f421 3c98920 e82f421 3c98920 e82f421 3c98920 e82f421 3c98920 e82f421 3c98920 e82f421 3c98920 e82f421 3c98920 e82f421 3c98920 e82f421 3c98920 e82f421 3c98920 e82f421 3c98920 e82f421 3c98920 e82f421 3c98920 e82f421 3c98920 e82f421 3c98920 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | import torch
from transformers import pipeline
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from random import choice
import io
detector50 = pipeline(model="facebook/detr-resnet-50")
detector101 = pipeline(model="facebook/detr-resnet-101")
import gradio as gr
COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
"#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
"#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
fdic = {
"family" : "Impact",
"style" : "italic",
"size" : 15,
"color" : "yellow",
"weight" : "bold"
}
def get_figure(in_pil_img, in_results):
plt.figure(figsize=(16, 10))
plt.imshow(in_pil_img)
#pyplot.gcf()
ax = plt.gca()
for prediction in in_results:
selected_color = choice(COLORS)
x, y = prediction['box']['xmin'], prediction['box']['ymin'],
w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']
ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)
plt.axis("off")
return plt.gcf()
def infer(model, in_pil_img):
results = None
if model == "detr-resnet-101":
results = detector101(in_pil_img)
else:
results = detector50(in_pil_img)
figure = get_figure(in_pil_img, results)
buf = io.BytesIO()
figure.savefig(buf, bbox_inches='tight')
buf.seek(0)
output_pil_img = Image.open(buf)
return output_pil_img
with gr.Blocks(title="DETR Object Detection - ClassCat",
css=".gradio-container {background:lightyellow;}"
) as demo:
#sample_index = gr.State([])
gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">DETR Object Detection</div>""")
gr.HTML("""<h4 style="color:navy;">1. Select a model.</h4>""")
model = gr.Radio(["detr-resnet-50", "detr-resnet-101"], value="detr-resnet-50", label="Model name")
gr.HTML("""<br/>""")
gr.HTML("""<h4 style="color:navy;">2-a. Select an example by clicking a thumbnail below.</h4>""")
gr.HTML("""<h4 style="color:navy;">2-b. Or upload an image by clicking on the canvas.</h4>""")
with gr.Row():
input_image = gr.Image(label="Input image", type="pil")
output_image = gr.Image(label="Output image with predicted instances", type="pil")
gr.Examples(['samples/cats.jpg', 'samples/detectron2.png', 'samples/cat.jpg', 'samples/hotdog.jpg'], inputs=input_image)
gr.HTML("""<br/>""")
gr.HTML("""<h4 style="color:navy;">3. Then, click "Infer" button to predict object instances. It will take about 10 seconds (on cpu)</h4>""")
send_btn = gr.Button("Infer")
send_btn.click(fn=infer, inputs=[model, input_image], outputs=[output_image])
gr.HTML("""<br/>""")
gr.HTML("""<h4 style="color:navy;">Reference</h4>""")
gr.HTML("""<ul>""")
gr.HTML("""<li><a href="https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/detr_attention.ipynb" target="_blank">Hands-on tutorial for DETR</a>""")
gr.HTML("""</ul>""")
#demo.queue()
demo.launch(debug=True)
|