Spaces:
Build error
Build error
| import os | |
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms as transforms | |
| from PIL import Image, ImageDraw | |
| from ultralytics import YOLO | |
| import matplotlib.pyplot as plt | |
| from gradcam.py import extract_gradcam | |
| # Load the trained YOLO model | |
| model_path = "best.pt" | |
| model = YOLO(model_path) # Assuming CUDA is available for GPU acceleration | |
| # Function to process the image and perform YOLO object detection with GradCAM visualization | |
| def process_image(image): | |
| try: | |
| # Convert image to RGB | |
| image = image.convert("RGB") | |
| # Perform YOLO object detection | |
| results = model(np.array(image)) | |
| # Draw bounding boxes and labels on the image | |
| img_draw = image.copy() | |
| draw = ImageDraw.Draw(img_draw) | |
| for result in results: | |
| for box in result.boxes: | |
| label = result.names[box.cls] | |
| confidence = box.conf | |
| draw.rectangle(box.xyxy[0], outline="red", width=2) | |
| draw.text((box.xyxy[0][0], box.xyxy[0][1] - 10), f"{label} {confidence:.2f}", fill="red") | |
| # Perform GradCAM visualization | |
| gradcam_img = extract_gradcam(image) | |
| # Create a figure to display the GradCAM image with a color bar | |
| fig, ax = plt.subplots() | |
| cax = ax.imshow(gradcam_img, cmap='jet') | |
| fig.colorbar(cax) | |
| # Save the figure to a BytesIO object | |
| from io import BytesIO | |
| buf = BytesIO() | |
| plt.savefig(buf, format='png') | |
| plt.close(fig) | |
| buf.seek(0) | |
| gradcam_img = Image.open(buf) | |
| return img_draw, gradcam_img | |
| except Exception as e: | |
| print(f"Error processing image: {e}") | |
| return Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8)), Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8)) | |
| # Define the Gradio interface function | |
| def upload_image(image): | |
| img_draw, gradcam_img = process_image(image) | |
| return img_draw, gradcam_img | |
| # Configure the Gradio interface | |
| iface = gr.Interface( | |
| fn=upload_image, | |
| inputs=gr.Image(type="pil"), | |
| outputs=[gr.Image(type="pil"), gr.Image(type="pil")], | |
| title="YOLO Object Detection with GradCAM Visualization", | |
| description="Upload an image to detect objects and visualize with GradCAM.", | |
| allow_flagging="never" # Disable the NSFW filter | |
| ) | |
| # Launch the Gradio interface | |
| iface.launch() | |