import os import torch from torchvision import transforms import torchvision.models.detection as detection import gradio as gr from PIL import Image import numpy as np import cv2 # Load the trained model model = detection.fasterrcnn_resnet50_fpn(weights='DEFAULT') # Use 'weights' instead of 'pretrained' num_classes = 91 # COCO has 80 classes + 1 background in_features = model.roi_heads.box_predictor.cls_score.in_features model.roi_heads.box_predictor = detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes) # Load the model weights try: model.load_state_dict(torch.load('final_model.pth', map_location=torch.device('cpu'))) model.eval() except Exception as e: print(f"Error loading model weights: {e}") # Define transformations transform = transforms.Compose([ transforms.Resize((600, 600)), transforms.ToTensor(), ]) # COCO class names COCO_CLASSES = [ "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "TV", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush" ] # Prediction function def predict(image): print("Predict function called") # Debugging line try: # Convert to RGB if the image has an alpha channel if image.mode != 'RGB': image = image.convert('RGB') # Transform the image image_tensor = transform(image).unsqueeze(0) # Add batch dimension with torch.no_grad(): predictions = model(image_tensor) # Process predictions boxes = predictions[0]['boxes'].cpu().numpy() scores = predictions[0]['scores'].cpu().numpy() labels = predictions[0]['labels'].cpu().numpy() # Debugging: Print predictions print("Boxes:", boxes) print("Scores:", scores) print("Labels:", labels) # Filter out low-confidence predictions threshold = 0.3 # Lowered threshold boxes = boxes[scores > threshold] labels = labels[scores > threshold] # Debugging: Print filtered predictions print("Filtered Boxes:", boxes) print("Filtered Labels:", labels) # Convert the input image to a NumPy array image_np = np.array(image) # Draw boxes and labels on the image for box, label in zip(boxes, labels): x1, y1, x2, y2 = box.astype(int) image_np = cv2.rectangle(image_np, (x1, y1), (x2, y2), (255, 0, 0), 2) if label < len(COCO_CLASSES): # Ensure label is within bounds class_name = COCO_CLASSES[label] # Get the class name image_np = cv2.putText(image_np, class_name, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2) else: print(f"Warning: Label {label} is out of bounds for COCO_CLASSES.") # Ensure the output is in the correct format return Image.fromarray(image_np.astype(np.uint8)) except Exception as e: return f"Error: {str(e)}" # Gradio interface iface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil"), title="Object Detection with Faster R-CNN") iface.launch()