James2236's picture
Uploading Trashify Box Detector Demo!
bb013f7 verified
# 1. Import required dependencies
import gradio as gr
import torch
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoImageProcessor, AutoModelForObjectDetection
# model_path = James2236/rt_detrv2_finetuned_trashify_box_detector_v1
# 2. Setup preprocessing and model functions
model_save_path = "James2236/rt_detrv2_finetuned_trashify_box_detector_v1"
image_processor = AutoImageProcessor.from_pretrained(model_save_path)
image_processor.size = {"height": 640,
"width": 640}
# Load the model
model = AutoModelForObjectDetection.from_pretrained(model_save_path)
# Setup the target device (GPU if it's accessable)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
# Get the id2label dictionary from the model
id2label = model.config.id2label
# Setup a colour dictionary
colour_dict = {"bin": "green",
"trash": "blue",
"hand": "purple",
"trash_arm": "yellow",
"not_trash": "red",
"not_bin": "red",
"not_hand": "red"}
# 3. Create a function 'predict_on_image
def predict_on_image(image, conf_threshold):
model.eval()
# Make a prediction on the target image
with torch.no_grad():
inputs= image_processor(images=image, return_tensors="pt")
model_outputs= model(**inputs.to(device))
# Get the original size of the image
target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # -> [batch_size, height, width]
# Post process the raw outputs from the model
results = image_processor.post_process_object_detection(model_outputs,
threshold=conf_threshold,
target_sizes=target_sizes)[0]
# Return all data items/objects to the CPU if they aren't already there
for key, value in results.items():
try:
results[key] = value.item().cpu() # can't get scalars as .item() so add try/except block
except:
results[key] = value.cpu()
# 4. Draw the predictions on the target image
draw = ImageDraw.Draw(image)
# Get a font to draw on our image
font = ImageFont.load_default(size=20)
# Get a list of the detected class names
detected_class_names_text_labels = []
# Iterate through the predictions of the model and draw on the target image
for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
# Create to coordinates
x, y, x2, y2 = tuple(box.tolist()) # XYXY
# Get the text-based label
label_name = id2label[label.item()]
targ_colour = colour_dict[label_name]
detected_class_names_text_labels.append(label_name)
# Draw the bounding box
draw.rectangle(xy=(x, y, x2, y2),
outline=targ_colour,
width=3)
# Create the text to display on the box
text_string_to_show = f"{label_name} {round(score.item(), 4)}"
# Draw the text on the image
draw.text(xy=(x, y),
text=text_string_to_show,
fill="white",
font=font)
# Remove the draw each time to make sure it doesn't get caught in memory
del draw
# 5. We'll write some logic to display a string whether the target items (bin, trash, hand) are present for +1 points
# Setup set of target items to discover
target_items = {"trash", "bin", "hand"}
detected_items = set(detected_class_names_text_labels)
# If no items detected or bin, trash, hand not in detected items, return a notification
if not detected_items & target_items:
return_string = (f"No trash, bin, or hand detected at confidence threshold {conf_threshold}."
"Try another image or lowering the confidence threshold.")
print(return_string)
return image, return_string
# If there are items missing, output what's missing for + 1 point
missing_items = target_items - detected_items
if missing_items:
return_string = (f"Detected the following items: {sorted(detected_items & target_items)}."
f"Missing the following: {missing_items}."
"In order to get + 1 points all items need to be detected.")
print(return_string)
return image, return_string
# Final case, all items are detected
return_string = f"+1 point! Found the following items: {sorted(detected_items)}, thank you for cleaning your local area!"
print(return_string)
return image, return_string
# 6. Setup the demo application to take in an image/conf threshold, pass it through our function & show the output image/text
description = """
Help clean up your local area! Upload an image and get +1 if there is all of the following items detected: trash, bin, hand.
Model is a fine-tuned version of [RT-DETRv2](https://huggingface.co/docs/transformers/main/en/model_doc/rt_detr_v2#transformers.RTDetrV2Config) on the [Trashify dataset](https://huggingface.co/datasets/mrdbourke/trashify_manual_labelled_images).
See the full data loading and training code on [learnhuggingface.com](https://www.learnhuggingface.com/notebooks/hugging_face_object_detection_tutorial).
"""
# Create Gradio interface
demo = gr.Interface(fn=predict_on_image,
inputs=[gr.Image(type="pil", label="Target Input Image"),
gr.Slider(minimum=0, maximum=1, value=0.3, label="Confidence Threshold (set higher for more confident boxes)")],
outputs=[gr.Image(type="pil", label="Target Image Output"),
gr.Text(label="Text Output")],
description=description,
title="🚮 Trashify Object Detection",
examples=[["trashify_examples/trashify_example_1.jpeg", 0.3],
["trashify_examples/trashify_example_2.jpeg", 0.3],
["trashify_examples/trashify_example_3.jpeg", 0.3]],
cache_examples=True
)
# Launch Demo
demo.launch(debug=False)