Mask_R_CNN / app.py
KRayRay's picture
Create app.py
dba911c verified
import torch
from torchvision import transforms
from PIL import Image, ImageDraw, ImageEnhance
import requests
from torchvision.models.detection import maskrcnn_resnet50_fpn
import random
# Load the Mask R-CNN model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = maskrcnn_resnet50_fpn(pretrained=True).to(device).eval()
# Function to preprocess the image
def preprocess_image(image_path):
# Open and convert to RGB
image = Image.open(image_path).convert("RGB")
transform = transforms.Compose([
# Convert image to a tensor
transforms.ToTensor(),
])
# Add batch dimension and send to device
return transform(image).unsqueeze(0).to(device), image
# Run object detection
def detect_objects(image_path, threshold=0.5):
image_tensor, image_pil = preprocess_image(image_path)
with torch.no_grad():
outputs = model(image_tensor)[0] # Get model output
# Extract data from model output
masks = outputs["masks"] # Object masks
labels = outputs["labels"] # Object labels
scores = outputs["scores"] # Confidence scores
filtered_masks = []
for i in range(len(masks)):
# Only keep objects with high confidence
if scores[i] >= threshold:
# Convert to binary mask
mask = masks[i, 0].mul(255).byte().cpu().numpy()
filtered_masks.append((mask, labels[i].item(), scores[i].item()))
return filtered_masks, image_pil
# Apply color masks to detected objects
def apply_instance_masks(image_path):
masks, image = detect_objects(image_path)
# Convert to RGBA to support transparency
img = image.convert("RGBA")
# Create a transparent layer
overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
draw = ImageDraw.Draw(overlay)
# Store unique colors for each object category
color_map = {}
for mask, label, score in masks:
if label not in color_map:
# Assign a random color for this object category
color_map[label] = (random.randint(50, 50), random.randint(225, 255), random.randint(50, 50), 150)
mask_pil = Image.fromarray(mask, mode="L") # Convert mask to grayscale image
colored_mask = Image.new("RGBA", mask_pil.size, color_map[label]) # Create a color mask
overlay.paste(colored_mask, (0, 0), mask_pil) # Apply mask to the overlay
# Combine the original image with the overlay
result_image = Image.alpha_composite(img, overlay)
return result_image.convert("RGB") # Convert back to RGB mode
import gradio as gr
with gr.Blocks() as demo:
gr.Markdown("## Object Detection with Mask R-CNN")
gr.Markdown("This demo applies instance segmentation to an image using Mask R-CNN.")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="filepath")
threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Confidence Threshold")
detect_button = gr.Button("Detect Objects")
with gr.Column():
output_image = gr.Image(label="Output Image with Masks")
detect_button.click(
fn=lambda img_path, thresh: apply_instance_masks(img_path) if img_path else None,
inputs=[input_image, threshold],
outputs=output_image
)
demo.launch()