Spaces:
Build error
Build error
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| from transformers import DetrImageProcessor, DetrForObjectDetection | |
| import torch | |
| # colors for visualization | |
| COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], [0.494, 0.184, 0.556], [0.466, 0.674, 0.188]] | |
| import io | |
| processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
| model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") | |
| def fig2img(fig): | |
| buf = io.BytesIO() | |
| fig.savefig(buf) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| return img | |
| def plot_results(image, results): | |
| plt.figure(figsize=(16, 10)) | |
| plt.imshow(image) | |
| ax = plt.gca() | |
| colors = COLORS * 100 | |
| for box, label, prob, color in zip(results["boxes"], results["labels"], results["scores"], colors): | |
| xmin, xmax, ymin, ymax = box[0].item(), box[2].item(), box[1].item(), box[3].item() | |
| ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, | |
| fill=False, color=color, linewidth=3)) | |
| text = f'{model.config.id2label[label.item()]}: {prob:0.2f}' | |
| ax.text(xmin, ymin, text, fontsize=15, | |
| bbox=dict(facecolor='yellow', alpha=0.5)) | |
| ax.axis("off") | |
| return fig2img(plt.gcf()) | |
| def predict(input_img): | |
| inputs = processor(images=input_img, return_tensors="pt") | |
| outputs = model(**inputs) | |
| target_sizes = torch.tensor([input_img.size[::-1]]) | |
| results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] | |
| return plot_results(input_img, results) | |
| import gradio as gr | |
| demo = gr.Interface(fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs="image") | |
| demo.launch() |