File size: 4,529 Bytes
b82277e
 
c511cb2
5a0e173
 
c742416
 
d93354f
 
 
 
 
1069b12
c742416
 
 
 
 
d93354f
b82277e
d93354f
 
c742416
b82277e
 
c742416
 
1ab2f3b
d93354f
c742416
 
 
 
 
 
 
 
 
 
 
1ab2f3b
c742416
 
 
b82277e
d93354f
c742416
 
 
d93354f
 
952a581
c742416
 
 
 
 
 
 
d6bec6f
c742416
 
 
 
 
d6bec6f
c742416
d6bec6f
d93354f
c742416
 
 
 
 
 
 
d93354f
c742416
 
 
 
 
 
 
d6bec6f
c742416
 
 
 
d6bec6f
 
d93354f
c742416
 
 
1069b12
c742416
c511cb2
b82277e
c742416
 
 
 
 
d93354f
c742416
 
c511cb2
c40bb88
c742416
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import cv2
import gradio as gr
import torch
from ultralytics import YOLO
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
import random
import numpy as np

random.seed(42)
np.random.seed(42)

# Configuration
MODEL_PATH = os.getenv("MODEL_PATH", "last.pt")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
VALID_EXTENSIONS = [".jpg", ".jpeg", ".png"]

# Load YOLO model
try:
    if not os.path.exists(MODEL_PATH):
        raise FileNotFoundError(f"Model file {MODEL_PATH} not found.")
    yolo_model = YOLO(MODEL_PATH).to(DEVICE)
    print("YOLO model loaded successfully.")
except Exception as e:
    print(f"Error loading YOLO model: {e}")
    yolo_model = None

# Load SAHI model
try:
    sahi_model = AutoDetectionModel.from_pretrained(
        model_type="ultralytics",
        model_path=MODEL_PATH,
        confidence_threshold=0.5,
        device=DEVICE,
    )
    print("SAHI model loaded successfully.")
except Exception as e:
    print(f"Error loading SAHI model: {e}")
    sahi_model = None

def predict_and_show_bounding_boxes(image_path, model_choice, conf_threshold=0.5):
    if not image_path or not any(image_path.lower().endswith(ext) for ext in VALID_EXTENSIONS):
        return None, "Error: Invalid or unsupported image format."

    # Read and resize image to maintain aspect ratio while reducing size
    img = cv2.imread(image_path)
    if img is None:
        return None, "Error: Could not load image."
    original_height, original_width = img.shape[:2]
    img = cv2.resize(img, (640, int(640 * original_height / original_width)))  # Maintain aspect ratio

    if model_choice == "YOLO":
        if yolo_model is None:
            return None, "Error: YOLO model not loaded."
        try:
            results = yolo_model(img, conf=conf_threshold)[0]
            boxes = results.boxes
            if len(boxes) == 0:
                return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            for box in boxes:
                xyxy = box.xyxy[0].tolist()
                x_min, y_min, x_max, y_max = map(int, xyxy[:4])
                conf = box.conf[0].item()
                cls = int(box.cls[0])
                cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (0, 255, 0), 1)  # Thinner box (thickness 1)
                label = f"{results.names[cls]}: {conf:.2f}"
                cv2.putText(img, label, (x_min, y_min - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)  # Smaller font (scale 0.5, thickness 1)
            return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        except Exception as e:
            return None, f"Error during YOLO prediction: {e}"
    elif model_choice == "SAHI":
        if sahi_model is None:
            return None, "Error: SAHI model not loaded."
        try:
            result = get_sliced_prediction(
                img,
                sahi_model,
                slice_height=512,
                slice_width=512,
                overlap_height_ratio=0.1,
                overlap_width_ratio=0.1,
            )
            if len(result.object_prediction_list) == 0:
                return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            for pred in result.object_prediction_list:
                box = pred.bbox.to_xyxy()
                x_min, y_min, x_max, y_max = map(int, box)
                label = f"{pred.category.name}: {pred.score.value:.2f}"
                cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (255, 0, 0), 1)  # Thinner box (thickness 1)
                cv2.putText(img, label, (x_min, y_min - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)  # Smaller font (scale 0.5, thickness 1)
            return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        except Exception as e:
            return None, f"Error during SAHI prediction: {e}"
    return None, "Invalid model choice."

# Gradio interface
iface = gr.Interface(
    fn=predict_and_show_bounding_boxes,
    inputs=[
        gr.Image(type="filepath", label="Upload Image"),
        gr.Radio(choices=["YOLO", "SAHI"], label="Choose Detection Mode", value="YOLO"),
        gr.Slider(minimum=0.1, maximum=0.9, value=0.5, label="Confidence Threshold"),
    ],
    outputs=[gr.Image(label="Result", image_mode="keep")],
    title="PCB Defect Detection",
    description="Upload a PCB image and choose YOLO (green boxes) or SAHI (red boxes) for defect detection. Adjust confidence threshold for sensitivity.",
)

if __name__ == "__main__":
    share = os.getenv("HF_SHARE", "False").lower() == "true"
    iface.launch(share=share)