bigsmoke07 commited on
Commit
f329fb2
·
verified ·
1 Parent(s): 79bd333

Uploading Trashify box detection model app.py

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ trashify_examples/trashify_example_2.jpeg filter=lfs diff=lfs merge=lfs -text
37
+ trashify_examples/trashify_example_3.jpeg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,36 @@
1
  ---
2
- title: Trashify Demo
3
- emoji: 🌍
4
- colorFrom: blue
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 6.4.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Trashify Demo V4 🚮
3
+ emoji: 🗑️
4
+ colorFrom: purple
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 5.34.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
+ # 🚮 Trashify Object Detector V4
14
+
15
+ Object detection demo to detect `trash`, `bin`, `hand`, `trash_arm`, `not_trash`, `not_bin`, `not_hand`.
16
+
17
+ Used as example for encouraging people to cleanup their local area.
18
+
19
+ If `trash`, `hand`, `bin` all detected = +1 point.
20
+
21
+ ## Dataset
22
+
23
+ All Trashify models are trained on a custom hand-labelled dataset of people picking up trash and placing it in a bin.
24
+
25
+ The dataset can be found on Hugging Face as [`mrdbourke/trashify_manual_labelled_images`](https://huggingface.co/datasets/mrdbourke/trashify_manual_labelled_images).
26
+
27
+ ## Demos
28
+
29
+ * [V1](https://huggingface.co/spaces/mrdbourke/trashify_demo_v1) = Fine-tuned [Conditional DETR](https://huggingface.co/docs/transformers/en/model_doc/conditional_detr) model trained *without* data augmentation.
30
+ * [V2](https://huggingface.co/spaces/mrdbourke/trashify_demo_v2) = Fine-tuned Conditional DETR model trained *with* data augmentation.
31
+ * [V3](https://huggingface.co/spaces/mrdbourke/trashify_demo_v3) = Fine-tuned Conditional DETR model trained *with* data augmentation (same as V2) with an NMS (Non Maximum Suppression) post-processing step.
32
+ * [V4](https://huggingface.co/spaces/mrdbourke/trashify_demo_v4) = Fine-tuned [RT-DETRv2](https://huggingface.co/docs/transformers/main/en/model_doc/rt_detr_v2) model trained *without* data augmentation or NMS post-processing (current best mAP).
33
+
34
+ ## Learn more
35
+
36
+ See the full end-to-end code of how this demo was built at [learnhuggingface.com](https://www.learnhuggingface.com/notebooks/hugging_face_object_detection_tutorial).
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # 1. Import the required libraries and packages
3
+ import gradio as gr
4
+ import torch
5
+ from PIL import Image, ImageDraw, ImageFont # could also use torch utilities for drawing
6
+
7
+ from transformers import AutoImageProcessor
8
+ from transformers import AutoModelForObjectDetection
9
+
10
+ ### 2. Setup preprocessing and helper functions ###
11
+
12
+ # Setup target model path to load
13
+ # Note: Can load from Hugging Face or can load from local
14
+ model_save_path = "mrdbourke/rt_detrv2_finetuned_trashify_box_detector_v1"
15
+
16
+ # Load the model and preprocessor
17
+ # Because this app.py file is running directly on Hugging Face Spaces, the model will be loaded from the Hugging Face Hub
18
+ image_processor = AutoImageProcessor.from_pretrained(model_save_path)
19
+ model = AutoModelForObjectDetection.from_pretrained(model_save_path)
20
+
21
+ # Set the target device (use CUDA/GPU if it is available)
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ model = model.to(device)
24
+
25
+ # Get the id2label dictionary from the model
26
+ id2label = model.config.id2label
27
+
28
+ # Set up a colour dictionary for plotting boxes with different colours
29
+ color_dict = {
30
+ "bin": "green",
31
+ "trash": "blue",
32
+ "hand": "purple",
33
+ "trash_arm": "yellow",
34
+ "not_trash": "red",
35
+ "not_bin": "red",
36
+ "not_hand": "red",
37
+ }
38
+
39
+ ### 3. Create function to predict on a given image with a given confidence threshold ###
40
+ def predict_on_image(image, conf_threshold):
41
+ # Make sure model is in eval mode
42
+ model.eval()
43
+
44
+ # Make a prediction on target image
45
+ with torch.no_grad():
46
+ inputs = image_processor(images=[image], return_tensors="pt")
47
+ model_outputs = model(**inputs.to(device))
48
+
49
+ target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # -> [batch_size, height, width]
50
+
51
+ # Post process the raw outputs from the model
52
+ results = image_processor.post_process_object_detection(model_outputs,
53
+ threshold=conf_threshold,
54
+ target_sizes=target_sizes)[0]
55
+
56
+ # Return all items in results to CPU (we'll want this for displaying outputs with matplotlib)
57
+ for key, value in results.items():
58
+ try:
59
+ results[key] = value.item().cpu() # can't get scalar as .item() so add try/except block
60
+ except:
61
+ results[key] = value.cpu()
62
+
63
+ ### 4. Draw the predictions on the target image ###
64
+
65
+ # Can return results as plotted on a PIL image (then display the image)
66
+ draw = ImageDraw.Draw(image)
67
+
68
+ # Get a font from ImageFont
69
+ font = ImageFont.load_default(size=20)
70
+
71
+ # Get class names as text for print out
72
+ detected_class_name_text_labels = []
73
+
74
+ # Iterate through the predictions of the model and draw them on the target image
75
+ for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
76
+ # Create coordinates
77
+ x, y, x2, y2 = tuple(box.tolist())
78
+
79
+ # Get label_name
80
+ label_name = id2label[label.item()]
81
+ targ_color = color_dict[label_name]
82
+ detected_class_name_text_labels.append(label_name)
83
+
84
+ # Draw the rectangle
85
+ draw.rectangle(xy=(x, y, x2, y2),
86
+ outline=targ_color,
87
+ width=3)
88
+
89
+ # Create a text string to display
90
+ text_string_to_show = f"{label_name} ({round(score.item(), 3)})"
91
+
92
+ # Draw the text on the image
93
+ draw.text(xy=(x, y),
94
+ text=text_string_to_show,
95
+ fill="white",
96
+ font=font)
97
+
98
+ # Remove the draw each time
99
+ del draw
100
+
101
+ ### 5. Create logic for outputting information message ###
102
+
103
+ # Setup set of target items to discover
104
+ target_items = {"trash", "bin", "hand"}
105
+ detected_items = set(detected_class_name_text_labels)
106
+
107
+ # If no items detected or trash, bin, hand not in detected items, return notification
108
+ if not detected_items & target_items:
109
+ return_string = (
110
+ f"No trash, bin or hand detected at confidence threshold {conf_threshold}. "
111
+ "Try another image or lowering the confidence threshold."
112
+ )
113
+ print(return_string)
114
+ return image, return_string
115
+
116
+ # If there are missing items, say what the missing items are
117
+ missing_items = target_items - detected_items
118
+ if missing_items:
119
+ return_string = (
120
+ f"Detected the following items: {sorted(detected_items & target_items)}. But missing the following in order to get +1: {sorted(missing_items)}. "
121
+ "If this is an error, try another image or altering the confidence threshold. "
122
+ "Otherwise, the model may need to be updated with better data."
123
+ )
124
+ print(return_string)
125
+ return image, return_string
126
+
127
+ # If all target items are present (the final remaining case)
128
+ return_string = f"+1! Found the following items: {sorted(detected_items)}, thank you for cleaning up the area!"
129
+ print(return_string)
130
+ return image, return_string
131
+
132
+ ### 6. Setup the demo application to take in image, make a prediction with our model, return the image with drawn predicitons ###
133
+
134
+ # Write description for our demo application
135
+ description = """
136
+ Help clean up your local area! Upload an image and get +1 if there is all of the following items detected: trash, bin, hand.
137
+
138
+ Model is a fine-tuned version of [RT-DETRv2](https://huggingface.co/docs/transformers/main/en/model_doc/rt_detr_v2#transformers.RTDetrV2Config) on the [Trashify dataset](https://huggingface.co/datasets/mrdbourke/trashify_manual_labelled_images).
139
+
140
+ See the full data loading and training code on [learnhuggingface.com](https://www.learnhuggingface.com/notebooks/hugging_face_object_detection_tutorial).
141
+
142
+ This version is v4 because the first three versions were using a different model and did not perform as well, see the [README](https://huggingface.co/spaces/mrdbourke/trashify_demo_v4/blob/main/README.md) for more.
143
+ """
144
+
145
+ # Create the Gradio interface to accept an image and confidence threshold and return an image with drawn prediction boxes
146
+ demo = gr.Interface(
147
+ fn=predict_on_image,
148
+ inputs=[
149
+ gr.Image(type="pil", label="Target Image"),
150
+ gr.Slider(minimum=0, maximum=1, value=0.3, label="Confidence Threshold")
151
+ ],
152
+ outputs=[
153
+ gr.Image(type="pil", label="Image Output"),
154
+ gr.Text(label="Text Output")
155
+ ],
156
+ title="🚮 Trashify Object Detection Demo V4",
157
+ description=description,
158
+ # Examples come in the form of a list of lists, where each inner list contains elements to prefill the `inputs` parameter with
159
+ # See where the examples originate from here: https://huggingface.co/datasets/mrdbourke/trashify_examples/
160
+ examples=[
161
+ ["trashify_examples/trashify_example_1.jpeg", 0.3],
162
+ ["trashify_examples/trashify_example_2.jpeg", 0.3],
163
+ ["trashify_examples/trashify_example_3.jpeg", 0.3],
164
+ ],
165
+ cache_examples=True
166
+ )
167
+
168
+ # Launch the demo
169
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ timm
2
+ gradio
3
+ torch
4
+ transformers
trashify_examples/trashify_example_1.jpeg ADDED
trashify_examples/trashify_example_2.jpeg ADDED

Git LFS Details

  • SHA256: e1c170311bdc358d5158049f42aa38fba3794c91bcb2d11578f7eb92d924c55c
  • Pointer size: 131 Bytes
  • Size of remote file: 361 kB
trashify_examples/trashify_example_3.jpeg ADDED

Git LFS Details

  • SHA256: 666068a4e4e92384bce54c5f9fa533ccef96da46df065e8760f03d49a04e3fd3
  • Pointer size: 131 Bytes
  • Size of remote file: 278 kB