File size: 3,895 Bytes
4e2babc
 
 
 
 
 
 
 
 
 
2c72ce6
4e2babc
 
 
 
 
2c72ce6
 
 
 
 
4e2babc
 
 
 
 
 
 
6ceca29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c72ce6
4e2babc
e9e1b64
8e98acf
6ceca29
 
 
 
8e98acf
 
 
 
 
 
 
 
 
 
fed820d
 
 
 
 
8e98acf
fee145c
8e98acf
 
 
fed820d
 
 
 
8e98acf
 
 
6ceca29
8e98acf
 
 
58d79c2
 
 
 
 
8e98acf
 
 
 
 
2c72ce6
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import torch
from torchvision import transforms
import torchvision.models.detection as detection
import gradio as gr
from PIL import Image
import numpy as np
import cv2

# Load the trained model
model = detection.fasterrcnn_resnet50_fpn(weights='DEFAULT')  # Use 'weights' instead of 'pretrained'
num_classes = 91  # COCO has 80 classes + 1 background
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)

# Load the model weights
try:
    model.load_state_dict(torch.load('final_model.pth', map_location=torch.device('cpu')))
    model.eval()
except Exception as e:
    print(f"Error loading model weights: {e}")

# Define transformations
transform = transforms.Compose([
    transforms.Resize((600, 600)),
    transforms.ToTensor(),
])

# COCO class names
COCO_CLASSES = [
    "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
    "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
    "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra",
    "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
    "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
    "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup",
    "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
    "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
    "potted plant", "bed", "dining table", "toilet", "TV", "laptop", "mouse",
    "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink",
    "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier",
    "toothbrush"
]

# Prediction function
def predict(image):
    print("Predict function called")  # Debugging line
    try:
        # Convert to RGB if the image has an alpha channel
        if image.mode != 'RGB':
            image = image.convert('RGB')

        # Transform the image
        image_tensor = transform(image).unsqueeze(0)  # Add batch dimension
        with torch.no_grad():
            predictions = model(image_tensor)

        # Process predictions
        boxes = predictions[0]['boxes'].cpu().numpy()
        scores = predictions[0]['scores'].cpu().numpy()
        labels = predictions[0]['labels'].cpu().numpy()

        # Debugging: Print predictions
        print("Boxes:", boxes)
        print("Scores:", scores)
        print("Labels:", labels)

        # Filter out low-confidence predictions
        threshold = 0.3  # Lowered threshold
        boxes = boxes[scores > threshold]
        labels = labels[scores > threshold]

        # Debugging: Print filtered predictions
        print("Filtered Boxes:", boxes)
        print("Filtered Labels:", labels)

        # Convert the input image to a NumPy array
        image_np = np.array(image)

        # Draw boxes and labels on the image
        for box, label in zip(boxes, labels):
            x1, y1, x2, y2 = box.astype(int)
            image_np = cv2.rectangle(image_np, (x1, y1), (x2, y2), (255, 0, 0), 2)
            if label < len(COCO_CLASSES):  # Ensure label is within bounds
                class_name = COCO_CLASSES[label]  # Get the class name
                image_np = cv2.putText(image_np, class_name, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
            else:
                print(f"Warning: Label {label} is out of bounds for COCO_CLASSES.")

        # Ensure the output is in the correct format
        return Image.fromarray(image_np.astype(np.uint8))

    except Exception as e:
        return f"Error: {str(e)}"

# Gradio interface
iface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil"), title="Object Detection with Faster R-CNN")
iface.launch()