Spaces:
Sleeping
Sleeping
| import torch | |
| import torchvision | |
| from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights | |
| from PIL import Image | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import gradio as gr | |
| import os | |
| import sys | |
| # Load the pre-trained model once | |
| model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT) | |
| model.eval() | |
| # COCO class names | |
| COCO_INSTANCE_CATEGORY_NAMES = [ | |
| '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |
| 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', | |
| 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', | |
| 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', | |
| 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', | |
| 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', | |
| 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', | |
| 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', | |
| 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', | |
| 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', | |
| 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', | |
| 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' | |
| ] | |
| # Gradio-compatible detection function | |
| def detect_objects(image, threshold=0.5): | |
| if image is None: | |
| print("Image is None, returning empty output", file=sys.stderr) | |
| # Create a blank image as output | |
| blank_img = Image.new('RGB', (400, 400), color='white') | |
| plt.figure(figsize=(10, 10)) | |
| plt.imshow(blank_img) | |
| plt.text(0.5, 0.5, "No image provided", | |
| horizontalalignment='center', verticalalignment='center', | |
| transform=plt.gca().transAxes, fontsize=20) | |
| plt.axis('off') | |
| output_path = "blank_output.png" | |
| plt.savefig(output_path) | |
| plt.close() | |
| return output_path | |
| try: | |
| print(f"Processing image of type {type(image)} and threshold {threshold}", file=sys.stderr) | |
| # Make sure threshold is a valid number | |
| if threshold is None: | |
| threshold = 0.5 | |
| print("Threshold was None, using default 0.5", file=sys.stderr) | |
| # Convert threshold to float if it's not already | |
| threshold = float(threshold) | |
| transform = FasterRCNN_ResNet50_FPN_Weights.DEFAULT.transforms() | |
| image_tensor = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| prediction = model(image_tensor)[0] | |
| boxes = prediction['boxes'].cpu().numpy() | |
| labels = prediction['labels'].cpu().numpy() | |
| scores = prediction['scores'].cpu().numpy() | |
| image_np = np.array(image) | |
| plt.figure(figsize=(10, 10)) | |
| plt.imshow(image_np) | |
| ax = plt.gca() | |
| for box, label, score in zip(boxes, labels, scores): | |
| # Explicit debug prints to trace the comparison issue | |
| print(f"Score: {score}, Threshold: {threshold}, Type: {type(score)}/{type(threshold)}", file=sys.stderr) | |
| if score >= threshold: | |
| x1, y1, x2, y2 = box | |
| ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, | |
| fill=False, color='red', linewidth=2)) | |
| class_name = COCO_INSTANCE_CATEGORY_NAMES[label] | |
| ax.text(x1, y1, f'{class_name}: {score:.2f}', bbox=dict(facecolor='yellow', alpha=0.5), | |
| fontsize=12, color='black') | |
| plt.axis('off') | |
| plt.tight_layout() | |
| # Save the figure to return | |
| output_path = "output.png" | |
| plt.savefig(output_path) | |
| plt.close() | |
| return output_path | |
| except Exception as e: | |
| print(f"Error in detect_objects: {e}", file=sys.stderr) | |
| import traceback | |
| traceback.print_exc(file=sys.stderr) | |
| # Create an error image | |
| error_img = Image.new('RGB', (400, 400), color='white') | |
| plt.figure(figsize=(10, 10)) | |
| plt.imshow(error_img) | |
| plt.text(0.5, 0.5, f"Error: {str(e)}", | |
| horizontalalignment='center', verticalalignment='center', | |
| transform=plt.gca().transAxes, fontsize=12, wrap=True) | |
| plt.axis('off') | |
| error_path = "error_output.png" | |
| plt.savefig(error_path) | |
| plt.close() | |
| return error_path | |
| # Create direct file paths for examples | |
| # These exact filenames match what's visible in your repository | |
| examples = [ | |
| os.path.join("/home/user/app", "TEST_IMG_1.jpg"), | |
| os.path.join("/home/user/app", "TEST_IMG_2.JPG"), | |
| os.path.join("/home/user/app", "TEST_IMG_3.jpg"), | |
| os.path.join("/home/user/app", "TEST_IMG_4.jpg") | |
| ] | |
| # Create Gradio interface | |
| # Important: For Gradio examples, we need to create a list of lists | |
| example_list = [[path] for path in examples if os.path.exists(path)] | |
| print(f"Found {len(example_list)} valid examples: {example_list}", file=sys.stderr) | |
| # Create Gradio interface with a simplified approach | |
| interface = gr.Interface( | |
| fn=detect_objects, | |
| inputs=[ | |
| gr.Image(type="pil", label="Input Image"), | |
| gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Confidence Threshold") | |
| ], | |
| outputs=gr.Image(type="filepath", label="Detected Objects"), | |
| title="Faster R-CNN Object Detection", | |
| description="Upload an image to detect objects using a pretrained Faster R-CNN model.", | |
| examples=example_list, | |
| cache_examples=False # Disable caching to avoid potential issues | |
| ) | |
| # Launch with specific configuration for Hugging Face | |
| if __name__ == "__main__": | |
| # Launch with debug mode enabled | |
| interface.launch(debug=True) |