Kunitomi commited on
Commit
b5ef777
·
0 Parent(s):

feat: add confidence/bean number toggle and fix mask visualization

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Coffee Bean Detection
3
+ emoji: ☕
4
+ colorFrom: red
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ # ☕ Coffee Bean Detection with Mask R-CNN
14
+
15
+ An interactive demo for detecting and segmenting coffee beans using a fine-tuned Mask R-CNN model.
16
+
17
+ ## Features
18
+
19
+ 🎯 **High Accuracy Detection**
20
+ - Precision: 99.92%
21
+ - Recall: 96.71%
22
+ - Average IoU: 90.93%
23
+
24
+ 🔧 **Adjustable Parameters**
25
+ - Confidence threshold for detection sensitivity
26
+ - NMS threshold for overlap handling
27
+ - Maximum detection limits
28
+
29
+ 📊 **Detailed Results**
30
+ - Individual bean segmentation masks
31
+ - Confidence scores for each detection
32
+ - Summary statistics
33
+
34
+ ## How to Use
35
+
36
+ 1. **Upload an Image**: Drop or select an image of coffee beans
37
+ 2. **Adjust Settings** (optional): Fine-tune detection parameters
38
+ 3. **View Results**: See detected beans with masks and confidence scores
39
+
40
+ ## Model Details
41
+
42
+ - **Architecture**: Mask R-CNN with ResNet-50 FPN backbone
43
+ - **Framework**: PyTorch/TorchVision
44
+ - **Training**: Fine-tuned on 128 coffee bean images
45
+ - **Hardware**: Trained on Mac Mini M2 (CPU only)
46
+ - **Model Size**: 176MB in SafeTensors format
47
+
48
+ ## Applications
49
+
50
+ - Coffee bean quality control
51
+ - Automated inventory counting
52
+ - Bean size and shape analysis
53
+ - Agricultural research
54
+ - Educational demonstrations
55
+
56
+ ## Links
57
+
58
+ - 🤗 [Model Repository](https://huggingface.co/Kunitomi/coffee-bean-maskrcnn)
59
+ - 💻 [Source Code](https://github.com/Markkunitomi/bean-vision)
60
+ - 📖 [Documentation](https://github.com/Markkunitomi/bean-vision/blob/main/README.md)
61
+
62
+ ---
63
+
64
+ Built by [Mark Kunitomi](https://huggingface.co/Kunitomi)
app.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision
4
+ from torchvision.models.detection import maskrcnn_resnet50_fpn
5
+ from torchvision.transforms import functional as F
6
+ import torchvision.ops as ops
7
+ import numpy as np
8
+ from PIL import Image, ImageDraw, ImageFont
9
+ import colorsys
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ # Download model from Hugging Face Hub
13
+ @torch.no_grad()
14
+ def load_model():
15
+ model_path = hf_hub_download(
16
+ repo_id="Kunitomi/coffee-bean-maskrcnn",
17
+ filename="maskrcnn_coffeebeans_v1.safetensors"
18
+ )
19
+
20
+ model = maskrcnn_resnet50_fpn(num_classes=2) # background + bean
21
+
22
+ from safetensors.torch import load_file
23
+ state_dict = load_file(model_path)
24
+ model.load_state_dict(state_dict)
25
+ model.eval()
26
+
27
+ return model
28
+
29
+ # Load model once at startup
30
+ model = load_model()
31
+
32
+ # Pre-generate colors for visualization
33
+ def generate_colors(n=20):
34
+ """Generate n distinct colors using HSV color space."""
35
+ colors = []
36
+ for i in range(n):
37
+ hue = i / n
38
+ saturation = 0.8 + 0.2 * (i % 2) # Alternate between 0.8 and 1.0
39
+ value = 0.8 + 0.2 * ((i + 1) % 2) # Alternate between 0.8 and 1.0
40
+ rgb = colorsys.hsv_to_rgb(hue, saturation, value)
41
+ colors.append(tuple(int(255 * c) for c in rgb))
42
+ return colors
43
+
44
+ COLORS = generate_colors(20)
45
+
46
+ def draw_detection_pil(image, predictions, bean_count, show_confidence=True):
47
+ """Fast PIL-based visualization instead of matplotlib."""
48
+ # Create a copy of the image to draw on
49
+ result_img = image.copy()
50
+ draw = ImageDraw.Draw(result_img)
51
+
52
+ # Try to load a font, fall back to default if not available
53
+ try:
54
+ font = ImageFont.truetype("arial.ttf", 16)
55
+ except:
56
+ try:
57
+ font = ImageFont.load_default()
58
+ except:
59
+ font = None
60
+
61
+ # Draw each detection
62
+ for i in range(bean_count):
63
+ color = COLORS[i % len(COLORS)]
64
+
65
+ # Get detection data
66
+ box = predictions['boxes'][i].cpu().numpy()
67
+ score = predictions['scores'][i].cpu().item()
68
+ mask = predictions['masks'][i][0].cpu().numpy()
69
+
70
+ x1, y1, x2, y2 = box.astype(int)
71
+
72
+ # Create mask overlay - resize mask to match image size
73
+ mask_resized = Image.fromarray((mask * 255).astype(np.uint8), mode='L').resize(result_img.size, Image.NEAREST)
74
+
75
+ # Create colored overlay for this mask
76
+ colored_mask = Image.new('RGBA', result_img.size, (*color, 120)) # Semi-transparent colored overlay
77
+
78
+ # Apply mask transparency
79
+ colored_mask.putalpha(mask_resized)
80
+
81
+ # Composite the mask overlay onto the result image
82
+ result_img = result_img.convert('RGBA')
83
+ result_img = Image.alpha_composite(result_img, colored_mask)
84
+ result_img = result_img.convert('RGB')
85
+ draw = ImageDraw.Draw(result_img)
86
+
87
+ # Draw bounding box
88
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
89
+
90
+ # Draw confidence score or bean number
91
+ if show_confidence:
92
+ label = f"{score:.2f}"
93
+ else:
94
+ label = f"#{i+1}"
95
+
96
+ if font:
97
+ # Get text size for background
98
+ bbox = draw.textbbox((0, 0), label, font=font)
99
+ text_width = bbox[2] - bbox[0]
100
+ text_height = bbox[3] - bbox[1]
101
+ else:
102
+ text_width, text_height = 30, 15 # Fallback size
103
+
104
+ # Draw text background
105
+ text_bg_coords = [x1, y1 - text_height - 4, x1 + text_width + 8, y1]
106
+ draw.rectangle(text_bg_coords, fill=color)
107
+
108
+ # Draw text
109
+ draw.text((x1 + 4, y1 - text_height - 2), label, fill='white', font=font)
110
+
111
+ return result_img
112
+
113
+ def predict_beans(image, confidence_threshold, nms_threshold, max_detections, show_confidence):
114
+ """Run inference on uploaded image."""
115
+ if image is None:
116
+ return None, "Please upload an image first."
117
+
118
+ # Convert to PIL if needed
119
+ if not isinstance(image, Image.Image):
120
+ image = Image.fromarray(image)
121
+
122
+ # Convert to RGB
123
+ image = image.convert('RGB')
124
+
125
+ # Preprocess image
126
+ image_tensor = F.to_tensor(image).unsqueeze(0)
127
+
128
+ # Run inference
129
+ with torch.no_grad():
130
+ predictions = model(image_tensor)[0]
131
+
132
+ # Apply NMS
133
+ keep = ops.nms(predictions['boxes'], predictions['scores'], nms_threshold)
134
+ predictions = {k: v[keep] for k, v in predictions.items()}
135
+
136
+ # Filter by confidence threshold
137
+ mask = predictions['scores'] > confidence_threshold
138
+ filtered_predictions = {
139
+ 'boxes': predictions['boxes'][mask],
140
+ 'labels': predictions['labels'][mask],
141
+ 'scores': predictions['scores'][mask],
142
+ 'masks': predictions['masks'][mask]
143
+ }
144
+
145
+ # Limit number of detections
146
+ if len(filtered_predictions['boxes']) > max_detections:
147
+ # Keep top detections by confidence
148
+ top_indices = torch.topk(filtered_predictions['scores'], max_detections)[1]
149
+ filtered_predictions = {k: v[top_indices] for k, v in filtered_predictions.items()}
150
+
151
+ bean_count = len(filtered_predictions['boxes'])
152
+
153
+ # Create fast PIL-based visualization
154
+ if bean_count > 0:
155
+ result_image = draw_detection_pil(image, filtered_predictions, bean_count, show_confidence)
156
+ else:
157
+ result_image = image.copy()
158
+
159
+ # Create summary text
160
+ if bean_count > 0:
161
+ avg_confidence = filtered_predictions['scores'].mean().item()
162
+ summary = f"**Detected {bean_count} coffee beans** with {avg_confidence:.1%} average confidence"
163
+ else:
164
+ summary = "**No beans detected.** Try lowering the confidence threshold or check image quality."
165
+
166
+ return result_image, summary
167
+
168
+ # Example images
169
+ examples = [
170
+ ["examples/green_beans.png", 0.5, 0.5, 300, True],
171
+ ["examples/roasted_beans.png", 0.5, 0.3, 300, True],
172
+ ]
173
+
174
+ # Create Gradio interface
175
+ with gr.Blocks(title="Coffee Bean Detection", theme=gr.themes.Soft()) as demo:
176
+ gr.Markdown("""
177
+ # ☕ Coffee Bean Detection with Mask R-CNN
178
+
179
+ Upload an image of coffee beans to detect and segment individual beans using a fine-tuned Mask R-CNN model.
180
+ """)
181
+
182
+ with gr.Row():
183
+ with gr.Column(scale=1):
184
+ # Input controls
185
+ input_image = gr.Image(
186
+ type="pil",
187
+ label="Upload Coffee Bean Image",
188
+ height=400
189
+ )
190
+
191
+ with gr.Accordion("Advanced Settings", open=False):
192
+ confidence_threshold = gr.Slider(
193
+ minimum=0.1,
194
+ maximum=0.9,
195
+ value=0.5,
196
+ step=0.05,
197
+ label="Confidence Threshold",
198
+ info="Higher values = fewer but more confident detections"
199
+ )
200
+
201
+ nms_threshold = gr.Slider(
202
+ minimum=0.1,
203
+ maximum=0.8,
204
+ value=0.5,
205
+ step=0.05,
206
+ label="NMS Threshold",
207
+ info="Lower values = less overlap between detections"
208
+ )
209
+
210
+ max_detections = gr.Slider(
211
+ minimum=10,
212
+ maximum=300,
213
+ value=300,
214
+ step=10,
215
+ label="Maximum Detections",
216
+ info="Limit total number of detections shown"
217
+ )
218
+
219
+ show_confidence = gr.Checkbox(
220
+ value=True,
221
+ label="Show Confidence Scores",
222
+ info="Show confidence scores instead of bean numbers"
223
+ )
224
+
225
+ detect_btn = gr.Button("🔍 Detect Beans", variant="primary", size="lg")
226
+
227
+ with gr.Column(scale=1):
228
+ # Output
229
+ output_image = gr.Image(label="Detection Results", height=400)
230
+ results_text = gr.Markdown()
231
+
232
+ # Event handlers
233
+ detect_btn.click(
234
+ fn=predict_beans,
235
+ inputs=[input_image, confidence_threshold, nms_threshold, max_detections, show_confidence],
236
+ outputs=[output_image, results_text]
237
+ )
238
+
239
+ # Auto-detect when image is uploaded
240
+ input_image.change(
241
+ fn=predict_beans,
242
+ inputs=[input_image, confidence_threshold, nms_threshold, max_detections, show_confidence],
243
+ outputs=[output_image, results_text]
244
+ )
245
+
246
+ # Examples section
247
+ gr.Markdown("## 📸 Try These Examples")
248
+ gr.Examples(
249
+ examples=examples,
250
+ inputs=[input_image, confidence_threshold, nms_threshold, max_detections, show_confidence],
251
+ outputs=[output_image, results_text],
252
+ fn=predict_beans,
253
+ cache_examples=True
254
+ )
255
+
256
+ # Footer
257
+ gr.Markdown("""
258
+ ---
259
+ **Model Details:**
260
+ - Architecture: Mask R-CNN with ResNet-50 FPN backbone
261
+ - Framework: PyTorch/TorchVision
262
+ - Fine-tuned on 128 coffee bean images
263
+ - Model size: 176MB (SafeTensors format)
264
+
265
+ **Links:**
266
+ - 🤗 [Model on Hugging Face](https://huggingface.co/Kunitomi/coffee-bean-maskrcnn)
267
+
268
+ Built by [Mark Kunitomi](https://huggingface.co/Kunitomi)
269
+ """)
270
+
271
+ if __name__ == "__main__":
272
+ demo.launch()
examples/green_beans.png ADDED

Git LFS Details

  • SHA256: e3a5142de33d011debf2828a49c614eb8c14fedce53fc4483ea25a2383a58369
  • Pointer size: 132 Bytes
  • Size of remote file: 7.95 MB
examples/roasted_beans.png ADDED

Git LFS Details

  • SHA256: 58bfc58ec9bfee8b350d99cf3fee7bf4ee2fc394707fa4a5c228119ca03b65b1
  • Pointer size: 132 Bytes
  • Size of remote file: 8.94 MB
pyproject.toml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "coffee-bean-detection-space"
3
+ version = "0.1.0"
4
+ description = "Coffee Bean Detection Gradio Space"
5
+ dependencies = [
6
+ "gradio>=4.44.0",
7
+ "torch>=2.0.0",
8
+ "torchvision>=0.15.0",
9
+ "pillow>=9.0.0",
10
+ "numpy>=1.21.0",
11
+ "safetensors>=0.3.0",
12
+ "huggingface-hub>=0.16.0",
13
+ ]
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio==4.44.0
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ pillow>=9.0.0
5
+ numpy>=1.21.0
6
+ safetensors>=0.3.0
7
+ huggingface-hub>=0.16.0
uv.lock ADDED
The diff for this file is too large to render. See raw diff