Ibrahim125's picture
Uploading Trashify Box Detector Demo from the videos!
5cf2d0d verified
# 1. Import the required libraries and packages
import gradio as gr
import torch
from PIL import Image, ImageDraw, ImageFont # Could also use torch utilities for drawing
from transformers import AutoImageProcessor
from transformers import AutoModelForObjectDetection
### 2. Setup preprocessing and helper functions ###
# Setup target model path to load
# Note: Can load from Hugging Face or can load from local
model_save_path = "Ibrahim125/rt_detrv2_finetuned_trashify_box_detector_v1"
# Load the model and preprocessor
# Because this app.py file is running directly on Hugging Face Spaces, the model will be loaded from the Hugging Face Hub
image_processor = AutoImageProcessor.from_pretrained(model_save_path)
image_processor.size = {'height': 640, 'width': 640}
model = AutoModelForObjectDetection.from_pretrained(model_save_path)
# Set the target device (use CUDA/GPU if it is available)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
# Get the id2label dictionary from the model
id2label = model.config.id2label
# Set up a colour dictionary for plotting boxes with different colours
color_dict = {
"bin": "green",
"trash": "blue",
"hand": "purple",
"trash_arm": "yellow",
"not_trash": "red",
"not_bin": "red",
"not_hand": "red",
}
### 3. Create function to predict on a given image with a given confidence threshold ###
def predict_on_image(image, conf_threshold):
# Make sure model is in eval mode
model.eval()
# Make a prediction on target image
with torch.no_grad():
inputs = image_processor(images = [image], return_tensors = "pt")
model_outputs = model(**inputs.to(device))
target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # -> [batch_size, height, width]
# Post process the raw outputs from the model
results = image_processor.post_process_object_detection(model_outputs,
threshold = conf_threshold,
target_sizes = target_sizes)[0]
# Return all items in results to CPU (we'll want this for displaying outputs with matplotlib)
for key, value in results.items():
try:
results[key] = value.item().cpu() # can't get scalar as .item() so add try/except block
except:
results[key] = value.cpu()
### 4. Draw the predictions on the target image ###
# Can return results as plotted on a PIL image (then display the image)
draw = ImageDraw.Draw(image)
# Get a font from ImageFont
font = ImageFont.load_default(size=20)
# Get class names as text for print out
detected_class_name_text_labels = []
# Iterate through the predictions of the model and draw them on the target image
for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
# Create coordinates
x, y, x2, y2 = tuple(box.tolist())
# Get label_name
label_name = id2label[label.item()]
targ_color = color_dict[label_name]
detected_class_name_text_labels.append(label_name)
# Draw the rectangle
draw.rectangle(xy = (x, y, x2, y2),
outline = targ_color,
width = 3)
# Create a text string to display
text_string_to_show = f"{label_name} ({round(score.item(), 3)})"
# Draw the text on the image
draw.text(xy = (x, y),
text = text_string_to_show,
fill = "white",
font = font)
# Remove the draw each time
del draw
### 5. Create logic for outputting information message ###
# Setup set of target items to discover
target_items = {"trash", "bin", "hand"}
detected_items = set(detected_class_name_text_labels)
# If no items detected or trash, bin, hand not in detected items, return notification
if not detected_items & target_items:
return_string = (
f"No trash, bin or hand detected at confidence threshold {conf_threshold}. "
"Try another image or lowering the confidence threshold."
)
print(return_string)
return image, return_string
# If there are missing items, say what the missing items are
missing_items = target_items - detected_items
if missing_items:
return_string = (
f"Detected the following items: {sorted(detected_items & target_items)}. But missing the following in order to get +1: {sorted(missing_items)}. "
"If this is an error, try another image or altering the confidence threshold. "
"Otherwise, the model may need to be updated with better data."
)
print(return_string)
return image, return_string
# Final case, all items detected
return_string = f"+1! Found the following items: {sorted(detected_items)}, thank you for cleaning up the area!"
print(return_string)
return image, return_string
### 6. Setup the demo application to take in image/threshold, pass it through our function, show the output image/text
description = """
Help clean up your local area! Upload an image and get +1 if there is all of the following items detected: trash, bin, hand.
Model is a fine-tuned version of [RT-DETRv2](https://huggingface.co/docs/transformers/main/en/model_doc/rt_detr_v2#transformers.RTDetrV2Config) on the [Trashify dataset](https://huggingface.co/datasets/mrdbourke/trashify_manual_labelled_images).
See the full data loading and training code on [learnhuggingface.com](https://www.learnhuggingface.com/notebooks/hugging_face_object_detection_tutorial).
This version is v4 because the first three versions were using a different model and did not perform as well, see the [README](https://huggingface.co/spaces/mrdbourke/trashify_demo_v4/blob/main/README.md) for more.
"""
# Create the Gradio interface
demo = gr.Interface(
fn = predict_on_image,
inputs = [
gr.Image(type = "pil", label = "Target Input Image"),
gr.Slider(minimum = 0, maximum = 1, value = 0.3, label = "Confidence Threshold (set higher for more confident boxes)")
],
outputs = [
gr.Image(type = 'pil', label = 'Target Image Outputs'),
gr.Text(label = 'Text output')
],
title = "♻️ Trashify Object Detection Demo V4",
description = description,
examples = [
['trashify_examples/trashify_example_1.jpeg', 0.3],
['trashify_examples/trashify_example_2.jpeg', 0.3],
['trashify_examples/trashify_example_3.jpeg', 0.3],
],
cache_examples = True
)
# Launch demo
# demo.launch(debug = True) # Run with debug = True to see errors in Google Colab
demo.launch(share = False)