ckcl commited on
Commit
2c82790
·
verified ·
1 Parent(s): c893a2e

Upload 12 files

Browse files
.gitattributes CHANGED
@@ -1,34 +1,5 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tflite filter=lfs diff=lfs merge=lfs -text
29
- *.tgz filter=lfs diff=lfs merge=lfs -text
30
- *.wasm filter=lfs diff=lfs merge=lfs -text
31
- *.xz filter=lfs diff=lfs merge=lfs -text
32
- *.zip filter=lfs diff=lfs merge=lfs -text
33
- *.zst filter=lfs diff=lfs merge=lfs -text
34
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ # Auto detect text files and perform LF normalization
2
+ * text=auto
3
+ *.h5 filter=lfs diff=lfs merge=lfs -text
4
+ *.bin filter=lfs diff=lfs merge=lfs -text
5
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ env/
8
+ build/
9
+ develop-eggs/
10
+ dist/
11
+ downloads/
12
+ eggs/
13
+ .eggs/
14
+ lib/
15
+ lib64/
16
+ parts/
17
+ sdist/
18
+ var/
19
+ wheels/
20
+ *.egg-info/
21
+ .installed.cfg
22
+ *.egg
23
+
24
+ # Virtual Environment
25
+ venv/
26
+ ENV/
27
+
28
+ # IDE
29
+ .idea/
30
+ .vscode/
31
+ *.swp
32
+ *.swo
33
+
34
+ # Project specific
35
+ temp_output.mp4
36
+ *.h5
37
+ *.bin
38
+ *.pth
39
+ *.pt
40
+ *.onnx
41
+ *.pkl
42
+
43
+ # Logs
44
+ *.log
Dockerfile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ # 安裝系統依賴
4
+ RUN apt-get update && apt-get install -y \
5
+ libgl1-mesa-glx \
6
+ libglib2.0-0 \
7
+ libsm6 \
8
+ libxext6 \
9
+ libxrender-dev \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ # 設置工作目錄
13
+ WORKDIR /app
14
+
15
+ # 複製依賴文件
16
+ COPY requirements.txt .
17
+
18
+ # 安裝 Python 依賴
19
+ RUN pip install --no-cache-dir -r requirements.txt
20
+
21
+ # 複製應用程式文件
22
+ COPY . .
23
+
24
+ # 暴露端口
25
+ EXPOSE 8080
26
+
27
+ # 啟動應用
28
+ CMD ["python", "app.py"]
README.md CHANGED
@@ -1,36 +1,33 @@
1
- ---
2
- title: Dnn Space
3
- emoji: 🚀
4
- colorFrom: blue
5
- colorTo: green
6
- sdk: docker
7
- pinned: false
8
- short_description: Create powerful AI models without code
9
- hf_oauth: true
10
- hf_oauth_expiration_minutes: 36000
11
- hf_oauth_scopes:
12
- - read-repos
13
- - write-repos
14
- - manage-repos
15
- - inference-api
16
- - read-billing
17
- tags:
18
- - autotrain
19
- license: mit
20
- ---
21
-
22
- # Docs
23
-
24
- https://huggingface.co/docs/autotrain
25
-
26
- # Citation
27
-
28
- @misc{thakur2024autotrainnocodetrainingstateoftheart,
29
- title={AutoTrain: No-code training for state-of-the-art models},
30
- author={Abhishek Thakur},
31
- year={2024},
32
- eprint={2410.15735},
33
- archivePrefix={arXiv},
34
- primaryClass={cs.AI},
35
- url={https://arxiv.org/abs/2410.15735},
36
- }
 
1
+ # Driver Drowsiness Detection System
2
+
3
+ This is a real-time driver drowsiness detection system that uses computer vision and deep learning to detect signs of drowsiness in drivers. The system can process webcam feeds, video files, and single images.
4
+
5
+ ## Features
6
+
7
+ - Real-time webcam monitoring
8
+ - Video file processing
9
+ - Single image analysis
10
+ - Face detection and drowsiness prediction
11
+ - Visual feedback with bounding boxes and status indicators
12
+
13
+ ## How to Use
14
+
15
+ 1. **Webcam Mode**: Click the "Start Webcam" button to begin real-time monitoring
16
+ 2. **Video Mode**: Upload a video file for processing
17
+ 3. **Image Mode**: Upload a single image for analysis
18
+
19
+ The system will display the results with:
20
+ - Green box: Alert (not drowsy)
21
+ - Red box: Drowsy
22
+ - Probability score for drowsiness
23
+
24
+ ## Technical Details
25
+
26
+ - Built with PyTorch and Vision Transformer (ViT)
27
+ - Uses OpenCV for face detection
28
+ - Gradio interface for easy interaction
29
+ - Real-time processing capabilities
30
+
31
+ ## Model
32
+
33
+ The system uses a Vision Transformer (ViT) model trained on driver drowsiness detection. The model is capable of detecting subtle signs of drowsiness in facial expressions.
 
 
 
app.py CHANGED
@@ -1,158 +1,264 @@
1
  import gradio as gr
2
- import cv2
3
- import numpy as np
4
  import torch
5
- from transformers import AutoImageProcessor, AutoModelForImageClassification
6
- import tempfile
7
- import os
8
- import shutil
9
  from PIL import Image
 
 
 
10
  import time
11
 
12
- # 加載模型和處理器
13
- # 嘗試使用本地模型,如果失敗則使用遠程模型
14
- try:
15
- model_path = "./huggingface_model" # 本地模型路徑
16
- processor = AutoImageProcessor.from_pretrained(model_path)
17
- model = AutoModelForImageClassification.from_pretrained(model_path)
18
- print(f"使用本地模型: {model_path}")
19
- except Exception as e:
20
- print(f"無法載入本地模型: {e}")
21
- print("嘗試使用遠程模型...")
22
- model_name = "ckcl/dnn_space2" # 遠程模型名稱
23
- processor = AutoImageProcessor.from_pretrained(model_name)
24
- model = AutoModelForImageClassification.from_pretrained(model_name)
25
- print(f"使用遠程模型: {model_name}")
26
 
27
- # 如果使用GPU
28
- device = "cuda" if torch.cuda.is_available() else "cpu"
29
- print(f"使用設備: {device}")
30
- model = model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- def process_frame(frame):
33
- """處理單個視頻幀"""
34
- # 轉換為RGB
35
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
36
-
37
- # 使用處理器處理圖像
38
- inputs = processor(images=frame_rgb, return_tensors="pt")
39
- inputs = {k: v.to(device) for k, v in inputs.items()}
40
-
41
- # 進行預測
42
- with torch.no_grad():
43
- outputs = model(**inputs)
44
- logits = outputs.logits
45
- probabilities = torch.nn.functional.softmax(logits, dim=-1)
46
- prediction = torch.argmax(probabilities, dim=-1).item()
47
- confidence = probabilities[0][prediction].item()
48
-
49
- # 添加預測結果到圖像
50
- label = "Alert" if prediction == 0 else "Drowsy"
51
- color = (0, 255, 0) if prediction == 0 else (0, 0, 255)
52
-
53
- cv2.putText(frame, f"{label}: {confidence:.2f}", (10, 30),
54
- cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)
55
-
56
- return frame, label, confidence
57
 
58
- def process_video(video_file):
59
- """處理上傳的視頻文件"""
60
- # 創建臨時文件
61
- video_path = ""
62
-
63
- # 檢查 video_file 是字符串還是二進制數據
64
- if isinstance(video_file, str):
65
- # 如果是字符串(文件路徑),直接使用它
66
- video_path = video_file
67
- else:
68
- # 如果是二進制數據,寫入臨時文件
69
- with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_file:
70
- tmp_file.write(video_file)
71
- video_path = tmp_file.name
72
-
73
- # 打開視頻
74
- cap = cv2.VideoCapture(video_path)
75
- if not cap.isOpened():
76
- if video_path != video_file: # 如果創建了臨時文件,需要刪除
77
- os.unlink(video_path)
78
- return None, "Error: Could not open video file"
79
-
80
- # 獲取視頻信息
81
- fps = cap.get(cv2.CAP_PROP_FPS)
82
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
83
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
84
-
85
- # 創建輸出視頻
86
- os.makedirs("output_videos", exist_ok=True)
87
- timestamp = int(time.time())
88
- output_path = f"output_videos/output_{timestamp}.mp4"
89
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
90
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
91
-
92
- drowsy_frames = 0
93
- total_frames = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- while True:
96
- ret, frame = cap.read()
97
- if not ret:
98
- break
 
 
 
 
 
99
 
100
- # 處理幀
101
- processed_frame, label, confidence = process_frame(frame)
102
- out.write(processed_frame)
 
 
 
 
 
 
 
103
 
104
- if label == "Drowsy":
105
- drowsy_frames += 1
106
- total_frames += 1
107
-
108
- # 釋放資源
109
- cap.release()
110
- out.release()
111
-
112
- # 如果創建了臨時文件,需要刪除
113
- if video_path != video_file:
114
- os.unlink(video_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
- # 計算困倦比例
117
- drowsy_ratio = drowsy_frames / total_frames if total_frames > 0 else 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- # 生成報告
120
- report = f"""
121
- Video Analysis Report:
122
- ---------------------
123
- Total Frames: {total_frames}
124
- Drowsy Frames: {drowsy_frames}
125
- Drowsy Ratio: {drowsy_ratio:.2%}
126
 
127
- Alert Level:
128
- {'⚠️ HIGH ALERT' if drowsy_ratio > 0.3 else '✅ Normal' if drowsy_ratio < 0.1 else '⚠️ Warning'}
129
- """
 
 
 
130
 
131
- return output_path, report
132
-
133
- # 創建Gradio界面
134
- with gr.Blocks(title="Driver Drowsiness Detection") as demo:
135
- gr.Markdown(
136
- """
137
- # 🚗 Driver Drowsiness Detection
138
- Upload a video of a driver. The system will analyze the video and detect drowsiness frame by frame.\n
139
- - **Green label:** Alert
140
- - **Red label:** Drowsy
141
- - **Report:** Shows drowsy ratio and alert level.
142
- """
143
- )
144
- with gr.Row():
145
- with gr.Column(scale=1):
146
- video_input = gr.Video(label="Upload Driver Video (MP4)")
147
- analyze_btn = gr.Button("Analyze Video")
148
- with gr.Column(scale=1):
149
- video_output = gr.Video(label="Processed Video with Drowsiness Labels")
150
- report_output = gr.Textbox(label="Analysis Report", lines=8, interactive=False)
151
- analyze_btn.click(
152
- fn=process_video,
153
- inputs=video_input,
154
- outputs=[video_output, report_output]
155
- )
156
 
157
  if __name__ == "__main__":
158
- demo.launch()
 
1
  import gradio as gr
 
 
2
  import torch
3
+ from transformers import ViTForImageClassification, ViTImageProcessor
4
+ import numpy as np
5
+ import cv2
 
6
  from PIL import Image
7
+ import io
8
+ import os
9
+ import sys
10
  import time
11
 
12
+ class DrowsinessDetector:
13
+ def __init__(self):
14
+ self.model = None
15
+ self.processor = None
16
+ self.input_shape = (224, 224, 3)
17
+ self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
18
+ self.id2label = {0: "notdrowsy", 1: "drowsy"}
19
+ self.label2id = {"notdrowsy": 0, "drowsy": 1}
 
 
 
 
 
 
20
 
21
+ def load_model(self, model_path):
22
+ """Load the ViT model and processor from the specified path or directory"""
23
+ try:
24
+ self.model = ViTForImageClassification.from_pretrained(
25
+ model_path, # 直接給資料夾路徑
26
+ num_labels=2,
27
+ id2label=self.id2label,
28
+ label2id=self.label2id,
29
+ ignore_mismatched_sizes=True
30
+ )
31
+ self.model.eval()
32
+ self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
33
+ print(f"ViT model loaded successfully from {model_path}")
34
+ except Exception as e:
35
+ print(f"Error loading ViT model: {str(e)}")
36
+ raise
37
 
38
+ def detect_face(self, frame):
39
+ """Detect face in the frame"""
40
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
41
+ faces = self.face_cascade.detectMultiScale(gray, 1.1, 4)
42
+ if len(faces) > 0:
43
+ (x, y, w, h) = faces[0] # Get the first face
44
+ face = frame[y:y+h, x:x+w]
45
+ return face, (x, y, w, h)
46
+ return None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ def preprocess_image(self, image):
49
+ """Preprocess the input image for ViT"""
50
+ if image is None:
51
+ return None
52
+ pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
53
+ inputs = self.processor(images=pil_img, return_tensors="pt")
54
+ return inputs
55
+
56
+ def predict(self, image):
57
+ """Make prediction on the input image using ViT"""
58
+ if self.model is None or self.processor is None:
59
+ raise ValueError("Model not loaded. Call load_model() first.")
60
+ # Detect face
61
+ face, face_coords = self.detect_face(image)
62
+ if face is None:
63
+ return None, None, "No face detected"
64
+ # Preprocess the face image
65
+ inputs = self.preprocess_image(face)
66
+ if inputs is None:
67
+ return None, None, "Error processing image"
68
+ # Make prediction
69
+ with torch.no_grad():
70
+ outputs = self.model(**inputs)
71
+ logits = outputs.logits
72
+ probs = torch.softmax(logits, dim=1)
73
+ pred_class = torch.argmax(probs, dim=1).item()
74
+ pred_label = self.id2label[pred_class]
75
+ pred_prob = probs[0, pred_class].item()
76
+ # Return drowsy probability (class 1)
77
+ drowsy_prob = probs[0, 1].item()
78
+ return drowsy_prob, face_coords, None
79
+
80
+ # Initialize detector
81
+ detector = DrowsinessDetector()
82
+
83
+ def find_model_file():
84
+ """Find the model directory or file in common locations"""
85
+ possible_paths = [
86
+ "huggingface_model", # 優先資料夾
87
+ "pytorch_model.bin",
88
+ "model_weights.h5",
89
+ "drowsiness_model.h5",
90
+ "model/drowsiness_model.h5",
91
+ "models/drowsiness_model.h5",
92
+ "huggingface_model/model_weights.h5",
93
+ "huggingface_model/drowsiness_model.h5",
94
+ "../model_weights.h5",
95
+ "../drowsiness_model.h5"
96
+ ]
97
+ for path in possible_paths:
98
+ if os.path.exists(path):
99
+ return path
100
+ return None
101
+
102
+ def load_model():
103
+ """Load the model"""
104
+ model_path = find_model_file()
105
 
106
+ if model_path is None:
107
+ print("\nError: Model file not found!")
108
+ print("\nPlease ensure one of the following files exists:")
109
+ print("1. model_weights.h5")
110
+ print("2. drowsiness_model.h5")
111
+ print("3. model/drowsiness_model.h5")
112
+ print("4. models/drowsiness_model.h5")
113
+ print("\nYou can download the model from Hugging Face Hub or train it using train_model.py")
114
+ sys.exit(1)
115
 
116
+ try:
117
+ detector.load_model(model_path)
118
+ except Exception as e:
119
+ print(f"\nError loading model: {str(e)}")
120
+ sys.exit(1)
121
+
122
+ def process_frame(frame):
123
+ """Process a single frame"""
124
+ if frame is None:
125
+ return None
126
 
127
+ try:
128
+ # Convert frame to RGB if needed
129
+ if len(frame.shape) == 2:
130
+ frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
131
+ elif frame.shape[2] == 4:
132
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
133
+
134
+ # Make prediction
135
+ drowsy_prob, face_coords, error = detector.predict(frame)
136
+
137
+ if error:
138
+ return frame
139
+
140
+ if face_coords is not None:
141
+ x, y, w, h = face_coords
142
+ # Draw rectangle around face
143
+ color = (0, 0, 255) if drowsy_prob > 0.7 else (0, 255, 0)
144
+ cv2.rectangle(frame, (x, y), (x+w, y+h), color, 2)
145
+
146
+ # Add text
147
+ status = "DROWSY" if drowsy_prob > 0.7 else "ALERT"
148
+ cv2.putText(frame, f"{status} ({drowsy_prob:.2%})",
149
+ (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
150
+
151
+ return frame
152
+
153
+ except Exception as e:
154
+ print(f"Error processing frame: {str(e)}")
155
+ return frame
156
+
157
+ def process_video(video_input):
158
+ """Process video input"""
159
+ if video_input is None:
160
+ return None
161
 
162
+ try:
163
+ # Get input video properties
164
+ cap = cv2.VideoCapture(video_input)
165
+ fps = cap.get(cv2.CAP_PROP_FPS)
166
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
167
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
168
+
169
+ # Create temporary output video file
170
+ temp_output = "temp_output.mp4"
171
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
172
+ out = cv2.VideoWriter(temp_output, fourcc, fps, (width, height))
173
+
174
+ while True:
175
+ ret, frame = cap.read()
176
+ if not ret:
177
+ break
178
+
179
+ processed_frame = process_frame(frame)
180
+ if processed_frame is not None:
181
+ out.write(processed_frame)
182
+
183
+ # Release resources
184
+ cap.release()
185
+ out.release()
186
+
187
+ # Check if video was created
188
+ if os.path.exists(temp_output) and os.path.getsize(temp_output) > 0:
189
+ return temp_output
190
+ else:
191
+ print("Error: Failed to create output video")
192
+ return None
193
+
194
+ except Exception as e:
195
+ print(f"Error processing video: {str(e)}")
196
+ return None
197
+ finally:
198
+ # Clean up temporary file
199
+ if 'out' in locals():
200
+ out.release()
201
+ if 'cap' in locals():
202
+ cap.release()
203
+
204
+ def webcam_feed():
205
+ """Process webcam feed"""
206
+ try:
207
+ cap = cv2.VideoCapture(0)
208
+ while True:
209
+ ret, frame = cap.read()
210
+ if not ret:
211
+ break
212
+
213
+ processed_frame = process_frame(frame)
214
+ if processed_frame is not None:
215
+ yield processed_frame
216
+
217
+ except Exception as e:
218
+ print(f"Error processing webcam feed: {str(e)}")
219
+ yield None
220
+ finally:
221
+ cap.release()
222
+
223
+ # Load the model at startup
224
+ load_model()
225
+
226
+ # Create interface
227
+ with gr.Blocks(title="Driver Drowsiness Detection") as demo:
228
+ gr.Markdown("""
229
+ # 🚗 Driver Drowsiness Detection System
230
 
231
+ This system detects driver drowsiness using computer vision and deep learning.
 
 
 
 
 
 
232
 
233
+ ## Features:
234
+ - Real-time webcam monitoring
235
+ - Video file processing
236
+ - Single image analysis
237
+ - Face detection and drowsiness prediction
238
+ """)
239
 
240
+ with gr.Tabs():
241
+ with gr.Tab("Webcam"):
242
+ gr.Markdown("Real-time drowsiness detection using your webcam")
243
+ webcam_output = gr.Image(label="Live Detection")
244
+ webcam_button = gr.Button("Start Webcam")
245
+ webcam_button.click(fn=webcam_feed, inputs=None, outputs=webcam_output)
246
+
247
+ with gr.Tab("Video"):
248
+ gr.Markdown("Upload a video file for drowsiness detection")
249
+ with gr.Row():
250
+ video_input = gr.Video(label="Input Video")
251
+ video_output = gr.Video(label="Detection Result")
252
+ video_button = gr.Button("Process Video")
253
+ video_button.click(fn=process_video, inputs=video_input, outputs=video_output)
254
+
255
+ with gr.Tab("Image"):
256
+ gr.Markdown("Upload an image for drowsiness detection")
257
+ with gr.Row():
258
+ image_input = gr.Image(type="numpy", label="Input Image")
259
+ image_output = gr.Image(label="Detection Result")
260
+ image_button = gr.Button("Process Image")
261
+ image_button.click(fn=process_frame, inputs=image_input, outputs=image_output)
 
 
 
262
 
263
  if __name__ == "__main__":
264
+ demo.launch()
drowsiness_detector.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ from speed_detector import SpeedDetector
5
+ from face_analyzer import FaceAnalyzer
6
+ import pandas as pd
7
+ import time
8
+
9
+ class DrowsinessDetector:
10
+ def __init__(self):
11
+ self.speed_detector = SpeedDetector()
12
+ self.face_analyzer = FaceAnalyzer()
13
+
14
+ def process_frame(self, frame_path, face_path):
15
+ """
16
+ 處理單個幀
17
+ :param frame_path: 場景圖片路徑
18
+ :param face_path: 人臉圖片路徑
19
+ :return: (速度, 是否犯困)
20
+ """
21
+ try:
22
+ # 讀取圖片
23
+ frame = cv2.imread(frame_path)
24
+ face = cv2.imread(face_path)
25
+
26
+ if frame is None or face is None:
27
+ print(f"處理 {os.path.basename(frame_path)} 時出錯: 無法讀取圖片")
28
+ return None, None
29
+
30
+ # 檢測速度
31
+ speed = self.speed_detector.detect_speed(frame)
32
+
33
+ # 檢測是否犯困
34
+ is_drowsy = self.face_analyzer.is_drowsy(face)
35
+
36
+ return speed, is_drowsy
37
+ except Exception as e:
38
+ print(f"處理 {os.path.basename(frame_path)} 時出錯: {str(e)}")
39
+ return None, None
40
+
41
+ def process_video_folder(self, folder_path):
42
+ """
43
+ 處理一個視頻文件夾中的所有幀
44
+ :param folder_path: 視頻文件夾路徑
45
+ :return: 處理結果列表
46
+ """
47
+ results = []
48
+
49
+ # 獲取所有幀圖片
50
+ frame_files = [f for f in os.listdir(folder_path) if f.endswith('.jpg') and not f.endswith('_face.jpg')]
51
+ total_frames = len(frame_files)
52
+
53
+ for i, frame_file in enumerate(frame_files, 1):
54
+ # 構建完整的文件路徑
55
+ frame_path = os.path.join(folder_path, frame_file)
56
+ face_path = os.path.join(folder_path, frame_file.replace('.jpg', '_face.jpg'))
57
+
58
+ # 顯示進度
59
+ print(f"\r處理進度: {i}/{total_frames} ({i/total_frames*100:.1f}%)", end="")
60
+
61
+ try:
62
+ speed, is_drowsy = self.process_frame(frame_path, face_path)
63
+ if speed is not None and is_drowsy is not None:
64
+ results.append({
65
+ 'frame': frame_file,
66
+ 'speed': speed,
67
+ 'is_drowsy': is_drowsy
68
+ })
69
+ except KeyboardInterrupt:
70
+ print("\n檢測到中斷,保存當前結果...")
71
+ return results
72
+ except Exception as e:
73
+ print(f"\n處理 {frame_file} 時出錯: {str(e)}")
74
+ continue
75
+
76
+ print() # 換行
77
+ return results
78
+
79
+ def main():
80
+ # 初始化檢測器
81
+ detector = DrowsinessDetector()
82
+
83
+ # 獲取所有視頻文件夾
84
+ dataset_path = os.path.join('dataset', 'driver')
85
+ video_folders = [f for f in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, f))]
86
+ total_folders = len(video_folders)
87
+
88
+ all_results = []
89
+ batch_size = 100 # 每處理100個文件夾保存一次結果
90
+
91
+ try:
92
+ # 處理每個視頻文件夾
93
+ for i, folder in enumerate(video_folders, 1):
94
+ print(f"\n處理文件夾 {i}/{total_folders}: {folder}")
95
+ folder_path = os.path.join(dataset_path, folder)
96
+ results = detector.process_video_folder(folder_path)
97
+ all_results.extend(results)
98
+
99
+ # 每處理完一批文件夾就保存一次結果
100
+ if i % batch_size == 0 or i == total_folders:
101
+ print(f"\n保存第 {i//batch_size + 1} 批結果...")
102
+ df = pd.DataFrame(all_results)
103
+ df.to_csv(f'drowsiness_results_batch_{i//batch_size + 1}.csv', index=False)
104
+ all_results = [] # 清空結果列表
105
+
106
+ except KeyboardInterrupt:
107
+ print("\n檢測到中斷,保存當前結果...")
108
+ if all_results:
109
+ df = pd.DataFrame(all_results)
110
+ df.to_csv('drowsiness_results_final.csv', index=False)
111
+ print("結果已保存到 drowsiness_results_final.csv")
112
+ except Exception as e:
113
+ print(f"\n發生錯誤: {str(e)}")
114
+ if all_results:
115
+ df = pd.DataFrame(all_results)
116
+ df.to_csv('drowsiness_results_error.csv', index=False)
117
+ print("結果已保存到 drowsiness_results_error.csv")
118
+ finally:
119
+ print("\n處理完成")
120
+
121
+ if __name__ == "__main__":
122
+ main()
drowsiness_model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33ed6e261f05e4d4be1493ed052502babd13f646198240093db910011b8b6797
3
+ size 532812672
face_analyzer.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ class FaceAnalyzer:
5
+ def __init__(self):
6
+ # 加載OpenCV的人臉檢測器和眼睛檢測器
7
+ self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
8
+ self.eye_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_eye.xml')
9
+
10
+ def _get_eye_aspect_ratio(self, eye_region):
11
+ """
12
+ 計算眼睛縱橫比(EAR)
13
+ :param eye_region: 眼睛區域的圖像
14
+ :return: EAR值
15
+ """
16
+ # 將眼睛區域轉換為灰度圖
17
+ gray_eye = cv2.cvtColor(eye_region, cv2.COLOR_BGR2GRAY)
18
+
19
+ # 檢測眼睛
20
+ eyes = self.eye_cascade.detectMultiScale(gray_eye)
21
+
22
+ if len(eyes) != 2: # 如果沒有檢測到兩個眼睛
23
+ return 0.0
24
+
25
+ # 獲取眼睛的寬度和高度
26
+ eye1 = eyes[0]
27
+ eye2 = eyes[1]
28
+
29
+ # 計算眼睛的寬高比
30
+ ear1 = eye1[2] / eye1[3]
31
+ ear2 = eye2[2] / eye2[3]
32
+
33
+ # 返回平均EAR
34
+ return (ear1 + ear2) / 2.0
35
+
36
+ def is_drowsy(self, face_image):
37
+ """
38
+ 檢測是否犯困
39
+ :param face_image: 人臉圖片
40
+ :return: 是否犯困(True/False)
41
+ """
42
+ # 將圖片轉換為灰度圖
43
+ gray = cv2.cvtColor(face_image, cv2.COLOR_BGR2GRAY)
44
+
45
+ # 檢測人臉
46
+ faces = self.face_cascade.detectMultiScale(gray, 1.3, 5)
47
+
48
+ if len(faces) == 0:
49
+ return False
50
+
51
+ # 獲取最大的人臉區域
52
+ (x, y, w, h) = faces[0]
53
+ face_roi = face_image[y:y+h, x:x+w]
54
+
55
+ # 計算眼睛縱橫比
56
+ ear = self._get_eye_aspect_ratio(face_roi)
57
+
58
+ # 如果EAR小於閾值,認為是犯困
59
+ EAR_THRESHOLD = 0.25
60
+ return ear < EAR_THRESHOLD
haarcascade_frontalface_default.xml ADDED
The diff for this file is too large to render. See raw diff
 
inference.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import numpy as np
3
+ import cv2
4
+ from PIL import Image
5
+ import io
6
+ import base64
7
+ import os
8
+
9
+ class DrowsinessDetector:
10
+ def __init__(self):
11
+ self.model = None
12
+ self.input_shape = (64, 64, 3)
13
+
14
+ def load_model(self, model_path):
15
+ """Load the model from the specified path"""
16
+ self.model = tf.keras.models.load_model(model_path)
17
+
18
+ def preprocess_image(self, image):
19
+ """Preprocess the input image"""
20
+ if isinstance(image, str):
21
+ # If image is a base64 string
22
+ image_data = base64.b64decode(image)
23
+ image = Image.open(io.BytesIO(image_data))
24
+ image = np.array(image)
25
+ elif isinstance(image, bytes):
26
+ # If image is raw bytes
27
+ image = Image.open(io.BytesIO(image))
28
+ image = np.array(image)
29
+
30
+ # Convert to RGB if needed
31
+ if len(image.shape) == 2:
32
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
33
+ elif image.shape[2] == 4:
34
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
35
+
36
+ # Resize and normalize
37
+ image = cv2.resize(image, self.input_shape[:2])
38
+ image = image.astype(np.float32) / 255.0
39
+ image = np.expand_dims(image, axis=0)
40
+
41
+ return image
42
+
43
+ def predict(self, image):
44
+ """Make prediction on the input image"""
45
+ if self.model is None:
46
+ raise ValueError("Model not loaded. Call load_model() first.")
47
+
48
+ # Preprocess the image
49
+ processed_image = self.preprocess_image(image)
50
+
51
+ # Make prediction
52
+ prediction = self.model.predict(processed_image)
53
+
54
+ # Return prediction results
55
+ return {
56
+ "drowsy_probability": float(prediction[0][0]),
57
+ "is_drowsy": bool(prediction[0][0] > 0.5)
58
+ }
59
+
60
+ # Create a global instance
61
+ detector = DrowsinessDetector()
62
+
63
+ def load_model():
64
+ """Load the model when the API starts"""
65
+ global detector
66
+ detector.load_model("model_weights.h5")
67
+
68
+ def predict(image):
69
+ """API endpoint for prediction"""
70
+ try:
71
+ result = detector.predict(image)
72
+ return {
73
+ "status": "success",
74
+ "prediction": result
75
+ }
76
+ except Exception as e:
77
+ return {
78
+ "status": "error",
79
+ "message": str(e)
80
+ }
81
+
82
+ # For local testing
83
+ if __name__ == "__main__":
84
+ # Load model
85
+ load_model()
86
+
87
+ # Test with a sample image
88
+ test_image_path = "test_image.jpg" # Replace with your test image
89
+ if os.path.exists(test_image_path):
90
+ with open(test_image_path, "rb") as f:
91
+ image_data = f.read()
92
+ result = predict(image_data)
93
+ print("Prediction result:", result)
requirements.txt CHANGED
@@ -1,12 +1,10 @@
1
- torch>=1.7.0
2
- transformers>=4.18.0
3
- huggingface-hub>=0.4.0
4
- opencv-python>=4.5.0
5
- gradio>=3.50.2
6
- numpy>=1.19.0
7
- Pillow>=8.0.0
8
- matplotlib>=3.5.1
9
- scikit-learn>=1.0.2
10
- tqdm>=4.64.0
11
- pandas>=1.4.2
12
- datasets>=2.11.0
 
1
+ gradio==3.50.2
2
+ numpy==1.26.4
3
+ opencv-python==4.8.0
4
+ Pillow==10.0.0
5
+ ffmpeg-python==0.2.0
6
+ huggingface-hub>=0.21.0
7
+ transformers==4.35.2
8
+ torch>=2.0.0
9
+ torchvision>=0.15.0
10
+ tqdm==4.66.1
 
 
speed_detector.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ class SpeedDetector:
5
+ def __init__(self):
6
+ # 初始化車速檢測模型
7
+ self.model = self._load_model()
8
+
9
+ def _load_model(self):
10
+ """
11
+ 加載車速檢測模型
12
+ 這裡我們使用一個簡單的基於模板匹配的方法
13
+ 實際應用中應該使用更複雜的深度學習模型
14
+ """
15
+ # TODO: 實現實際的模型加載
16
+ return None
17
+
18
+ def detect_speed(self, frame):
19
+ """
20
+ 從圖片中檢測車速
21
+ :param frame: 輸入圖片
22
+ :return: 檢測到的車速(km/h)
23
+ """
24
+ # 將圖片轉換為灰度圖
25
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
26
+
27
+ # 使用Canny邊緣檢測
28
+ edges = cv2.Canny(gray, 50, 150)
29
+
30
+ # 使用霍夫變換檢測直線
31
+ lines = cv2.HoughLinesP(edges, 1, np.pi/180, 100, minLineLength=100, maxLineGap=10)
32
+
33
+ if lines is None:
34
+ return 0
35
+
36
+ # 計算車速(這裡使用一個簡單的啟發式方法)
37
+ # 實際應用中應該使用更複雜的算法
38
+ speed = len(lines) * 5 # 簡單的線性關係
39
+
40
+ return min(speed, 120) # 限制最大速度為120km/h