File size: 3,374 Bytes
dba911c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
from torchvision import transforms
from PIL import Image, ImageDraw, ImageEnhance
import requests
from torchvision.models.detection import maskrcnn_resnet50_fpn
import random

# Load the Mask R-CNN model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = maskrcnn_resnet50_fpn(pretrained=True).to(device).eval()

# Function to preprocess the image
def preprocess_image(image_path):
    # Open and convert to RGB
    image = Image.open(image_path).convert("RGB")  
    transform = transforms.Compose([
        # Convert image to a tensor
        transforms.ToTensor(),  
    ])
    # Add batch dimension and send to device
    return transform(image).unsqueeze(0).to(device), image  

# Run object detection
def detect_objects(image_path, threshold=0.5):
    image_tensor, image_pil = preprocess_image(image_path)
    with torch.no_grad():
        outputs = model(image_tensor)[0]  # Get model output
    
    # Extract data from model output
    masks = outputs["masks"]  # Object masks
    labels = outputs["labels"]  # Object labels
    scores = outputs["scores"]  # Confidence scores
    filtered_masks = []
    
    for i in range(len(masks)):
        # Only keep objects with high confidence
        if scores[i] >= threshold:  
            # Convert to binary mask
            mask = masks[i, 0].mul(255).byte().cpu().numpy()  
            filtered_masks.append((mask, labels[i].item(), scores[i].item()))
    
    return filtered_masks, image_pil

# Apply color masks to detected objects
def apply_instance_masks(image_path):
    masks, image = detect_objects(image_path)

    # Convert to RGBA to support transparency
    img = image.convert("RGBA")  
    # Create a transparent layer
    overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))  
    draw = ImageDraw.Draw(overlay)

    # Store unique colors for each object category
    color_map = {}  

    for mask, label, score in masks:
        if label not in color_map:
            # Assign a random color for this object category
            color_map[label] = (random.randint(50, 50), random.randint(225, 255), random.randint(50, 50), 150)

        mask_pil = Image.fromarray(mask, mode="L")  # Convert mask to grayscale image
        colored_mask = Image.new("RGBA", mask_pil.size, color_map[label])  # Create a color mask
        overlay.paste(colored_mask, (0, 0), mask_pil)  # Apply mask to the overlay
    
    # Combine the original image with the overlay
    result_image = Image.alpha_composite(img, overlay)

    return result_image.convert("RGB")  # Convert back to RGB mode

import gradio as gr

with gr.Blocks() as demo:
    gr.Markdown("## Object Detection with Mask R-CNN")
    gr.Markdown("This demo applies instance segmentation to an image using Mask R-CNN.")
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Input Image", type="filepath")
            threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Confidence Threshold")
            detect_button = gr.Button("Detect Objects")
        
        with gr.Column():
            output_image = gr.Image(label="Output Image with Masks")

    detect_button.click(
        fn=lambda img_path, thresh: apply_instance_masks(img_path) if img_path else None,
        inputs=[input_image, threshold],
        outputs=output_image
    )

demo.launch()