Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from torchvision import transforms | |
| import torchvision.models.detection as detection | |
| import gradio as gr | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| # Load the trained model | |
| model = detection.fasterrcnn_resnet50_fpn(weights='DEFAULT') # Use 'weights' instead of 'pretrained' | |
| num_classes = 91 # COCO has 80 classes + 1 background | |
| in_features = model.roi_heads.box_predictor.cls_score.in_features | |
| model.roi_heads.box_predictor = detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes) | |
| # Load the model weights | |
| try: | |
| model.load_state_dict(torch.load('final_model.pth', map_location=torch.device('cpu'))) | |
| model.eval() | |
| except Exception as e: | |
| print(f"Error loading model weights: {e}") | |
| # Define transformations | |
| transform = transforms.Compose([ | |
| transforms.Resize((600, 600)), | |
| transforms.ToTensor(), | |
| ]) | |
| # COCO class names | |
| COCO_CLASSES = [ | |
| "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", | |
| "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", | |
| "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", | |
| "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", | |
| "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", | |
| "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", | |
| "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", | |
| "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", | |
| "potted plant", "bed", "dining table", "toilet", "TV", "laptop", "mouse", | |
| "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", | |
| "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", | |
| "toothbrush" | |
| ] | |
| # Prediction function | |
| def predict(image): | |
| print("Predict function called") # Debugging line | |
| try: | |
| # Convert to RGB if the image has an alpha channel | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Transform the image | |
| image_tensor = transform(image).unsqueeze(0) # Add batch dimension | |
| with torch.no_grad(): | |
| predictions = model(image_tensor) | |
| # Process predictions | |
| boxes = predictions[0]['boxes'].cpu().numpy() | |
| scores = predictions[0]['scores'].cpu().numpy() | |
| labels = predictions[0]['labels'].cpu().numpy() | |
| # Debugging: Print predictions | |
| print("Boxes:", boxes) | |
| print("Scores:", scores) | |
| print("Labels:", labels) | |
| # Filter out low-confidence predictions | |
| threshold = 0.3 # Lowered threshold | |
| boxes = boxes[scores > threshold] | |
| labels = labels[scores > threshold] | |
| # Debugging: Print filtered predictions | |
| print("Filtered Boxes:", boxes) | |
| print("Filtered Labels:", labels) | |
| # Convert the input image to a NumPy array | |
| image_np = np.array(image) | |
| # Draw boxes and labels on the image | |
| for box, label in zip(boxes, labels): | |
| x1, y1, x2, y2 = box.astype(int) | |
| image_np = cv2.rectangle(image_np, (x1, y1), (x2, y2), (255, 0, 0), 2) | |
| if label < len(COCO_CLASSES): # Ensure label is within bounds | |
| class_name = COCO_CLASSES[label] # Get the class name | |
| image_np = cv2.putText(image_np, class_name, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2) | |
| else: | |
| print(f"Warning: Label {label} is out of bounds for COCO_CLASSES.") | |
| # Ensure the output is in the correct format | |
| return Image.fromarray(image_np.astype(np.uint8)) | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Gradio interface | |
| iface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil"), title="Object Detection with Faster R-CNN") | |
| iface.launch() |