Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torchvision | |
| from torchvision.models.detection import maskrcnn_resnet50_fpn | |
| from torchvision.transforms import functional as F | |
| import torchvision.ops as ops | |
| import numpy as np | |
| from PIL import Image, ImageDraw, ImageFont | |
| import colorsys | |
| from huggingface_hub import hf_hub_download | |
| import json | |
| import gc | |
| # Download model from Hugging Face Hub | |
| def load_model(): | |
| model_path = hf_hub_download( | |
| repo_id="Kunitomi/coffee-bean-maskrcnn", | |
| filename="maskrcnn_coffeebeans_v1.safetensors" | |
| ) | |
| model = maskrcnn_resnet50_fpn( | |
| num_classes=2, # background + bean | |
| box_detections_per_img=300 # Increase from default 100 | |
| ) | |
| from safetensors.torch import load_file | |
| state_dict = load_file(model_path) | |
| model.load_state_dict(state_dict) | |
| # Force CPU mode to reduce memory usage | |
| model = model.cpu() | |
| model.eval() | |
| return model | |
| # Load model once at startup | |
| model = load_model() | |
| # Pre-generate colors for visualization | |
| def generate_colors(n=20): | |
| """Generate n distinct colors using HSV color space.""" | |
| colors = [] | |
| for i in range(n): | |
| hue = i / n | |
| saturation = 0.8 + 0.2 * (i % 2) # Alternate between 0.8 and 1.0 | |
| value = 0.8 + 0.2 * ((i + 1) % 2) # Alternate between 0.8 and 1.0 | |
| rgb = colorsys.hsv_to_rgb(hue, saturation, value) | |
| colors.append(tuple(int(255 * c) for c in rgb)) | |
| return colors | |
| COLORS = generate_colors(20) | |
| def draw_detection_pil(image, predictions, bean_count, show_confidence=True): | |
| """Fast PIL-based visualization instead of matplotlib.""" | |
| # Create a copy of the image to draw on | |
| result_img = image.copy() | |
| draw = ImageDraw.Draw(result_img) | |
| # Try to load a font, fall back to default if not available | |
| try: | |
| font = ImageFont.truetype("arial.ttf", 16) | |
| except: | |
| try: | |
| font = ImageFont.load_default() | |
| except: | |
| font = None | |
| # Draw each detection | |
| for i in range(bean_count): | |
| color = COLORS[i % len(COLORS)] | |
| # Get detection data (already numpy arrays) | |
| box = predictions['boxes'][i] | |
| score = float(predictions['scores'][i]) | |
| mask = predictions['masks'][i][0] | |
| x1, y1, x2, y2 = box.astype(int) | |
| # Create mask overlay - resize mask to match image size | |
| mask_resized = Image.fromarray((mask * 255).astype(np.uint8), mode='L').resize(result_img.size, Image.NEAREST) | |
| # Create colored overlay for this mask | |
| colored_mask = Image.new('RGBA', result_img.size, (*color, 120)) # Semi-transparent colored overlay | |
| # Apply mask transparency | |
| colored_mask.putalpha(mask_resized) | |
| # Composite the mask overlay onto the result image | |
| result_img = result_img.convert('RGBA') | |
| result_img = Image.alpha_composite(result_img, colored_mask) | |
| result_img = result_img.convert('RGB') | |
| draw = ImageDraw.Draw(result_img) | |
| # Draw confidence score or bean number (no bounding box) | |
| if show_confidence: | |
| label = f"{score:.2f}" | |
| else: | |
| label = f"#{i+1}" | |
| if font: | |
| # Get text size for background | |
| bbox = draw.textbbox((0, 0), label, font=font) | |
| text_width = bbox[2] - bbox[0] | |
| text_height = bbox[3] - bbox[1] | |
| else: | |
| text_width, text_height = 30, 15 # Fallback size | |
| # Draw text background | |
| text_bg_coords = [x1, y1 - text_height - 4, x1 + text_width + 8, y1] | |
| draw.rectangle(text_bg_coords, fill=color) | |
| # Draw text | |
| draw.text((x1 + 4, y1 - text_height - 2), label, fill='white', font=font) | |
| return result_img | |
| def create_json_output(predictions, bean_count): | |
| """Create JSON output with detection results.""" | |
| json_data = { | |
| "bean_count": bean_count, | |
| "detections": [] | |
| } | |
| for i in range(bean_count): | |
| detection = { | |
| "bean_id": i + 1, | |
| "confidence": float(predictions['scores'][i]), | |
| "bbox": predictions['boxes'][i].tolist(), | |
| "mask_area": float((predictions['masks'][i][0] > 0.5).sum()) | |
| } | |
| json_data["detections"].append(detection) | |
| if bean_count > 0: | |
| json_data["average_confidence"] = float(np.mean(predictions['scores'])) | |
| return json.dumps(json_data, indent=2) | |
| def predict_beans(image, confidence_threshold, nms_threshold, max_detections, show_confidence): | |
| """Run inference on uploaded image.""" | |
| if image is None: | |
| return None, "Please upload an image first.", None | |
| # Convert to PIL if needed | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| # Convert to RGB | |
| image = image.convert('RGB') | |
| # Preprocess image - ensure CPU tensor | |
| image_tensor = F.to_tensor(image).unsqueeze(0).cpu() | |
| # Run inference | |
| with torch.no_grad(): | |
| predictions = model(image_tensor)[0] | |
| # Apply NMS | |
| keep = ops.nms(predictions['boxes'], predictions['scores'], nms_threshold) | |
| # Convert to numpy immediately to free tensor memory | |
| boxes_np = predictions['boxes'][keep].cpu().numpy() | |
| scores_np = predictions['scores'][keep].cpu().numpy() | |
| labels_np = predictions['labels'][keep].cpu().numpy() | |
| masks_np = predictions['masks'][keep].cpu().numpy() | |
| # Delete original predictions to free memory | |
| del predictions | |
| # Filter by confidence threshold | |
| mask = scores_np > confidence_threshold | |
| filtered_predictions = { | |
| 'boxes': boxes_np[mask], | |
| 'labels': labels_np[mask], | |
| 'scores': scores_np[mask], | |
| 'masks': masks_np[mask] | |
| } | |
| # Clean up intermediate arrays | |
| del boxes_np, scores_np, labels_np, masks_np | |
| # Limit number of detections | |
| if len(filtered_predictions['boxes']) > max_detections: | |
| # Keep top detections by confidence | |
| k = min(max_detections, len(filtered_predictions['scores'])) | |
| top_indices = np.argpartition(filtered_predictions['scores'], -k)[-k:] | |
| top_indices = top_indices[np.argsort(filtered_predictions['scores'][top_indices])[::-1]] | |
| filtered_predictions = {key: val[top_indices] for key, val in filtered_predictions.items()} | |
| bean_count = len(filtered_predictions['boxes']) | |
| # Create fast PIL-based visualization | |
| if bean_count > 0: | |
| result_image = draw_detection_pil(image, filtered_predictions, bean_count, show_confidence) | |
| else: | |
| result_image = image.copy() | |
| # Create summary text | |
| if bean_count > 0: | |
| avg_confidence = np.mean(filtered_predictions['scores']) | |
| summary = f"**Detected {bean_count} coffee beans** with {avg_confidence:.1%} average confidence" | |
| else: | |
| summary = "**No beans detected.** Try lowering the confidence threshold or check image quality." | |
| # Create JSON output for download | |
| json_output = create_json_output(filtered_predictions, bean_count) | |
| # Clean up memory | |
| del filtered_predictions | |
| del image_tensor | |
| gc.collect() | |
| # Clear any PyTorch cache | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return result_image, summary, json_output | |
| # Example images | |
| examples = [ | |
| ["examples/green_beans.png", 0.5, 0.5, 300, True], | |
| ["examples/roasted_beans.png", 0.5, 0.3, 300, True], | |
| ] | |
| # Create Gradio interface | |
| with gr.Blocks(title="Coffee Bean Detection", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # ☕ Coffee Bean Detection with Mask R-CNN | |
| Upload an image of coffee beans to detect and segment individual beans using a fine-tuned Mask R-CNN model. | |
| **Default: Up to 300 beans detected** | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Input controls | |
| input_image = gr.Image( | |
| type="pil", | |
| label="Upload Coffee Bean Image", | |
| height=400 | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| confidence_threshold = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.9, | |
| value=0.5, | |
| step=0.05, | |
| label="Confidence Threshold", | |
| info="Higher values = fewer but more confident detections" | |
| ) | |
| nms_threshold = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.8, | |
| value=0.5, | |
| step=0.05, | |
| label="NMS Threshold", | |
| info="Lower values = less overlap between detections" | |
| ) | |
| max_detections = gr.Slider( | |
| minimum=10, | |
| maximum=300, | |
| value=300, | |
| step=10, | |
| label="Maximum Detections", | |
| info="Limit total number of detections shown" | |
| ) | |
| show_confidence = gr.Checkbox( | |
| value=True, | |
| label="Show Confidence Scores", | |
| info="Show confidence scores instead of bean numbers" | |
| ) | |
| detect_btn = gr.Button("🔍 Detect Beans", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| # Output | |
| output_image = gr.Image(label="Detection Results", height=400) | |
| results_text = gr.Markdown() | |
| json_download = gr.JSON(label="📥 Detection Data (Copy or Download)", visible=True) | |
| # Event handlers | |
| detect_btn.click( | |
| fn=predict_beans, | |
| inputs=[input_image, confidence_threshold, nms_threshold, max_detections, show_confidence], | |
| outputs=[output_image, results_text, json_download] | |
| ) | |
| # Auto-detect when image is uploaded | |
| input_image.change( | |
| fn=predict_beans, | |
| inputs=[input_image, confidence_threshold, nms_threshold, max_detections, show_confidence], | |
| outputs=[output_image, results_text, json_download] | |
| ) | |
| # Examples section | |
| gr.Markdown("## 📸 Try These Examples") | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[input_image, confidence_threshold, nms_threshold, max_detections, show_confidence], | |
| outputs=[output_image, results_text, json_download], | |
| fn=predict_beans, | |
| cache_examples=True | |
| ) | |
| # Footer | |
| gr.Markdown(""" | |
| --- | |
| **Model Details:** | |
| - Architecture: Mask R-CNN with ResNet-50 FPN backbone | |
| - Framework: PyTorch/TorchVision | |
| - Fine-tuned on 128 coffee bean images | |
| - Model size: 176MB (SafeTensors format) | |
| **Links:** | |
| - 🤗 [Model on Hugging Face](https://huggingface.co/Kunitomi/coffee-bean-maskrcnn) | |
| Built by [Mark Kunitomi](https://huggingface.co/Kunitomi) | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() |