jassminvo1 commited on
Commit
486a300
·
verified ·
1 Parent(s): d6339b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -55
app.py CHANGED
@@ -1,104 +1,190 @@
1
- import gradio as gr
2
  import cv2
 
3
  import requests
4
- import os
5
-
6
  from ultralytics import YOLO
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  file_urls = [
9
- 'https://www.dropbox.com/s/b5g97xo901zb3ds/pothole_example.jpg?dl=1',
10
- 'https://www.dropbox.com/s/86uxlxxlm1iaexa/pothole_screenshot.png?dl=1',
11
- 'https://www.dropbox.com/s/7sjfwncffg8xej2/video_7.mp4?dl=1'
 
12
  ]
13
 
 
14
  def download_file(url, save_name):
15
- url = url
16
  if not os.path.exists(save_name):
17
  file = requests.get(url)
18
- open(save_name, 'wb').write(file.content)
 
19
 
20
  for i, url in enumerate(file_urls):
21
- if 'mp4' in file_urls[i]:
22
- download_file(
23
- file_urls[i],
24
- f"video.mp4"
25
- )
26
  else:
27
- download_file(
28
- file_urls[i],
29
- f"image_{i}.jpg"
30
- )
31
 
32
- model = YOLO('best.pt')
33
- path = [['image_0.jpg'], ['image_1.jpg']]
34
- video_path = [['video.mp4']]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- def show_preds_image(image_path):
37
- image = cv2.imread(image_path)
38
- outputs = model.predict(source=image_path)
39
- results = outputs[0].cpu().numpy()
40
- for i, det in enumerate(results.boxes.xyxy):
41
  cv2.rectangle(
42
  image,
43
- (int(det[0]), int(det[1])),
44
- (int(det[2]), int(det[3])),
45
- color=(0, 0, 255),
46
  thickness=2,
47
- lineType=cv2.LINE_AA
 
 
 
 
 
 
 
 
 
 
48
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
50
 
 
51
  inputs_image = [
52
- gr.components.Image(type="filepath", label="Input Image"),
53
  ]
54
  outputs_image = [
55
- gr.components.Image(type="numpy", label="Output Image"),
56
  ]
 
57
  interface_image = gr.Interface(
58
  fn=show_preds_image,
59
  inputs=inputs_image,
60
  outputs=outputs_image,
61
- title="Pothole detector",
62
- examples=path,
63
  cache_examples=False,
64
  )
65
 
 
 
66
  def show_preds_video(video_path):
67
  cap = cv2.VideoCapture(video_path)
68
- while(cap.isOpened()):
 
69
  ret, frame = cap.read()
70
- if ret:
71
- frame_copy = frame.copy()
72
- outputs = model.predict(source=frame)
73
- results = outputs[0].cpu().numpy()
74
- for i, det in enumerate(results.boxes.xyxy):
75
- cv2.rectangle(
76
- frame_copy,
77
- (int(det[0]), int(det[1])),
78
- (int(det[2]), int(det[3])),
79
- color=(0, 0, 255),
80
- thickness=2,
81
- lineType=cv2.LINE_AA
82
- )
83
- yield cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)
84
 
85
- inputs_video = [
86
- gr.components.Video(type="filepath", label="Input Video"),
 
 
 
 
87
 
 
 
 
 
 
 
 
88
  ]
89
  outputs_video = [
90
- gr.components.Image(type="numpy", label="Output Image"),
91
  ]
 
92
  interface_video = gr.Interface(
93
  fn=show_preds_video,
94
  inputs=inputs_video,
95
  outputs=outputs_video,
96
- title="Pothole detector",
97
- examples=video_path,
98
  cache_examples=False,
99
  )
100
 
 
 
101
  gr.TabbedInterface(
102
  [interface_image, interface_video],
103
- tab_names=['Image inference', 'Video inference']
104
- ).queue().launch()
 
1
+ import os
2
  import cv2
3
+ import gradio as gr
4
  import requests
 
 
5
  from ultralytics import YOLO
6
 
7
+ # ==== CẤU HÌNH PHÁT HIỆN NGỦ ====
8
+ # Các tên class trong model được coi là "ngủ gật"
9
+ SLEEPY_CLASS_NAMES = {
10
+ "drowsy",
11
+ "sleepy",
12
+ "closed_eyes",
13
+ "sleep",
14
+ "ngủ",
15
+ "buonngu",
16
+ }
17
+
18
+ # Ngưỡng confidence để kết luận là ngủ
19
+ SLEEP_CONF_THRESHOLD = 0.4
20
+
21
+ # ==== DEMO FILES (có thể bỏ nếu không cần) ====
22
  file_urls = [
23
+ # Bạn có thể thay bằng ảnh/video tài xế của bạn
24
+ "https://www.dropbox.com/s/b5g97xo901zb3ds/pothole_example.jpg?dl=1",
25
+ "https://www.dropbox.com/s/86uxlxxlm1iaexa/pothole_screenshot.png?dl=1",
26
+ "https://www.dropbox.com/s/7sjfwncffg8xej2/video_7.mp4?dl=1",
27
  ]
28
 
29
+
30
  def download_file(url, save_name):
 
31
  if not os.path.exists(save_name):
32
  file = requests.get(url)
33
+ open(save_name, "wb").write(file.content)
34
+
35
 
36
  for i, url in enumerate(file_urls):
37
+ if "mp4" in url:
38
+ download_file(url, "video.mp4")
 
 
 
39
  else:
40
+ download_file(url, f"image_{i}.jpg")
 
 
 
41
 
42
+ # ==== LOAD MODEL YOLO (đã train phát hiện buồn ngủ) ====
43
+ model = YOLO("best.pt")
44
+
45
+ image_examples = [["image_0.jpg"], ["image_1.jpg"]]
46
+ video_examples = [["video.mp4"]]
47
+
48
+
49
+ def _normalize_name(name: str) -> str:
50
+ return name.lower().replace(" ", "_")
51
+
52
+
53
+ def draw_and_decide_state(image, results):
54
+ """
55
+ Vẽ bounding box + label lên ảnh
56
+ Đồng thời quyết định xem tài xế đang ngủ hay tỉnh
57
+ """
58
+ sleepy_detected = False
59
+
60
+ names = results.names
61
+ boxes = results.boxes.xyxy
62
+ confs = results.boxes.conf
63
+ clss = results.boxes.cls
64
+
65
+ for i, box in enumerate(boxes):
66
+ x1, y1, x2, y2 = map(int, box)
67
+ cls_id = int(clss[i])
68
+ conf = float(confs[i])
69
+ cls_name = names[cls_id]
70
+
71
+ norm_name = _normalize_name(cls_name)
72
+ if norm_name in SLEEPY_CLASS_NAMES and conf >= SLEEP_CONF_THRESHOLD:
73
+ sleepy_detected = True
74
+
75
+ label = f"{cls_name} ({conf:.2f})"
76
+ color = (0, 0, 255) if norm_name in SLEEPY_CLASS_NAMES else (0, 255, 0)
77
 
 
 
 
 
 
78
  cv2.rectangle(
79
  image,
80
+ (x1, y1),
81
+ (x2, y2),
82
+ color=color,
83
  thickness=2,
84
+ lineType=cv2.LINE_AA,
85
+ )
86
+ cv2.putText(
87
+ image,
88
+ label,
89
+ (x1, max(y1 - 10, 10)),
90
+ cv2.FONT_HERSHEY_SIMPLEX,
91
+ 0.6,
92
+ color,
93
+ 2,
94
+ lineType=cv2.LINE_AA,
95
  )
96
+
97
+ # Dòng trạng thái tổng quát
98
+ if sleepy_detected:
99
+ state_text = "NGỦ GỤC / DROWSY"
100
+ state_color = (0, 0, 255)
101
+ else:
102
+ state_text = "TỈNH TÁO / ALERT"
103
+ state_color = (0, 255, 0)
104
+
105
+ cv2.putText(
106
+ image,
107
+ state_text,
108
+ (10, 30),
109
+ cv2.FONT_HERSHEY_SIMPLEX,
110
+ 1.0,
111
+ state_color,
112
+ 2,
113
+ lineType=cv2.LINE_AA,
114
+ )
115
+
116
+ return image, sleepy_detected
117
+
118
+
119
+ # ==== 1. ẢNH TĨNH ====
120
+ def show_preds_image(image_path):
121
+ image = cv2.imread(image_path)
122
+
123
+ # conf=0.25 cho YOLO, bạn có thể giảm nếu muốn nhạy hơn
124
+ outputs = model.predict(source=image_path, conf=0.25)
125
+ results = outputs[0].cpu().numpy()
126
+
127
+ image, _ = draw_and_decide_state(image, results)
128
  return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
129
 
130
+
131
  inputs_image = [
132
+ gr.Image(type="filepath", label="Ảnh đầu vào (driver image)"),
133
  ]
134
  outputs_image = [
135
+ gr.Image(type="numpy", label="Kết quả nhận diện"),
136
  ]
137
+
138
  interface_image = gr.Interface(
139
  fn=show_preds_image,
140
  inputs=inputs_image,
141
  outputs=outputs_image,
142
+ title="Drowsy Driver Detector - Image",
143
+ examples=image_examples,
144
  cache_examples=False,
145
  )
146
 
147
+
148
+ # ==== 2. VIDEO ====
149
  def show_preds_video(video_path):
150
  cap = cv2.VideoCapture(video_path)
151
+
152
+ while cap.isOpened():
153
  ret, frame = cap.read()
154
+ if not ret:
155
+ break
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
+ frame_copy = frame.copy()
158
+
159
+ outputs = model.predict(source=frame, conf=0.25, verbose=False)
160
+ results = outputs[0].cpu().numpy()
161
+
162
+ frame_copy, _ = draw_and_decide_state(frame_copy, results)
163
 
164
+ yield cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)
165
+
166
+ cap.release()
167
+
168
+
169
+ inputs_video = [
170
+ gr.Video(type="filepath", label="Video đầu vào (driver camera)"),
171
  ]
172
  outputs_video = [
173
+ gr.Image(type="numpy", label="Kết quả từng frame"),
174
  ]
175
+
176
  interface_video = gr.Interface(
177
  fn=show_preds_video,
178
  inputs=inputs_video,
179
  outputs=outputs_video,
180
+ title="Drowsy Driver Detector - Video",
181
+ examples=video_examples,
182
  cache_examples=False,
183
  )
184
 
185
+
186
+ # ==== Giao diện Tab ====
187
  gr.TabbedInterface(
188
  [interface_image, interface_video],
189
+ tab_names=["Ảnh", "Video"],
190
+ ).queue().launch()