Spaces:
Sleeping
Sleeping
File size: 3,604 Bytes
ca67edc 4f81c3c ca67edc 4f81c3c ca67edc 4f81c3c ca67edc 8eb1b60 4f81c3c ca67edc 4f81c3c ca67edc 4f81c3c ca67edc 4f81c3c ca67edc 4f81c3c ca67edc 4f81c3c ca67edc 4f81c3c ca67edc 4f81c3c ca67edc 4f81c3c ca67edc 4f81c3c ca67edc 4f81c3c ca67edc a69fe39 ca67edc 4f81c3c ca67edc 4f81c3c ca67edc 4f81c3c ca67edc 4f81c3c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 | import torch
from transformers import pipeline
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from random import choice
import io
detector50 = pipeline(model="facebook/detr-resnet-50")
detector101 = pipeline(model="facebook/detr-resnet-101")
import gradio as gr
COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
"#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
"#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
fdic = {
"family" : "Impact",
"style" : "italic",
"size" : 15,
"color" : "yellow",
"weight" : "bold"
}
def get_figure(in_pil_img, in_results):
plt.figure(figsize=(16, 10))
plt.imshow(in_pil_img)
#pyplot.gcf()
ax = plt.gca()
for prediction in in_results:
selected_color = choice(COLORS)
x, y = prediction['box']['xmin'], prediction['box']['ymin'],
w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']
ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)
plt.axis("off")
return plt.gcf()
def infer(model, in_pil_img):
results = None
if model == "detr-resnet-101":
results = detector101(in_pil_img)
else:
results = detector50(in_pil_img)
figure = get_figure(in_pil_img, results)
buf = io.BytesIO()
figure.savefig(buf, bbox_inches='tight')
buf.seek(0)
output_pil_img = Image.open(buf)
return output_pil_img
with gr.Blocks(title="DETR Object Detection - ClassCat",
css=".gradio-container {background:lightyellow;}"
) as demo:
#sample_index = gr.State([])
gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">DETR Object Detection</div>""")
gr.HTML("""<h4 style="color:navy;">1. Select a model.</h4>""")
model = gr.Radio(["detr-resnet-50", "detr-resnet-101"], value="detr-resnet-50", label="Model name")
gr.HTML("""<br/>""")
gr.HTML("""<h4 style="color:navy;">2-a. Select an example by clicking a thumbnail below.</h4>""")
gr.HTML("""<h4 style="color:navy;">2-b. Or upload an image by clicking on the canvas.</h4>""")
with gr.Row():
input_image = gr.Image(label="Input image", type="pil")
output_image = gr.Image(label="Output image with predicted instances", type="pil")
gr.Examples(['samples/anime.jpg','samples/cartoon.jpg','samples/CATTY.jpg','samples/food.png','samples/ghibli.jpeg', 'samples/news.png', 'samples/people.png','samples/road.png','samples/Paris.jpg','samples/Blobfish.jpg', 'samples/Pokemon.jpg', 'samples/Egg Tart.jpg', 'samples/Rose.jpg', 'samples/CEO.jpg',
'samples/juice.jpeg','samples/mountain.jpeg'], inputs=input_image)
gr.HTML("""<br/>""")
gr.HTML("""<h4 style="color:navy;">3. Then, click "Infer" button to predict object instances. It will take about 10 seconds (on cpu)</h4>""")
send_btn = gr.Button("Infer")
send_btn.click(fn=infer, inputs=[model, input_image], outputs=[output_image])
gr.HTML("""<br/>""")
gr.HTML("""<h4 style="color:navy;">Reference</h4>""")
gr.HTML("""<ul>""")
gr.HTML("""<li><a href="https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/detr_attention.ipynb" target="_blank">Hands-on tutorial for DETR</a>""")
gr.HTML("""</ul>""")
#demo.queue()
demo.launch(debug=True)
### EOF ###
|