Spaces:
Sleeping
Sleeping
File size: 4,529 Bytes
b82277e c511cb2 5a0e173 c742416 d93354f 1069b12 c742416 d93354f b82277e d93354f c742416 b82277e c742416 1ab2f3b d93354f c742416 1ab2f3b c742416 b82277e d93354f c742416 d93354f 952a581 c742416 d6bec6f c742416 d6bec6f c742416 d6bec6f d93354f c742416 d93354f c742416 d6bec6f c742416 d6bec6f d93354f c742416 1069b12 c742416 c511cb2 b82277e c742416 d93354f c742416 c511cb2 c40bb88 c742416 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | import os
import cv2
import gradio as gr
import torch
from ultralytics import YOLO
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
import random
import numpy as np
random.seed(42)
np.random.seed(42)
# Configuration
MODEL_PATH = os.getenv("MODEL_PATH", "last.pt")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
VALID_EXTENSIONS = [".jpg", ".jpeg", ".png"]
# Load YOLO model
try:
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Model file {MODEL_PATH} not found.")
yolo_model = YOLO(MODEL_PATH).to(DEVICE)
print("YOLO model loaded successfully.")
except Exception as e:
print(f"Error loading YOLO model: {e}")
yolo_model = None
# Load SAHI model
try:
sahi_model = AutoDetectionModel.from_pretrained(
model_type="ultralytics",
model_path=MODEL_PATH,
confidence_threshold=0.5,
device=DEVICE,
)
print("SAHI model loaded successfully.")
except Exception as e:
print(f"Error loading SAHI model: {e}")
sahi_model = None
def predict_and_show_bounding_boxes(image_path, model_choice, conf_threshold=0.5):
if not image_path or not any(image_path.lower().endswith(ext) for ext in VALID_EXTENSIONS):
return None, "Error: Invalid or unsupported image format."
# Read and resize image to maintain aspect ratio while reducing size
img = cv2.imread(image_path)
if img is None:
return None, "Error: Could not load image."
original_height, original_width = img.shape[:2]
img = cv2.resize(img, (640, int(640 * original_height / original_width))) # Maintain aspect ratio
if model_choice == "YOLO":
if yolo_model is None:
return None, "Error: YOLO model not loaded."
try:
results = yolo_model(img, conf=conf_threshold)[0]
boxes = results.boxes
if len(boxes) == 0:
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
for box in boxes:
xyxy = box.xyxy[0].tolist()
x_min, y_min, x_max, y_max = map(int, xyxy[:4])
conf = box.conf[0].item()
cls = int(box.cls[0])
cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (0, 255, 0), 1) # Thinner box (thickness 1)
label = f"{results.names[cls]}: {conf:.2f}"
cv2.putText(img, label, (x_min, y_min - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) # Smaller font (scale 0.5, thickness 1)
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
except Exception as e:
return None, f"Error during YOLO prediction: {e}"
elif model_choice == "SAHI":
if sahi_model is None:
return None, "Error: SAHI model not loaded."
try:
result = get_sliced_prediction(
img,
sahi_model,
slice_height=512,
slice_width=512,
overlap_height_ratio=0.1,
overlap_width_ratio=0.1,
)
if len(result.object_prediction_list) == 0:
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
for pred in result.object_prediction_list:
box = pred.bbox.to_xyxy()
x_min, y_min, x_max, y_max = map(int, box)
label = f"{pred.category.name}: {pred.score.value:.2f}"
cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (255, 0, 0), 1) # Thinner box (thickness 1)
cv2.putText(img, label, (x_min, y_min - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1) # Smaller font (scale 0.5, thickness 1)
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
except Exception as e:
return None, f"Error during SAHI prediction: {e}"
return None, "Invalid model choice."
# Gradio interface
iface = gr.Interface(
fn=predict_and_show_bounding_boxes,
inputs=[
gr.Image(type="filepath", label="Upload Image"),
gr.Radio(choices=["YOLO", "SAHI"], label="Choose Detection Mode", value="YOLO"),
gr.Slider(minimum=0.1, maximum=0.9, value=0.5, label="Confidence Threshold"),
],
outputs=[gr.Image(label="Result", image_mode="keep")],
title="PCB Defect Detection",
description="Upload a PCB image and choose YOLO (green boxes) or SAHI (red boxes) for defect detection. Adjust confidence threshold for sensitivity.",
)
if __name__ == "__main__":
share = os.getenv("HF_SHARE", "False").lower() == "true"
iface.launch(share=share) |