import os import gradio as gr import cv2 import numpy as np import torch import torchvision.transforms as transforms from PIL import Image, ImageDraw from ultralytics import YOLO import matplotlib.pyplot as plt from gradcam.py import extract_gradcam # Load the trained YOLO model model_path = "best.pt" model = YOLO(model_path) # Assuming CUDA is available for GPU acceleration # Function to process the image and perform YOLO object detection with GradCAM visualization def process_image(image): try: # Convert image to RGB image = image.convert("RGB") # Perform YOLO object detection results = model(np.array(image)) # Draw bounding boxes and labels on the image img_draw = image.copy() draw = ImageDraw.Draw(img_draw) for result in results: for box in result.boxes: label = result.names[box.cls] confidence = box.conf draw.rectangle(box.xyxy[0], outline="red", width=2) draw.text((box.xyxy[0][0], box.xyxy[0][1] - 10), f"{label} {confidence:.2f}", fill="red") # Perform GradCAM visualization gradcam_img = extract_gradcam(image) # Create a figure to display the GradCAM image with a color bar fig, ax = plt.subplots() cax = ax.imshow(gradcam_img, cmap='jet') fig.colorbar(cax) # Save the figure to a BytesIO object from io import BytesIO buf = BytesIO() plt.savefig(buf, format='png') plt.close(fig) buf.seek(0) gradcam_img = Image.open(buf) return img_draw, gradcam_img except Exception as e: print(f"Error processing image: {e}") return Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8)), Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8)) # Define the Gradio interface function def upload_image(image): img_draw, gradcam_img = process_image(image) return img_draw, gradcam_img # Configure the Gradio interface iface = gr.Interface( fn=upload_image, inputs=gr.Image(type="pil"), outputs=[gr.Image(type="pil"), gr.Image(type="pil")], title="YOLO Object Detection with GradCAM Visualization", description="Upload an image to detect objects and visualize with GradCAM.", allow_flagging="never" # Disable the NSFW filter ) # Launch the Gradio interface iface.launch()