File size: 4,434 Bytes
d74b66f
 
 
 
 
 
976162b
d74b66f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
976162b
 
 
 
 
 
 
 
 
 
 
 
 
d74b66f
976162b
 
 
 
d74b66f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
976162b
 
 
 
 
d74b66f
 
 
 
976162b
d74b66f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import torchvision
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import gradio as gr
import os

# Load pretrained Mask R-CNN model
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
model.eval()

# COCO labels
COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella',
    'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
    'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
    'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork',
    'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
    'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
    'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop',
    'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',
    'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
    'hair drier', 'toothbrush'
]

# Function to load example image or use uploaded image
def load_image(uploaded_image, example_image):
    if uploaded_image is not None:
        return uploaded_image
    elif example_image and example_image != "None":
        example_path = os.path.join("examples", example_image)
        if os.path.exists(example_path):
            return Image.open(example_path).convert("RGB")
        else:
            raise FileNotFoundError(f"Example image {example_path} not found.")
    else:
        raise ValueError("Please upload an image or select an example image.")

# Detection and segmentation function
def segment_objects(uploaded_image, example_image, threshold=0.5):
    # Load the image (either uploaded or example)
    image = load_image(uploaded_image, example_image)
    
    transform = torchvision.transforms.ToTensor()
    img_tensor = transform(image).unsqueeze(0)

    with torch.no_grad():
        output = model(img_tensor)[0]

    masks = output['masks']  # shape: [N, 1, H, W]
    boxes = output['boxes']
    labels = output['labels']
    scores = output['scores']

    image_np = np.array(image).copy()
    fig, ax = plt.subplots(1, figsize=(10, 10))
    ax.imshow(image_np)

    for i in range(len(masks)):
        if scores[i] >= threshold:
            mask = masks[i, 0].cpu().numpy()
            mask = mask > 0.5  # convert to binary mask

            # Random color for each mask
            color = np.random.rand(3)
            colored_mask = np.zeros_like(image_np, dtype=np.uint8)
            for c in range(3):
                colored_mask[:, :, c] = mask * int(color[c] * 255)

            # Blend the mask onto the image
            image_np = np.where(mask[:, :, None], 0.5 * image_np + 0.5 * colored_mask, image_np).astype(np.uint8)

            # Draw bounding box
            x1, y1, x2, y2 = boxes[i].cpu().numpy()
            ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1,
                                       fill=False, color=color, linewidth=2))
            label = COCO_INSTANCE_CATEGORY_NAMES[labels[i].item()]
            ax.text(x1, y1, f"{label}: {scores[i]:.2f}",
                    bbox=dict(facecolor='yellow', alpha=0.5), fontsize=10)

    ax.imshow(image_np)
    ax.axis('off')
    output_path = "output_maskrcnn_with_masks.png"
    plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
    plt.close()
    return output_path

# Gradio interface
interface = gr.Interface(
    fn=segment_objects,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Dropdown(
            choices=["None", "example1.jpg", "example2.jpg", "example3.jpg", "example4.jpg"],
            value="None",
            label="Select Example Image"
        ),
        gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Confidence Threshold")
    ],
    outputs=gr.Image(type="filepath", label="Segmented Output"),
    title="Mask R-CNN Instance Segmentation",
    description="Upload an image or select an example image to detect and segment objects using a pretrained Mask R-CNN model (TorchVision)."
)

if __name__ == "__main__":
    interface.launch(debug=True)