Kunitomi's picture
fix: optimize memory usage - force CPU mode, clean up tensors, use numpy arrays
2633efc
import gradio as gr
import torch
import torchvision
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.transforms import functional as F
import torchvision.ops as ops
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import colorsys
from huggingface_hub import hf_hub_download
import json
import gc
# Download model from Hugging Face Hub
@torch.no_grad()
def load_model():
model_path = hf_hub_download(
repo_id="Kunitomi/coffee-bean-maskrcnn",
filename="maskrcnn_coffeebeans_v1.safetensors"
)
model = maskrcnn_resnet50_fpn(
num_classes=2, # background + bean
box_detections_per_img=300 # Increase from default 100
)
from safetensors.torch import load_file
state_dict = load_file(model_path)
model.load_state_dict(state_dict)
# Force CPU mode to reduce memory usage
model = model.cpu()
model.eval()
return model
# Load model once at startup
model = load_model()
# Pre-generate colors for visualization
def generate_colors(n=20):
"""Generate n distinct colors using HSV color space."""
colors = []
for i in range(n):
hue = i / n
saturation = 0.8 + 0.2 * (i % 2) # Alternate between 0.8 and 1.0
value = 0.8 + 0.2 * ((i + 1) % 2) # Alternate between 0.8 and 1.0
rgb = colorsys.hsv_to_rgb(hue, saturation, value)
colors.append(tuple(int(255 * c) for c in rgb))
return colors
COLORS = generate_colors(20)
def draw_detection_pil(image, predictions, bean_count, show_confidence=True):
"""Fast PIL-based visualization instead of matplotlib."""
# Create a copy of the image to draw on
result_img = image.copy()
draw = ImageDraw.Draw(result_img)
# Try to load a font, fall back to default if not available
try:
font = ImageFont.truetype("arial.ttf", 16)
except:
try:
font = ImageFont.load_default()
except:
font = None
# Draw each detection
for i in range(bean_count):
color = COLORS[i % len(COLORS)]
# Get detection data (already numpy arrays)
box = predictions['boxes'][i]
score = float(predictions['scores'][i])
mask = predictions['masks'][i][0]
x1, y1, x2, y2 = box.astype(int)
# Create mask overlay - resize mask to match image size
mask_resized = Image.fromarray((mask * 255).astype(np.uint8), mode='L').resize(result_img.size, Image.NEAREST)
# Create colored overlay for this mask
colored_mask = Image.new('RGBA', result_img.size, (*color, 120)) # Semi-transparent colored overlay
# Apply mask transparency
colored_mask.putalpha(mask_resized)
# Composite the mask overlay onto the result image
result_img = result_img.convert('RGBA')
result_img = Image.alpha_composite(result_img, colored_mask)
result_img = result_img.convert('RGB')
draw = ImageDraw.Draw(result_img)
# Draw confidence score or bean number (no bounding box)
if show_confidence:
label = f"{score:.2f}"
else:
label = f"#{i+1}"
if font:
# Get text size for background
bbox = draw.textbbox((0, 0), label, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
else:
text_width, text_height = 30, 15 # Fallback size
# Draw text background
text_bg_coords = [x1, y1 - text_height - 4, x1 + text_width + 8, y1]
draw.rectangle(text_bg_coords, fill=color)
# Draw text
draw.text((x1 + 4, y1 - text_height - 2), label, fill='white', font=font)
return result_img
def create_json_output(predictions, bean_count):
"""Create JSON output with detection results."""
json_data = {
"bean_count": bean_count,
"detections": []
}
for i in range(bean_count):
detection = {
"bean_id": i + 1,
"confidence": float(predictions['scores'][i]),
"bbox": predictions['boxes'][i].tolist(),
"mask_area": float((predictions['masks'][i][0] > 0.5).sum())
}
json_data["detections"].append(detection)
if bean_count > 0:
json_data["average_confidence"] = float(np.mean(predictions['scores']))
return json.dumps(json_data, indent=2)
def predict_beans(image, confidence_threshold, nms_threshold, max_detections, show_confidence):
"""Run inference on uploaded image."""
if image is None:
return None, "Please upload an image first.", None
# Convert to PIL if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Convert to RGB
image = image.convert('RGB')
# Preprocess image - ensure CPU tensor
image_tensor = F.to_tensor(image).unsqueeze(0).cpu()
# Run inference
with torch.no_grad():
predictions = model(image_tensor)[0]
# Apply NMS
keep = ops.nms(predictions['boxes'], predictions['scores'], nms_threshold)
# Convert to numpy immediately to free tensor memory
boxes_np = predictions['boxes'][keep].cpu().numpy()
scores_np = predictions['scores'][keep].cpu().numpy()
labels_np = predictions['labels'][keep].cpu().numpy()
masks_np = predictions['masks'][keep].cpu().numpy()
# Delete original predictions to free memory
del predictions
# Filter by confidence threshold
mask = scores_np > confidence_threshold
filtered_predictions = {
'boxes': boxes_np[mask],
'labels': labels_np[mask],
'scores': scores_np[mask],
'masks': masks_np[mask]
}
# Clean up intermediate arrays
del boxes_np, scores_np, labels_np, masks_np
# Limit number of detections
if len(filtered_predictions['boxes']) > max_detections:
# Keep top detections by confidence
k = min(max_detections, len(filtered_predictions['scores']))
top_indices = np.argpartition(filtered_predictions['scores'], -k)[-k:]
top_indices = top_indices[np.argsort(filtered_predictions['scores'][top_indices])[::-1]]
filtered_predictions = {key: val[top_indices] for key, val in filtered_predictions.items()}
bean_count = len(filtered_predictions['boxes'])
# Create fast PIL-based visualization
if bean_count > 0:
result_image = draw_detection_pil(image, filtered_predictions, bean_count, show_confidence)
else:
result_image = image.copy()
# Create summary text
if bean_count > 0:
avg_confidence = np.mean(filtered_predictions['scores'])
summary = f"**Detected {bean_count} coffee beans** with {avg_confidence:.1%} average confidence"
else:
summary = "**No beans detected.** Try lowering the confidence threshold or check image quality."
# Create JSON output for download
json_output = create_json_output(filtered_predictions, bean_count)
# Clean up memory
del filtered_predictions
del image_tensor
gc.collect()
# Clear any PyTorch cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
return result_image, summary, json_output
# Example images
examples = [
["examples/green_beans.png", 0.5, 0.5, 300, True],
["examples/roasted_beans.png", 0.5, 0.3, 300, True],
]
# Create Gradio interface
with gr.Blocks(title="Coffee Bean Detection", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# ☕ Coffee Bean Detection with Mask R-CNN
Upload an image of coffee beans to detect and segment individual beans using a fine-tuned Mask R-CNN model.
**Default: Up to 300 beans detected**
""")
with gr.Row():
with gr.Column(scale=1):
# Input controls
input_image = gr.Image(
type="pil",
label="Upload Coffee Bean Image",
height=400
)
with gr.Accordion("Advanced Settings", open=False):
confidence_threshold = gr.Slider(
minimum=0.1,
maximum=0.9,
value=0.5,
step=0.05,
label="Confidence Threshold",
info="Higher values = fewer but more confident detections"
)
nms_threshold = gr.Slider(
minimum=0.1,
maximum=0.8,
value=0.5,
step=0.05,
label="NMS Threshold",
info="Lower values = less overlap between detections"
)
max_detections = gr.Slider(
minimum=10,
maximum=300,
value=300,
step=10,
label="Maximum Detections",
info="Limit total number of detections shown"
)
show_confidence = gr.Checkbox(
value=True,
label="Show Confidence Scores",
info="Show confidence scores instead of bean numbers"
)
detect_btn = gr.Button("🔍 Detect Beans", variant="primary", size="lg")
with gr.Column(scale=1):
# Output
output_image = gr.Image(label="Detection Results", height=400)
results_text = gr.Markdown()
json_download = gr.JSON(label="📥 Detection Data (Copy or Download)", visible=True)
# Event handlers
detect_btn.click(
fn=predict_beans,
inputs=[input_image, confidence_threshold, nms_threshold, max_detections, show_confidence],
outputs=[output_image, results_text, json_download]
)
# Auto-detect when image is uploaded
input_image.change(
fn=predict_beans,
inputs=[input_image, confidence_threshold, nms_threshold, max_detections, show_confidence],
outputs=[output_image, results_text, json_download]
)
# Examples section
gr.Markdown("## 📸 Try These Examples")
gr.Examples(
examples=examples,
inputs=[input_image, confidence_threshold, nms_threshold, max_detections, show_confidence],
outputs=[output_image, results_text, json_download],
fn=predict_beans,
cache_examples=True
)
# Footer
gr.Markdown("""
---
**Model Details:**
- Architecture: Mask R-CNN with ResNet-50 FPN backbone
- Framework: PyTorch/TorchVision
- Fine-tuned on 128 coffee bean images
- Model size: 176MB (SafeTensors format)
**Links:**
- 🤗 [Model on Hugging Face](https://huggingface.co/Kunitomi/coffee-bean-maskrcnn)
Built by [Mark Kunitomi](https://huggingface.co/Kunitomi)
""")
if __name__ == "__main__":
demo.launch()