ckcl commited on
Commit
56df4e3
·
verified ·
1 Parent(s): 7e775cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -263
app.py CHANGED
@@ -1,264 +1,212 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import ViTForImageClassification, ViTImageProcessor
4
- import numpy as np
5
- import cv2
6
- from PIL import Image
7
- import io
8
- import os
9
- import sys
10
- import time
11
-
12
- class DrowsinessDetector:
13
- def __init__(self):
14
- self.model = None
15
- self.processor = None
16
- self.input_shape = (224, 224, 3)
17
- self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
18
- self.id2label = {0: "notdrowsy", 1: "drowsy"}
19
- self.label2id = {"notdrowsy": 0, "drowsy": 1}
20
-
21
- def load_model(self, model_path):
22
- """Load the ViT model and processor from the specified path or directory"""
23
- try:
24
- self.model = ViTForImageClassification.from_pretrained(
25
- model_path, # 直接給資料夾路徑
26
- num_labels=2,
27
- id2label=self.id2label,
28
- label2id=self.label2id,
29
- ignore_mismatched_sizes=True
30
- )
31
- self.model.eval()
32
- self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
33
- print(f"ViT model loaded successfully from {model_path}")
34
- except Exception as e:
35
- print(f"Error loading ViT model: {str(e)}")
36
- raise
37
-
38
- def detect_face(self, frame):
39
- """Detect face in the frame"""
40
- gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
41
- faces = self.face_cascade.detectMultiScale(gray, 1.1, 4)
42
- if len(faces) > 0:
43
- (x, y, w, h) = faces[0] # Get the first face
44
- face = frame[y:y+h, x:x+w]
45
- return face, (x, y, w, h)
46
- return None, None
47
-
48
- def preprocess_image(self, image):
49
- """Preprocess the input image for ViT"""
50
- if image is None:
51
- return None
52
- pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
53
- inputs = self.processor(images=pil_img, return_tensors="pt")
54
- return inputs
55
-
56
- def predict(self, image):
57
- """Make prediction on the input image using ViT"""
58
- if self.model is None or self.processor is None:
59
- raise ValueError("Model not loaded. Call load_model() first.")
60
- # Detect face
61
- face, face_coords = self.detect_face(image)
62
- if face is None:
63
- return None, None, "No face detected"
64
- # Preprocess the face image
65
- inputs = self.preprocess_image(face)
66
- if inputs is None:
67
- return None, None, "Error processing image"
68
- # Make prediction
69
- with torch.no_grad():
70
- outputs = self.model(**inputs)
71
- logits = outputs.logits
72
- probs = torch.softmax(logits, dim=1)
73
- pred_class = torch.argmax(probs, dim=1).item()
74
- pred_label = self.id2label[pred_class]
75
- pred_prob = probs[0, pred_class].item()
76
- # Return drowsy probability (class 1)
77
- drowsy_prob = probs[0, 1].item()
78
- return drowsy_prob, face_coords, None
79
-
80
- # Initialize detector
81
- detector = DrowsinessDetector()
82
-
83
- def find_model_file():
84
- """Find the model directory or file in common locations"""
85
- possible_paths = [
86
- "huggingface_model", # 優先資料夾
87
- "pytorch_model.bin",
88
- "model_weights.h5",
89
- "drowsiness_model.h5",
90
- "model/drowsiness_model.h5",
91
- "models/drowsiness_model.h5",
92
- "huggingface_model/model_weights.h5",
93
- "huggingface_model/drowsiness_model.h5",
94
- "../model_weights.h5",
95
- "../drowsiness_model.h5"
96
- ]
97
- for path in possible_paths:
98
- if os.path.exists(path):
99
- return path
100
- return None
101
-
102
- def load_model():
103
- """Load the model"""
104
- model_path = find_model_file()
105
-
106
- if model_path is None:
107
- print("\nError: Model file not found!")
108
- print("\nPlease ensure one of the following files exists:")
109
- print("1. model_weights.h5")
110
- print("2. drowsiness_model.h5")
111
- print("3. model/drowsiness_model.h5")
112
- print("4. models/drowsiness_model.h5")
113
- print("\nYou can download the model from Hugging Face Hub or train it using train_model.py")
114
- sys.exit(1)
115
-
116
- try:
117
- detector.load_model(model_path)
118
- except Exception as e:
119
- print(f"\nError loading model: {str(e)}")
120
- sys.exit(1)
121
-
122
- def process_frame(frame):
123
- """Process a single frame"""
124
- if frame is None:
125
- return None
126
-
127
- try:
128
- # Convert frame to RGB if needed
129
- if len(frame.shape) == 2:
130
- frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
131
- elif frame.shape[2] == 4:
132
- frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
133
-
134
- # Make prediction
135
- drowsy_prob, face_coords, error = detector.predict(frame)
136
-
137
- if error:
138
- return frame
139
-
140
- if face_coords is not None:
141
- x, y, w, h = face_coords
142
- # Draw rectangle around face
143
- color = (0, 0, 255) if drowsy_prob > 0.7 else (0, 255, 0)
144
- cv2.rectangle(frame, (x, y), (x+w, y+h), color, 2)
145
-
146
- # Add text
147
- status = "DROWSY" if drowsy_prob > 0.7 else "ALERT"
148
- cv2.putText(frame, f"{status} ({drowsy_prob:.2%})",
149
- (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
150
-
151
- return frame
152
-
153
- except Exception as e:
154
- print(f"Error processing frame: {str(e)}")
155
- return frame
156
-
157
- def process_video(video_input):
158
- """Process video input"""
159
- if video_input is None:
160
- return None
161
-
162
- try:
163
- # Get input video properties
164
- cap = cv2.VideoCapture(video_input)
165
- fps = cap.get(cv2.CAP_PROP_FPS)
166
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
167
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
168
-
169
- # Create temporary output video file
170
- temp_output = "temp_output.mp4"
171
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
172
- out = cv2.VideoWriter(temp_output, fourcc, fps, (width, height))
173
-
174
- while True:
175
- ret, frame = cap.read()
176
- if not ret:
177
- break
178
-
179
- processed_frame = process_frame(frame)
180
- if processed_frame is not None:
181
- out.write(processed_frame)
182
-
183
- # Release resources
184
- cap.release()
185
- out.release()
186
-
187
- # Check if video was created
188
- if os.path.exists(temp_output) and os.path.getsize(temp_output) > 0:
189
- return temp_output
190
- else:
191
- print("Error: Failed to create output video")
192
- return None
193
-
194
- except Exception as e:
195
- print(f"Error processing video: {str(e)}")
196
- return None
197
- finally:
198
- # Clean up temporary file
199
- if 'out' in locals():
200
- out.release()
201
- if 'cap' in locals():
202
- cap.release()
203
-
204
- def webcam_feed():
205
- """Process webcam feed"""
206
- try:
207
- cap = cv2.VideoCapture(0)
208
- while True:
209
- ret, frame = cap.read()
210
- if not ret:
211
- break
212
-
213
- processed_frame = process_frame(frame)
214
- if processed_frame is not None:
215
- yield processed_frame
216
-
217
- except Exception as e:
218
- print(f"Error processing webcam feed: {str(e)}")
219
- yield None
220
- finally:
221
- cap.release()
222
-
223
- # Load the model at startup
224
- load_model()
225
-
226
- # Create interface
227
- with gr.Blocks(title="Driver Drowsiness Detection") as demo:
228
- gr.Markdown("""
229
- # 🚗 Driver Drowsiness Detection System
230
-
231
- This system detects driver drowsiness using computer vision and deep learning.
232
-
233
- ## Features:
234
- - Real-time webcam monitoring
235
- - Video file processing
236
- - Single image analysis
237
- - Face detection and drowsiness prediction
238
- """)
239
-
240
- with gr.Tabs():
241
- with gr.Tab("Webcam"):
242
- gr.Markdown("Real-time drowsiness detection using your webcam")
243
- webcam_output = gr.Image(label="Live Detection")
244
- webcam_button = gr.Button("Start Webcam")
245
- webcam_button.click(fn=webcam_feed, inputs=None, outputs=webcam_output)
246
-
247
- with gr.Tab("Video"):
248
- gr.Markdown("Upload a video file for drowsiness detection")
249
- with gr.Row():
250
- video_input = gr.Video(label="Input Video")
251
- video_output = gr.Video(label="Detection Result")
252
- video_button = gr.Button("Process Video")
253
- video_button.click(fn=process_video, inputs=video_input, outputs=video_output)
254
-
255
- with gr.Tab("Image"):
256
- gr.Markdown("Upload an image for drowsiness detection")
257
- with gr.Row():
258
- image_input = gr.Image(type="numpy", label="Input Image")
259
- image_output = gr.Image(label="Detection Result")
260
- image_button = gr.Button("Process Image")
261
- image_button.click(fn=process_frame, inputs=image_input, outputs=image_output)
262
-
263
- if __name__ == "__main__":
264
  demo.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import ViTForImageClassification, ViTImageProcessor
4
+ import numpy as np
5
+ import cv2
6
+ from PIL import Image
7
+ import io
8
+ import os
9
+
10
+ class DrowsinessDetector:
11
+ def __init__(self):
12
+ self.model = None
13
+ self.processor = None
14
+ self.input_shape = (224, 224, 3)
15
+ self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
16
+ self.id2label = {0: "notdrowsy", 1: "drowsy"}
17
+ self.label2id = {"notdrowsy": 0, "drowsy": 1}
18
+
19
+ def load_model(self):
20
+ """Load the ViT model and processor from Hugging Face Hub"""
21
+ try:
22
+ model_id = "ckcl/driver-drowsiness-detector" # 使用你的模型ID
23
+ self.model = ViTForImageClassification.from_pretrained(
24
+ model_id,
25
+ num_labels=2,
26
+ id2label=self.id2label,
27
+ label2id=self.label2id,
28
+ ignore_mismatched_sizes=True
29
+ )
30
+ self.model.eval()
31
+ self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
32
+ print(f"ViT model loaded successfully from {model_id}")
33
+ except Exception as e:
34
+ print(f"Error loading ViT model: {str(e)}")
35
+ raise
36
+
37
+ def detect_face(self, frame):
38
+ """Detect face in the frame"""
39
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
40
+ faces = self.face_cascade.detectMultiScale(gray, 1.1, 4)
41
+ if len(faces) > 0:
42
+ (x, y, w, h) = faces[0] # Get the first face
43
+ face = frame[y:y+h, x:x+w]
44
+ return face, (x, y, w, h)
45
+ return None, None
46
+
47
+ def preprocess_image(self, image):
48
+ """Preprocess the input image for ViT"""
49
+ if image is None:
50
+ return None
51
+ pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
52
+ inputs = self.processor(images=pil_img, return_tensors="pt")
53
+ return inputs
54
+
55
+ def predict(self, image):
56
+ """Make prediction on the input image using ViT"""
57
+ if self.model is None or self.processor is None:
58
+ raise ValueError("Model not loaded. Call load_model() first.")
59
+ # Detect face
60
+ face, face_coords = self.detect_face(image)
61
+ if face is None:
62
+ return None, None, "No face detected"
63
+ # Preprocess the face image
64
+ inputs = self.preprocess_image(face)
65
+ if inputs is None:
66
+ return None, None, "Error processing image"
67
+ # Make prediction
68
+ with torch.no_grad():
69
+ outputs = self.model(**inputs)
70
+ logits = outputs.logits
71
+ probs = torch.softmax(logits, dim=1)
72
+ pred_class = torch.argmax(probs, dim=1).item()
73
+ pred_label = self.id2label[pred_class]
74
+ pred_prob = probs[0, pred_class].item()
75
+ # Return drowsy probability (class 1)
76
+ drowsy_prob = probs[0, 1].item()
77
+ return drowsy_prob, face_coords, None
78
+
79
+ # Initialize detector
80
+ detector = DrowsinessDetector()
81
+
82
+ def process_image(image):
83
+ """Process a single image"""
84
+ if image is None:
85
+ return None, "No image provided"
86
+
87
+ try:
88
+ # Convert image to numpy array if it's a PIL Image
89
+ if isinstance(image, Image.Image):
90
+ image = np.array(image)
91
+
92
+ # Convert frame to RGB if needed
93
+ if len(image.shape) == 2:
94
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
95
+ elif image.shape[2] == 4:
96
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
97
+
98
+ # Make prediction
99
+ drowsy_prob, face_coords, error = detector.predict(image)
100
+
101
+ if error:
102
+ return image, error
103
+
104
+ if face_coords is not None:
105
+ x, y, w, h = face_coords
106
+ # Draw rectangle around face
107
+ color = (0, 0, 255) if drowsy_prob > 0.7 else (0, 255, 0)
108
+ cv2.rectangle(image, (x, y), (x+w, y+h), color, 2)
109
+
110
+ # Add text
111
+ status = "DROWSY" if drowsy_prob > 0.7 else "ALERT"
112
+ cv2.putText(image, f"{status} ({drowsy_prob:.2%})",
113
+ (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
114
+
115
+ return image, f"Status: {status} (Confidence: {drowsy_prob:.2%})"
116
+ else:
117
+ return image, "No face detected"
118
+
119
+ except Exception as e:
120
+ return image, f"Error processing image: {str(e)}"
121
+
122
+ def process_video(video):
123
+ """Process video input"""
124
+ if video is None:
125
+ return None, "No video provided"
126
+
127
+ try:
128
+ # Get input video properties
129
+ cap = cv2.VideoCapture(video)
130
+ fps = cap.get(cv2.CAP_PROP_FPS)
131
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
132
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
133
+
134
+ # Create temporary output video file
135
+ temp_output = "temp_output.mp4"
136
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
137
+ out = cv2.VideoWriter(temp_output, fourcc, fps, (width, height))
138
+
139
+ while True:
140
+ ret, frame = cap.read()
141
+ if not ret:
142
+ break
143
+
144
+ processed_frame = process_image(frame)[0]
145
+ if processed_frame is not None:
146
+ out.write(processed_frame)
147
+
148
+ # Release resources
149
+ cap.release()
150
+ out.release()
151
+
152
+ # Check if video was created
153
+ if os.path.exists(temp_output) and os.path.getsize(temp_output) > 0:
154
+ return temp_output, "Video processed successfully"
155
+ else:
156
+ return None, "Error: Failed to create output video"
157
+
158
+ except Exception as e:
159
+ return None, f"Error processing video: {str(e)}"
160
+ finally:
161
+ # Clean up temporary file
162
+ if 'out' in locals():
163
+ out.release()
164
+ if 'cap' in locals():
165
+ cap.release()
166
+
167
+ # Load the model at startup
168
+ detector.load_model()
169
+
170
+ # Create interface
171
+ with gr.Blocks(title="Driver Drowsiness Detection") as demo:
172
+ gr.Markdown("""
173
+ # 🚗 Driver Drowsiness Detection System
174
+
175
+ This system detects driver drowsiness using computer vision and deep learning.
176
+
177
+ ## Features:
178
+ - Image analysis
179
+ - Video processing
180
+ - Face detection and drowsiness prediction
181
+ """)
182
+
183
+ with gr.Tabs():
184
+ with gr.Tab("Image"):
185
+ gr.Markdown("Upload an image for drowsiness detection")
186
+ with gr.Row():
187
+ image_input = gr.Image(label="Input Image", type="numpy")
188
+ image_output = gr.Image(label="Processed Image")
189
+ with gr.Row():
190
+ status_output = gr.Textbox(label="Status")
191
+ image_input.change(
192
+ fn=process_image,
193
+ inputs=[image_input],
194
+ outputs=[image_output, status_output]
195
+ )
196
+
197
+ with gr.Tab("Video"):
198
+ gr.Markdown("Upload a video file for drowsiness detection")
199
+ with gr.Row():
200
+ video_input = gr.Video(label="Input Video")
201
+ video_output = gr.Video(label="Processed Video")
202
+ with gr.Row():
203
+ video_status = gr.Textbox(label="Status")
204
+ video_input.change(
205
+ fn=process_video,
206
+ inputs=[video_input],
207
+ outputs=[video_output, video_status]
208
+ )
209
+
210
+ # Launch the app
211
+ if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  demo.launch()