Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| from transformers import AutoFeatureExtractor, YolosForObjectDetection | |
| from PIL import Image | |
| COLORS = [ | |
| [0.000, 0.447, 0.741], | |
| [0.850, 0.325, 0.098], | |
| [0.929, 0.694, 0.125], | |
| [0.494, 0.184, 0.556], | |
| [0.466, 0.674, 0.188], | |
| [0.301, 0.745, 0.933], | |
| ] | |
| def process_class_list(classes_string: str): | |
| if classes_string == "": | |
| return [] | |
| classes_list = classes_string.split(",") | |
| classes_list = [x.strip() for x in classes_list] | |
| return classes_list | |
| def model_inference(img, prob_threshold, classes_to_show): | |
| feature_extractor = AutoFeatureExtractor.from_pretrained(f"hustvl/yolos-small-dwr") | |
| model = YolosForObjectDetection.from_pretrained(f"hustvl/yolos-small-dwr") | |
| img = Image.fromarray(img) | |
| pixel_values = feature_extractor(img, return_tensors="pt").pixel_values | |
| with torch.no_grad(): | |
| outputs = model(pixel_values, output_attentions=True) | |
| probas = outputs.logits.softmax(-1)[0, :, :-1] | |
| keep = probas.max(-1).values > prob_threshold | |
| target_sizes = torch.tensor(img.size[::-1]).unsqueeze(0) | |
| postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes) | |
| bboxes_scaled = postprocessed_outputs[0]["boxes"] | |
| classes_list = process_class_list(classes_to_show) | |
| res_img = plot_results(img, probas[keep], bboxes_scaled[keep], model, classes_list) | |
| return res_img | |
| def plot_results(pil_img, prob, boxes, model, classes_list): | |
| plt.figure(figsize=(16, 10)) | |
| plt.imshow(pil_img) | |
| ax = plt.gca() | |
| colors = COLORS * 100 | |
| for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors): | |
| cl = p.argmax() | |
| object_class = model.config.id2label[cl.item()] | |
| if len(classes_list) > 0: | |
| if object_class not in classes_list: | |
| continue | |
| ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3)) | |
| text = f"{object_class}: {p[cl]:0.2f}" | |
| ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5)) | |
| plt.axis("off") | |
| return fig2img(plt.gcf()) | |
| def fig2img(fig): | |
| buf = io.BytesIO() | |
| fig.savefig(buf) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| return img | |
| description = """Upload an image and get the detected classes""" | |
| title = """Object Detection""" | |
| # Create examples list from "examples/" directory | |
| # example_list = [["examples/" + example] for example in os.listdir("examples")] | |
| # example_list = [["carplane.webp"]] | |
| image_in = gr.components.Image(label="Upload an image") | |
| image_out = gr.components.Image() | |
| classes_to_show = gr.components.Textbox(placeholder="e.g. car, dog", label="Classes to filter (leave empty to detect all classes)") | |
| prob_threshold_slider = gr.components.Slider(minimum=0, maximum=1.0, step=0.01, value=0.7, label="Probability Threshold") | |
| inputs = [image_in, prob_threshold_slider, classes_to_show] | |
| # gr.Examples([['carplane.webp'], ['CTH.png']], inputs=image_in) | |
| gr.Interface(fn=model_inference, | |
| inputs=inputs, | |
| outputs=image_out, | |
| title=title, | |
| description=description, | |
| # examples=example_list | |
| ).launch() | |