Girishug's picture
Update app.py
2c72ce6 verified
Raw
History Blame Contribute Delete
3.9 kB
import os
import torch
from torchvision import transforms
import torchvision.models.detection as detection
import gradio as gr
from PIL import Image
import numpy as np
import cv2
# Load the trained model
model = detection.fasterrcnn_resnet50_fpn(weights='DEFAULT') # Use 'weights' instead of 'pretrained'
num_classes = 91 # COCO has 80 classes + 1 background
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
# Load the model weights
try:
model.load_state_dict(torch.load('final_model.pth', map_location=torch.device('cpu')))
model.eval()
except Exception as e:
print(f"Error loading model weights: {e}")
# Define transformations
transform = transforms.Compose([
transforms.Resize((600, 600)),
transforms.ToTensor(),
])
# COCO class names
COCO_CLASSES = [
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
"boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
"bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra",
"giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
"skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
"skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup",
"fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
"broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
"potted plant", "bed", "dining table", "toilet", "TV", "laptop", "mouse",
"remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink",
"refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier",
"toothbrush"
]
# Prediction function
def predict(image):
print("Predict function called") # Debugging line
try:
# Convert to RGB if the image has an alpha channel
if image.mode != 'RGB':
image = image.convert('RGB')
# Transform the image
image_tensor = transform(image).unsqueeze(0) # Add batch dimension
with torch.no_grad():
predictions = model(image_tensor)
# Process predictions
boxes = predictions[0]['boxes'].cpu().numpy()
scores = predictions[0]['scores'].cpu().numpy()
labels = predictions[0]['labels'].cpu().numpy()
# Debugging: Print predictions
print("Boxes:", boxes)
print("Scores:", scores)
print("Labels:", labels)
# Filter out low-confidence predictions
threshold = 0.3 # Lowered threshold
boxes = boxes[scores > threshold]
labels = labels[scores > threshold]
# Debugging: Print filtered predictions
print("Filtered Boxes:", boxes)
print("Filtered Labels:", labels)
# Convert the input image to a NumPy array
image_np = np.array(image)
# Draw boxes and labels on the image
for box, label in zip(boxes, labels):
x1, y1, x2, y2 = box.astype(int)
image_np = cv2.rectangle(image_np, (x1, y1), (x2, y2), (255, 0, 0), 2)
if label < len(COCO_CLASSES): # Ensure label is within bounds
class_name = COCO_CLASSES[label] # Get the class name
image_np = cv2.putText(image_np, class_name, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
else:
print(f"Warning: Label {label} is out of bounds for COCO_CLASSES.")
# Ensure the output is in the correct format
return Image.fromarray(image_np.astype(np.uint8))
except Exception as e:
return f"Error: {str(e)}"
# Gradio interface
iface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil"), title="Object Detection with Faster R-CNN")
iface.launch()