EPECscan / app.py
chk03042010's picture
Update app.py
c5d6c86 verified
import os
import gradio as gr
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image, ImageDraw
from ultralytics import YOLO
import matplotlib.pyplot as plt
from gradcam.py import extract_gradcam
# Load the trained YOLO model
model_path = "best.pt"
model = YOLO(model_path) # Assuming CUDA is available for GPU acceleration
# Function to process the image and perform YOLO object detection with GradCAM visualization
def process_image(image):
try:
# Convert image to RGB
image = image.convert("RGB")
# Perform YOLO object detection
results = model(np.array(image))
# Draw bounding boxes and labels on the image
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
for result in results:
for box in result.boxes:
label = result.names[box.cls]
confidence = box.conf
draw.rectangle(box.xyxy[0], outline="red", width=2)
draw.text((box.xyxy[0][0], box.xyxy[0][1] - 10), f"{label} {confidence:.2f}", fill="red")
# Perform GradCAM visualization
gradcam_img = extract_gradcam(image)
# Create a figure to display the GradCAM image with a color bar
fig, ax = plt.subplots()
cax = ax.imshow(gradcam_img, cmap='jet')
fig.colorbar(cax)
# Save the figure to a BytesIO object
from io import BytesIO
buf = BytesIO()
plt.savefig(buf, format='png')
plt.close(fig)
buf.seek(0)
gradcam_img = Image.open(buf)
return img_draw, gradcam_img
except Exception as e:
print(f"Error processing image: {e}")
return Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8)), Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))
# Define the Gradio interface function
def upload_image(image):
img_draw, gradcam_img = process_image(image)
return img_draw, gradcam_img
# Configure the Gradio interface
iface = gr.Interface(
fn=upload_image,
inputs=gr.Image(type="pil"),
outputs=[gr.Image(type="pil"), gr.Image(type="pil")],
title="YOLO Object Detection with GradCAM Visualization",
description="Upload an image to detect objects and visualize with GradCAM.",
allow_flagging="never" # Disable the NSFW filter
)
# Launch the Gradio interface
iface.launch()