abdrabo01 commited on
Commit
952a581
·
verified ·
1 Parent(s): a03779a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -43
app.py CHANGED
@@ -11,48 +11,53 @@ import numpy as np
11
  random.seed(42)
12
  np.random.seed(42)
13
 
14
- # Load default YOLO model
 
 
 
 
 
15
  try:
16
- yolo_model = YOLO("last.pt")
 
 
17
  print("YOLO model loaded successfully.")
18
- except FileNotFoundError:
19
- print("Error: 'last.pt' not found.")
20
- yolo_model = None
21
  except Exception as e:
22
- print(f"An error occurred while loading the YOLO model: {e}")
23
  yolo_model = None
24
 
25
  # Load SAHI model
26
  try:
 
27
  sahi_model = AutoDetectionModel.from_pretrained(
28
  model_type="ultralytics",
29
- model_path="last.pt", # same model used for consistency
30
  confidence_threshold=0.5,
31
- device="cpu",
32
  )
33
  print("SAHI model loaded successfully.")
34
  except Exception as e:
35
- print(f"An error occurred while loading the SAHI model: {e}")
36
  sahi_model = None
37
 
38
- # Prediction function
39
- def predict_and_show_bounding_boxes(image_path, model_choice):
 
 
 
40
  if model_choice == "YOLO":
41
  if yolo_model is None:
42
  return None, "Error: YOLO model not loaded."
43
-
44
  try:
45
  img = cv2.imread(image_path)
46
  if img is None:
47
- return None, "Error: Could not load image"
48
-
49
- results = yolo_model(image_path, conf=0.5)[0]
50
  boxes = results.boxes
51
-
52
  if len(boxes) == 0:
53
- zero_defects_img = cv2.imread('zero_defects.png')
54
- return zero_defects_img if zero_defects_img is not None else (None, "Error: Could not load zero defects image")
55
-
56
  for box in boxes:
57
  xyxy = box.xyxy[0].tolist()
58
  x_min, y_min, x_max, y_max = map(int, xyxy[:4])
@@ -61,16 +66,16 @@ def predict_and_show_bounding_boxes(image_path, model_choice):
61
  cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
62
  label = f"{results.names[cls]}: {conf:.2f}"
63
  cv2.putText(img, label, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
64
-
65
- return img
66
  except Exception as e:
67
- return None, str(e)
68
-
69
  elif model_choice == "SAHI":
70
  if sahi_model is None:
71
  return None, "Error: SAHI model not loaded."
72
-
73
  try:
 
 
 
74
  result = get_sliced_prediction(
75
  image_path,
76
  sahi_model,
@@ -79,29 +84,20 @@ def predict_and_show_bounding_boxes(image_path, model_choice):
79
  overlap_height_ratio=0.2,
80
  overlap_width_ratio=0.2,
81
  )
82
-
83
- img = cv2.imread(image_path)
84
- if img is None:
85
- return None, "Error: Could not load image"
86
-
87
  for pred in result.object_prediction_list:
88
  box = pred.bbox.to_xyxy()
89
  x_min, y_min, x_max, y_max = map(int, box)
90
  label = f"{pred.category.name}: {pred.score.value:.2f}"
91
  cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (255, 0, 0), 2)
92
  cv2.putText(img, label, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 0, 0), 2)
93
-
94
- if len(result.object_prediction_list) == 0:
95
- zero_defects_img = cv2.imread('zero_defects.png')
96
- return zero_defects_img if zero_defects_img is not None else (None, "Error: Could not load zero defects image")
97
-
98
- return img
99
  except Exception as e:
100
- return None, str(e)
101
-
102
- else:
103
- return None, "Invalid model choice."
104
-
105
 
106
  # Gradio interface
107
  iface = gr.Interface(
@@ -109,10 +105,13 @@ iface = gr.Interface(
109
  inputs=[
110
  gr.Image(type="filepath", label="Upload Image"),
111
  gr.Radio(choices=["YOLO", "SAHI"], label="Choose Detection Mode", value="YOLO"),
 
112
  ],
113
  outputs=[gr.Image(label="Result"), gr.Textbox(label="Message")],
114
- title="Defect Detection",
115
- description="Upload an image and choose YOLO or SAHI for defect detection.",
116
  )
117
 
118
- iface.launch(share=True)
 
 
 
11
  random.seed(42)
12
  np.random.seed(42)
13
 
14
+ # Configuration
15
+ MODEL_PATH = os.getenv("MODEL_PATH", "last.pt")
16
+ ZERO_DEFECTS_PATH = "zero_defects.png"
17
+ VALID_EXTENSIONS = [".jpg", ".jpeg", ".png"]
18
+
19
+ # Load YOLO model
20
  try:
21
+ if not os.path.exists(MODEL_PATH):
22
+ raise FileNotFoundError(f"Model file {MODEL_PATH} not found.")
23
+ yolo_model = YOLO(MODEL_PATH)
24
  print("YOLO model loaded successfully.")
 
 
 
25
  except Exception as e:
26
+ print(f"Error loading YOLO model: {e}")
27
  yolo_model = None
28
 
29
  # Load SAHI model
30
  try:
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
  sahi_model = AutoDetectionModel.from_pretrained(
33
  model_type="ultralytics",
34
+ model_path=MODEL_PATH,
35
  confidence_threshold=0.5,
36
+ device=device,
37
  )
38
  print("SAHI model loaded successfully.")
39
  except Exception as e:
40
+ print(f"Error loading SAHI model: {e}")
41
  sahi_model = None
42
 
43
+ def predict_and_show_bounding_boxes(image_path, model_choice, conf_threshold=0.5):
44
+ # Validate image path
45
+ if not image_path or not any(image_path.lower().endswith(ext) for ext in VALID_EXTENSIONS):
46
+ return None, "Error: Invalid or unsupported image format."
47
+
48
  if model_choice == "YOLO":
49
  if yolo_model is None:
50
  return None, "Error: YOLO model not loaded."
 
51
  try:
52
  img = cv2.imread(image_path)
53
  if img is None:
54
+ return None, "Error: Could not load image."
55
+ results = yolo_model(image_path, conf=conf_threshold)[0]
 
56
  boxes = results.boxes
 
57
  if len(boxes) == 0:
58
+ if os.path.exists(ZERO_DEFECTS_PATH):
59
+ return cv2.imread(ZERO_DEFECTS_PATH), "No defects detected."
60
+ return img, "No defects detected (zero_defects.png not found)."
61
  for box in boxes:
62
  xyxy = box.xyxy[0].tolist()
63
  x_min, y_min, x_max, y_max = map(int, xyxy[:4])
 
66
  cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
67
  label = f"{results.names[cls]}: {conf:.2f}"
68
  cv2.putText(img, label, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
69
+ return cv2.cvtColor(img, cv2.COLOR_BGR2RGB), "Detection complete."
 
70
  except Exception as e:
71
+ return None, f"Error during YOLO prediction: {e}"
 
72
  elif model_choice == "SAHI":
73
  if sahi_model is None:
74
  return None, "Error: SAHI model not loaded."
 
75
  try:
76
+ img = cv2.imread(image_path)
77
+ if img is None:
78
+ return None, "Error: Could not load image."
79
  result = get_sliced_prediction(
80
  image_path,
81
  sahi_model,
 
84
  overlap_height_ratio=0.2,
85
  overlap_width_ratio=0.2,
86
  )
87
+ if len(result.object_prediction_list) == 0:
88
+ if os.path.exists(ZERO_DEFECTS_PATH):
89
+ return cv2.imread(ZERO_DEFECTS_PATH), "No defects detected."
90
+ return img, "No defects detected (zero_defects.png not found)."
 
91
  for pred in result.object_prediction_list:
92
  box = pred.bbox.to_xyxy()
93
  x_min, y_min, x_max, y_max = map(int, box)
94
  label = f"{pred.category.name}: {pred.score.value:.2f}"
95
  cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (255, 0, 0), 2)
96
  cv2.putText(img, label, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 0, 0), 2)
97
+ return cv2.cvtColor(img, cv2.COLOR_BGR2RGB), "Detection complete."
 
 
 
 
 
98
  except Exception as e:
99
+ return None, f"Error during SAHI prediction: {e}"
100
+ return None, "Invalid model choice."
 
 
 
101
 
102
  # Gradio interface
103
  iface = gr.Interface(
 
105
  inputs=[
106
  gr.Image(type="filepath", label="Upload Image"),
107
  gr.Radio(choices=["YOLO", "SAHI"], label="Choose Detection Mode", value="YOLO"),
108
+ gr.Slider(minimum=0.1, maximum=0.9, value=0.5, label="Confidence Threshold"),
109
  ],
110
  outputs=[gr.Image(label="Result"), gr.Textbox(label="Message")],
111
+ title="PCB Defect Detection",
112
+ description="Upload a PCB image and choose YOLO (green boxes) or SAHI (red boxes) for defect detection. Adjust confidence threshold for sensitivity.",
113
  )
114
 
115
+ if __name__ == "__main__":
116
+ share = os.getenv("HF_SHARE", "False").lower() == "true"
117
+ iface.launch(share=share)