import os import random import torch import torchvision import gradio as gr import numpy as np from PIL import Image import torchvision.transforms.functional as F from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks import kornia.augmentation as K def get_model_instance_segmentation(num_classes): """ Returns a Mask R-CNN model with a modified head for the specified number of classes. """ # Load an instance segmentation model pre-trained on COCO model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT") # Get the number of input features for the classifier in_features = model.roi_heads.box_predictor.cls_score.in_features # Replace the pre-trained head with a new one model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) # Get the number of input features for the mask classifier in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels hidden_layer = 256 # Replace the mask predictor with a new one model.roi_heads.mask_predictor = MaskRCNNPredictor( in_features_mask, hidden_layer, num_classes ) return model center_x = torch.tensor([-0.3, 0.3]) center_y = torch.tensor([-0.3, 0.3]) gamma = torch.tensor([0.9, 1.0]) # Define fisheye augmentation with given parameters fisheye_transform = K.RandomFisheye( center_x=center_x.unsqueeze(1), center_y=center_y.unsqueeze(1), gamma=gamma.unsqueeze(1), p=1.0, same_on_batch=True, keepdim=True, ) # --- Setup --- # Check for model file and data directory if not os.path.exists("maskrcnn_pennfudan.pth"): raise FileNotFoundError( "Model file 'maskrcnn_pennfudan.pth' not found. Please place it in the root directory." ) image_dir = "data/PennFudanPed/PNGImages" if not os.path.isdir(image_dir): raise FileNotFoundError( f"Image directory '{image_dir}' not found. Please ensure the data is structured correctly." ) # Device and model loading device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") # PennFudanPed has 2 classes: background and person num_classes = 2 model = get_model_instance_segmentation(num_classes) model.load_state_dict(torch.load("maskrcnn_pennfudan.pth", map_location=device)) model.to(device) model.eval() # Load image paths image_files = sorted( [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(".png")] ) def predict_on_image(img): """ Runs prediction on a PIL image and returns the image with masks and boxes drawn. """ img = img.convert("RGB") img_tensor = F.to_tensor(img) # image = image[:3, ...].to(torch.float32) / 255.0 img_tensor = fisheye_transform(img_tensor.unsqueeze(0)).squeeze(0) with torch.no_grad(): prediction = model([img_tensor.to(device)]) pred = prediction[0] # Filter predictions by a confidence score score_threshold = 0.7 high_conf_indices = pred["scores"] > score_threshold boxes = pred["boxes"][high_conf_indices] labels = [f"person: {score:.2f}" for score in pred["scores"][high_conf_indices]] masks = pred["masks"][high_conf_indices] # Convert image tensor back to uint8 for drawing functions img_to_draw = (img_tensor * 255).to(torch.uint8) # Draw bounding boxes if len(boxes) > 0: img_with_boxes = draw_bounding_boxes( img_to_draw, boxes=boxes, labels=labels, colors="red", width=2 ) else: img_with_boxes = img_to_draw # Draw segmentation masks if len(masks) > 0: masks_bool = masks.squeeze(1) > 0.5 img_with_masks = draw_segmentation_masks( img_with_boxes, masks=masks_bool, alpha=0.5, colors="blue" ) else: img_with_masks = img_with_boxes # Convert tensor to PIL Image for Gradio display final_image = F.to_pil_image(img_with_masks.cpu()) return final_image def predict_and_draw(image_index): """ Runs prediction on an image from the dataset and returns the image with masks and boxes drawn. """ if not image_files: return None, "No images found in data/PennFudanPed/PNGImages", 0 image_index = image_index % len(image_files) image_path = image_files[image_index] img = Image.open(image_path) final_image = predict_on_image(img) info_text = f"Displaying image {image_index + 1}/{len(image_files)}: {os.path.basename(image_path)}" return final_image, info_text, image_index # --- Gradio App --- with gr.Blocks() as demo: gr.Markdown( "# Mask R-CNN Pedestrian Detection on PennFudanPed with Fish Eye Augmentation" ) gr.Markdown("### Browse Dataset Images") # State to keep track of the current image index current_index = gr.State(value=-1) with gr.Row(): prev_btn = gr.Button("Previous") next_btn = gr.Button("Next") random_btn = gr.Button("Random") output_image = gr.Image(label="Image with Predictions") info_text = gr.Textbox(label="Image Info") def next_image(index): new_index = index + 1 return predict_and_draw(new_index) def prev_image(index): new_index = index - 1 if new_index < 0: new_index = len(image_files) - 1 # Wrap around return predict_and_draw(new_index) def random_image(): new_index = random.randint(0, len(image_files) - 1) return predict_and_draw(new_index) next_btn.click( next_image, inputs=current_index, outputs=[output_image, info_text, current_index], ) prev_btn.click( prev_image, inputs=current_index, outputs=[output_image, info_text, current_index], ) random_btn.click( random_image, inputs=None, outputs=[output_image, info_text, current_index] ) gr.Markdown("---") gr.Markdown("### Or upload your own image") input_image = gr.Image(type="pil", label="Upload Image") upload_btn = gr.Button("Predict on Uploaded Image") def handle_upload(img): if img is None: return None, "Please upload an image.", -1 result = predict_on_image(img) return result, "Prediction for uploaded image.", -1 upload_btn.click( handle_upload, inputs=input_image, outputs=[output_image, info_text, current_index], ) # Load the first image on startup demo.load( lambda: next_image(-1), inputs=None, outputs=[output_image, info_text, current_index], ) if __name__ == "__main__": demo.launch()