KRayRay commited on
Commit
dba911c
·
verified ·
1 Parent(s): ff6334f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image, ImageDraw, ImageEnhance
4
+ import requests
5
+ from torchvision.models.detection import maskrcnn_resnet50_fpn
6
+ import random
7
+
8
+ # Load the Mask R-CNN model
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ model = maskrcnn_resnet50_fpn(pretrained=True).to(device).eval()
12
+
13
+ # Function to preprocess the image
14
+ def preprocess_image(image_path):
15
+ # Open and convert to RGB
16
+ image = Image.open(image_path).convert("RGB")
17
+ transform = transforms.Compose([
18
+ # Convert image to a tensor
19
+ transforms.ToTensor(),
20
+ ])
21
+ # Add batch dimension and send to device
22
+ return transform(image).unsqueeze(0).to(device), image
23
+
24
+ # Run object detection
25
+ def detect_objects(image_path, threshold=0.5):
26
+ image_tensor, image_pil = preprocess_image(image_path)
27
+ with torch.no_grad():
28
+ outputs = model(image_tensor)[0] # Get model output
29
+
30
+ # Extract data from model output
31
+ masks = outputs["masks"] # Object masks
32
+ labels = outputs["labels"] # Object labels
33
+ scores = outputs["scores"] # Confidence scores
34
+ filtered_masks = []
35
+
36
+ for i in range(len(masks)):
37
+ # Only keep objects with high confidence
38
+ if scores[i] >= threshold:
39
+ # Convert to binary mask
40
+ mask = masks[i, 0].mul(255).byte().cpu().numpy()
41
+ filtered_masks.append((mask, labels[i].item(), scores[i].item()))
42
+
43
+ return filtered_masks, image_pil
44
+
45
+ # Apply color masks to detected objects
46
+ def apply_instance_masks(image_path):
47
+ masks, image = detect_objects(image_path)
48
+
49
+ # Convert to RGBA to support transparency
50
+ img = image.convert("RGBA")
51
+ # Create a transparent layer
52
+ overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
53
+ draw = ImageDraw.Draw(overlay)
54
+
55
+ # Store unique colors for each object category
56
+ color_map = {}
57
+
58
+ for mask, label, score in masks:
59
+ if label not in color_map:
60
+ # Assign a random color for this object category
61
+ color_map[label] = (random.randint(50, 50), random.randint(225, 255), random.randint(50, 50), 150)
62
+
63
+ mask_pil = Image.fromarray(mask, mode="L") # Convert mask to grayscale image
64
+ colored_mask = Image.new("RGBA", mask_pil.size, color_map[label]) # Create a color mask
65
+ overlay.paste(colored_mask, (0, 0), mask_pil) # Apply mask to the overlay
66
+
67
+ # Combine the original image with the overlay
68
+ result_image = Image.alpha_composite(img, overlay)
69
+
70
+ return result_image.convert("RGB") # Convert back to RGB mode
71
+
72
+ import gradio as gr
73
+
74
+ with gr.Blocks() as demo:
75
+ gr.Markdown("## Object Detection with Mask R-CNN")
76
+ gr.Markdown("This demo applies instance segmentation to an image using Mask R-CNN.")
77
+
78
+ with gr.Row():
79
+ with gr.Column():
80
+ input_image = gr.Image(label="Input Image", type="filepath")
81
+ threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Confidence Threshold")
82
+ detect_button = gr.Button("Detect Objects")
83
+
84
+ with gr.Column():
85
+ output_image = gr.Image(label="Output Image with Masks")
86
+
87
+ detect_button.click(
88
+ fn=lambda img_path, thresh: apply_instance_masks(img_path) if img_path else None,
89
+ inputs=[input_image, threshold],
90
+ outputs=output_image
91
+ )
92
+
93
+ demo.launch()