abdrabo01 commited on
Commit
b82277e
·
verified ·
1 Parent(s): fa0d38c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -41
app.py CHANGED
@@ -1,55 +1,119 @@
 
 
 
1
  import gradio as gr
2
- from sahi.predict import get_sliced_prediction
3
- from sahi.model import Yolov5DetectionModel
4
- from PIL import Image
5
- import numpy as np
6
  import torch
7
  from ultralytics import YOLO
 
 
 
 
8
 
9
- # Initialize YOLOv11n model
10
- yolo_model = YOLO('yolov11n.pt')
11
- print("YOLO model loaded successfully.")
12
 
13
- # Wrap with SAHI
14
- sahi_model = Yolov5DetectionModel(
15
- model_path='yolov11n.pt',
16
- confidence_threshold=0.3,
17
- device='cuda' if torch.cuda.is_available() else 'cpu',
18
- )
19
- print("SAHI model loaded successfully.")
20
-
21
- # Inference function
22
- def detect_objects(img: Image.Image):
23
- image_np = np.array(img)
24
- result = get_sliced_prediction(
25
- image_np,
26
- detection_model=sahi_model,
27
- slice_height=512,
28
- slice_width=512,
29
- overlap_height_ratio=0.2,
30
- overlap_width_ratio=0.2
31
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- detections = result.object_prediction_list
 
 
34
 
35
- # Draw boxes
36
- from PIL import ImageDraw
37
- draw = ImageDraw.Draw(img)
38
 
39
- for det in detections:
40
- box = det.bbox.to_xyxy()
41
- label = det.category.name
42
- draw.rectangle(box, outline="red", width=2)
43
- draw.text((box[0], box[1]), label, fill="red")
44
 
45
- return img
46
 
47
- # Create Gradio interface
48
  iface = gr.Interface(
49
- fn=detect_objects,
50
- inputs=gr.Image(type="pil"),
51
- outputs=gr.Image(type="pil"),
52
- title="YOLOv11n SAHI Object Detection",
 
 
 
 
53
  )
54
 
55
- iface.launch()
 
1
+ import os
2
+ os.system('pip install --upgrade gradio sahi')
3
+ import cv2
4
  import gradio as gr
 
 
 
 
5
  import torch
6
  from ultralytics import YOLO
7
+ from sahi import AutoDetectionModel
8
+ from sahi.predict import get_sliced_prediction
9
+ import random
10
+ import numpy as np
11
 
12
+ random.seed(42)
13
+ np.random.seed(42)
 
14
 
15
+ # Load default YOLO model
16
+ try:
17
+ yolo_model = YOLO("last.pt")
18
+ print("YOLO model loaded successfully.")
19
+ except FileNotFoundError:
20
+ print("Error: 'last.pt' not found.")
21
+ yolo_model = None
22
+ except Exception as e:
23
+ print(f"An error occurred while loading the YOLO model: {e}")
24
+ yolo_model = None
25
+
26
+ # Load SAHI model
27
+ try:
28
+ sahi_model = AutoDetectionModel.from_pretrained(
29
+ model_type="ultralytics",
30
+ model_path="last.pt", # same model used for consistency
31
+ confidence_threshold=0.5,
32
+ device="cpu",
33
  )
34
+ print("SAHI model loaded successfully.")
35
+ except Exception as e:
36
+ print(f"An error occurred while loading the SAHI model: {e}")
37
+ sahi_model = None
38
+
39
+ # Prediction function
40
+ def predict_and_show_bounding_boxes(image_path, model_choice):
41
+ if model_choice == "YOLO":
42
+ if yolo_model is None:
43
+ return None, "Error: YOLO model not loaded."
44
+
45
+ try:
46
+ img = cv2.imread(image_path)
47
+ if img is None:
48
+ return None, "Error: Could not load image"
49
+
50
+ results = yolo_model(image_path, conf=0.5)[0]
51
+ boxes = results.boxes
52
+
53
+ if len(boxes) == 0:
54
+ zero_defects_img = cv2.imread('zero_defects.png')
55
+ return zero_defects_img if zero_defects_img is not None else (None, "Error: Could not load zero defects image")
56
+
57
+ for box in boxes:
58
+ xyxy = box.xyxy[0].tolist()
59
+ x_min, y_min, x_max, y_max = map(int, xyxy[:4])
60
+ conf = box.conf[0].item()
61
+ cls = int(box.cls[0])
62
+ cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
63
+ label = f"{results.names[cls]}: {conf:.2f}"
64
+ cv2.putText(img, label, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
65
+
66
+ return img
67
+ except Exception as e:
68
+ return None, str(e)
69
+
70
+ elif model_choice == "SAHI":
71
+ if sahi_model is None:
72
+ return None, "Error: SAHI model not loaded."
73
+
74
+ try:
75
+ result = get_sliced_prediction(
76
+ image_path,
77
+ sahi_model,
78
+ slice_height=256,
79
+ slice_width=256,
80
+ overlap_height_ratio=0.2,
81
+ overlap_width_ratio=0.2,
82
+ )
83
+
84
+ img = cv2.imread(image_path)
85
+ if img is None:
86
+ return None, "Error: Could not load image"
87
+
88
+ for pred in result.object_prediction_list:
89
+ box = pred.bbox.to_xyxy()
90
+ x_min, y_min, x_max, y_max = map(int, box)
91
+ label = f"{pred.category.name}: {pred.score.value:.2f}"
92
+ cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (255, 0, 0), 2)
93
+ cv2.putText(img, label, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 0, 0), 2)
94
 
95
+ if len(result.object_prediction_list) == 0:
96
+ zero_defects_img = cv2.imread('zero_defects.png')
97
+ return zero_defects_img if zero_defects_img is not None else (None, "Error: Could not load zero defects image")
98
 
99
+ return img
100
+ except Exception as e:
101
+ return None, str(e)
102
 
103
+ else:
104
+ return None, "Invalid model choice."
 
 
 
105
 
 
106
 
107
+ # Gradio interface
108
  iface = gr.Interface(
109
+ fn=predict_and_show_bounding_boxes,
110
+ inputs=[
111
+ gr.Image(type="filepath", label="Upload Image"),
112
+ gr.Radio(choices=["YOLO", "SAHI"], label="Choose Detection Mode", value="YOLO"),
113
+ ],
114
+ outputs=[gr.Image(label="Result"), gr.Textbox(label="Message")],
115
+ title="Defect Detection",
116
+ description="Upload an image and choose YOLO or SAHI for defect detection.",
117
  )
118
 
119
+ iface.launch(share=True)