|
|
import torch |
|
|
import gradio as gr |
|
|
import torchvision.transforms as T |
|
|
from PIL import Image, ImageDraw |
|
|
from torchvision.models.detection import maskrcnn_resnet50_fpn |
|
|
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor |
|
|
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor |
|
|
import requests |
|
|
import os |
|
|
|
|
|
|
|
|
model_path = "mask_rcnn_lego.pth" |
|
|
|
|
|
if not os.path.exists(model_path): |
|
|
response = requests.get(model_url) |
|
|
with open(model_path, "wb") as f: |
|
|
f.write(response.content) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = maskrcnn_resnet50_fpn(weights="DEFAULT") |
|
|
in_features = model.roi_heads.box_predictor.cls_score.in_features |
|
|
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=2) |
|
|
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels |
|
|
hidden_layer = 256 |
|
|
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes=2) |
|
|
model.load_state_dict(torch.load(model_path, map_location=device)) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
transform = T.Compose([T.ToTensor()]) |
|
|
|
|
|
|
|
|
def detect_legos(image): |
|
|
img_tensor = transform(image).unsqueeze(0).to(device) |
|
|
with torch.no_grad(): |
|
|
outputs = model(img_tensor) |
|
|
boxes = outputs[0]["boxes"].cpu().numpy() |
|
|
num_legos_detected = len(boxes) |
|
|
draw = ImageDraw.Draw(image) |
|
|
for box in boxes: |
|
|
x1, y1, x2, y2 = box |
|
|
draw.rectangle([x1, y1, x2, y2], outline="red", width=3) |
|
|
title = f"Detected LEGO pieces: {num_legos_detected}" |
|
|
return image, title |
|
|
|
|
|
|
|
|
def gradio_interface(image): |
|
|
image_with_boxes, title = detect_legos(image) |
|
|
return image_with_boxes, title |
|
|
|
|
|
|
|
|
interface = gr.Interface( |
|
|
fn=gradio_interface, |
|
|
inputs=gr.Image(type="pil"), |
|
|
outputs=[gr.Image(type="pil"), gr.Textbox(label="Detection Summary")], |
|
|
title="LEGO Detection with Mask R-CNN", |
|
|
description="Upload an image to detect and count LEGO pieces with bounding boxes." |
|
|
) |
|
|
|
|
|
|
|
|
interface.launch() |
|
|
|
|
|
|