abdrabo01's picture
Update app.py
d6bec6f verified
import os
import cv2
import gradio as gr
import torch
from ultralytics import YOLO
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
import random
import numpy as np
random.seed(42)
np.random.seed(42)
# Configuration
MODEL_PATH = os.getenv("MODEL_PATH", "last.pt")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
VALID_EXTENSIONS = [".jpg", ".jpeg", ".png"]
# Load YOLO model
try:
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Model file {MODEL_PATH} not found.")
yolo_model = YOLO(MODEL_PATH).to(DEVICE)
print("YOLO model loaded successfully.")
except Exception as e:
print(f"Error loading YOLO model: {e}")
yolo_model = None
# Load SAHI model
try:
sahi_model = AutoDetectionModel.from_pretrained(
model_type="ultralytics",
model_path=MODEL_PATH,
confidence_threshold=0.5,
device=DEVICE,
)
print("SAHI model loaded successfully.")
except Exception as e:
print(f"Error loading SAHI model: {e}")
sahi_model = None
def predict_and_show_bounding_boxes(image_path, model_choice, conf_threshold=0.5):
if not image_path or not any(image_path.lower().endswith(ext) for ext in VALID_EXTENSIONS):
return None, "Error: Invalid or unsupported image format."
# Read and resize image to maintain aspect ratio while reducing size
img = cv2.imread(image_path)
if img is None:
return None, "Error: Could not load image."
original_height, original_width = img.shape[:2]
img = cv2.resize(img, (640, int(640 * original_height / original_width))) # Maintain aspect ratio
if model_choice == "YOLO":
if yolo_model is None:
return None, "Error: YOLO model not loaded."
try:
results = yolo_model(img, conf=conf_threshold)[0]
boxes = results.boxes
if len(boxes) == 0:
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
for box in boxes:
xyxy = box.xyxy[0].tolist()
x_min, y_min, x_max, y_max = map(int, xyxy[:4])
conf = box.conf[0].item()
cls = int(box.cls[0])
cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (0, 255, 0), 1) # Thinner box (thickness 1)
label = f"{results.names[cls]}: {conf:.2f}"
cv2.putText(img, label, (x_min, y_min - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) # Smaller font (scale 0.5, thickness 1)
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
except Exception as e:
return None, f"Error during YOLO prediction: {e}"
elif model_choice == "SAHI":
if sahi_model is None:
return None, "Error: SAHI model not loaded."
try:
result = get_sliced_prediction(
img,
sahi_model,
slice_height=512,
slice_width=512,
overlap_height_ratio=0.1,
overlap_width_ratio=0.1,
)
if len(result.object_prediction_list) == 0:
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
for pred in result.object_prediction_list:
box = pred.bbox.to_xyxy()
x_min, y_min, x_max, y_max = map(int, box)
label = f"{pred.category.name}: {pred.score.value:.2f}"
cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (255, 0, 0), 1) # Thinner box (thickness 1)
cv2.putText(img, label, (x_min, y_min - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1) # Smaller font (scale 0.5, thickness 1)
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
except Exception as e:
return None, f"Error during SAHI prediction: {e}"
return None, "Invalid model choice."
# Gradio interface
iface = gr.Interface(
fn=predict_and_show_bounding_boxes,
inputs=[
gr.Image(type="filepath", label="Upload Image"),
gr.Radio(choices=["YOLO", "SAHI"], label="Choose Detection Mode", value="YOLO"),
gr.Slider(minimum=0.1, maximum=0.9, value=0.5, label="Confidence Threshold"),
],
outputs=[gr.Image(label="Result", image_mode="keep")],
title="PCB Defect Detection",
description="Upload a PCB image and choose YOLO (green boxes) or SAHI (red boxes) for defect detection. Adjust confidence threshold for sensitivity.",
)
if __name__ == "__main__":
share = os.getenv("HF_SHARE", "False").lower() == "true"
iface.launch(share=share)