import torch from torchvision import transforms from PIL import Image, ImageDraw, ImageEnhance import requests from torchvision.models.detection import maskrcnn_resnet50_fpn import random # Load the Mask R-CNN model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = maskrcnn_resnet50_fpn(pretrained=True).to(device).eval() # Function to preprocess the image def preprocess_image(image_path): # Open and convert to RGB image = Image.open(image_path).convert("RGB") transform = transforms.Compose([ # Convert image to a tensor transforms.ToTensor(), ]) # Add batch dimension and send to device return transform(image).unsqueeze(0).to(device), image # Run object detection def detect_objects(image_path, threshold=0.5): image_tensor, image_pil = preprocess_image(image_path) with torch.no_grad(): outputs = model(image_tensor)[0] # Get model output # Extract data from model output masks = outputs["masks"] # Object masks labels = outputs["labels"] # Object labels scores = outputs["scores"] # Confidence scores filtered_masks = [] for i in range(len(masks)): # Only keep objects with high confidence if scores[i] >= threshold: # Convert to binary mask mask = masks[i, 0].mul(255).byte().cpu().numpy() filtered_masks.append((mask, labels[i].item(), scores[i].item())) return filtered_masks, image_pil # Apply color masks to detected objects def apply_instance_masks(image_path): masks, image = detect_objects(image_path) # Convert to RGBA to support transparency img = image.convert("RGBA") # Create a transparent layer overlay = Image.new("RGBA", img.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(overlay) # Store unique colors for each object category color_map = {} for mask, label, score in masks: if label not in color_map: # Assign a random color for this object category color_map[label] = (random.randint(50, 50), random.randint(225, 255), random.randint(50, 50), 150) mask_pil = Image.fromarray(mask, mode="L") # Convert mask to grayscale image colored_mask = Image.new("RGBA", mask_pil.size, color_map[label]) # Create a color mask overlay.paste(colored_mask, (0, 0), mask_pil) # Apply mask to the overlay # Combine the original image with the overlay result_image = Image.alpha_composite(img, overlay) return result_image.convert("RGB") # Convert back to RGB mode import gradio as gr with gr.Blocks() as demo: gr.Markdown("## Object Detection with Mask R-CNN") gr.Markdown("This demo applies instance segmentation to an image using Mask R-CNN.") with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input Image", type="filepath") threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Confidence Threshold") detect_button = gr.Button("Detect Objects") with gr.Column(): output_image = gr.Image(label="Output Image with Masks") detect_button.click( fn=lambda img_path, thresh: apply_instance_masks(img_path) if img_path else None, inputs=[input_image, threshold], outputs=output_image ) demo.launch()