vorkna's picture
Update app.py
074fce9 verified
import os
import random
import torch
import torchvision
import gradio as gr
import numpy as np
from PIL import Image
import torchvision.transforms.functional as F
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
import kornia.augmentation as K
def get_model_instance_segmentation(num_classes):
"""
Returns a Mask R-CNN model with a modified head for the specified number of classes.
"""
# Load an instance segmentation model pre-trained on COCO
model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")
# Get the number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# Replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
# Get the number of input features for the mask classifier
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
# Replace the mask predictor with a new one
model.roi_heads.mask_predictor = MaskRCNNPredictor(
in_features_mask, hidden_layer, num_classes
)
return model
center_x = torch.tensor([-0.3, 0.3])
center_y = torch.tensor([-0.3, 0.3])
gamma = torch.tensor([0.9, 1.0])
# Define fisheye augmentation with given parameters
fisheye_transform = K.RandomFisheye(
center_x=center_x.unsqueeze(1),
center_y=center_y.unsqueeze(1),
gamma=gamma.unsqueeze(1),
p=1.0,
same_on_batch=True,
keepdim=True,
)
# --- Setup ---
# Check for model file and data directory
if not os.path.exists("maskrcnn_pennfudan.pth"):
raise FileNotFoundError(
"Model file 'maskrcnn_pennfudan.pth' not found. Please place it in the root directory."
)
image_dir = "data/PennFudanPed/PNGImages"
if not os.path.isdir(image_dir):
raise FileNotFoundError(
f"Image directory '{image_dir}' not found. Please ensure the data is structured correctly."
)
# Device and model loading
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# PennFudanPed has 2 classes: background and person
num_classes = 2
model = get_model_instance_segmentation(num_classes)
model.load_state_dict(torch.load("maskrcnn_pennfudan.pth", map_location=device))
model.to(device)
model.eval()
# Load image paths
image_files = sorted(
[os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(".png")]
)
def predict_on_image(img):
"""
Runs prediction on a PIL image and returns the image with masks and boxes drawn.
"""
img = img.convert("RGB")
img_tensor = F.to_tensor(img)
# image = image[:3, ...].to(torch.float32) / 255.0
img_tensor = fisheye_transform(img_tensor.unsqueeze(0)).squeeze(0)
with torch.no_grad():
prediction = model([img_tensor.to(device)])
pred = prediction[0]
# Filter predictions by a confidence score
score_threshold = 0.7
high_conf_indices = pred["scores"] > score_threshold
boxes = pred["boxes"][high_conf_indices]
labels = [f"person: {score:.2f}" for score in pred["scores"][high_conf_indices]]
masks = pred["masks"][high_conf_indices]
# Convert image tensor back to uint8 for drawing functions
img_to_draw = (img_tensor * 255).to(torch.uint8)
# Draw bounding boxes
if len(boxes) > 0:
img_with_boxes = draw_bounding_boxes(
img_to_draw, boxes=boxes, labels=labels, colors="red", width=2
)
else:
img_with_boxes = img_to_draw
# Draw segmentation masks
if len(masks) > 0:
masks_bool = masks.squeeze(1) > 0.5
img_with_masks = draw_segmentation_masks(
img_with_boxes, masks=masks_bool, alpha=0.5, colors="blue"
)
else:
img_with_masks = img_with_boxes
# Convert tensor to PIL Image for Gradio display
final_image = F.to_pil_image(img_with_masks.cpu())
return final_image
def predict_and_draw(image_index):
"""
Runs prediction on an image from the dataset and returns the image with masks and boxes drawn.
"""
if not image_files:
return None, "No images found in data/PennFudanPed/PNGImages", 0
image_index = image_index % len(image_files)
image_path = image_files[image_index]
img = Image.open(image_path)
final_image = predict_on_image(img)
info_text = f"Displaying image {image_index + 1}/{len(image_files)}: {os.path.basename(image_path)}"
return final_image, info_text, image_index
# --- Gradio App ---
with gr.Blocks() as demo:
gr.Markdown(
"# Mask R-CNN Pedestrian Detection on PennFudanPed with Fish Eye Augmentation"
)
gr.Markdown("### Browse Dataset Images")
# State to keep track of the current image index
current_index = gr.State(value=-1)
with gr.Row():
prev_btn = gr.Button("Previous")
next_btn = gr.Button("Next")
random_btn = gr.Button("Random")
output_image = gr.Image(label="Image with Predictions")
info_text = gr.Textbox(label="Image Info")
def next_image(index):
new_index = index + 1
return predict_and_draw(new_index)
def prev_image(index):
new_index = index - 1
if new_index < 0:
new_index = len(image_files) - 1 # Wrap around
return predict_and_draw(new_index)
def random_image():
new_index = random.randint(0, len(image_files) - 1)
return predict_and_draw(new_index)
next_btn.click(
next_image,
inputs=current_index,
outputs=[output_image, info_text, current_index],
)
prev_btn.click(
prev_image,
inputs=current_index,
outputs=[output_image, info_text, current_index],
)
random_btn.click(
random_image, inputs=None, outputs=[output_image, info_text, current_index]
)
gr.Markdown("---")
gr.Markdown("### Or upload your own image")
input_image = gr.Image(type="pil", label="Upload Image")
upload_btn = gr.Button("Predict on Uploaded Image")
def handle_upload(img):
if img is None:
return None, "Please upload an image.", -1
result = predict_on_image(img)
return result, "Prediction for uploaded image.", -1
upload_btn.click(
handle_upload,
inputs=input_image,
outputs=[output_image, info_text, current_index],
)
# Load the first image on startup
demo.load(
lambda: next_image(-1),
inputs=None,
outputs=[output_image, info_text, current_index],
)
if __name__ == "__main__":
demo.launch()