| |
|
| | |
| | import gradio as gr |
| | import torch |
| | from PIL import Image, ImageDraw, ImageFont |
| | from transformers import AutoImageProcessor, AutoModelForObjectDetection |
| |
|
| | |
| |
|
| | |
| | model_save_path = "James2236/rt_detrv2_finetuned_trashify_box_detector_v1" |
| |
|
| | image_processor = AutoImageProcessor.from_pretrained(model_save_path) |
| | image_processor.size = {"height": 640, |
| | "width": 640} |
| | |
| | model = AutoModelForObjectDetection.from_pretrained(model_save_path) |
| |
|
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | model = model.to(device) |
| |
|
| | |
| | id2label = model.config.id2label |
| |
|
| | |
| | colour_dict = {"bin": "green", |
| | "trash": "blue", |
| | "hand": "purple", |
| | "trash_arm": "yellow", |
| | "not_trash": "red", |
| | "not_bin": "red", |
| | "not_hand": "red"} |
| |
|
| | |
| | def predict_on_image(image, conf_threshold): |
| | model.eval() |
| |
|
| | |
| | with torch.no_grad(): |
| | inputs= image_processor(images=image, return_tensors="pt") |
| | model_outputs= model(**inputs.to(device)) |
| |
|
| | |
| | target_sizes = torch.tensor([[image.size[1], image.size[0]]]) |
| |
|
| | |
| | results = image_processor.post_process_object_detection(model_outputs, |
| | threshold=conf_threshold, |
| | target_sizes=target_sizes)[0] |
| |
|
| | |
| | for key, value in results.items(): |
| | try: |
| | results[key] = value.item().cpu() |
| | except: |
| | results[key] = value.cpu() |
| |
|
| | |
| | draw = ImageDraw.Draw(image) |
| |
|
| | |
| | font = ImageFont.load_default(size=20) |
| |
|
| | |
| | detected_class_names_text_labels = [] |
| |
|
| | |
| | for box, score, label in zip(results["boxes"], results["scores"], results["labels"]): |
| | |
| | x, y, x2, y2 = tuple(box.tolist()) |
| |
|
| | |
| | label_name = id2label[label.item()] |
| | targ_colour = colour_dict[label_name] |
| | detected_class_names_text_labels.append(label_name) |
| |
|
| | |
| | draw.rectangle(xy=(x, y, x2, y2), |
| | outline=targ_colour, |
| | width=3) |
| |
|
| | |
| | text_string_to_show = f"{label_name} {round(score.item(), 4)}" |
| |
|
| | |
| | draw.text(xy=(x, y), |
| | text=text_string_to_show, |
| | fill="white", |
| | font=font) |
| |
|
| | |
| | del draw |
| |
|
| | |
| |
|
| | |
| | target_items = {"trash", "bin", "hand"} |
| | detected_items = set(detected_class_names_text_labels) |
| |
|
| | |
| | if not detected_items & target_items: |
| | return_string = (f"No trash, bin, or hand detected at confidence threshold {conf_threshold}." |
| | "Try another image or lowering the confidence threshold.") |
| | print(return_string) |
| | return image, return_string |
| |
|
| | |
| | missing_items = target_items - detected_items |
| | if missing_items: |
| | return_string = (f"Detected the following items: {sorted(detected_items & target_items)}." |
| | f"Missing the following: {missing_items}." |
| | "In order to get + 1 points all items need to be detected.") |
| | print(return_string) |
| | return image, return_string |
| |
|
| | |
| | return_string = f"+1 point! Found the following items: {sorted(detected_items)}, thank you for cleaning your local area!" |
| | print(return_string) |
| | return image, return_string |
| |
|
| | |
| |
|
| | description = """ |
| | Help clean up your local area! Upload an image and get +1 if there is all of the following items detected: trash, bin, hand. |
| | |
| | Model is a fine-tuned version of [RT-DETRv2](https://huggingface.co/docs/transformers/main/en/model_doc/rt_detr_v2#transformers.RTDetrV2Config) on the [Trashify dataset](https://huggingface.co/datasets/mrdbourke/trashify_manual_labelled_images). |
| | |
| | See the full data loading and training code on [learnhuggingface.com](https://www.learnhuggingface.com/notebooks/hugging_face_object_detection_tutorial). |
| | |
| | """ |
| |
|
| | |
| | demo = gr.Interface(fn=predict_on_image, |
| | inputs=[gr.Image(type="pil", label="Target Input Image"), |
| | gr.Slider(minimum=0, maximum=1, value=0.3, label="Confidence Threshold (set higher for more confident boxes)")], |
| | outputs=[gr.Image(type="pil", label="Target Image Output"), |
| | gr.Text(label="Text Output")], |
| | description=description, |
| | title="🚮 Trashify Object Detection", |
| | examples=[["trashify_examples/trashify_example_1.jpeg", 0.3], |
| | ["trashify_examples/trashify_example_2.jpeg", 0.3], |
| | ["trashify_examples/trashify_example_3.jpeg", 0.3]], |
| | cache_examples=True |
| | ) |
| |
|
| | |
| | demo.launch(debug=False) |
| |
|