Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| import torch | |
| import torchvision | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import torchvision.transforms.functional as F | |
| from torchvision.models.detection.faster_rcnn import FastRCNNPredictor | |
| from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor | |
| from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks | |
| import kornia.augmentation as K | |
| def get_model_instance_segmentation(num_classes): | |
| """ | |
| Returns a Mask R-CNN model with a modified head for the specified number of classes. | |
| """ | |
| # Load an instance segmentation model pre-trained on COCO | |
| model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT") | |
| # Get the number of input features for the classifier | |
| in_features = model.roi_heads.box_predictor.cls_score.in_features | |
| # Replace the pre-trained head with a new one | |
| model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) | |
| # Get the number of input features for the mask classifier | |
| in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels | |
| hidden_layer = 256 | |
| # Replace the mask predictor with a new one | |
| model.roi_heads.mask_predictor = MaskRCNNPredictor( | |
| in_features_mask, hidden_layer, num_classes | |
| ) | |
| return model | |
| center_x = torch.tensor([-0.3, 0.3]) | |
| center_y = torch.tensor([-0.3, 0.3]) | |
| gamma = torch.tensor([0.9, 1.0]) | |
| # Define fisheye augmentation with given parameters | |
| fisheye_transform = K.RandomFisheye( | |
| center_x=center_x.unsqueeze(1), | |
| center_y=center_y.unsqueeze(1), | |
| gamma=gamma.unsqueeze(1), | |
| p=1.0, | |
| same_on_batch=True, | |
| keepdim=True, | |
| ) | |
| # --- Setup --- | |
| # Check for model file and data directory | |
| if not os.path.exists("maskrcnn_pennfudan.pth"): | |
| raise FileNotFoundError( | |
| "Model file 'maskrcnn_pennfudan.pth' not found. Please place it in the root directory." | |
| ) | |
| image_dir = "data/PennFudanPed/PNGImages" | |
| if not os.path.isdir(image_dir): | |
| raise FileNotFoundError( | |
| f"Image directory '{image_dir}' not found. Please ensure the data is structured correctly." | |
| ) | |
| # Device and model loading | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| # PennFudanPed has 2 classes: background and person | |
| num_classes = 2 | |
| model = get_model_instance_segmentation(num_classes) | |
| model.load_state_dict(torch.load("maskrcnn_pennfudan.pth", map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| # Load image paths | |
| image_files = sorted( | |
| [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(".png")] | |
| ) | |
| def predict_on_image(img): | |
| """ | |
| Runs prediction on a PIL image and returns the image with masks and boxes drawn. | |
| """ | |
| img = img.convert("RGB") | |
| img_tensor = F.to_tensor(img) | |
| # image = image[:3, ...].to(torch.float32) / 255.0 | |
| img_tensor = fisheye_transform(img_tensor.unsqueeze(0)).squeeze(0) | |
| with torch.no_grad(): | |
| prediction = model([img_tensor.to(device)]) | |
| pred = prediction[0] | |
| # Filter predictions by a confidence score | |
| score_threshold = 0.7 | |
| high_conf_indices = pred["scores"] > score_threshold | |
| boxes = pred["boxes"][high_conf_indices] | |
| labels = [f"person: {score:.2f}" for score in pred["scores"][high_conf_indices]] | |
| masks = pred["masks"][high_conf_indices] | |
| # Convert image tensor back to uint8 for drawing functions | |
| img_to_draw = (img_tensor * 255).to(torch.uint8) | |
| # Draw bounding boxes | |
| if len(boxes) > 0: | |
| img_with_boxes = draw_bounding_boxes( | |
| img_to_draw, boxes=boxes, labels=labels, colors="red", width=2 | |
| ) | |
| else: | |
| img_with_boxes = img_to_draw | |
| # Draw segmentation masks | |
| if len(masks) > 0: | |
| masks_bool = masks.squeeze(1) > 0.5 | |
| img_with_masks = draw_segmentation_masks( | |
| img_with_boxes, masks=masks_bool, alpha=0.5, colors="blue" | |
| ) | |
| else: | |
| img_with_masks = img_with_boxes | |
| # Convert tensor to PIL Image for Gradio display | |
| final_image = F.to_pil_image(img_with_masks.cpu()) | |
| return final_image | |
| def predict_and_draw(image_index): | |
| """ | |
| Runs prediction on an image from the dataset and returns the image with masks and boxes drawn. | |
| """ | |
| if not image_files: | |
| return None, "No images found in data/PennFudanPed/PNGImages", 0 | |
| image_index = image_index % len(image_files) | |
| image_path = image_files[image_index] | |
| img = Image.open(image_path) | |
| final_image = predict_on_image(img) | |
| info_text = f"Displaying image {image_index + 1}/{len(image_files)}: {os.path.basename(image_path)}" | |
| return final_image, info_text, image_index | |
| # --- Gradio App --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| "# Mask R-CNN Pedestrian Detection on PennFudanPed with Fish Eye Augmentation" | |
| ) | |
| gr.Markdown("### Browse Dataset Images") | |
| # State to keep track of the current image index | |
| current_index = gr.State(value=-1) | |
| with gr.Row(): | |
| prev_btn = gr.Button("Previous") | |
| next_btn = gr.Button("Next") | |
| random_btn = gr.Button("Random") | |
| output_image = gr.Image(label="Image with Predictions") | |
| info_text = gr.Textbox(label="Image Info") | |
| def next_image(index): | |
| new_index = index + 1 | |
| return predict_and_draw(new_index) | |
| def prev_image(index): | |
| new_index = index - 1 | |
| if new_index < 0: | |
| new_index = len(image_files) - 1 # Wrap around | |
| return predict_and_draw(new_index) | |
| def random_image(): | |
| new_index = random.randint(0, len(image_files) - 1) | |
| return predict_and_draw(new_index) | |
| next_btn.click( | |
| next_image, | |
| inputs=current_index, | |
| outputs=[output_image, info_text, current_index], | |
| ) | |
| prev_btn.click( | |
| prev_image, | |
| inputs=current_index, | |
| outputs=[output_image, info_text, current_index], | |
| ) | |
| random_btn.click( | |
| random_image, inputs=None, outputs=[output_image, info_text, current_index] | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("### Or upload your own image") | |
| input_image = gr.Image(type="pil", label="Upload Image") | |
| upload_btn = gr.Button("Predict on Uploaded Image") | |
| def handle_upload(img): | |
| if img is None: | |
| return None, "Please upload an image.", -1 | |
| result = predict_on_image(img) | |
| return result, "Prediction for uploaded image.", -1 | |
| upload_btn.click( | |
| handle_upload, | |
| inputs=input_image, | |
| outputs=[output_image, info_text, current_index], | |
| ) | |
| # Load the first image on startup | |
| demo.load( | |
| lambda: next_image(-1), | |
| inputs=None, | |
| outputs=[output_image, info_text, current_index], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |