# 1. Import required dependencies import gradio as gr import torch from PIL import Image, ImageDraw, ImageFont from transformers import AutoImageProcessor, AutoModelForObjectDetection # model_path = James2236/rt_detrv2_finetuned_trashify_box_detector_v1 # 2. Setup preprocessing and model functions model_save_path = "James2236/rt_detrv2_finetuned_trashify_box_detector_v1" image_processor = AutoImageProcessor.from_pretrained(model_save_path) image_processor.size = {"height": 640, "width": 640} # Load the model model = AutoModelForObjectDetection.from_pretrained(model_save_path) # Setup the target device (GPU if it's accessable) device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) # Get the id2label dictionary from the model id2label = model.config.id2label # Setup a colour dictionary colour_dict = {"bin": "green", "trash": "blue", "hand": "purple", "trash_arm": "yellow", "not_trash": "red", "not_bin": "red", "not_hand": "red"} # 3. Create a function 'predict_on_image def predict_on_image(image, conf_threshold): model.eval() # Make a prediction on the target image with torch.no_grad(): inputs= image_processor(images=image, return_tensors="pt") model_outputs= model(**inputs.to(device)) # Get the original size of the image target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # -> [batch_size, height, width] # Post process the raw outputs from the model results = image_processor.post_process_object_detection(model_outputs, threshold=conf_threshold, target_sizes=target_sizes)[0] # Return all data items/objects to the CPU if they aren't already there for key, value in results.items(): try: results[key] = value.item().cpu() # can't get scalars as .item() so add try/except block except: results[key] = value.cpu() # 4. Draw the predictions on the target image draw = ImageDraw.Draw(image) # Get a font to draw on our image font = ImageFont.load_default(size=20) # Get a list of the detected class names detected_class_names_text_labels = [] # Iterate through the predictions of the model and draw on the target image for box, score, label in zip(results["boxes"], results["scores"], results["labels"]): # Create to coordinates x, y, x2, y2 = tuple(box.tolist()) # XYXY # Get the text-based label label_name = id2label[label.item()] targ_colour = colour_dict[label_name] detected_class_names_text_labels.append(label_name) # Draw the bounding box draw.rectangle(xy=(x, y, x2, y2), outline=targ_colour, width=3) # Create the text to display on the box text_string_to_show = f"{label_name} {round(score.item(), 4)}" # Draw the text on the image draw.text(xy=(x, y), text=text_string_to_show, fill="white", font=font) # Remove the draw each time to make sure it doesn't get caught in memory del draw # 5. We'll write some logic to display a string whether the target items (bin, trash, hand) are present for +1 points # Setup set of target items to discover target_items = {"trash", "bin", "hand"} detected_items = set(detected_class_names_text_labels) # If no items detected or bin, trash, hand not in detected items, return a notification if not detected_items & target_items: return_string = (f"No trash, bin, or hand detected at confidence threshold {conf_threshold}." "Try another image or lowering the confidence threshold.") print(return_string) return image, return_string # If there are items missing, output what's missing for + 1 point missing_items = target_items - detected_items if missing_items: return_string = (f"Detected the following items: {sorted(detected_items & target_items)}." f"Missing the following: {missing_items}." "In order to get + 1 points all items need to be detected.") print(return_string) return image, return_string # Final case, all items are detected return_string = f"+1 point! Found the following items: {sorted(detected_items)}, thank you for cleaning your local area!" print(return_string) return image, return_string # 6. Setup the demo application to take in an image/conf threshold, pass it through our function & show the output image/text description = """ Help clean up your local area! Upload an image and get +1 if there is all of the following items detected: trash, bin, hand. Model is a fine-tuned version of [RT-DETRv2](https://huggingface.co/docs/transformers/main/en/model_doc/rt_detr_v2#transformers.RTDetrV2Config) on the [Trashify dataset](https://huggingface.co/datasets/mrdbourke/trashify_manual_labelled_images). See the full data loading and training code on [learnhuggingface.com](https://www.learnhuggingface.com/notebooks/hugging_face_object_detection_tutorial). """ # Create Gradio interface demo = gr.Interface(fn=predict_on_image, inputs=[gr.Image(type="pil", label="Target Input Image"), gr.Slider(minimum=0, maximum=1, value=0.3, label="Confidence Threshold (set higher for more confident boxes)")], outputs=[gr.Image(type="pil", label="Target Image Output"), gr.Text(label="Text Output")], description=description, title="🚮 Trashify Object Detection", examples=[["trashify_examples/trashify_example_1.jpeg", 0.3], ["trashify_examples/trashify_example_2.jpeg", 0.3], ["trashify_examples/trashify_example_3.jpeg", 0.3]], cache_examples=True ) # Launch Demo demo.launch(debug=False)