import torch import gradio as gr import torchvision.transforms as T from PIL import Image, ImageDraw from torchvision.models.detection import maskrcnn_resnet50_fpn from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor import requests import os # Model setup model_path = "mask_rcnn_lego.pth" if not os.path.exists(model_path): response = requests.get(model_url) with open(model_path, "wb") as f: f.write(response.content) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = maskrcnn_resnet50_fpn(weights="DEFAULT") in_features = model.roi_heads.box_predictor.cls_score.in_features model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=2) in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels hidden_layer = 256 model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes=2) model.load_state_dict(torch.load(model_path, map_location=device)) model.to(device) model.eval() # Set up transformations transform = T.Compose([T.ToTensor()]) # Function for image processing and bounding box detection def detect_legos(image): img_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(img_tensor) boxes = outputs[0]["boxes"].cpu().numpy() num_legos_detected = len(boxes) draw = ImageDraw.Draw(image) for box in boxes: x1, y1, x2, y2 = box draw.rectangle([x1, y1, x2, y2], outline="red", width=3) title = f"Detected LEGO pieces: {num_legos_detected}" return image, title # Gradio interface function def gradio_interface(image): image_with_boxes, title = detect_legos(image) return image_with_boxes, title # Set up Gradio Interface interface = gr.Interface( fn=gradio_interface, inputs=gr.Image(type="pil"), outputs=[gr.Image(type="pil"), gr.Textbox(label="Detection Summary")], title="LEGO Detection with Mask R-CNN", description="Upload an image to detect and count LEGO pieces with bounding boxes." ) # Launch interface (no share=True needed for Gradio hosted or Hugging Face) interface.launch()