abdrabo01 commited on
Commit
1ab2f3b
·
verified ·
1 Parent(s): 952a581

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -96
app.py CHANGED
@@ -1,117 +1,70 @@
1
  import os
 
 
2
  import cv2
3
  import gradio as gr
4
  import torch
5
  from ultralytics import YOLO
6
- from sahi import AutoDetectionModel
7
- from sahi.predict import get_sliced_prediction
8
  import random
9
  import numpy as np
10
-
11
  random.seed(42)
12
  np.random.seed(42)
13
 
14
- # Configuration
15
- MODEL_PATH = os.getenv("MODEL_PATH", "last.pt")
16
- ZERO_DEFECTS_PATH = "zero_defects.png"
17
- VALID_EXTENSIONS = [".jpg", ".jpeg", ".png"]
18
-
19
- # Load YOLO model
20
  try:
21
- if not os.path.exists(MODEL_PATH):
22
- raise FileNotFoundError(f"Model file {MODEL_PATH} not found.")
23
- yolo_model = YOLO(MODEL_PATH)
24
  print("YOLO model loaded successfully.")
 
 
 
25
  except Exception as e:
26
- print(f"Error loading YOLO model: {e}")
27
- yolo_model = None
28
 
29
- # Load SAHI model
30
- try:
31
- device = "cuda" if torch.cuda.is_available() else "cpu"
32
- sahi_model = AutoDetectionModel.from_pretrained(
33
- model_type="ultralytics",
34
- model_path=MODEL_PATH,
35
- confidence_threshold=0.5,
36
- device=device,
37
- )
38
- print("SAHI model loaded successfully.")
39
- except Exception as e:
40
- print(f"Error loading SAHI model: {e}")
41
- sahi_model = None
 
 
 
 
 
 
 
 
 
 
42
 
43
- def predict_and_show_bounding_boxes(image_path, model_choice, conf_threshold=0.5):
44
- # Validate image path
45
- if not image_path or not any(image_path.lower().endswith(ext) for ext in VALID_EXTENSIONS):
46
- return None, "Error: Invalid or unsupported image format."
 
 
 
 
47
 
48
- if model_choice == "YOLO":
49
- if yolo_model is None:
50
- return None, "Error: YOLO model not loaded."
51
- try:
52
- img = cv2.imread(image_path)
53
- if img is None:
54
- return None, "Error: Could not load image."
55
- results = yolo_model(image_path, conf=conf_threshold)[0]
56
- boxes = results.boxes
57
- if len(boxes) == 0:
58
- if os.path.exists(ZERO_DEFECTS_PATH):
59
- return cv2.imread(ZERO_DEFECTS_PATH), "No defects detected."
60
- return img, "No defects detected (zero_defects.png not found)."
61
- for box in boxes:
62
- xyxy = box.xyxy[0].tolist()
63
- x_min, y_min, x_max, y_max = map(int, xyxy[:4])
64
- conf = box.conf[0].item()
65
- cls = int(box.cls[0])
66
- cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
67
- label = f"{results.names[cls]}: {conf:.2f}"
68
- cv2.putText(img, label, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
69
- return cv2.cvtColor(img, cv2.COLOR_BGR2RGB), "Detection complete."
70
- except Exception as e:
71
- return None, f"Error during YOLO prediction: {e}"
72
- elif model_choice == "SAHI":
73
- if sahi_model is None:
74
- return None, "Error: SAHI model not loaded."
75
- try:
76
- img = cv2.imread(image_path)
77
- if img is None:
78
- return None, "Error: Could not load image."
79
- result = get_sliced_prediction(
80
- image_path,
81
- sahi_model,
82
- slice_height=256,
83
- slice_width=256,
84
- overlap_height_ratio=0.2,
85
- overlap_width_ratio=0.2,
86
- )
87
- if len(result.object_prediction_list) == 0:
88
- if os.path.exists(ZERO_DEFECTS_PATH):
89
- return cv2.imread(ZERO_DEFECTS_PATH), "No defects detected."
90
- return img, "No defects detected (zero_defects.png not found)."
91
- for pred in result.object_prediction_list:
92
- box = pred.bbox.to_xyxy()
93
- x_min, y_min, x_max, y_max = map(int, box)
94
- label = f"{pred.category.name}: {pred.score.value:.2f}"
95
- cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (255, 0, 0), 2)
96
- cv2.putText(img, label, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 0, 0), 2)
97
- return cv2.cvtColor(img, cv2.COLOR_BGR2RGB), "Detection complete."
98
- except Exception as e:
99
- return None, f"Error during SAHI prediction: {e}"
100
- return None, "Invalid model choice."
101
 
102
- # Gradio interface
103
  iface = gr.Interface(
104
  fn=predict_and_show_bounding_boxes,
105
- inputs=[
106
- gr.Image(type="filepath", label="Upload Image"),
107
- gr.Radio(choices=["YOLO", "SAHI"], label="Choose Detection Mode", value="YOLO"),
108
- gr.Slider(minimum=0.1, maximum=0.9, value=0.5, label="Confidence Threshold"),
109
- ],
110
- outputs=[gr.Image(label="Result"), gr.Textbox(label="Message")],
111
- title="PCB Defect Detection",
112
- description="Upload a PCB image and choose YOLO (green boxes) or SAHI (red boxes) for defect detection. Adjust confidence threshold for sensitivity.",
113
  )
114
 
115
- if __name__ == "__main__":
116
- share = os.getenv("HF_SHARE", "False").lower() == "true"
117
- iface.launch(share=share)
 
1
  import os
2
+ os.system('pip install --upgrade gradio')
3
+ # os.system('')
4
  import cv2
5
  import gradio as gr
6
  import torch
7
  from ultralytics import YOLO
 
 
8
  import random
9
  import numpy as np
 
10
  random.seed(42)
11
  np.random.seed(42)
12
 
 
 
 
 
 
 
13
  try:
14
+ model = YOLO("last.pt")
 
 
15
  print("YOLO model loaded successfully.")
16
+ except FileNotFoundError:
17
+ print("Error: 'yolo_modeln11_1502.pt' not found.")
18
+ model = None
19
  except Exception as e:
20
+ print(f"An error occurred while loading the model: {e}")
21
+ model = None
22
 
23
+ # Function to predict and show bounding boxes
24
+ def predict_and_show_bounding_boxes(image_path):
25
+ if model is None:
26
+ return None, "Error: Model not loaded."
27
+
28
+ try:
29
+ # Load the image using cv2
30
+ img = cv2.imread(image_path)
31
+ if img is None:
32
+ print(f"Error: Could not load image at {image_path}")
33
+ return None, "Error: Could not load image"
34
+
35
+ # Perform inference using the YOLO model
36
+ results = model(image_path,conf=0.5)[0]
37
+ boxes = results.boxes
38
+
39
+ if len(boxes) == 0:
40
+ # No defects found, show the zero defects image
41
+ zero_defects_img = cv2.imread('zero_defects.png')
42
+ if zero_defects_img is not None:
43
+ return zero_defects_img
44
+ else:
45
+ return None, "Error: Could not load zero defects image"
46
 
47
+ for box in boxes:
48
+ xyxy = box.xyxy[0].tolist()
49
+ x_min, y_min, x_max, y_max = int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3])
50
+ conf = box.conf[0].item()
51
+ cls = int(box.cls[0])
52
+ cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
53
+ label = f"{results.names[cls]}: {conf:.2f}"
54
+ cv2.putText(img, label, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
55
 
56
+ return img
57
+ except Exception as e:
58
+ print(f"An error occurred during prediction: {e}")
59
+ return None, str(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ # Create Gradio interface
62
  iface = gr.Interface(
63
  fn=predict_and_show_bounding_boxes,
64
+ inputs=gr.Image(type="filepath"),
65
+ outputs=[gr.Image()],
66
+ title="Defect Detection",
67
+ description="Upload an image to detect defects"
 
 
 
 
68
  )
69
 
70
+ iface.launch(share=True)