Upload 12 files
Browse files- .gitattributes +5 -34
- .gitignore +44 -0
- Dockerfile +28 -0
- README.md +33 -36
- app.py +247 -141
- drowsiness_detector.py +122 -0
- drowsiness_model.h5 +3 -0
- face_analyzer.py +60 -0
- haarcascade_frontalface_default.xml +0 -0
- inference.py +93 -0
- requirements.txt +10 -12
- speed_detector.py +40 -0
.gitattributes
CHANGED
|
@@ -1,34 +1,5 @@
|
|
| 1 |
-
|
| 2 |
-
*
|
| 3 |
-
*.
|
| 4 |
-
*.
|
| 5 |
-
*.
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
# Auto detect text files and perform LF normalization
|
| 2 |
+
* text=auto
|
| 3 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
env/
|
| 8 |
+
build/
|
| 9 |
+
develop-eggs/
|
| 10 |
+
dist/
|
| 11 |
+
downloads/
|
| 12 |
+
eggs/
|
| 13 |
+
.eggs/
|
| 14 |
+
lib/
|
| 15 |
+
lib64/
|
| 16 |
+
parts/
|
| 17 |
+
sdist/
|
| 18 |
+
var/
|
| 19 |
+
wheels/
|
| 20 |
+
*.egg-info/
|
| 21 |
+
.installed.cfg
|
| 22 |
+
*.egg
|
| 23 |
+
|
| 24 |
+
# Virtual Environment
|
| 25 |
+
venv/
|
| 26 |
+
ENV/
|
| 27 |
+
|
| 28 |
+
# IDE
|
| 29 |
+
.idea/
|
| 30 |
+
.vscode/
|
| 31 |
+
*.swp
|
| 32 |
+
*.swo
|
| 33 |
+
|
| 34 |
+
# Project specific
|
| 35 |
+
temp_output.mp4
|
| 36 |
+
*.h5
|
| 37 |
+
*.bin
|
| 38 |
+
*.pth
|
| 39 |
+
*.pt
|
| 40 |
+
*.onnx
|
| 41 |
+
*.pkl
|
| 42 |
+
|
| 43 |
+
# Logs
|
| 44 |
+
*.log
|
Dockerfile
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
# 安裝系統依賴
|
| 4 |
+
RUN apt-get update && apt-get install -y \
|
| 5 |
+
libgl1-mesa-glx \
|
| 6 |
+
libglib2.0-0 \
|
| 7 |
+
libsm6 \
|
| 8 |
+
libxext6 \
|
| 9 |
+
libxrender-dev \
|
| 10 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
+
|
| 12 |
+
# 設置工作目錄
|
| 13 |
+
WORKDIR /app
|
| 14 |
+
|
| 15 |
+
# 複製依賴文件
|
| 16 |
+
COPY requirements.txt .
|
| 17 |
+
|
| 18 |
+
# 安裝 Python 依賴
|
| 19 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 20 |
+
|
| 21 |
+
# 複製應用程式文件
|
| 22 |
+
COPY . .
|
| 23 |
+
|
| 24 |
+
# 暴露端口
|
| 25 |
+
EXPOSE 8080
|
| 26 |
+
|
| 27 |
+
# 啟動應用
|
| 28 |
+
CMD ["python", "app.py"]
|
README.md
CHANGED
|
@@ -1,36 +1,33 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
primaryClass={cs.AI},
|
| 35 |
-
url={https://arxiv.org/abs/2410.15735},
|
| 36 |
-
}
|
|
|
|
| 1 |
+
# Driver Drowsiness Detection System
|
| 2 |
+
|
| 3 |
+
This is a real-time driver drowsiness detection system that uses computer vision and deep learning to detect signs of drowsiness in drivers. The system can process webcam feeds, video files, and single images.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
- Real-time webcam monitoring
|
| 8 |
+
- Video file processing
|
| 9 |
+
- Single image analysis
|
| 10 |
+
- Face detection and drowsiness prediction
|
| 11 |
+
- Visual feedback with bounding boxes and status indicators
|
| 12 |
+
|
| 13 |
+
## How to Use
|
| 14 |
+
|
| 15 |
+
1. **Webcam Mode**: Click the "Start Webcam" button to begin real-time monitoring
|
| 16 |
+
2. **Video Mode**: Upload a video file for processing
|
| 17 |
+
3. **Image Mode**: Upload a single image for analysis
|
| 18 |
+
|
| 19 |
+
The system will display the results with:
|
| 20 |
+
- Green box: Alert (not drowsy)
|
| 21 |
+
- Red box: Drowsy
|
| 22 |
+
- Probability score for drowsiness
|
| 23 |
+
|
| 24 |
+
## Technical Details
|
| 25 |
+
|
| 26 |
+
- Built with PyTorch and Vision Transformer (ViT)
|
| 27 |
+
- Uses OpenCV for face detection
|
| 28 |
+
- Gradio interface for easy interaction
|
| 29 |
+
- Real-time processing capabilities
|
| 30 |
+
|
| 31 |
+
## Model
|
| 32 |
+
|
| 33 |
+
The system uses a Vision Transformer (ViT) model trained on driver drowsiness detection. The model is capable of detecting subtle signs of drowsiness in facial expressions.
|
|
|
|
|
|
|
|
|
app.py
CHANGED
|
@@ -1,158 +1,264 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
import cv2
|
| 3 |
-
import numpy as np
|
| 4 |
import torch
|
| 5 |
-
from transformers import
|
| 6 |
-
import
|
| 7 |
-
import
|
| 8 |
-
import shutil
|
| 9 |
from PIL import Image
|
|
|
|
|
|
|
|
|
|
| 10 |
import time
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
print(f"無法載入本地模型: {e}")
|
| 21 |
-
print("嘗試使用遠程模型...")
|
| 22 |
-
model_name = "ckcl/dnn_space2" # 遠程模型名稱
|
| 23 |
-
processor = AutoImageProcessor.from_pretrained(model_name)
|
| 24 |
-
model = AutoModelForImageClassification.from_pretrained(model_name)
|
| 25 |
-
print(f"使用遠程模型: {model_name}")
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
model =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
-
def
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
# 進行預測
|
| 42 |
-
with torch.no_grad():
|
| 43 |
-
outputs = model(**inputs)
|
| 44 |
-
logits = outputs.logits
|
| 45 |
-
probabilities = torch.nn.functional.softmax(logits, dim=-1)
|
| 46 |
-
prediction = torch.argmax(probabilities, dim=-1).item()
|
| 47 |
-
confidence = probabilities[0][prediction].item()
|
| 48 |
-
|
| 49 |
-
# 添加預測結果到圖像
|
| 50 |
-
label = "Alert" if prediction == 0 else "Drowsy"
|
| 51 |
-
color = (0, 255, 0) if prediction == 0 else (0, 0, 255)
|
| 52 |
-
|
| 53 |
-
cv2.putText(frame, f"{label}: {confidence:.2f}", (10, 30),
|
| 54 |
-
cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)
|
| 55 |
-
|
| 56 |
-
return frame, label, confidence
|
| 57 |
|
| 58 |
-
def
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
if
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
-
|
| 120 |
-
report = f"""
|
| 121 |
-
Video Analysis Report:
|
| 122 |
-
---------------------
|
| 123 |
-
Total Frames: {total_frames}
|
| 124 |
-
Drowsy Frames: {drowsy_frames}
|
| 125 |
-
Drowsy Ratio: {drowsy_ratio:.2%}
|
| 126 |
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
inputs=video_input,
|
| 154 |
-
outputs=[video_output, report_output]
|
| 155 |
-
)
|
| 156 |
|
| 157 |
if __name__ == "__main__":
|
| 158 |
-
demo.launch()
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
| 2 |
import torch
|
| 3 |
+
from transformers import ViTForImageClassification, ViTImageProcessor
|
| 4 |
+
import numpy as np
|
| 5 |
+
import cv2
|
|
|
|
| 6 |
from PIL import Image
|
| 7 |
+
import io
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
import time
|
| 11 |
|
| 12 |
+
class DrowsinessDetector:
|
| 13 |
+
def __init__(self):
|
| 14 |
+
self.model = None
|
| 15 |
+
self.processor = None
|
| 16 |
+
self.input_shape = (224, 224, 3)
|
| 17 |
+
self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
|
| 18 |
+
self.id2label = {0: "notdrowsy", 1: "drowsy"}
|
| 19 |
+
self.label2id = {"notdrowsy": 0, "drowsy": 1}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
+
def load_model(self, model_path):
|
| 22 |
+
"""Load the ViT model and processor from the specified path or directory"""
|
| 23 |
+
try:
|
| 24 |
+
self.model = ViTForImageClassification.from_pretrained(
|
| 25 |
+
model_path, # 直接給資料夾路徑
|
| 26 |
+
num_labels=2,
|
| 27 |
+
id2label=self.id2label,
|
| 28 |
+
label2id=self.label2id,
|
| 29 |
+
ignore_mismatched_sizes=True
|
| 30 |
+
)
|
| 31 |
+
self.model.eval()
|
| 32 |
+
self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
|
| 33 |
+
print(f"ViT model loaded successfully from {model_path}")
|
| 34 |
+
except Exception as e:
|
| 35 |
+
print(f"Error loading ViT model: {str(e)}")
|
| 36 |
+
raise
|
| 37 |
|
| 38 |
+
def detect_face(self, frame):
|
| 39 |
+
"""Detect face in the frame"""
|
| 40 |
+
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
| 41 |
+
faces = self.face_cascade.detectMultiScale(gray, 1.1, 4)
|
| 42 |
+
if len(faces) > 0:
|
| 43 |
+
(x, y, w, h) = faces[0] # Get the first face
|
| 44 |
+
face = frame[y:y+h, x:x+w]
|
| 45 |
+
return face, (x, y, w, h)
|
| 46 |
+
return None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
def preprocess_image(self, image):
|
| 49 |
+
"""Preprocess the input image for ViT"""
|
| 50 |
+
if image is None:
|
| 51 |
+
return None
|
| 52 |
+
pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
| 53 |
+
inputs = self.processor(images=pil_img, return_tensors="pt")
|
| 54 |
+
return inputs
|
| 55 |
+
|
| 56 |
+
def predict(self, image):
|
| 57 |
+
"""Make prediction on the input image using ViT"""
|
| 58 |
+
if self.model is None or self.processor is None:
|
| 59 |
+
raise ValueError("Model not loaded. Call load_model() first.")
|
| 60 |
+
# Detect face
|
| 61 |
+
face, face_coords = self.detect_face(image)
|
| 62 |
+
if face is None:
|
| 63 |
+
return None, None, "No face detected"
|
| 64 |
+
# Preprocess the face image
|
| 65 |
+
inputs = self.preprocess_image(face)
|
| 66 |
+
if inputs is None:
|
| 67 |
+
return None, None, "Error processing image"
|
| 68 |
+
# Make prediction
|
| 69 |
+
with torch.no_grad():
|
| 70 |
+
outputs = self.model(**inputs)
|
| 71 |
+
logits = outputs.logits
|
| 72 |
+
probs = torch.softmax(logits, dim=1)
|
| 73 |
+
pred_class = torch.argmax(probs, dim=1).item()
|
| 74 |
+
pred_label = self.id2label[pred_class]
|
| 75 |
+
pred_prob = probs[0, pred_class].item()
|
| 76 |
+
# Return drowsy probability (class 1)
|
| 77 |
+
drowsy_prob = probs[0, 1].item()
|
| 78 |
+
return drowsy_prob, face_coords, None
|
| 79 |
+
|
| 80 |
+
# Initialize detector
|
| 81 |
+
detector = DrowsinessDetector()
|
| 82 |
+
|
| 83 |
+
def find_model_file():
|
| 84 |
+
"""Find the model directory or file in common locations"""
|
| 85 |
+
possible_paths = [
|
| 86 |
+
"huggingface_model", # 優先資料夾
|
| 87 |
+
"pytorch_model.bin",
|
| 88 |
+
"model_weights.h5",
|
| 89 |
+
"drowsiness_model.h5",
|
| 90 |
+
"model/drowsiness_model.h5",
|
| 91 |
+
"models/drowsiness_model.h5",
|
| 92 |
+
"huggingface_model/model_weights.h5",
|
| 93 |
+
"huggingface_model/drowsiness_model.h5",
|
| 94 |
+
"../model_weights.h5",
|
| 95 |
+
"../drowsiness_model.h5"
|
| 96 |
+
]
|
| 97 |
+
for path in possible_paths:
|
| 98 |
+
if os.path.exists(path):
|
| 99 |
+
return path
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
def load_model():
|
| 103 |
+
"""Load the model"""
|
| 104 |
+
model_path = find_model_file()
|
| 105 |
|
| 106 |
+
if model_path is None:
|
| 107 |
+
print("\nError: Model file not found!")
|
| 108 |
+
print("\nPlease ensure one of the following files exists:")
|
| 109 |
+
print("1. model_weights.h5")
|
| 110 |
+
print("2. drowsiness_model.h5")
|
| 111 |
+
print("3. model/drowsiness_model.h5")
|
| 112 |
+
print("4. models/drowsiness_model.h5")
|
| 113 |
+
print("\nYou can download the model from Hugging Face Hub or train it using train_model.py")
|
| 114 |
+
sys.exit(1)
|
| 115 |
|
| 116 |
+
try:
|
| 117 |
+
detector.load_model(model_path)
|
| 118 |
+
except Exception as e:
|
| 119 |
+
print(f"\nError loading model: {str(e)}")
|
| 120 |
+
sys.exit(1)
|
| 121 |
+
|
| 122 |
+
def process_frame(frame):
|
| 123 |
+
"""Process a single frame"""
|
| 124 |
+
if frame is None:
|
| 125 |
+
return None
|
| 126 |
|
| 127 |
+
try:
|
| 128 |
+
# Convert frame to RGB if needed
|
| 129 |
+
if len(frame.shape) == 2:
|
| 130 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
|
| 131 |
+
elif frame.shape[2] == 4:
|
| 132 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
|
| 133 |
+
|
| 134 |
+
# Make prediction
|
| 135 |
+
drowsy_prob, face_coords, error = detector.predict(frame)
|
| 136 |
+
|
| 137 |
+
if error:
|
| 138 |
+
return frame
|
| 139 |
+
|
| 140 |
+
if face_coords is not None:
|
| 141 |
+
x, y, w, h = face_coords
|
| 142 |
+
# Draw rectangle around face
|
| 143 |
+
color = (0, 0, 255) if drowsy_prob > 0.7 else (0, 255, 0)
|
| 144 |
+
cv2.rectangle(frame, (x, y), (x+w, y+h), color, 2)
|
| 145 |
+
|
| 146 |
+
# Add text
|
| 147 |
+
status = "DROWSY" if drowsy_prob > 0.7 else "ALERT"
|
| 148 |
+
cv2.putText(frame, f"{status} ({drowsy_prob:.2%})",
|
| 149 |
+
(x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
|
| 150 |
+
|
| 151 |
+
return frame
|
| 152 |
+
|
| 153 |
+
except Exception as e:
|
| 154 |
+
print(f"Error processing frame: {str(e)}")
|
| 155 |
+
return frame
|
| 156 |
+
|
| 157 |
+
def process_video(video_input):
|
| 158 |
+
"""Process video input"""
|
| 159 |
+
if video_input is None:
|
| 160 |
+
return None
|
| 161 |
|
| 162 |
+
try:
|
| 163 |
+
# Get input video properties
|
| 164 |
+
cap = cv2.VideoCapture(video_input)
|
| 165 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 166 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 167 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 168 |
+
|
| 169 |
+
# Create temporary output video file
|
| 170 |
+
temp_output = "temp_output.mp4"
|
| 171 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 172 |
+
out = cv2.VideoWriter(temp_output, fourcc, fps, (width, height))
|
| 173 |
+
|
| 174 |
+
while True:
|
| 175 |
+
ret, frame = cap.read()
|
| 176 |
+
if not ret:
|
| 177 |
+
break
|
| 178 |
+
|
| 179 |
+
processed_frame = process_frame(frame)
|
| 180 |
+
if processed_frame is not None:
|
| 181 |
+
out.write(processed_frame)
|
| 182 |
+
|
| 183 |
+
# Release resources
|
| 184 |
+
cap.release()
|
| 185 |
+
out.release()
|
| 186 |
+
|
| 187 |
+
# Check if video was created
|
| 188 |
+
if os.path.exists(temp_output) and os.path.getsize(temp_output) > 0:
|
| 189 |
+
return temp_output
|
| 190 |
+
else:
|
| 191 |
+
print("Error: Failed to create output video")
|
| 192 |
+
return None
|
| 193 |
+
|
| 194 |
+
except Exception as e:
|
| 195 |
+
print(f"Error processing video: {str(e)}")
|
| 196 |
+
return None
|
| 197 |
+
finally:
|
| 198 |
+
# Clean up temporary file
|
| 199 |
+
if 'out' in locals():
|
| 200 |
+
out.release()
|
| 201 |
+
if 'cap' in locals():
|
| 202 |
+
cap.release()
|
| 203 |
+
|
| 204 |
+
def webcam_feed():
|
| 205 |
+
"""Process webcam feed"""
|
| 206 |
+
try:
|
| 207 |
+
cap = cv2.VideoCapture(0)
|
| 208 |
+
while True:
|
| 209 |
+
ret, frame = cap.read()
|
| 210 |
+
if not ret:
|
| 211 |
+
break
|
| 212 |
+
|
| 213 |
+
processed_frame = process_frame(frame)
|
| 214 |
+
if processed_frame is not None:
|
| 215 |
+
yield processed_frame
|
| 216 |
+
|
| 217 |
+
except Exception as e:
|
| 218 |
+
print(f"Error processing webcam feed: {str(e)}")
|
| 219 |
+
yield None
|
| 220 |
+
finally:
|
| 221 |
+
cap.release()
|
| 222 |
+
|
| 223 |
+
# Load the model at startup
|
| 224 |
+
load_model()
|
| 225 |
+
|
| 226 |
+
# Create interface
|
| 227 |
+
with gr.Blocks(title="Driver Drowsiness Detection") as demo:
|
| 228 |
+
gr.Markdown("""
|
| 229 |
+
# 🚗 Driver Drowsiness Detection System
|
| 230 |
|
| 231 |
+
This system detects driver drowsiness using computer vision and deep learning.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
+
## Features:
|
| 234 |
+
- Real-time webcam monitoring
|
| 235 |
+
- Video file processing
|
| 236 |
+
- Single image analysis
|
| 237 |
+
- Face detection and drowsiness prediction
|
| 238 |
+
""")
|
| 239 |
|
| 240 |
+
with gr.Tabs():
|
| 241 |
+
with gr.Tab("Webcam"):
|
| 242 |
+
gr.Markdown("Real-time drowsiness detection using your webcam")
|
| 243 |
+
webcam_output = gr.Image(label="Live Detection")
|
| 244 |
+
webcam_button = gr.Button("Start Webcam")
|
| 245 |
+
webcam_button.click(fn=webcam_feed, inputs=None, outputs=webcam_output)
|
| 246 |
+
|
| 247 |
+
with gr.Tab("Video"):
|
| 248 |
+
gr.Markdown("Upload a video file for drowsiness detection")
|
| 249 |
+
with gr.Row():
|
| 250 |
+
video_input = gr.Video(label="Input Video")
|
| 251 |
+
video_output = gr.Video(label="Detection Result")
|
| 252 |
+
video_button = gr.Button("Process Video")
|
| 253 |
+
video_button.click(fn=process_video, inputs=video_input, outputs=video_output)
|
| 254 |
+
|
| 255 |
+
with gr.Tab("Image"):
|
| 256 |
+
gr.Markdown("Upload an image for drowsiness detection")
|
| 257 |
+
with gr.Row():
|
| 258 |
+
image_input = gr.Image(type="numpy", label="Input Image")
|
| 259 |
+
image_output = gr.Image(label="Detection Result")
|
| 260 |
+
image_button = gr.Button("Process Image")
|
| 261 |
+
image_button.click(fn=process_frame, inputs=image_input, outputs=image_output)
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
if __name__ == "__main__":
|
| 264 |
+
demo.launch()
|
drowsiness_detector.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
from speed_detector import SpeedDetector
|
| 5 |
+
from face_analyzer import FaceAnalyzer
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
class DrowsinessDetector:
|
| 10 |
+
def __init__(self):
|
| 11 |
+
self.speed_detector = SpeedDetector()
|
| 12 |
+
self.face_analyzer = FaceAnalyzer()
|
| 13 |
+
|
| 14 |
+
def process_frame(self, frame_path, face_path):
|
| 15 |
+
"""
|
| 16 |
+
處理單個幀
|
| 17 |
+
:param frame_path: 場景圖片路徑
|
| 18 |
+
:param face_path: 人臉圖片路徑
|
| 19 |
+
:return: (速度, 是否犯困)
|
| 20 |
+
"""
|
| 21 |
+
try:
|
| 22 |
+
# 讀取圖片
|
| 23 |
+
frame = cv2.imread(frame_path)
|
| 24 |
+
face = cv2.imread(face_path)
|
| 25 |
+
|
| 26 |
+
if frame is None or face is None:
|
| 27 |
+
print(f"處理 {os.path.basename(frame_path)} 時出錯: 無法讀取圖片")
|
| 28 |
+
return None, None
|
| 29 |
+
|
| 30 |
+
# 檢測速度
|
| 31 |
+
speed = self.speed_detector.detect_speed(frame)
|
| 32 |
+
|
| 33 |
+
# 檢測是否犯困
|
| 34 |
+
is_drowsy = self.face_analyzer.is_drowsy(face)
|
| 35 |
+
|
| 36 |
+
return speed, is_drowsy
|
| 37 |
+
except Exception as e:
|
| 38 |
+
print(f"處理 {os.path.basename(frame_path)} 時出錯: {str(e)}")
|
| 39 |
+
return None, None
|
| 40 |
+
|
| 41 |
+
def process_video_folder(self, folder_path):
|
| 42 |
+
"""
|
| 43 |
+
處理一個視頻文件夾中的所有幀
|
| 44 |
+
:param folder_path: 視頻文件夾路徑
|
| 45 |
+
:return: 處理結果列表
|
| 46 |
+
"""
|
| 47 |
+
results = []
|
| 48 |
+
|
| 49 |
+
# 獲取所有幀圖片
|
| 50 |
+
frame_files = [f for f in os.listdir(folder_path) if f.endswith('.jpg') and not f.endswith('_face.jpg')]
|
| 51 |
+
total_frames = len(frame_files)
|
| 52 |
+
|
| 53 |
+
for i, frame_file in enumerate(frame_files, 1):
|
| 54 |
+
# 構建完整的文件路徑
|
| 55 |
+
frame_path = os.path.join(folder_path, frame_file)
|
| 56 |
+
face_path = os.path.join(folder_path, frame_file.replace('.jpg', '_face.jpg'))
|
| 57 |
+
|
| 58 |
+
# 顯示進度
|
| 59 |
+
print(f"\r處理進度: {i}/{total_frames} ({i/total_frames*100:.1f}%)", end="")
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
speed, is_drowsy = self.process_frame(frame_path, face_path)
|
| 63 |
+
if speed is not None and is_drowsy is not None:
|
| 64 |
+
results.append({
|
| 65 |
+
'frame': frame_file,
|
| 66 |
+
'speed': speed,
|
| 67 |
+
'is_drowsy': is_drowsy
|
| 68 |
+
})
|
| 69 |
+
except KeyboardInterrupt:
|
| 70 |
+
print("\n檢測到中斷,保存當前結果...")
|
| 71 |
+
return results
|
| 72 |
+
except Exception as e:
|
| 73 |
+
print(f"\n處理 {frame_file} 時出錯: {str(e)}")
|
| 74 |
+
continue
|
| 75 |
+
|
| 76 |
+
print() # 換行
|
| 77 |
+
return results
|
| 78 |
+
|
| 79 |
+
def main():
|
| 80 |
+
# 初始化檢測器
|
| 81 |
+
detector = DrowsinessDetector()
|
| 82 |
+
|
| 83 |
+
# 獲取所有視頻文件夾
|
| 84 |
+
dataset_path = os.path.join('dataset', 'driver')
|
| 85 |
+
video_folders = [f for f in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, f))]
|
| 86 |
+
total_folders = len(video_folders)
|
| 87 |
+
|
| 88 |
+
all_results = []
|
| 89 |
+
batch_size = 100 # 每處理100個文件夾保存一次結果
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
# 處理每個視頻文件夾
|
| 93 |
+
for i, folder in enumerate(video_folders, 1):
|
| 94 |
+
print(f"\n處理文件夾 {i}/{total_folders}: {folder}")
|
| 95 |
+
folder_path = os.path.join(dataset_path, folder)
|
| 96 |
+
results = detector.process_video_folder(folder_path)
|
| 97 |
+
all_results.extend(results)
|
| 98 |
+
|
| 99 |
+
# 每處理完一批文件夾就保存一次結果
|
| 100 |
+
if i % batch_size == 0 or i == total_folders:
|
| 101 |
+
print(f"\n保存第 {i//batch_size + 1} 批結果...")
|
| 102 |
+
df = pd.DataFrame(all_results)
|
| 103 |
+
df.to_csv(f'drowsiness_results_batch_{i//batch_size + 1}.csv', index=False)
|
| 104 |
+
all_results = [] # 清空結果列表
|
| 105 |
+
|
| 106 |
+
except KeyboardInterrupt:
|
| 107 |
+
print("\n檢測到中斷,保存當前結果...")
|
| 108 |
+
if all_results:
|
| 109 |
+
df = pd.DataFrame(all_results)
|
| 110 |
+
df.to_csv('drowsiness_results_final.csv', index=False)
|
| 111 |
+
print("結果已保存到 drowsiness_results_final.csv")
|
| 112 |
+
except Exception as e:
|
| 113 |
+
print(f"\n發生錯誤: {str(e)}")
|
| 114 |
+
if all_results:
|
| 115 |
+
df = pd.DataFrame(all_results)
|
| 116 |
+
df.to_csv('drowsiness_results_error.csv', index=False)
|
| 117 |
+
print("結果已保存到 drowsiness_results_error.csv")
|
| 118 |
+
finally:
|
| 119 |
+
print("\n處理完成")
|
| 120 |
+
|
| 121 |
+
if __name__ == "__main__":
|
| 122 |
+
main()
|
drowsiness_model.h5
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:33ed6e261f05e4d4be1493ed052502babd13f646198240093db910011b8b6797
|
| 3 |
+
size 532812672
|
face_analyzer.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
class FaceAnalyzer:
|
| 5 |
+
def __init__(self):
|
| 6 |
+
# 加載OpenCV的人臉檢測器和眼睛檢測器
|
| 7 |
+
self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
|
| 8 |
+
self.eye_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_eye.xml')
|
| 9 |
+
|
| 10 |
+
def _get_eye_aspect_ratio(self, eye_region):
|
| 11 |
+
"""
|
| 12 |
+
計算眼睛縱橫比(EAR)
|
| 13 |
+
:param eye_region: 眼睛區域的圖像
|
| 14 |
+
:return: EAR值
|
| 15 |
+
"""
|
| 16 |
+
# 將眼睛區域轉換為灰度圖
|
| 17 |
+
gray_eye = cv2.cvtColor(eye_region, cv2.COLOR_BGR2GRAY)
|
| 18 |
+
|
| 19 |
+
# 檢測眼睛
|
| 20 |
+
eyes = self.eye_cascade.detectMultiScale(gray_eye)
|
| 21 |
+
|
| 22 |
+
if len(eyes) != 2: # 如果沒有檢測到兩個眼睛
|
| 23 |
+
return 0.0
|
| 24 |
+
|
| 25 |
+
# 獲取眼睛的寬度和高度
|
| 26 |
+
eye1 = eyes[0]
|
| 27 |
+
eye2 = eyes[1]
|
| 28 |
+
|
| 29 |
+
# 計算眼睛的寬高比
|
| 30 |
+
ear1 = eye1[2] / eye1[3]
|
| 31 |
+
ear2 = eye2[2] / eye2[3]
|
| 32 |
+
|
| 33 |
+
# 返回平均EAR
|
| 34 |
+
return (ear1 + ear2) / 2.0
|
| 35 |
+
|
| 36 |
+
def is_drowsy(self, face_image):
|
| 37 |
+
"""
|
| 38 |
+
檢測是否犯困
|
| 39 |
+
:param face_image: 人臉圖片
|
| 40 |
+
:return: 是否犯困(True/False)
|
| 41 |
+
"""
|
| 42 |
+
# 將圖片轉換為灰度圖
|
| 43 |
+
gray = cv2.cvtColor(face_image, cv2.COLOR_BGR2GRAY)
|
| 44 |
+
|
| 45 |
+
# 檢測人臉
|
| 46 |
+
faces = self.face_cascade.detectMultiScale(gray, 1.3, 5)
|
| 47 |
+
|
| 48 |
+
if len(faces) == 0:
|
| 49 |
+
return False
|
| 50 |
+
|
| 51 |
+
# 獲取最大的人臉區域
|
| 52 |
+
(x, y, w, h) = faces[0]
|
| 53 |
+
face_roi = face_image[y:y+h, x:x+w]
|
| 54 |
+
|
| 55 |
+
# 計算眼睛縱橫比
|
| 56 |
+
ear = self._get_eye_aspect_ratio(face_roi)
|
| 57 |
+
|
| 58 |
+
# 如果EAR小於閾值,認為是犯困
|
| 59 |
+
EAR_THRESHOLD = 0.25
|
| 60 |
+
return ear < EAR_THRESHOLD
|
haarcascade_frontalface_default.xml
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
inference.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import io
|
| 6 |
+
import base64
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
class DrowsinessDetector:
|
| 10 |
+
def __init__(self):
|
| 11 |
+
self.model = None
|
| 12 |
+
self.input_shape = (64, 64, 3)
|
| 13 |
+
|
| 14 |
+
def load_model(self, model_path):
|
| 15 |
+
"""Load the model from the specified path"""
|
| 16 |
+
self.model = tf.keras.models.load_model(model_path)
|
| 17 |
+
|
| 18 |
+
def preprocess_image(self, image):
|
| 19 |
+
"""Preprocess the input image"""
|
| 20 |
+
if isinstance(image, str):
|
| 21 |
+
# If image is a base64 string
|
| 22 |
+
image_data = base64.b64decode(image)
|
| 23 |
+
image = Image.open(io.BytesIO(image_data))
|
| 24 |
+
image = np.array(image)
|
| 25 |
+
elif isinstance(image, bytes):
|
| 26 |
+
# If image is raw bytes
|
| 27 |
+
image = Image.open(io.BytesIO(image))
|
| 28 |
+
image = np.array(image)
|
| 29 |
+
|
| 30 |
+
# Convert to RGB if needed
|
| 31 |
+
if len(image.shape) == 2:
|
| 32 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
| 33 |
+
elif image.shape[2] == 4:
|
| 34 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
| 35 |
+
|
| 36 |
+
# Resize and normalize
|
| 37 |
+
image = cv2.resize(image, self.input_shape[:2])
|
| 38 |
+
image = image.astype(np.float32) / 255.0
|
| 39 |
+
image = np.expand_dims(image, axis=0)
|
| 40 |
+
|
| 41 |
+
return image
|
| 42 |
+
|
| 43 |
+
def predict(self, image):
|
| 44 |
+
"""Make prediction on the input image"""
|
| 45 |
+
if self.model is None:
|
| 46 |
+
raise ValueError("Model not loaded. Call load_model() first.")
|
| 47 |
+
|
| 48 |
+
# Preprocess the image
|
| 49 |
+
processed_image = self.preprocess_image(image)
|
| 50 |
+
|
| 51 |
+
# Make prediction
|
| 52 |
+
prediction = self.model.predict(processed_image)
|
| 53 |
+
|
| 54 |
+
# Return prediction results
|
| 55 |
+
return {
|
| 56 |
+
"drowsy_probability": float(prediction[0][0]),
|
| 57 |
+
"is_drowsy": bool(prediction[0][0] > 0.5)
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
# Create a global instance
|
| 61 |
+
detector = DrowsinessDetector()
|
| 62 |
+
|
| 63 |
+
def load_model():
|
| 64 |
+
"""Load the model when the API starts"""
|
| 65 |
+
global detector
|
| 66 |
+
detector.load_model("model_weights.h5")
|
| 67 |
+
|
| 68 |
+
def predict(image):
|
| 69 |
+
"""API endpoint for prediction"""
|
| 70 |
+
try:
|
| 71 |
+
result = detector.predict(image)
|
| 72 |
+
return {
|
| 73 |
+
"status": "success",
|
| 74 |
+
"prediction": result
|
| 75 |
+
}
|
| 76 |
+
except Exception as e:
|
| 77 |
+
return {
|
| 78 |
+
"status": "error",
|
| 79 |
+
"message": str(e)
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
# For local testing
|
| 83 |
+
if __name__ == "__main__":
|
| 84 |
+
# Load model
|
| 85 |
+
load_model()
|
| 86 |
+
|
| 87 |
+
# Test with a sample image
|
| 88 |
+
test_image_path = "test_image.jpg" # Replace with your test image
|
| 89 |
+
if os.path.exists(test_image_path):
|
| 90 |
+
with open(test_image_path, "rb") as f:
|
| 91 |
+
image_data = f.read()
|
| 92 |
+
result = predict(image_data)
|
| 93 |
+
print("Prediction result:", result)
|
requirements.txt
CHANGED
|
@@ -1,12 +1,10 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
tqdm
|
| 11 |
-
pandas>=1.4.2
|
| 12 |
-
datasets>=2.11.0
|
|
|
|
| 1 |
+
gradio==3.50.2
|
| 2 |
+
numpy==1.26.4
|
| 3 |
+
opencv-python==4.8.0
|
| 4 |
+
Pillow==10.0.0
|
| 5 |
+
ffmpeg-python==0.2.0
|
| 6 |
+
huggingface-hub>=0.21.0
|
| 7 |
+
transformers==4.35.2
|
| 8 |
+
torch>=2.0.0
|
| 9 |
+
torchvision>=0.15.0
|
| 10 |
+
tqdm==4.66.1
|
|
|
|
|
|
speed_detector.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
class SpeedDetector:
|
| 5 |
+
def __init__(self):
|
| 6 |
+
# 初始化車速檢測模型
|
| 7 |
+
self.model = self._load_model()
|
| 8 |
+
|
| 9 |
+
def _load_model(self):
|
| 10 |
+
"""
|
| 11 |
+
加載車速檢測模型
|
| 12 |
+
這裡我們使用一個簡單的基於模板匹配的方法
|
| 13 |
+
實際應用中應該使用更複雜的深度學習模型
|
| 14 |
+
"""
|
| 15 |
+
# TODO: 實現實際的模型加載
|
| 16 |
+
return None
|
| 17 |
+
|
| 18 |
+
def detect_speed(self, frame):
|
| 19 |
+
"""
|
| 20 |
+
從圖片中檢測車速
|
| 21 |
+
:param frame: 輸入圖片
|
| 22 |
+
:return: 檢測到的車速(km/h)
|
| 23 |
+
"""
|
| 24 |
+
# 將圖片轉換為灰度圖
|
| 25 |
+
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
| 26 |
+
|
| 27 |
+
# 使用Canny邊緣檢測
|
| 28 |
+
edges = cv2.Canny(gray, 50, 150)
|
| 29 |
+
|
| 30 |
+
# 使用霍夫變換檢測直線
|
| 31 |
+
lines = cv2.HoughLinesP(edges, 1, np.pi/180, 100, minLineLength=100, maxLineGap=10)
|
| 32 |
+
|
| 33 |
+
if lines is None:
|
| 34 |
+
return 0
|
| 35 |
+
|
| 36 |
+
# 計算車速(這裡使用一個簡單的啟發式方法)
|
| 37 |
+
# 實際應用中應該使用更複雜的算法
|
| 38 |
+
speed = len(lines) * 5 # 簡單的線性關係
|
| 39 |
+
|
| 40 |
+
return min(speed, 120) # 限制最大速度為120km/h
|