Spaces:
ckcl
/
Build error

ckcl commited on
Commit
5e2919a
·
verified ·
1 Parent(s): 15fdf57

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -1013
app.py CHANGED
@@ -1,1082 +1,264 @@
1
  import gradio as gr
 
 
2
  import numpy as np
3
  import cv2
4
  from PIL import Image
5
  import io
6
  import os
7
- import json
8
  import time
9
- import argparse
10
- import tensorflow as tf
11
- from tensorflow import keras
12
- import math
13
- from collections import deque
14
- from mtcnn import MTCNN
15
-
16
- class SpeedDetector:
17
- def __init__(self, history_size=30):
18
- self.speed_history = deque(maxlen=history_size)
19
- self.last_update_time = None
20
- self.current_speed = 0
21
- self.speed_change_threshold = 5 # km/h
22
- self.abnormal_speed_changes = 0
23
- self.speed_deviation_sum = 0
24
- self.speed_change_score = 0
25
-
26
- # For optical flow speed estimation
27
- self.prev_gray = None
28
- self.prev_points = None
29
- self.frame_idx = 0
30
- self.speed_estimate = 60 # Initial estimate
31
-
32
- def update_speed(self, speed_km_h):
33
- """Update with current speed in km/h"""
34
- current_time = time.time()
35
-
36
- # Add to history
37
- self.speed_history.append(speed_km_h)
38
- self.current_speed = speed_km_h
39
-
40
- # Not enough data yet
41
- if len(self.speed_history) < 5:
42
- return 0
43
-
44
- # Calculate speed variation metrics
45
- speed_arr = np.array(self.speed_history)
46
-
47
- # 1. Standard deviation of speed
48
- speed_std = np.std(speed_arr)
49
-
50
- # 2. Detect abrupt changes
51
- for i in range(1, len(speed_arr)):
52
- change = abs(speed_arr[i] - speed_arr[i-1])
53
- if change >= self.speed_change_threshold:
54
- self.abnormal_speed_changes += 1
55
-
56
- # 3. Calculate average rate of change
57
- changes = np.abs(np.diff(speed_arr))
58
- avg_change = np.mean(changes) if len(changes) > 0 else 0
59
-
60
- # Combine into a score (0-1 range)
61
- self.speed_deviation_sum = min(5, speed_std) / 5 # Normalize to 0-1
62
- abnormal_change_factor = min(1, self.abnormal_speed_changes / 5)
63
- avg_change_factor = min(1, avg_change / self.speed_change_threshold)
64
-
65
- # Weighted combination
66
- self.speed_change_score = (
67
- 0.4 * self.speed_deviation_sum +
68
- 0.4 * abnormal_change_factor +
69
- 0.2 * avg_change_factor
70
- )
71
-
72
- return self.speed_change_score
73
-
74
- def detect_speed_from_frame(self, frame):
75
- """Detect speed from video frame using optical flow"""
76
- if frame is None:
77
- return self.current_speed
78
-
79
- # Convert frame to grayscale
80
- gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
81
-
82
- # For the first frame, initialize points to track
83
- if self.prev_gray is None or self.frame_idx % 30 == 0: # Reset tracking points every 30 frames
84
- # Detect good features to track
85
- mask = np.zeros_like(gray)
86
- # Focus on the lower portion of the frame (road)
87
- h, w = gray.shape
88
- mask[h//2:, :] = 255
89
-
90
- corners = cv2.goodFeaturesToTrack(gray, maxCorners=100, qualityLevel=0.01, minDistance=10, mask=mask)
91
- if corners is not None and len(corners) > 0:
92
- self.prev_points = corners
93
- self.prev_gray = gray.copy()
94
- else:
95
- # No good points to track
96
- self.frame_idx += 1
97
- return self.current_speed
98
-
99
- # Calculate optical flow if we have previous points
100
- if self.prev_gray is not None and self.prev_points is not None:
101
- # Calculate optical flow
102
- new_points, status, _ = cv2.calcOpticalFlowPyrLK(self.prev_gray, gray, self.prev_points, None)
103
-
104
- # Filter only valid points
105
- if new_points is not None and status is not None:
106
- good_new = new_points[status == 1]
107
- good_old = self.prev_points[status == 1]
108
-
109
- # Calculate flow magnitude
110
- if len(good_new) > 0 and len(good_old) > 0:
111
- flow_magnitudes = np.sqrt(
112
- np.sum((good_new - good_old)**2, axis=1)
113
- )
114
- avg_flow = np.mean(flow_magnitudes) if len(flow_magnitudes) > 0 else 0
115
-
116
- # Map optical flow to speed change
117
- # Higher flow = faster movement
118
- # This is a simplified mapping and would need calibration for real-world use
119
- flow_threshold = 1.0 # Adjust based on testing
120
-
121
- if avg_flow > flow_threshold:
122
- # Movement detected, estimate acceleration
123
- speed_change = min(5, max(-5, (avg_flow - flow_threshold) * 2))
124
-
125
- # Add some temporal smoothing to avoid sudden changes
126
- speed_change = speed_change * 0.3 # Reduce magnitude for smoother change
127
- else:
128
- # Minimal movement, slight deceleration (coasting)
129
- speed_change = -0.1
130
-
131
- # Update speed with detected change
132
- self.speed_estimate += speed_change
133
- # Keep speed in reasonable range
134
- self.speed_estimate = max(40, min(120, self.speed_estimate))
135
-
136
- # Update tracking points
137
- self.prev_points = good_new.reshape(-1, 1, 2)
138
-
139
- # Update previous gray frame
140
- self.prev_gray = gray.copy()
141
-
142
- self.frame_idx += 1
143
-
144
- # Check for dashboard speedometer (would require more sophisticated OCR in a real system)
145
- # For now, just use our estimated speed
146
- detected_speed = self.speed_estimate
147
-
148
- # Update current speed and trigger speed change detection
149
- self.update_speed(detected_speed)
150
-
151
- return detected_speed
152
-
153
- def get_speed_change_score(self):
154
- """Return a score from 0-1 indicating abnormal speed changes"""
155
- return self.speed_change_score
156
-
157
- def reset(self):
158
- """Reset the detector state"""
159
- self.speed_history.clear()
160
- self.abnormal_speed_changes = 0
161
- self.speed_deviation_sum = 0
162
- self.speed_change_score = 0
163
- self.prev_gray = None
164
- self.prev_points = None
165
- self.frame_idx = 0
166
- self.speed_estimate = 60 # Reset to initial estimate
167
 
168
  class DrowsinessDetector:
169
  def __init__(self):
170
  self.model = None
171
- self.input_shape = (224, 224, 3) # Updated to match model's expected input shape
 
172
  self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
173
  self.id2label = {0: "notdrowsy", 1: "drowsy"}
174
  self.label2id = {"notdrowsy": 0, "drowsy": 1}
175
-
176
- # Speed detector
177
- self.speed_detector = SpeedDetector()
178
- self.SPEED_CHANGE_WEIGHT = 0.15 # Weight for speed changes in drowsiness calculation
179
-
180
- # Yawn detection parameters
181
- self.MAR_THRESHOLD = 0.5 # Mouth aspect ratio threshold for yawn detection
182
- self.yawn_counter = 0
183
- self.CONSECUTIVE_YAWN_FRAMES = 10 # Number of consecutive frames to confirm yawn
184
- self.last_yawn_time = 0
185
- self.YAWN_COOLDOWN = 3 # Seconds between yawn alerts
186
-
187
- # 嘗試動態 import dlib,並設置 fallback
188
- self.landmark_detection_enabled = False
189
- try:
190
- import dlib
191
- self.detector = dlib.get_frontal_face_detector()
192
- predictor_path = "shape_predictor_68_face_landmarks.dat"
193
- if not os.path.exists(predictor_path):
194
- print(f"Warning: {predictor_path} not found. Downloading...")
195
- import urllib.request
196
- urllib.request.urlretrieve(
197
- "https://github.com/italojs/facial-landmarks-recognition/raw/master/shape_predictor_68_face_landmarks.dat",
198
- predictor_path
199
- )
200
- self.predictor = dlib.shape_predictor(predictor_path)
201
- self.landmark_detection_enabled = True
202
- print("Facial landmark detection enabled")
203
- except Exception as e:
204
- print(f"Warning: Facial landmark detection disabled: {e}")
205
- print("The system will use a simpler detection method. For better accuracy, install CMake and dlib.")
206
-
207
- # Constants for drowsiness detection
208
- self.EAR_THRESHOLD = 0.25 # Eye aspect ratio threshold
209
- self.CONSECUTIVE_FRAMES = 20
210
- self.ear_counter = 0
211
- self.GAZE_THRESHOLD = 0.2 # Gaze direction threshold
212
- self.HEAD_POSE_THRESHOLD = 0.3 # Head pose threshold
213
-
214
- # Parameters for weighted ensemble
215
- self.MODEL_WEIGHT = 0.45 # Reduced to accommodate speed factor
216
- self.EAR_WEIGHT = 0.2
217
- self.GAZE_WEIGHT = 0.1
218
- self.HEAD_POSE_WEIGHT = 0.1
219
-
220
- # For tracking across frames
221
- self.prev_drowsy_count = 0
222
- self.drowsy_history = []
223
- self.current_speed = 0 # Current speed in km/h
224
- self.mtcnn_detector = MTCNN()
225
 
226
- def update_speed(self, speed_km_h):
227
- """Update the current speed"""
228
- self.current_speed = speed_km_h
229
- return self.speed_detector.update_speed(speed_km_h)
230
-
231
- def reset_speed_detector(self):
232
- """Reset the speed detector"""
233
- self.speed_detector.reset()
234
-
235
- def load_model(self):
236
- """Load the CNN model from local files"""
237
  try:
238
- # Use local model files
239
- config_path = "huggingface_model/config.json"
240
- model_path = "drowsiness_model.h5"
241
-
242
- # Load config
243
- with open(config_path, 'r') as f:
244
- config = json.load(f)
245
-
246
- # Load the Keras model directly
247
- self.model = keras.models.load_model(model_path)
248
-
249
- # Print model summary for debugging
250
- print("Model loaded successfully")
251
- print(f"Model input shape: {self.model.input_shape}")
252
- self.model.summary()
253
-
254
  except Exception as e:
255
- print(f"Error loading CNN model: {str(e)}")
256
  raise
257
-
258
- def eye_aspect_ratio(self, eye):
259
- """Calculate the eye aspect ratio"""
260
- # Compute the euclidean distances between the two sets of vertical eye landmarks
261
- A = dist.euclidean(eye[1], eye[5])
262
- B = dist.euclidean(eye[2], eye[4])
263
-
264
- # Compute the euclidean distance between the horizontal eye landmarks
265
- C = dist.euclidean(eye[0], eye[3])
266
-
267
- # Calculate the eye aspect ratio
268
- ear = (A + B) / (2.0 * C)
269
- return ear
270
-
271
- def calculate_gaze(self, eye_points, facial_landmarks):
272
- """Calculate gaze direction"""
273
- left_eye_region = np.array([(facial_landmarks.part(i).x, facial_landmarks.part(i).y) for i in range(36, 42)])
274
- right_eye_region = np.array([(facial_landmarks.part(i).x, facial_landmarks.part(i).y) for i in range(42, 48)])
275
-
276
- # Compute eye centers
277
- left_eye_center = left_eye_region.mean(axis=0).astype("int")
278
- right_eye_center = right_eye_region.mean(axis=0).astype("int")
279
-
280
- # Compute the angle between eye centers
281
- dY = right_eye_center[1] - left_eye_center[1]
282
- dX = right_eye_center[0] - left_eye_center[0]
283
- angle = np.degrees(np.arctan2(dY, dX))
284
-
285
- # Normalize the angle
286
- return abs(angle) / 180.0
287
-
288
- def get_head_pose(self, shape):
289
- """Calculate the head pose"""
290
- # Get specific facial landmarks for head pose estimation
291
- image_points = np.array([
292
- (shape.part(30).x, shape.part(30).y), # Nose tip
293
- (shape.part(8).x, shape.part(8).y), # Chin
294
- (shape.part(36).x, shape.part(36).y), # Left eye left corner
295
- (shape.part(45).x, shape.part(45).y), # Right eye right corner
296
- (shape.part(48).x, shape.part(48).y), # Left mouth corner
297
- (shape.part(54).x, shape.part(54).y) # Right mouth corner
298
- ], dtype="double")
299
-
300
- # A simple head pose estimation using the angle of the face
301
- # Calculate center of the face
302
- center_x = np.mean([p[0] for p in image_points])
303
- center_y = np.mean([p[1] for p in image_points])
304
-
305
- # Calculate angle with respect to vertical
306
- angle = 0
307
- if len(image_points) > 2:
308
- point1 = image_points[0] # Nose
309
- point2 = image_points[1] # Chin
310
- angle = abs(math.atan2(point2[1] - point1[1], point2[0] - point1[0]))
311
-
312
- # Normalize to 0-1 range where 0 is upright and 1 is drooping
313
- normalized_pose = min(1.0, abs(angle) / (math.pi/2))
314
- return normalized_pose
315
-
316
- def enhance_image(self, frame):
317
- # Apply CLAHE to improve contrast
318
- gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
319
- clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
320
- enhanced = clahe.apply(gray)
321
- enhanced_bgr = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR)
322
- return enhanced_bgr
323
 
324
  def detect_face(self, frame):
325
- # Enhance image before detection
326
- enhanced_frame = self.enhance_image(frame)
327
- # Try MTCNN
328
- try:
329
- results = self.mtcnn_detector.detect_faces(cv2.cvtColor(enhanced_frame, cv2.COLOR_BGR2RGB))
330
- print('MTCNN results:', results)
331
- if results:
332
- # 選擇最右側的臉(x+w最大者)
333
- rightmost = max(results, key=lambda r: r['box'][0] + r['box'][2])
334
- x, y, w, h = rightmost['box']
335
- x, y = max(0, x), max(0, y)
336
- w, h = max(0, w), max(0, h)
337
- if x+w > frame.shape[1] or y+h > frame.shape[0] or w == 0 or h == 0:
338
- print('MTCNN box out of bounds or zero size')
339
- else:
340
- face = frame[y:y+h, x:x+w]
341
- return face, (x, y, w, h)
342
- except Exception as e:
343
- print(f"MTCNN detection error: {e}")
344
- # Fallback to haarcascade
345
- gray = cv2.cvtColor(enhanced_frame, cv2.COLOR_BGR2GRAY)
346
  faces = self.face_cascade.detectMultiScale(gray, 1.1, 4)
347
- print('Haar results:', faces)
348
  if len(faces) > 0:
349
- # 選擇最右側的臉
350
- rightmost_idx = np.argmax([x+w for (x, y, w, h) in faces])
351
- (x, y, w, h) = faces[rightmost_idx]
352
- if w > 0 and h > 0:
353
- face = frame[y:y+h, x:x+w]
354
- return face, (x, y, w, h)
355
  return None, None
356
 
357
  def preprocess_image(self, image):
358
- """Preprocess the input image for CNN"""
359
  if image is None:
360
  return None
361
- # Convert to RGB
362
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
363
- # Resize to model input size (224x224)
364
- image = cv2.resize(image, (self.input_shape[0], self.input_shape[1]))
365
- # Normalize
366
- image = image.astype(np.float32) / 255.0
367
- # Add batch dimension
368
- image = np.expand_dims(image, axis=0)
369
- return image
370
-
371
- def mouth_aspect_ratio(self, mouth_points):
372
- """Calculate the mouth aspect ratio"""
373
- # Compute the euclidean distances between the vertical mouth landmarks
374
- A = np.linalg.norm(mouth_points[1] - mouth_points[7])
375
- B = np.linalg.norm(mouth_points[2] - mouth_points[6])
376
- C = np.linalg.norm(mouth_points[3] - mouth_points[5])
377
-
378
- # Compute the euclidean distance between the horizontal mouth landmarks
379
- D = np.linalg.norm(mouth_points[0] - mouth_points[4])
380
-
381
- # Calculate the mouth aspect ratio
382
- mar = (A + B + C) / (2.0 * D)
383
- return mar
384
-
385
- def detect_yawn(self, shape):
386
- """Detect if the person is yawning using mouth aspect ratio"""
387
- if not self.landmark_detection_enabled:
388
- return False, 0
389
-
390
- # Get mouth landmarks (points 48-68)
391
- mouth_points = np.array([(shape.part(i).x, shape.part(i).y) for i in range(48, 68)])
392
-
393
- # Calculate mouth aspect ratio
394
- mar = self.mouth_aspect_ratio(mouth_points)
395
-
396
- # Check if mouth is open wide enough to be considered a yawn
397
- current_time = time.time()
398
- if mar > self.MAR_THRESHOLD:
399
- self.yawn_counter += 1
400
- if self.yawn_counter >= self.CONSECUTIVE_YAWN_FRAMES:
401
- # Check if enough time has passed since last yawn alert
402
- if current_time - self.last_yawn_time > self.YAWN_COOLDOWN:
403
- self.last_yawn_time = current_time
404
- return True, mar
405
- else:
406
- self.yawn_counter = 0
407
-
408
- return False, mar
409
 
410
  def predict(self, image):
411
- """Predict drowsiness using multiple features"""
412
- try:
413
- # Convert image to numpy array if it's not already
414
- if isinstance(image, Image.Image):
415
- image = np.array(image)
416
-
417
- # Convert to RGB if image is in BGR format
418
- if len(image.shape) == 3 and image.shape[2] == 3:
419
- if image.dtype == np.uint8:
420
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
421
-
422
- # Detect face
423
- face, face_coords = self.detect_face(image)
424
- if face is None or face_coords is None:
425
- return 0, "No face detected", None, 0, 0, 0, 0, 0, False
426
-
427
- # Initialize feature scores
428
- model_score = 0
429
- ear_score = 0
430
- gaze_score = 0
431
- head_pose_score = 0
432
- yawn_detected = False
433
- mar = 0
434
-
435
- # Get facial landmarks if available
436
- if self.landmark_detection_enabled:
437
- gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
438
- rect = dlib.rectangle(face_coords[0], face_coords[1], face_coords[0] + face_coords[2], face_coords[1] + face_coords[3])
439
- shape = self.predictor(gray, rect)
440
- # Calculate EAR
441
- left_eye = np.array([(shape.part(i).x, shape.part(i).y) for i in range(36, 42)])
442
- right_eye = np.array([(shape.part(i).x, shape.part(i).y) for i in range(42, 48)])
443
- ear = (self.eye_aspect_ratio(left_eye) + self.eye_aspect_ratio(right_eye)) / 2.0
444
- ear_score = 1.0 if ear < self.EAR_THRESHOLD else 0.0
445
- # Calculate gaze direction
446
- gaze = self.calculate_gaze([left_eye, right_eye], shape)
447
- gaze_score = 1.0 if abs(gaze[0]) > self.GAZE_THRESHOLD or abs(gaze[1]) > self.GAZE_THRESHOLD else 0.0
448
- # Calculate head pose
449
- head_pose = self.get_head_pose(shape)
450
- head_pose_score = 1.0 if abs(head_pose[0]) > self.HEAD_POSE_THRESHOLD or abs(head_pose[1]) > self.HEAD_POSE_THRESHOLD else 0.0
451
- # Detect yawn
452
- yawn_detected, mar = self.detect_yawn(shape)
453
- else:
454
- # Fallback: simple EAR/MAR estimation using grayscale intensity
455
- # Estimate eye region based on face proportions
456
- face_gray = cv2.cvtColor(face, cv2.COLOR_BGR2GRAY)
457
- fh, fw = face_gray.shape[:2]
458
- # Approximate left/right eye regions
459
- left_eye_region = face_gray[int(fh*0.25):int(fh*0.45), int(fw*0.13):int(fw*0.37)]
460
- right_eye_region = face_gray[int(fh*0.25):int(fh*0.45), int(fw*0.63):int(fw*0.87)]
461
- # Use average intensity: lower means more likely closed
462
- if left_eye_region.size > 0 and right_eye_region.size > 0:
463
- left_eye_avg = np.mean(left_eye_region) / 255.0
464
- right_eye_avg = np.mean(right_eye_region) / 255.0
465
- # Invert so that darker regions (potentially closed eyes) have higher values
466
- left_eye_closed = 1.0 - left_eye_avg
467
- right_eye_closed = 1.0 - right_eye_avg
468
- # Combine into a simple eye closure metric (0-1 range, higher means more closed)
469
- eye_closure = (left_eye_closed + right_eye_closed) / 2.0
470
- # Convert to a rough approximation of EAR
471
- estimated_ear = max(0.15, 0.4 - (eye_closure * 0.25))
472
- ear_score = 1.0 if estimated_ear < self.EAR_THRESHOLD else 0.0
473
- # Fallback MAR: use mouth region intensity
474
- mouth_region = face_gray[int(fh*0.65):int(fh*0.90), int(fw*0.25):int(fw*0.75)]
475
- if mouth_region.size > 0:
476
- mar = np.mean(mouth_region) / 255.0
477
- yawn_detected = False # fallback下不判斷yawn,避免誤判
478
- # Get CNN model prediction
479
- processed_image = self.preprocess_image(face)
480
- if self.model is not None:
481
- model_pred = self.model.predict(processed_image, verbose=0)
482
- if len(model_pred.shape) == 2:
483
- if model_pred.shape[1] == 1:
484
- model_score = float(model_pred[0][0])
485
- else:
486
- model_score = float(model_pred[0][1])
487
- else:
488
- model_score = float(model_pred[0])
489
- # Calculate weighted ensemble score
490
- ensemble_score = (
491
- self.MODEL_WEIGHT * model_score +
492
- self.EAR_WEIGHT * ear_score +
493
- self.GAZE_WEIGHT * gaze_score +
494
- self.HEAD_POSE_WEIGHT * head_pose_score
495
- )
496
- # Add speed factor if available
497
- if self.current_speed > 0:
498
- speed_score = self.speed_detector.get_speed_change_score()
499
- ensemble_score = (1 - self.SPEED_CHANGE_WEIGHT) * ensemble_score + self.SPEED_CHANGE_WEIGHT * speed_score
500
- # Update drowsy history
501
- self.drowsy_history.append(ensemble_score)
502
- if len(self.drowsy_history) > 30: # Keep last 30 frames
503
- self.drowsy_history.pop(0)
504
- # Calculate average drowsiness over recent frames
505
- avg_drowsiness = np.mean(self.drowsy_history) if self.drowsy_history else 0
506
- # Determine final drowsiness state
507
- is_drowsy = avg_drowsiness > 0.5
508
- # Debug output
509
- print(f"[DEBUG] Model score: {model_score:.2f}, EAR: {ear_score:.2f}, MAR: {mar:.2f}, Drowsy: {is_drowsy}, Yawn: {yawn_detected}")
510
- # 強化EAR判斷:若模型分數高但EAR也高,強制標註為Alert
511
- if metrics['model_prob'] > 0.7 and metrics['ear'] > 0.25:
512
- is_drowsy = False
513
- alert_level = "Alert"
514
- color = (0, 255, 0)
515
- elif avg_drowsiness > 0.5:
516
- alert_level = "Drowsy"
517
- color = (0, 0, 255)
518
- else:
519
- alert_level = "Not Drowsy"
520
- color = (0, 255, 0)
521
- return (
522
- ensemble_score,
523
- alert_level,
524
- face_coords,
525
- ear_score,
526
- gaze_score,
527
- head_pose_score,
528
- model_score,
529
- mar,
530
- yawn_detected
531
- )
532
- except Exception as e:
533
- print(f"Error in predict: {str(e)}")
534
- return 0, "Error in prediction", None, 0, 0, 0, 0, 0, False
535
 
536
- # Create a global instance
537
  detector = DrowsinessDetector()
538
 
539
- def process_image(image):
540
- """Process image input"""
541
- if image is None:
542
- return None, "No image provided"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
  try:
544
- if image.size == 0 or image.shape[0] == 0 or image.shape[1] == 0:
545
- return None, "Invalid image dimensions"
546
- processed_image = image.copy()
547
- result = detector.predict(processed_image)
548
- if len(result) == 9:
549
- drowsy_prob, status, face_coords, ear_score, gaze_score, head_pose_score, model_score, mar, yawn_detected = result
550
- metrics = {
551
- 'model_prob': model_score,
552
- 'ear': ear_score,
553
- 'gaze': gaze_score,
554
- 'head_pose': head_pose_score,
555
- 'mar': mar,
556
- 'yawn_detected': yawn_detected
557
- }
558
- error = None
559
- elif len(result) == 4:
560
- drowsy_prob, face_coords, error, metrics = result
561
- elif len(result) == 2:
562
- return result
563
- else:
564
- return None, "Unknown error in prediction"
565
- if error:
566
- return None, error
567
- if face_coords is None:
568
- cv2.putText(processed_image, "Face detection error", (30, 30),
569
- cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 0, 255), 2)
570
- return processed_image, "Face detection error"
571
- if not (isinstance(face_coords, (tuple, list)) and len(face_coords) == 4):
572
- cv2.putText(processed_image, "Face detection error", (30, 60),
573
- cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 0, 255), 2)
574
- return processed_image, "Face detection error: invalid coordinates"
575
- x, y, w, h = face_coords
576
- # 強化EAR判斷:若模型分數高但EAR也高,強制標註為Alert
577
- is_drowsy = drowsy_prob >= 0.7
578
- if metrics['model_prob'] > 0.7 and metrics['ear'] > 0.25:
579
- is_drowsy = False
580
- alert_level = "Alert"
581
- color = (0, 255, 0)
582
- elif drowsy_prob >= 0.85:
583
- alert_level = "High Risk"
584
- color = (0, 0, 255)
585
- elif drowsy_prob >= 0.7:
586
- alert_level = "Medium Risk"
587
- color = (0, 165, 255)
588
- else:
589
- alert_level = "Alert"
590
- color = (0, 255, 0)
591
- cv2.rectangle(processed_image, (x, y), (x+w, y+h), color, 2)
592
- y_offset = 25
593
- cv2.putText(processed_image, f"{'Drowsy' if is_drowsy else 'Alert'} ({drowsy_prob:.2f})",
594
- (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
595
- cv2.putText(processed_image, alert_level, (x, y-35),
596
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
597
- cv2.putText(processed_image, f"Model: {metrics['model_prob']:.2f}", (10, processed_image.shape[0]-10-y_offset*3),
598
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
599
- cv2.putText(processed_image, f"Eye Ratio: {metrics['ear']:.2f}", (10, processed_image.shape[0]-10-y_offset*2),
600
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
601
- cv2.putText(processed_image, f"Head Pose: {metrics['head_pose']:.2f}", (10, processed_image.shape[0]-10-y_offset),
602
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
603
- if 'mar' in metrics:
604
- cv2.putText(processed_image, f"MAR: {metrics['mar']:.2f}", (10, processed_image.shape[0]-10-y_offset*4),
605
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
606
- if metrics.get('yawn_detected'):
607
- cv2.putText(processed_image, "YAWN DETECTED!", (x, y-60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
608
- if metrics['model_prob'] > 0.9 and metrics['ear'] > 0.25:
609
- cv2.putText(processed_image, "Model conflict - verify manually",
610
- (10, processed_image.shape[0]-10-y_offset*5),
611
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 165, 255), 1)
612
- return processed_image, f"Processed successfully. Drowsiness: {drowsy_prob:.2f}, Alert level: {alert_level}"
613
  except Exception as e:
614
- import traceback
615
- error_details = traceback.format_exc()
616
- print(f"Error processing image: {str(e)}\n{error_details}")
617
- return None, f"Error processing image: {str(e)}"
618
-
619
- def annotate_no_face(frame, head_moving=False):
620
- annotated = frame.copy()
621
- msg = "未偵測到臉部,請調整姿勢"
622
- color = (0, 0, 255)
623
- if head_moving:
624
- msg = "頭部晃動,請注意安全"
625
- color = (0, 165, 255)
626
- cv2.putText(annotated, msg, (30, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
627
- return annotated
628
 
629
- def process_video(video, initial_speed=60):
630
- """Process video input"""
631
- if video is None:
632
- return None, "No video provided"
633
-
634
  try:
635
- # 创建内存缓冲区而不是临时文件
636
- temp_input = None
 
 
 
637
 
638
- # Handle video input (can be file path or video data)
639
- if isinstance(video, str):
640
- print(f"Processing video from path: {video}")
641
- # 直接读取原始文件,不复制到临时目录
642
- cap = cv2.VideoCapture(video)
643
- else:
644
- print(f"Processing video from uploaded data")
645
- # 读取上传的视频数据到内存
646
- import tempfile
647
- temp_input = tempfile.NamedTemporaryFile(suffix='.avi', delete=False)
648
- temp_input_path = temp_input.name
649
- with open(temp_input_path, "wb") as f:
650
- f.write(video)
651
- cap = cv2.VideoCapture(temp_input_path)
652
 
653
- if not cap.isOpened():
654
- return None, "Error: Could not open video"
 
 
 
 
655
 
 
 
 
 
 
 
 
 
 
 
656
  # Get input video properties
 
657
  fps = cap.get(cv2.CAP_PROP_FPS)
658
- if fps <= 0:
659
- fps = 30 # Default to 30fps if invalid
660
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
661
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
662
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
663
-
664
- print(f"Video properties: {width}x{height} at {fps}fps, total frames: {total_frames}")
665
-
666
- # 使用临时文件来存储处理后的视频(处理完毕后会删除)
667
- import tempfile
668
- temp_output = tempfile.NamedTemporaryFile(suffix='.avi', delete=False)
669
- temp_output_path = temp_output.name
670
 
671
- # 使用XVID编码并输出为AVI格式
672
- fourcc = cv2.VideoWriter_fourcc(*'XVID')
673
- out = cv2.VideoWriter(temp_output_path, fourcc, fps, (width, height))
674
- if not out.isOpened():
675
- return None, "Error: Could not create output video file"
676
-
677
- # Reset speed detector at the start of each video
678
- detector.reset_speed_detector()
679
-
680
- # Initialize speed value with the provided initial speed
681
- current_speed = initial_speed
682
- detector.speed_detector.speed_estimate = initial_speed
683
-
684
- # Process each frame
685
- frame_count = 0
686
- processed_count = 0
687
- face_detected_count = 0
688
- drowsy_count = 0
689
- high_risk_count = 0
690
- ear_sum = 0
691
- model_prob_sum = 0
692
- yawn_count = 0
693
-
694
- # Calculate frames to skip for 2 FPS processing
695
- frames_to_skip = max(1, int(fps / 2))
696
- print(f"Processing at 2 FPS: skipping {frames_to_skip-1} frames between processed frames")
697
 
698
  while True:
699
  ret, frame = cap.read()
700
  if not ret:
701
- print(f"End of video or error reading frame at frame {frame_count}")
702
  break
703
-
704
- frame_count += 1
705
-
706
- # Skip frames to maintain 2 FPS processing
707
- if frame_count % frames_to_skip != 0:
708
- # 仍然要標註狀態,不能直接複製原圖
709
- # 嘗試用光流判斷頭部是否晃動
710
- head_moving = False
711
- try:
712
- # 使用SpeedDetector的optical flow估算頭部移動
713
- # 這裡只用flow magnitude判斷
714
- gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
715
- if detector.speed_detector.prev_gray is not None:
716
- flow = cv2.absdiff(gray, detector.speed_detector.prev_gray)
717
- mean_flow = np.mean(flow)
718
- head_moving = mean_flow > 8 # 閾值可調
719
- detector.speed_detector.prev_gray = gray.copy()
720
- except Exception as e:
721
- pass
722
- annotated = annotate_no_face(frame, head_moving=head_moving)
723
- out.write(annotated)
724
- continue
725
-
726
- # Detect speed from the current frame
727
- current_speed = detector.speed_detector.detect_speed_from_frame(frame)
728
-
729
- try:
730
- # Try to process the frame
731
- processed_frame, message = process_image(frame)
732
 
733
- # Add speed info to the frame
734
- if processed_frame is not None:
735
- speed_text = f"Speed: {current_speed:.1f} km/h"
736
- cv2.putText(processed_frame, speed_text, (10, processed_frame.shape[0]-45),
737
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
738
-
739
- # Add speed change score
740
- speed_change_score = detector.speed_detector.get_speed_change_score()
741
- cv2.putText(processed_frame, f"Speed Variation: {speed_change_score:.2f}",
742
- (10, processed_frame.shape[0]-70),
743
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
744
-
745
- # 確保每一幀尺寸正確
746
- if processed_frame is not None:
747
- if processed_frame.shape[1] != width or processed_frame.shape[0] != height:
748
- processed_frame = cv2.resize(processed_frame, (width, height))
749
- # 若無臉,則標註未偵測到臉或頭部晃動
750
- if "Face detection error" in message or "No face detected" in message or (isinstance(processed_frame, np.ndarray) and np.all(processed_frame == frame)):
751
- # 嘗試用光流判斷頭部是否晃動
752
- head_moving = False
753
- try:
754
- gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
755
- if detector.speed_detector.prev_gray is not None:
756
- flow = cv2.absdiff(gray, detector.speed_detector.prev_gray)
757
- mean_flow = np.mean(flow)
758
- head_moving = mean_flow > 8
759
- detector.speed_detector.prev_gray = gray.copy()
760
- except Exception as e:
761
- pass
762
- processed_frame = annotate_no_face(frame, head_moving=head_moving)
763
- out.write(processed_frame)
764
- processed_count += 1
765
- if "No face detected" not in message:
766
- face_detected_count += 1
767
- if "Drowsiness" in message:
768
- # Extract drowsiness probability
769
- try:
770
- drowsy_text = message.split("Drowsiness: ")[1].split(",")[0]
771
- drowsy_prob = float(drowsy_text)
772
-
773
- # Track drowsiness stats
774
- if drowsy_prob >= 0.7:
775
- drowsy_count += 1
776
- if drowsy_prob >= 0.85:
777
- high_risk_count += 1
778
- # Get metrics from the frame
779
- result = detector.predict(frame)
780
- if len(result) == 9:
781
- _, _, _, ear_score, _, _, model_score, _, yawn_detected = result
782
- ear_sum += ear_score
783
- model_prob_sum += model_score
784
- if yawn_detected:
785
- yawn_count += 1
786
- elif len(result) == 4:
787
- _, _, _, metrics = result
788
- if 'ear' in metrics:
789
- ear_sum += metrics['ear']
790
- if 'model_prob' in metrics:
791
- model_prob_sum += metrics['model_prob']
792
- if 'yawn_detected' in metrics and metrics['yawn_detected']:
793
- yawn_count += 1
794
- except:
795
- pass
796
- else:
797
- # Fallback: If processing fails, just用annotate_no_face標註
798
- head_moving = False
799
- try:
800
- gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
801
- if detector.speed_detector.prev_gray is not None:
802
- flow = cv2.absdiff(gray, detector.speed_detector.prev_gray)
803
- mean_flow = np.mean(flow)
804
- head_moving = mean_flow > 8
805
- detector.speed_detector.prev_gray = gray.copy()
806
- except Exception as e:
807
- pass
808
- processed_frame = annotate_no_face(frame, head_moving=head_moving)
809
- out.write(processed_frame)
810
- processed_count += 1
811
- print(f"Frame {frame_count}: Processing failed - {message}")
812
- except Exception as e:
813
- # If any error occurs during processing, use original frame
814
- cv2.putText(frame, f"Error: {str(e)[:30]}", (30, 30),
815
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
816
- out.write(frame)
817
- processed_count += 1
818
- print(f"Frame {frame_count}: Exception - {str(e)}")
819
-
820
- # Print progress for every 10th frame
821
- if frame_count % 10 == 0:
822
- print(f"Processed {frame_count}/{total_frames} frames")
823
 
824
  # Release resources
825
  cap.release()
826
  out.release()
827
 
828
- # Calculate statistics
829
- drowsy_percentage = (drowsy_count / face_detected_count * 100) if face_detected_count > 0 else 0
830
- high_risk_percentage = (high_risk_count / face_detected_count * 100) if face_detected_count > 0 else 0
831
- avg_ear = ear_sum / face_detected_count if face_detected_count > 0 else 0
832
- avg_model_prob = model_prob_sum / face_detected_count if face_detected_count > 0 else 0
833
- speed_score = detector.speed_detector.get_speed_change_score()
834
- yawn_percentage = (yawn_count / face_detected_count * 100) if face_detected_count > 0 else 0
835
-
836
- # Check if video was created successfully and return it directly
837
- if os.path.exists(temp_output_path) and os.path.getsize(temp_output_path) > 0:
838
- print(f"Video processed successfully with {processed_count} frames")
839
- print(f"Drowsy frames: {drowsy_count} ({drowsy_percentage:.1f}%), High risk frames: {high_risk_count} ({high_risk_percentage:.1f}%)")
840
- print(f"Average eye ratio: {avg_ear:.2f}, Average model probability: {avg_model_prob:.2f}")
841
- print(f"Speed change score: {speed_score:.2f}")
842
- print(f"Yawn frames: {yawn_count} ({yawn_percentage:.1f}%)")
843
-
844
- false_positive_warning = ""
845
- if avg_model_prob > 0.8 and avg_ear > 0.25:
846
- false_positive_warning = " ⚠️ Possible false positive (eyes open but model detects drowsiness)"
847
-
848
- result_message = (f"Video processed successfully. Frames: {frame_count}, faces detected: {face_detected_count}, "
849
- f"drowsy: {drowsy_count} ({drowsy_percentage:.1f}%), high risk: {high_risk_count} ({high_risk_percentage:.1f}%), "
850
- f"yawn: {yawn_count} ({yawn_percentage:.1f}%). "
851
- f"Avg eye ratio: {avg_ear:.2f}, Speed score: {speed_score:.2f}{false_positive_warning}")
852
-
853
- video_result = temp_output_path
854
-
855
- return video_result, result_message
856
  else:
857
- print(f"Failed to create output video. Frames read: {frame_count}, processed: {processed_count}")
858
- return None, f"Error: Failed to create output video. Frames read: {frame_count}, processed: {processed_count}"
859
 
860
  except Exception as e:
861
- import traceback
862
- error_details = traceback.format_exc()
863
- print(f"Error processing video: {str(e)}\n{error_details}")
864
- return None, f"Error processing video: {str(e)}"
865
  finally:
866
- if 'out' in locals() and out is not None:
 
867
  out.release()
868
- if 'cap' in locals() and cap is not None:
869
  cap.release()
870
- if temp_input is not None:
871
- try:
872
- os.unlink(temp_input.name)
873
- except:
874
- pass
875
 
876
- def process_webcam(image):
877
- """Process webcam input"""
878
  try:
879
- # Convert image to numpy array if it's not already
880
- if isinstance(image, Image.Image):
881
- image = np.array(image)
882
-
883
- # Convert to RGB if image is in BGR format
884
- if len(image.shape) == 3 and image.shape[2] == 3:
885
- if image.dtype == np.uint8:
886
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
887
-
888
- # Get prediction
889
- drowsy_prob, status, face_coords, ear_score, gaze_score, head_pose_score, model_score, mar, yawn_detected = detector.predict(image)
890
-
891
- # Draw results on image
892
- if face_coords is not None:
893
- x, y, w, h = face_coords
894
- # Draw face rectangle
895
- cv2.rectangle(image, (x, y), (x + w, y + h), (0, 255, 0), 2)
896
-
897
- # Add status text
898
- status_color = (0, 0, 255) if status == "Drowsy" else (0, 255, 0)
899
- cv2.putText(image, f"Status: {status}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, status_color, 2)
900
-
901
- # Add yawn detection text if yawn is detected
902
- if yawn_detected:
903
- cv2.putText(image, "YAWN DETECTED!", (x, y - 40), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
904
- # Play alert sound
905
- try:
906
- import winsound
907
- winsound.Beep(1000, 500) # Frequency: 1000Hz, Duration: 500ms
908
- except:
909
- print("Beep!")
910
-
911
- # Add metrics
912
- cv2.putText(image, f"EAR: {ear_score:.2f}", (x, y + h + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
913
- cv2.putText(image, f"Gaze: {gaze_score:.2f}", (x, y + h + 40), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
914
- cv2.putText(image, f"Head: {head_pose_score:.2f}", (x, y + h + 60), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
915
- cv2.putText(image, f"MAR: {mar:.2f}", (x, y + h + 80), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
916
-
917
- return image
918
-
919
  except Exception as e:
920
- print(f"Error in process_webcam: {str(e)}")
921
- return image
 
 
922
 
923
- # Launch the app
924
- if __name__ == "__main__":
925
- # Parse command line arguments
926
- parser = argparse.ArgumentParser(description="Driver Drowsiness Detection App")
927
- parser.add_argument("--share", action="store_true", help="Create a public link (may trigger security warnings)")
928
- parser.add_argument("--port", type=int, default=7860, help="Port to run the app on")
929
- args = parser.parse_args()
930
-
931
- # Print warning if share is enabled
932
- if args.share:
933
- print("WARNING: Running with --share may trigger security warnings on some systems.")
934
- print("The app will be accessible from the internet through a temporary URL.")
935
-
936
- # 注册退出时的清理函数
937
- import atexit
938
- import glob
939
- import shutil
940
 
941
- def cleanup_temp_files():
942
- """Clean up all temporary files"""
943
- try:
944
- # 删除所有可能留下的临时文件
945
- import tempfile
946
- temp_dir = tempfile.gettempdir()
947
- pattern = os.path.join(temp_dir, "tmp*")
948
- for file in glob.glob(pattern):
949
- try:
950
- if os.path.isfile(file):
951
- os.remove(file)
952
- except Exception as e:
953
- print(f"Failed to delete {file}: {e}")
954
-
955
- # 确保没有留下.mp4或.avi文件
956
- for ext in [".mp4", ".avi"]:
957
- pattern = os.path.join(temp_dir, f"*{ext}")
958
- for file in glob.glob(pattern):
959
- try:
960
- os.remove(file)
961
- except Exception as e:
962
- print(f"Failed to delete {file}: {e}")
963
-
964
- print("Cleaned up temporary files")
965
- except Exception as e:
966
- print(f"Error during cleanup: {e}")
967
 
968
- # 注册清理函数
969
- atexit.register(cleanup_temp_files)
 
 
 
 
970
 
971
- # Load the model at startup
972
- detector.load_model()
973
-
974
- # Create interface
975
- with gr.Blocks(title="Driver Drowsiness Detection") as demo:
976
- gr.Markdown("""
977
- # 🚗 Driver Drowsiness Detection System
978
-
979
- This system detects driver drowsiness using computer vision and deep learning.
980
-
981
- ## Features:
982
- - Image analysis
983
- - Video processing with speed monitoring
984
- - Webcam detection (PC and mobile)
985
- - Multi-factor drowsiness prediction (face, eyes, head pose, speed changes)
986
- """)
987
-
988
- with gr.Tabs():
989
- with gr.Tab("Image"):
990
- gr.Markdown("Upload an image for drowsiness detection")
991
- with gr.Row():
992
- image_input = gr.Image(label="Input Image", type="numpy")
993
- image_output = gr.Image(label="Processed Image")
994
- with gr.Row():
995
- status_output = gr.Textbox(label="Status")
996
- image_input.change(
997
- fn=process_image,
998
- inputs=[image_input],
999
- outputs=[image_output, status_output]
1000
- )
1001
-
1002
- with gr.Tab("Video"):
1003
- gr.Markdown("""
1004
- ### ### Upload driving videos for sleepy detection
1005
-
1006
- The system will automatically detect the following content from the video:
1007
- - Driver's facial expressions and eye status
1008
- - Vehicle speed changes (by optical flow analysis in video)
1009
- - When the vehicle speed changes more than ±5 km/h, it will be considered abnormal driving behavior
1010
-
1011
- ** Note: ** The processed videos will not be saved to the local folder.
1012
- Please use the download button in the upper right corner of the interface to save the results.
1013
- """)
1014
- with gr.Row():
1015
- video_input = gr.Video(label="Enter video")
1016
- video_output = gr.Video(label="Processed video (Click on the upper right corner to download)")
1017
- with gr.Row():
1018
- initial_speed = gr.Slider(minimum=10, maximum=120, value=60, label="Initial speed estimate (km/h)",
1019
- info="As initial estimate only, The system will automatically detect the actual speed changes from the video")
1020
- with gr.Row():
1021
- video_status = gr.Textbox(label="Processing status")
1022
- with gr.Row():
1023
- process_btn = gr.Button("Processing videos")
1024
- clear_btn = gr.Button("Clear")
1025
-
1026
- process_btn.click(
1027
- fn=process_video,
1028
- inputs=[video_input, initial_speed],
1029
- outputs=[video_output, video_status]
1030
- )
1031
-
1032
- clear_btn.click(
1033
- fn=lambda: (None, "Cleared results"),
1034
- inputs=[],
1035
- outputs=[video_output, video_status]
1036
- )
1037
-
1038
- with gr.Tab("Webcam"):
1039
- gr.Markdown("Use your webcam or mobile camera for real-time drowsiness detection")
1040
- with gr.Row():
1041
- webcam_input = gr.Image(label="Camera Feed", type="numpy", streaming=True)
1042
- webcam_output = gr.Image(label="Processed Feed")
1043
- with gr.Row():
1044
- speed_input = gr.Slider(minimum=0, maximum=150, value=60, label="Current Speed (km/h)")
1045
- update_speed_btn = gr.Button("Update Speed")
1046
- with gr.Row():
1047
- webcam_status = gr.Textbox(label="Status")
1048
-
1049
- def process_webcam_with_speed(image, speed):
1050
- detector.update_speed(speed)
1051
- return process_webcam(image)
1052
-
1053
- update_speed_btn.click(
1054
- fn=lambda speed: f"Speed updated to {speed} km/h",
1055
- inputs=[speed_input],
1056
- outputs=[webcam_status]
1057
- )
1058
-
1059
- webcam_input.change(
1060
- fn=process_webcam_with_speed,
1061
- inputs=[webcam_input, speed_input],
1062
- outputs=[webcam_output, webcam_status]
1063
- )
1064
-
1065
- gr.Markdown("""
1066
- ## How It Works
1067
- This system detects drowsiness using multiple factors:
1068
- 1. **Facial features** - Using a trained CNN model
1069
- 2. **Eye openness** - Measuring eye aspect ratio (EAR)
1070
- 3. **Head position** - Detecting head drooping
1071
- 4. **Automatic speed detection** - Using optical flow analysis to track vehicle movement and detect irregular speed changes
1072
-
1073
- The system automatically detects speed changes from the video frames using computer vision techniques:
1074
- - **Optical flow** is used to track movement between frames
1075
- - **Irregular speed changes** (±5 km/h) are detected as potential signs of drowsy driving
1076
- - **No external speed data required** - everything is analyzed directly from the video content
1077
-
1078
- Combining these factors provides more reliable drowsiness detection than using facial features alone.
1079
- """)
1080
 
1081
- # Launch the app
1082
- demo.launch(share=args.share, server_port=args.port)
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import ViTForImageClassification, ViTImageProcessor
4
  import numpy as np
5
  import cv2
6
  from PIL import Image
7
  import io
8
  import os
9
+ import sys
10
  import time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  class DrowsinessDetector:
13
  def __init__(self):
14
  self.model = None
15
+ self.processor = None
16
+ self.input_shape = (224, 224, 3)
17
  self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
18
  self.id2label = {0: "notdrowsy", 1: "drowsy"}
19
  self.label2id = {"notdrowsy": 0, "drowsy": 1}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ def load_model(self, model_path):
22
+ """Load the ViT model and processor from the specified path or directory"""
 
 
 
 
 
 
 
 
 
23
  try:
24
+ self.model = ViTForImageClassification.from_pretrained(
25
+ model_path, # 直接給資料夾路徑
26
+ num_labels=2,
27
+ id2label=self.id2label,
28
+ label2id=self.label2id,
29
+ ignore_mismatched_sizes=True
30
+ )
31
+ self.model.eval()
32
+ self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
33
+ print(f"ViT model loaded successfully from {model_path}")
 
 
 
 
 
 
34
  except Exception as e:
35
+ print(f"Error loading ViT model: {str(e)}")
36
  raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def detect_face(self, frame):
39
+ """Detect face in the frame"""
40
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  faces = self.face_cascade.detectMultiScale(gray, 1.1, 4)
 
42
  if len(faces) > 0:
43
+ (x, y, w, h) = faces[0] # Get the first face
44
+ face = frame[y:y+h, x:x+w]
45
+ return face, (x, y, w, h)
 
 
 
46
  return None, None
47
 
48
  def preprocess_image(self, image):
49
+ """Preprocess the input image for ViT"""
50
  if image is None:
51
  return None
52
+ pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
53
+ inputs = self.processor(images=pil_img, return_tensors="pt")
54
+ return inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  def predict(self, image):
57
+ """Make prediction on the input image using ViT"""
58
+ if self.model is None or self.processor is None:
59
+ raise ValueError("Model not loaded. Call load_model() first.")
60
+ # Detect face
61
+ face, face_coords = self.detect_face(image)
62
+ if face is None:
63
+ return None, None, "No face detected"
64
+ # Preprocess the face image
65
+ inputs = self.preprocess_image(face)
66
+ if inputs is None:
67
+ return None, None, "Error processing image"
68
+ # Make prediction
69
+ with torch.no_grad():
70
+ outputs = self.model(**inputs)
71
+ logits = outputs.logits
72
+ probs = torch.softmax(logits, dim=1)
73
+ pred_class = torch.argmax(probs, dim=1).item()
74
+ pred_label = self.id2label[pred_class]
75
+ pred_prob = probs[0, pred_class].item()
76
+ # Return drowsy probability (class 1)
77
+ drowsy_prob = probs[0, 1].item()
78
+ return drowsy_prob, face_coords, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ # Initialize detector
81
  detector = DrowsinessDetector()
82
 
83
+ def find_model_file():
84
+ """Find the model directory or file in common locations"""
85
+ possible_paths = [
86
+ "huggingface_model", # 優先資料夾
87
+ "pytorch_model.bin",
88
+ "model_weights.h5",
89
+ "drowsiness_model.h5",
90
+ "model/drowsiness_model.h5",
91
+ "models/drowsiness_model.h5",
92
+ "huggingface_model/model_weights.h5",
93
+ "huggingface_model/drowsiness_model.h5",
94
+ "../model_weights.h5",
95
+ "../drowsiness_model.h5"
96
+ ]
97
+ for path in possible_paths:
98
+ if os.path.exists(path):
99
+ return path
100
+ return None
101
+
102
+ def load_model():
103
+ """Load the model"""
104
+ model_path = find_model_file()
105
+
106
+ if model_path is None:
107
+ print("\nError: Model file not found!")
108
+ print("\nPlease ensure one of the following files exists:")
109
+ print("1. model_weights.h5")
110
+ print("2. drowsiness_model.h5")
111
+ print("3. model/drowsiness_model.h5")
112
+ print("4. models/drowsiness_model.h5")
113
+ print("\nYou can download the model from Hugging Face Hub or train it using train_model.py")
114
+ sys.exit(1)
115
+
116
  try:
117
+ detector.load_model(model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  except Exception as e:
119
+ print(f"\nError loading model: {str(e)}")
120
+ sys.exit(1)
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
+ def process_frame(frame):
123
+ """Process a single frame"""
124
+ if frame is None:
125
+ return None
126
+
127
  try:
128
+ # Convert frame to RGB if needed
129
+ if len(frame.shape) == 2:
130
+ frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
131
+ elif frame.shape[2] == 4:
132
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
133
 
134
+ # Make prediction
135
+ drowsy_prob, face_coords, error = detector.predict(frame)
136
+
137
+ if error:
138
+ return frame
139
+
140
+ if face_coords is not None:
141
+ x, y, w, h = face_coords
142
+ # Draw rectangle around face
143
+ color = (0, 0, 255) if drowsy_prob > 0.7 else (0, 255, 0)
144
+ cv2.rectangle(frame, (x, y), (x+w, y+h), color, 2)
 
 
 
145
 
146
+ # Add text
147
+ status = "DROWSY" if drowsy_prob > 0.7 else "ALERT"
148
+ cv2.putText(frame, f"{status} ({drowsy_prob:.2%})",
149
+ (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
150
+
151
+ return frame
152
 
153
+ except Exception as e:
154
+ print(f"Error processing frame: {str(e)}")
155
+ return frame
156
+
157
+ def process_video(video_input):
158
+ """Process video input"""
159
+ if video_input is None:
160
+ return None
161
+
162
+ try:
163
  # Get input video properties
164
+ cap = cv2.VideoCapture(video_input)
165
  fps = cap.get(cv2.CAP_PROP_FPS)
 
 
166
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
167
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
 
 
 
 
 
 
 
168
 
169
+ # Create temporary output video file
170
+ temp_output = "temp_output.mp4"
171
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
172
+ out = cv2.VideoWriter(temp_output, fourcc, fps, (width, height))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  while True:
175
  ret, frame = cap.read()
176
  if not ret:
 
177
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
+ processed_frame = process_frame(frame)
180
+ if processed_frame is not None:
181
+ out.write(processed_frame)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  # Release resources
184
  cap.release()
185
  out.release()
186
 
187
+ # Check if video was created
188
+ if os.path.exists(temp_output) and os.path.getsize(temp_output) > 0:
189
+ return temp_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  else:
191
+ print("Error: Failed to create output video")
192
+ return None
193
 
194
  except Exception as e:
195
+ print(f"Error processing video: {str(e)}")
196
+ return None
 
 
197
  finally:
198
+ # Clean up temporary file
199
+ if 'out' in locals():
200
  out.release()
201
+ if 'cap' in locals():
202
  cap.release()
 
 
 
 
 
203
 
204
+ def webcam_feed():
205
+ """Process webcam feed"""
206
  try:
207
+ cap = cv2.VideoCapture(0)
208
+ while True:
209
+ ret, frame = cap.read()
210
+ if not ret:
211
+ break
212
+
213
+ processed_frame = process_frame(frame)
214
+ if processed_frame is not None:
215
+ yield processed_frame
216
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  except Exception as e:
218
+ print(f"Error processing webcam feed: {str(e)}")
219
+ yield None
220
+ finally:
221
+ cap.release()
222
 
223
+ # Load the model at startup
224
+ load_model()
225
+
226
+ # Create interface
227
+ with gr.Blocks(title="Driver Drowsiness Detection") as demo:
228
+ gr.Markdown("""
229
+ # 🚗 Driver Drowsiness Detection System
 
 
 
 
 
 
 
 
 
 
230
 
231
+ This system detects driver drowsiness using computer vision and deep learning.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
+ ## Features:
234
+ - Real-time webcam monitoring
235
+ - Video file processing
236
+ - Single image analysis
237
+ - Face detection and drowsiness prediction
238
+ """)
239
 
240
+ with gr.Tabs():
241
+ with gr.Tab("Webcam"):
242
+ gr.Markdown("Real-time drowsiness detection using your webcam")
243
+ webcam_output = gr.Image(label="Live Detection")
244
+ webcam_button = gr.Button("Start Webcam")
245
+ webcam_button.click(fn=webcam_feed, inputs=None, outputs=webcam_output)
246
+
247
+ with gr.Tab("Video"):
248
+ gr.Markdown("Upload a video file for drowsiness detection")
249
+ with gr.Row():
250
+ video_input = gr.Video(label="Input Video")
251
+ video_output = gr.Video(label="Detection Result")
252
+ video_button = gr.Button("Process Video")
253
+ video_button.click(fn=process_video, inputs=video_input, outputs=video_output)
254
+
255
+ with gr.Tab("Image"):
256
+ gr.Markdown("Upload an image for drowsiness detection")
257
+ with gr.Row():
258
+ image_input = gr.Image(type="numpy", label="Input Image")
259
+ image_output = gr.Image(label="Detection Result")
260
+ image_button = gr.Button("Process Image")
261
+ image_button.click(fn=process_frame, inputs=image_input, outputs=image_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
+ if __name__ == "__main__":
264
+ demo.launch()