Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| import torchvision.transforms as T | |
| from PIL import Image, ImageDraw, ImageOps | |
| import numpy as np | |
| from torchvision.models.detection import maskrcnn_resnet50_fpn | |
| from torchvision.models.detection.faster_rcnn import FastRCNNPredictor | |
| from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor | |
| import os | |
| # Set up device | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| # Load and configure the Mask R-CNN model with 2 classes | |
| model_path = "mask_rcnn_lego.pth" | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError( | |
| "The model file 'mask_rcnn_lego.pth' was not found in the directory." | |
| ) | |
| model = maskrcnn_resnet50_fpn(weights="DEFAULT") | |
| in_features = model.roi_heads.box_predictor.cls_score.in_features | |
| # Update the box predictor head to match 2 classes (background + LEGO) | |
| model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=2) | |
| # Update the mask predictor head to match 2 classes | |
| in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels | |
| hidden_layer = 256 | |
| model.roi_heads.mask_predictor = MaskRCNNPredictor( | |
| in_features_mask, hidden_layer, num_classes=2 | |
| ) | |
| # Now, load the state_dict for your custom model | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| # Set up transformations | |
| transform = T.Compose([T.ToTensor()]) | |
| # Function to create pseudo-masks based on bounding boxes | |
| def create_pseudo_mask(image, box): | |
| mask = Image.new("L", image.size, 0) # Create a blank mask | |
| draw = ImageDraw.Draw(mask) | |
| draw.rectangle(box, fill=255) # Fill in the bounding box area | |
| return mask | |
| # Function to process image with pseudo-mask visualization and bounding boxes | |
| def detect_legos(image, use_pseudo_masks=True): | |
| # Apply transformations | |
| img_tensor = transform(image).unsqueeze(0).to(device) | |
| # Make predictions with the custom model | |
| with torch.no_grad(): | |
| outputs = model(img_tensor) | |
| # Extract boxes and scores above threshold | |
| boxes = outputs[0]["boxes"].cpu().numpy() | |
| scores = outputs[0]["scores"].cpu().numpy() | |
| thresholded_indices = [i for i, score in enumerate(scores) if score >= 0.5] | |
| boxes = boxes[thresholded_indices] | |
| num_legos_detected = len(boxes) | |
| # Draw pseudo-masks on the image first | |
| image_with_masks = image.copy() | |
| for box in boxes: | |
| x1, y1, x2, y2 = box | |
| # Use pseudo-masks based on bounding boxes | |
| if use_pseudo_masks: | |
| mask_img = create_pseudo_mask(image, [x1, y1, x2, y2]) | |
| mask_img = ImageOps.colorize( | |
| mask_img.convert("L"), black="blue", white="blue" | |
| ).convert("RGBA") | |
| image_with_masks.paste(mask_img, (0, 0), mask_img) | |
| # Draw the bounding boxes on top of the masks for better visibility | |
| draw = ImageDraw.Draw(image_with_masks) | |
| for box in boxes: | |
| x1, y1, x2, y2 = box | |
| draw.rectangle( | |
| [x1, y1, x2, y2], outline="yellow", width=3 | |
| ) # Draw yellow bounding box | |
| # Set title with count of detected LEGO pieces | |
| title = f"Detected LEGO pieces: {num_legos_detected}" | |
| return image_with_masks, title | |
| # Gradio interface function | |
| def gradio_interface(image): | |
| image_with_masks, title = detect_legos(image, use_pseudo_masks=True) | |
| return image_with_masks, title | |
| # Set up Gradio Interface | |
| interface = gr.Interface( | |
| fn=gradio_interface, | |
| inputs=gr.Image(type="pil"), | |
| outputs=[gr.Image(type="pil"), gr.Textbox(label="Detection Summary")], | |
| title="LEGO Detection with Mask R-CNN", | |
| description="Upload an image to detect and count LEGO pieces with bounding boxes and simulated masks.", | |
| ) | |
| # Launch the Gradio app | |
| interface.launch() | |