abdrabo01 commited on
Commit
c742416
·
verified ·
1 Parent(s): f827bf3
Files changed (1) hide show
  1. app.py +90 -51
app.py CHANGED
@@ -1,69 +1,108 @@
1
  import os
2
- # os.system('')
3
  import cv2
4
  import gradio as gr
5
  import torch
6
  from ultralytics import YOLO
7
- import random
8
- import numpy as np
9
- random.seed(42)
10
- np.random.seed(42)
11
 
 
 
 
 
 
 
12
  try:
13
- model = YOLO("last.pt")
14
  print("YOLO model loaded successfully.")
15
- except FileNotFoundError:
16
- print("Error: 'yolo_modeln11_1502.pt' not found.")
17
- model = None
18
  except Exception as e:
19
- print(f"An error occurred while loading the model: {e}")
20
- model = None
21
-
22
- # Function to predict and show bounding boxes
23
- def predict_and_show_bounding_boxes(image_path):
24
- if model is None:
25
- return None, "Error: Model not loaded."
26
-
27
- try:
28
- # Load the image using cv2
29
- img = cv2.imread(image_path)
30
- if img is None:
31
- print(f"Error: Could not load image at {image_path}")
32
- return None, "Error: Could not load image"
33
 
34
- # Perform inference using the YOLO model
35
- results = model(image_path,conf=0.5)[0]
36
- boxes = results.boxes
 
 
 
 
 
 
 
 
37
 
38
- if len(boxes) == 0:
39
- # No defects found, show the zero defects image
40
- zero_defects_img = cv2.imread('zero_defects.png')
41
- if zero_defects_img is not None:
42
- return zero_defects_img
43
- else:
44
- return None, "Error: Could not load zero defects image"
45
 
46
- for box in boxes:
47
- xyxy = box.xyxy[0].tolist()
48
- x_min, y_min, x_max, y_max = int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3])
49
- conf = box.conf[0].item()
50
- cls = int(box.cls[0])
51
- cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
52
- label = f"{results.names[cls]}: {conf:.2f}"
53
- cv2.putText(img, label, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
54
 
55
- return img
56
- except Exception as e:
57
- print(f"An error occurred during prediction: {e}")
58
- return None, str(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- # Create Gradio interface
61
  iface = gr.Interface(
62
  fn=predict_and_show_bounding_boxes,
63
- inputs=gr.Image(type="filepath"),
64
- outputs=[gr.Image()],
65
- title="Defect Detection",
66
- description="Upload an image to detect defects"
 
 
 
 
67
  )
68
 
69
- iface.launch(share=True)
 
 
 
1
  import os
 
2
  import cv2
3
  import gradio as gr
4
  import torch
5
  from ultralytics import YOLO
6
+ from sahi import AutoDetectionModel
7
+ from sahi.predict import get_sliced_prediction
8
+ import time
 
9
 
10
+ # Configuration
11
+ MODEL_PATH = os.getenv("MODEL_PATH", "last.pt")
12
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
+ VALID_EXTENSIONS = [".jpg", ".jpeg", ".png"]
14
+
15
+ # Load models
16
  try:
17
+ yolo_model = YOLO(MODEL_PATH).to(DEVICE)
18
  print("YOLO model loaded successfully.")
 
 
 
19
  except Exception as e:
20
+ print(f"Error loading YOLO model: {e}")
21
+ yolo_model = None
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ try:
24
+ sahi_model = AutoDetectionModel.from_pretrained(
25
+ model_type="ultralytics",
26
+ model_path=MODEL_PATH,
27
+ confidence_threshold=0.5,
28
+ device=DEVICE,
29
+ )
30
+ print("SAHI model loaded successfully.")
31
+ except Exception as e:
32
+ print(f"Error loading SAHI model: {e}")
33
+ sahi_model = None
34
 
35
+ def predict_and_show_bounding_boxes(image_path, model_choice, conf_threshold=0.5):
36
+ if not image_path or not any(image_path.lower().endswith(ext) for ext in VALID_EXTENSIONS):
37
+ return None, "Error: Invalid or unsupported image format."
 
 
 
 
38
 
39
+ # Resize image for faster processing
40
+ img = cv2.imread(image_path)
41
+ if img is None:
42
+ return None, "Error: Could not load image."
43
+ img = cv2.resize(img, (640, 640)) # Resize to 640x640
 
 
 
44
 
45
+ if model_choice == "YOLO":
46
+ if yolo_model is None:
47
+ return None, "Error: YOLO model not loaded."
48
+ try:
49
+ start_time = time.time()
50
+ results = yolo_model(img, conf=conf_threshold)[0]
51
+ print(f"YOLO inference time: {time.time() - start_time:.2f} seconds")
52
+ boxes = results.boxes
53
+ if len(boxes) == 0:
54
+ return cv2.cvtColor(img, cv2.COLOR_BGR2RGB), "No defects detected."
55
+ for box in boxes:
56
+ xyxy = box.xyxy[0].tolist()
57
+ x_min, y_min, x_max, y_max = map(int, xyxy[:4])
58
+ conf = box.conf[0].item()
59
+ cls = int(box.cls[0])
60
+ cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
61
+ label = f"{results.names[cls]}: {conf:.2f}"
62
+ cv2.putText(img, label, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
63
+ return cv2.cvtColor(img, cv2.COLOR_BGR2RGB), "Detection complete."
64
+ except Exception as e:
65
+ return None, f"Error during YOLO prediction: {e}"
66
+ elif model_choice == "SAHI":
67
+ if sahi_model is None:
68
+ return None, "Error: SAHI model not loaded."
69
+ try:
70
+ start_time = time.time()
71
+ result = get_sliced_prediction(
72
+ img, # Use resized image
73
+ sahi_model,
74
+ slice_height=512,
75
+ slice_width=512,
76
+ overlap_height_ratio=0.1,
77
+ overlap_width_ratio=0.1,
78
+ )
79
+ print(f"SAHI inference time: {time.time() - start_time:.2f} seconds")
80
+ if len(result.object_prediction_list) == 0:
81
+ return cv2.cvtColor(img, cv2.COLOR_BGR2RGB), "No defects detected."
82
+ for pred in result.object_prediction_list:
83
+ box = pred.bbox.to_xyxy()
84
+ x_min, y_min, x_max, y_max = map(int, box)
85
+ label = f"{pred.category.name}: {pred.score.value:.2f}"
86
+ cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (255, 0, 0), 2)
87
+ cv2.putText(img, label, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 0, 0), 2)
88
+ return cv2.cvtColor(img, cv2.COLOR_BGR2RGB), "Detection complete."
89
+ except Exception as e:
90
+ return None, f"Error during SAHI prediction: {e}"
91
+ return None, "Invalid model choice."
92
 
93
+ # Gradio interface
94
  iface = gr.Interface(
95
  fn=predict_and_show_bounding_boxes,
96
+ inputs=[
97
+ gr.Image(type="filepath", label="Upload Image"),
98
+ gr.Radio(choices=["YOLO", "SAHI"], label="Choose Detection Mode", value="YOLO"),
99
+ gr.Slider(minimum=0.1, maximum=0.9, value=0.5, label="Confidence Threshold"),
100
+ ],
101
+ outputs=[gr.Image(label="Result", width=640, height=640), gr.Textbox(label="Message")],
102
+ title="PCB Defect Detection",
103
+ description="Upload a PCB image and choose YOLO (green boxes) or SAHI (red boxes) for defect detection. Adjust confidence threshold for sensitivity.",
104
  )
105
 
106
+ if __name__ == "__main__":
107
+ share = os.getenv("HF_SHARE", "False").lower() == "true"
108
+ iface.launch(share=share)