ckcl commited on
Commit
0ce4306
·
verified ·
1 Parent(s): dd8e6c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -77
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import ViTForImageClassification, ViTImageProcessor
4
  import numpy as np
5
  import cv2
6
  from PIL import Image
@@ -10,28 +10,49 @@ import os
10
  class DrowsinessDetector:
11
  def __init__(self):
12
  self.model = None
13
- self.processor = None
14
- self.input_shape = (224, 224, 3)
15
  self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
16
- self.id2label = {0: "Notdrowsy", 1: "drowsy"}
17
- self.label2id = {"Notdrowsy": 0, "drowsy": 1}
18
 
19
  def load_model(self):
20
- """Load the ViT model and processor from Hugging Face Hub"""
21
  try:
22
- model_id = "ckcl/driver-drowsiness-detector" # 使用你的模型ID
23
- self.model = ViTForImageClassification.from_pretrained(
24
- model_id,
25
- num_labels=2,
26
- id2label=self.id2label,
27
- label2id=self.label2id,
28
- ignore_mismatched_sizes=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  )
 
 
 
30
  self.model.eval()
31
- self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
32
- print(f"ViT model loaded successfully from {model_id}")
33
  except Exception as e:
34
- print(f"Error loading ViT model: {str(e)}")
35
  raise
36
 
37
  def detect_face(self, frame):
@@ -45,30 +66,35 @@ class DrowsinessDetector:
45
  return None, None
46
 
47
  def preprocess_image(self, image):
48
- """Preprocess the input image for ViT"""
49
  if image is None:
50
  return None
51
- pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
52
- inputs = self.processor(images=pil_img, return_tensors="pt")
53
- return inputs
 
 
 
 
 
 
54
 
55
  def predict(self, image):
56
- """Make prediction on the input image using ViT"""
57
- if self.model is None or self.processor is None:
58
- raise ValueError("未加载模型. 首先呼用 load_model().")
59
  # Detect face
60
  face, face_coords = self.detect_face(image)
61
  if face is None:
62
- return None, None, "没有检测到的脸"
63
  # Preprocess the face image
64
  inputs = self.preprocess_image(face)
65
  if inputs is None:
66
- return None, None, "错误处理图像"
67
  # Make prediction
68
  with torch.no_grad():
69
- outputs = self.model(**inputs)
70
- logits = outputs.logits
71
- probs = torch.softmax(logits, dim=1)
72
  pred_class = torch.argmax(probs, dim=1).item()
73
  pred_label = self.id2label[pred_class]
74
  pred_prob = probs[0, pred_class].item()
@@ -76,53 +102,42 @@ class DrowsinessDetector:
76
  drowsy_prob = probs[0, 1].item()
77
  return drowsy_prob, face_coords, None
78
 
79
- # Initialize detector
80
  detector = DrowsinessDetector()
81
 
82
  def process_image(image):
83
- """Process a single image"""
84
  if image is None:
85
- return None, "没有提供图像"
86
-
87
  try:
88
- # Convert image to numpy array if it's a PIL Image
89
- if isinstance(image, Image.Image):
90
- image = np.array(image)
91
-
92
- # Convert frame to RGB if needed
93
- if len(image.shape) == 2:
94
- image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
95
- elif image.shape[2] == 4:
96
- image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
97
-
98
  # Make prediction
99
  drowsy_prob, face_coords, error = detector.predict(image)
100
 
101
  if error:
102
- return image, error
103
-
104
- if face_coords is not None:
105
- x, y, w, h = face_coords
106
- # Draw rectangle around face
107
- color = (0, 0, 255) if drowsy_prob > 0.7 else (0, 255, 0)
108
- cv2.rectangle(image, (x, y), (x+w, y+h), color, 2)
109
 
110
- # Add text
111
- status = "昏昏欲睡" if drowsy_prob > 0.7 else "警报"
112
- cv2.putText(image, f"{status} ({drowsy_prob:.2%})",
113
- (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
114
 
115
- return image, f"Status: {status} (Confidence: {drowsy_prob:.2%})"
116
- else:
117
- return image, "没有检测到的脸"
 
 
 
 
 
 
 
118
 
119
  except Exception as e:
120
- return image, f"Error processing image: {str(e)}"
121
 
122
  def process_video(video):
123
  """Process video input"""
124
  if video is None:
125
- return None, "没有提供视频"
126
 
127
  try:
128
  # Get input video properties
@@ -151,9 +166,9 @@ def process_video(video):
151
 
152
  # Check if video was created
153
  if os.path.exists(temp_output) and os.path.getsize(temp_output) > 0:
154
- return temp_output, "视频成功处理"
155
  else:
156
- return None, "错误:无法创建输出视频"
157
 
158
  except Exception as e:
159
  return None, f"Error processing video: {str(e)}"
@@ -168,39 +183,39 @@ def process_video(video):
168
  detector.load_model()
169
 
170
  # Create interface
171
- with gr.Blocks(title="驾驶员嗜睡检测") as demo:
172
  gr.Markdown("""
173
- # 🚗 驾驶员嗜睡检测系统
174
 
175
- 该系统使用计算机视觉和深度学习来检测驾驶员的嗜睡。
176
 
177
- ## 特征:
178
- - 图像分析
179
- - 视频处理
180
- - 面部检测和嗜睡预测
181
  """)
182
 
183
  with gr.Tabs():
184
- with gr.Tab("图像"):
185
- gr.Markdown("上传图像以进行嗜睡检测")
186
  with gr.Row():
187
- image_input = gr.Image(label="输入图像", type="numpy")
188
- image_output = gr.Image(label="处理的图像")
189
  with gr.Row():
190
- status_output = gr.Textbox(label="状态")
191
  image_input.change(
192
  fn=process_image,
193
  inputs=[image_input],
194
  outputs=[image_output, status_output]
195
  )
196
 
197
- with gr.Tab("视频"):
198
- gr.Markdown("上传视频文件以进行嗜睡检测")
199
  with gr.Row():
200
- video_input = gr.Video(label="输入视频")
201
- video_output = gr.Video(label="处理的视频")
202
  with gr.Row():
203
- video_status = gr.Textbox(label="状态")
204
  video_input.change(
205
  fn=process_video,
206
  inputs=[video_input],
 
1
  import gradio as gr
2
  import torch
3
+ import torch.nn as nn
4
  import numpy as np
5
  import cv2
6
  from PIL import Image
 
10
  class DrowsinessDetector:
11
  def __init__(self):
12
  self.model = None
13
+ self.input_shape = (64, 64, 3)
 
14
  self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
15
+ self.id2label = {0: "notdrowsy", 1: "drowsy"}
16
+ self.label2id = {"notdrowsy": 0, "drowsy": 1}
17
 
18
  def load_model(self):
19
+ """Load the CNN model from Hugging Face Hub"""
20
  try:
21
+ model_id = "ckcl/driver-drowsiness-detector"
22
+ # Load the model configuration
23
+ config = torch.load(f"{model_id}/config.json")
24
+
25
+ # Create CNN model
26
+ self.model = nn.Sequential(
27
+ nn.Conv2d(3, 32, kernel_size=3, padding=1),
28
+ nn.BatchNorm2d(32),
29
+ nn.ReLU(),
30
+ nn.MaxPool2d(2),
31
+
32
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
33
+ nn.BatchNorm2d(64),
34
+ nn.ReLU(),
35
+ nn.MaxPool2d(2),
36
+
37
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
38
+ nn.BatchNorm2d(128),
39
+ nn.ReLU(),
40
+ nn.MaxPool2d(2),
41
+
42
+ nn.Flatten(),
43
+ nn.Linear(128 * 8 * 8, 128),
44
+ nn.BatchNorm1d(128),
45
+ nn.ReLU(),
46
+ nn.Dropout(0.5),
47
+ nn.Linear(128, 2)
48
  )
49
+
50
+ # Load the model weights
51
+ self.model.load_state_dict(torch.load(f"{model_id}/pytorch_model.bin"))
52
  self.model.eval()
53
+ print(f"CNN model loaded successfully from {model_id}")
 
54
  except Exception as e:
55
+ print(f"Error loading CNN model: {str(e)}")
56
  raise
57
 
58
  def detect_face(self, frame):
 
66
  return None, None
67
 
68
  def preprocess_image(self, image):
69
+ """Preprocess the input image for CNN"""
70
  if image is None:
71
  return None
72
+ # Convert to RGB
73
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
74
+ # Resize to model input size
75
+ image = cv2.resize(image, (self.input_shape[0], self.input_shape[1]))
76
+ # Normalize
77
+ image = image.astype(np.float32) / 255.0
78
+ # Convert to tensor and add batch dimension
79
+ image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
80
+ return image
81
 
82
  def predict(self, image):
83
+ """Make prediction on the input image using CNN"""
84
+ if self.model is None:
85
+ raise ValueError("Model not loaded. Call load_model() first.")
86
  # Detect face
87
  face, face_coords = self.detect_face(image)
88
  if face is None:
89
+ return None, None, "No face detected"
90
  # Preprocess the face image
91
  inputs = self.preprocess_image(face)
92
  if inputs is None:
93
+ return None, None, "Error processing image"
94
  # Make prediction
95
  with torch.no_grad():
96
+ outputs = self.model(inputs)
97
+ probs = torch.softmax(outputs, dim=1)
 
98
  pred_class = torch.argmax(probs, dim=1).item()
99
  pred_label = self.id2label[pred_class]
100
  pred_prob = probs[0, pred_class].item()
 
102
  drowsy_prob = probs[0, 1].item()
103
  return drowsy_prob, face_coords, None
104
 
105
+ # Create a global instance
106
  detector = DrowsinessDetector()
107
 
108
  def process_image(image):
109
+ """Process image input"""
110
  if image is None:
111
+ return None, "No image provided"
112
+
113
  try:
 
 
 
 
 
 
 
 
 
 
114
  # Make prediction
115
  drowsy_prob, face_coords, error = detector.predict(image)
116
 
117
  if error:
118
+ return None, error
 
 
 
 
 
 
119
 
120
+ if face_coords is None:
121
+ return image, "No face detected"
 
 
122
 
123
+ # Draw bounding box
124
+ x, y, w, h = face_coords
125
+ color = (0, 255, 0) if drowsy_prob < 0.5 else (0, 0, 255)
126
+ cv2.rectangle(image, (x, y), (x+w, y+h), color, 2)
127
+
128
+ # Add text
129
+ text = f"{'Drowsy' if drowsy_prob >= 0.5 else 'Not Drowsy'} ({drowsy_prob:.2f})"
130
+ cv2.putText(image, text, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
131
+
132
+ return image, f"Processed successfully. Drowsiness probability: {drowsy_prob:.2f}"
133
 
134
  except Exception as e:
135
+ return None, f"Error processing image: {str(e)}"
136
 
137
  def process_video(video):
138
  """Process video input"""
139
  if video is None:
140
+ return None, "No video provided"
141
 
142
  try:
143
  # Get input video properties
 
166
 
167
  # Check if video was created
168
  if os.path.exists(temp_output) and os.path.getsize(temp_output) > 0:
169
+ return temp_output, "Video processed successfully"
170
  else:
171
+ return None, "Error: Failed to create output video"
172
 
173
  except Exception as e:
174
  return None, f"Error processing video: {str(e)}"
 
183
  detector.load_model()
184
 
185
  # Create interface
186
+ with gr.Blocks(title="Driver Drowsiness Detection") as demo:
187
  gr.Markdown("""
188
+ # 🚗 Driver Drowsiness Detection System
189
 
190
+ This system detects driver drowsiness using computer vision and deep learning.
191
 
192
+ ## Features:
193
+ - Image analysis
194
+ - Video processing
195
+ - Face detection and drowsiness prediction
196
  """)
197
 
198
  with gr.Tabs():
199
+ with gr.Tab("Image"):
200
+ gr.Markdown("Upload an image for drowsiness detection")
201
  with gr.Row():
202
+ image_input = gr.Image(label="Input Image", type="numpy")
203
+ image_output = gr.Image(label="Processed Image")
204
  with gr.Row():
205
+ status_output = gr.Textbox(label="Status")
206
  image_input.change(
207
  fn=process_image,
208
  inputs=[image_input],
209
  outputs=[image_output, status_output]
210
  )
211
 
212
+ with gr.Tab("Video"):
213
+ gr.Markdown("Upload a video file for drowsiness detection")
214
  with gr.Row():
215
+ video_input = gr.Video(label="Input Video")
216
+ video_output = gr.Video(label="Processed Video")
217
  with gr.Row():
218
+ video_status = gr.Textbox(label="Status")
219
  video_input.change(
220
  fn=process_video,
221
  inputs=[video_input],