ckcl commited on
Commit
9ad85ee
·
verified ·
1 Parent(s): f6392db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -108
app.py CHANGED
@@ -6,8 +6,6 @@ import cv2
6
  from PIL import Image
7
  import io
8
  import os
9
- import sys
10
- import time
11
 
12
  class DrowsinessDetector:
13
  def __init__(self):
@@ -18,11 +16,12 @@ class DrowsinessDetector:
18
  self.id2label = {0: "notdrowsy", 1: "drowsy"}
19
  self.label2id = {"notdrowsy": 0, "drowsy": 1}
20
 
21
- def load_model(self, model_path):
22
- """Load the ViT model and processor from the specified path or directory"""
23
  try:
 
24
  self.model = ViTForImageClassification.from_pretrained(
25
- model_path, # 直接給資料夾路徑
26
  num_labels=2,
27
  id2label=self.id2label,
28
  label2id=self.label2id,
@@ -30,7 +29,7 @@ class DrowsinessDetector:
30
  )
31
  self.model.eval()
32
  self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
33
- print(f"ViT model loaded successfully from {model_path}")
34
  except Exception as e:
35
  print(f"Error loading ViT model: {str(e)}")
36
  raise
@@ -80,88 +79,54 @@ class DrowsinessDetector:
80
  # Initialize detector
81
  detector = DrowsinessDetector()
82
 
83
- def find_model_file():
84
- """Find the model directory or file in common locations"""
85
- possible_paths = [
86
- "huggingface_model", # 優先資料夾
87
- "pytorch_model.bin",
88
- "model_weights.h5",
89
- "drowsiness_model.h5",
90
- "model/drowsiness_model.h5",
91
- "models/drowsiness_model.h5",
92
- "huggingface_model/model_weights.h5",
93
- "huggingface_model/drowsiness_model.h5",
94
- "../model_weights.h5",
95
- "../drowsiness_model.h5"
96
- ]
97
- for path in possible_paths:
98
- if os.path.exists(path):
99
- return path
100
- return None
101
-
102
- def load_model():
103
- """Load the model"""
104
- model_path = find_model_file()
105
-
106
- if model_path is None:
107
- print("\nError: Model file not found!")
108
- print("\nPlease ensure one of the following files exists:")
109
- print("1. model_weights.h5")
110
- print("2. drowsiness_model.h5")
111
- print("3. model/drowsiness_model.h5")
112
- print("4. models/drowsiness_model.h5")
113
- print("\nYou can download the model from Hugging Face Hub or train it using train_model.py")
114
- sys.exit(1)
115
 
116
  try:
117
- detector.load_model(model_path)
118
- except Exception as e:
119
- print(f"\nError loading model: {str(e)}")
120
- sys.exit(1)
121
-
122
- def process_frame(frame):
123
- """Process a single frame"""
124
- if frame is None:
125
- return None
126
 
127
- try:
128
  # Convert frame to RGB if needed
129
- if len(frame.shape) == 2:
130
- frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
131
- elif frame.shape[2] == 4:
132
- frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
133
 
134
  # Make prediction
135
- drowsy_prob, face_coords, error = detector.predict(frame)
136
 
137
  if error:
138
- return frame
139
 
140
  if face_coords is not None:
141
  x, y, w, h = face_coords
142
  # Draw rectangle around face
143
  color = (0, 0, 255) if drowsy_prob > 0.7 else (0, 255, 0)
144
- cv2.rectangle(frame, (x, y), (x+w, y+h), color, 2)
145
 
146
  # Add text
147
  status = "DROWSY" if drowsy_prob > 0.7 else "ALERT"
148
- cv2.putText(frame, f"{status} ({drowsy_prob:.2%})",
149
  (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
150
-
151
- return frame
 
 
152
 
153
  except Exception as e:
154
- print(f"Error processing frame: {str(e)}")
155
- return frame
156
 
157
- def process_video(video_input):
158
  """Process video input"""
159
- if video_input is None:
160
- return None
161
 
162
  try:
163
  # Get input video properties
164
- cap = cv2.VideoCapture(video_input)
165
  fps = cap.get(cv2.CAP_PROP_FPS)
166
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
167
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
@@ -176,7 +141,7 @@ def process_video(video_input):
176
  if not ret:
177
  break
178
 
179
- processed_frame = process_frame(frame)
180
  if processed_frame is not None:
181
  out.write(processed_frame)
182
 
@@ -186,14 +151,12 @@ def process_video(video_input):
186
 
187
  # Check if video was created
188
  if os.path.exists(temp_output) and os.path.getsize(temp_output) > 0:
189
- return temp_output
190
  else:
191
- print("Error: Failed to create output video")
192
- return None
193
 
194
  except Exception as e:
195
- print(f"Error processing video: {str(e)}")
196
- return None
197
  finally:
198
  # Clean up temporary file
199
  if 'out' in locals():
@@ -201,27 +164,8 @@ def process_video(video_input):
201
  if 'cap' in locals():
202
  cap.release()
203
 
204
- def webcam_feed():
205
- """Process webcam feed"""
206
- try:
207
- cap = cv2.VideoCapture(0)
208
- while True:
209
- ret, frame = cap.read()
210
- if not ret:
211
- break
212
-
213
- processed_frame = process_frame(frame)
214
- if processed_frame is not None:
215
- yield processed_frame
216
-
217
- except Exception as e:
218
- print(f"Error processing webcam feed: {str(e)}")
219
- yield None
220
- finally:
221
- cap.release()
222
-
223
  # Load the model at startup
224
- load_model()
225
 
226
  # Create interface
227
  with gr.Blocks(title="Driver Drowsiness Detection") as demo:
@@ -231,34 +175,38 @@ with gr.Blocks(title="Driver Drowsiness Detection") as demo:
231
  This system detects driver drowsiness using computer vision and deep learning.
232
 
233
  ## Features:
234
- - Real-time webcam monitoring
235
- - Video file processing
236
- - Single image analysis
237
  - Face detection and drowsiness prediction
238
  """)
239
 
240
  with gr.Tabs():
241
- with gr.Tab("Webcam"):
242
- gr.Markdown("Real-time drowsiness detection using your webcam")
243
- webcam_output = gr.Image(label="Live Detection")
244
- webcam_button = gr.Button("Start Webcam")
245
- webcam_button.click(fn=webcam_feed, inputs=None, outputs=webcam_output)
 
 
 
 
 
 
 
246
 
247
  with gr.Tab("Video"):
248
  gr.Markdown("Upload a video file for drowsiness detection")
249
  with gr.Row():
250
  video_input = gr.Video(label="Input Video")
251
- video_output = gr.Video(label="Detection Result")
252
- video_button = gr.Button("Process Video")
253
- video_button.click(fn=process_video, inputs=video_input, outputs=video_output)
254
-
255
- with gr.Tab("Image"):
256
- gr.Markdown("Upload an image for drowsiness detection")
257
  with gr.Row():
258
- image_input = gr.Image(type="numpy", label="Input Image")
259
- image_output = gr.Image(label="Detection Result")
260
- image_button = gr.Button("Process Image")
261
- image_button.click(fn=process_frame, inputs=image_input, outputs=image_output)
 
 
262
 
 
263
  if __name__ == "__main__":
264
  demo.launch()
 
6
  from PIL import Image
7
  import io
8
  import os
 
 
9
 
10
  class DrowsinessDetector:
11
  def __init__(self):
 
16
  self.id2label = {0: "notdrowsy", 1: "drowsy"}
17
  self.label2id = {"notdrowsy": 0, "drowsy": 1}
18
 
19
+ def load_model(self):
20
+ """Load the ViT model and processor from Hugging Face Hub"""
21
  try:
22
+ model_id = "ckcl/driver-drowsiness-detector" # 使用你的模型ID
23
  self.model = ViTForImageClassification.from_pretrained(
24
+ model_id,
25
  num_labels=2,
26
  id2label=self.id2label,
27
  label2id=self.label2id,
 
29
  )
30
  self.model.eval()
31
  self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
32
+ print(f"ViT model loaded successfully from {model_id}")
33
  except Exception as e:
34
  print(f"Error loading ViT model: {str(e)}")
35
  raise
 
79
  # Initialize detector
80
  detector = DrowsinessDetector()
81
 
82
+ def process_image(image):
83
+ """Process a single image"""
84
+ if image is None:
85
+ return None, "No image provided"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  try:
88
+ # Convert image to numpy array if it's a PIL Image
89
+ if isinstance(image, Image.Image):
90
+ image = np.array(image)
 
 
 
 
 
 
91
 
 
92
  # Convert frame to RGB if needed
93
+ if len(image.shape) == 2:
94
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
95
+ elif image.shape[2] == 4:
96
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
97
 
98
  # Make prediction
99
+ drowsy_prob, face_coords, error = detector.predict(image)
100
 
101
  if error:
102
+ return image, error
103
 
104
  if face_coords is not None:
105
  x, y, w, h = face_coords
106
  # Draw rectangle around face
107
  color = (0, 0, 255) if drowsy_prob > 0.7 else (0, 255, 0)
108
+ cv2.rectangle(image, (x, y), (x+w, y+h), color, 2)
109
 
110
  # Add text
111
  status = "DROWSY" if drowsy_prob > 0.7 else "ALERT"
112
+ cv2.putText(image, f"{status} ({drowsy_prob:.2%})",
113
  (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
114
+
115
+ return image, f"Status: {status} (Confidence: {drowsy_prob:.2%})"
116
+ else:
117
+ return image, "No face detected"
118
 
119
  except Exception as e:
120
+ return image, f"Error processing image: {str(e)}"
 
121
 
122
+ def process_video(video):
123
  """Process video input"""
124
+ if video is None:
125
+ return None, "No video provided"
126
 
127
  try:
128
  # Get input video properties
129
+ cap = cv2.VideoCapture(video)
130
  fps = cap.get(cv2.CAP_PROP_FPS)
131
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
132
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
141
  if not ret:
142
  break
143
 
144
+ processed_frame = process_image(frame)[0]
145
  if processed_frame is not None:
146
  out.write(processed_frame)
147
 
 
151
 
152
  # Check if video was created
153
  if os.path.exists(temp_output) and os.path.getsize(temp_output) > 0:
154
+ return temp_output, "Video processed successfully"
155
  else:
156
+ return None, "Error: Failed to create output video"
 
157
 
158
  except Exception as e:
159
+ return None, f"Error processing video: {str(e)}"
 
160
  finally:
161
  # Clean up temporary file
162
  if 'out' in locals():
 
164
  if 'cap' in locals():
165
  cap.release()
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  # Load the model at startup
168
+ detector.load_model()
169
 
170
  # Create interface
171
  with gr.Blocks(title="Driver Drowsiness Detection") as demo:
 
175
  This system detects driver drowsiness using computer vision and deep learning.
176
 
177
  ## Features:
178
+ - Image analysis
179
+ - Video processing
 
180
  - Face detection and drowsiness prediction
181
  """)
182
 
183
  with gr.Tabs():
184
+ with gr.Tab("Image"):
185
+ gr.Markdown("Upload an image for drowsiness detection")
186
+ with gr.Row():
187
+ image_input = gr.Image(label="Input Image", type="numpy")
188
+ image_output = gr.Image(label="Processed Image")
189
+ with gr.Row():
190
+ status_output = gr.Textbox(label="Status")
191
+ image_input.change(
192
+ fn=process_image,
193
+ inputs=[image_input],
194
+ outputs=[image_output, status_output]
195
+ )
196
 
197
  with gr.Tab("Video"):
198
  gr.Markdown("Upload a video file for drowsiness detection")
199
  with gr.Row():
200
  video_input = gr.Video(label="Input Video")
201
+ video_output = gr.Video(label="Processed Video")
 
 
 
 
 
202
  with gr.Row():
203
+ video_status = gr.Textbox(label="Status")
204
+ video_input.change(
205
+ fn=process_video,
206
+ inputs=[video_input],
207
+ outputs=[video_output, video_status]
208
+ )
209
 
210
+ # Launch the app
211
  if __name__ == "__main__":
212
  demo.launch()