| import gradio as gr |
| import numpy as np |
| from PIL import Image, ImageDraw, ImageFont |
|
|
| import torch |
| from transformers import AutoImageProcessor, AutoModelForObjectDetection |
|
|
| repo_id = "magomerob/yolo_finetuned_raccoons" |
|
|
| processor = AutoImageProcessor.from_pretrained(repo_id) |
| model = AutoModelForObjectDetection.from_pretrained(repo_id) |
| model.eval() |
|
|
| def show_preds(input_image, display_label=True, display_bbox=True, detection_threshold=0.5): |
| if detection_threshold == 0: |
| detection_threshold = 0.5 |
|
|
| if isinstance(input_image, np.ndarray): |
| img = Image.fromarray(input_image).convert("RGB") |
| else: |
| img = input_image.convert("RGB") |
|
|
| inputs = processor(images=img, return_tensors="pt") |
| with torch.no_grad(): |
| outputs = model(**inputs) |
|
|
| |
| target_sizes = torch.tensor([img.size[::-1]]) |
| results = processor.post_process_object_detection( |
| outputs, threshold=float(detection_threshold), target_sizes=target_sizes |
| )[0] |
|
|
| draw = ImageDraw.Draw(img) |
| try: |
| font = ImageFont.truetype("DejaVuSans.ttf", 16) |
| except Exception: |
| font = ImageFont.load_default() |
|
|
| for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): |
| x1, y1, x2, y2 = [int(round(v)) for v in box.tolist()] |
| name = model.config.id2label.get(int(label), str(int(label))) |
| s = float(score) |
|
|
| if display_bbox: |
| draw.rectangle([x1, y1, x2, y2], width=3) |
|
|
| if display_label: |
| text = f"{name} {s:.2f}" |
| tw, th = draw.textbbox((0, 0), text, font=font)[2:] |
| pad = 3 |
| draw.rectangle([x1, max(0, y1 - th - 2*pad), x1 + tw + 2*pad, y1], fill="black") |
| draw.text((x1 + pad, max(0, y1 - th - pad)), text, font=font, fill="white") |
|
|
| return img |
|
|
| demo = gr.Interface( |
| fn=show_preds, |
| inputs=[ |
| gr.Image(type="numpy"), |
| gr.Checkbox(label="Label", value=True), |
| gr.Checkbox(label="Box", value=True), |
| gr.Slider(0, 1, step=0.05, value=0.5, label="Detection Threshold"), |
| ], |
| outputs=gr.Image(type="pil"), |
| examples=[["raccoon-101.jpg", True, True, 0.5]], |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |