Spaces:
Runtime error
Runtime error
added stuff
Browse files- .gitignore +76 -0
- Dockerfile +7 -0
- README.md +4 -2
- app/app.py +69 -0
- docker-compose.yml +13 -0
- model/data/mp_process.py +139 -0
- model/data/process_data.py +167 -0
- model/training/model_training.ipynb +292 -0
- model/training/saved_models/README.md +2 -0
- requirements.txt +10 -0
.gitignore
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# auth-tokens
|
| 2 |
+
*.json
|
| 3 |
+
|
| 4 |
+
# gradio
|
| 5 |
+
app/flagged
|
| 6 |
+
*.h5
|
| 7 |
+
*.pkl
|
| 8 |
+
|
| 9 |
+
# Byte-compiled / optimized / DLL files
|
| 10 |
+
__pycache__/
|
| 11 |
+
*.py[cod]
|
| 12 |
+
*$py.class
|
| 13 |
+
|
| 14 |
+
# C extensions
|
| 15 |
+
*.so
|
| 16 |
+
|
| 17 |
+
# Distribution / packaging
|
| 18 |
+
.Python
|
| 19 |
+
build/
|
| 20 |
+
develop-eggs/
|
| 21 |
+
dist/
|
| 22 |
+
downloads/
|
| 23 |
+
eggs/
|
| 24 |
+
.eggs/
|
| 25 |
+
lib/
|
| 26 |
+
lib64/
|
| 27 |
+
parts/
|
| 28 |
+
sdist/
|
| 29 |
+
var/
|
| 30 |
+
wheels/
|
| 31 |
+
share/python-wheels/
|
| 32 |
+
*.egg-info/
|
| 33 |
+
.installed.cfg
|
| 34 |
+
*.egg
|
| 35 |
+
MANIFEST
|
| 36 |
+
|
| 37 |
+
# PyInstaller
|
| 38 |
+
# Usually these files are written by a python script from a template
|
| 39 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 40 |
+
*.manifest
|
| 41 |
+
*.spec
|
| 42 |
+
|
| 43 |
+
# Installer logs
|
| 44 |
+
pip-log.txt
|
| 45 |
+
pip-delete-this-directory.txt
|
| 46 |
+
|
| 47 |
+
# Unit test / coverage reports
|
| 48 |
+
htmlcov/
|
| 49 |
+
.tox/
|
| 50 |
+
.nox/
|
| 51 |
+
.coverage
|
| 52 |
+
.coverage.*
|
| 53 |
+
.cache
|
| 54 |
+
nosetests.xml
|
| 55 |
+
coverage.xml
|
| 56 |
+
*.cover
|
| 57 |
+
*.py,cover
|
| 58 |
+
.hypothesis/
|
| 59 |
+
.pytest_cache/
|
| 60 |
+
cover/
|
| 61 |
+
|
| 62 |
+
# Jupyter Notebook
|
| 63 |
+
.ipynb_checkpoints
|
| 64 |
+
|
| 65 |
+
# IPython
|
| 66 |
+
profile_default/
|
| 67 |
+
ipython_config.py
|
| 68 |
+
|
| 69 |
+
# Environments
|
| 70 |
+
.env
|
| 71 |
+
.venv
|
| 72 |
+
env/
|
| 73 |
+
venv/
|
| 74 |
+
ENV/
|
| 75 |
+
env.bak/
|
| 76 |
+
venv.bak/
|
Dockerfile
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from python:3.10.11-slim-bullseye
|
| 2 |
+
WORKDIR /app
|
| 3 |
+
COPY requirements.txt .
|
| 4 |
+
RUN apt-get update && apt-get upgrade -y
|
| 5 |
+
RUN pip install -r requirements.txt
|
| 6 |
+
RUN cd app
|
| 7 |
+
CMD ["gradio", "app.py"]
|
README.md
CHANGED
|
@@ -1,13 +1,15 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
emoji: 🔥
|
| 4 |
colorFrom: pink
|
| 5 |
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 3.27.0
|
| 8 |
app_file: app.py
|
| 9 |
-
pinned:
|
| 10 |
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: antisomnus
|
| 3 |
emoji: 🔥
|
| 4 |
colorFrom: pink
|
| 5 |
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 3.27.0
|
| 8 |
app_file: app.py
|
| 9 |
+
pinned: true
|
| 10 |
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 14 |
+
=======
|
| 15 |
+
# antisomnus - Driver Drowsiness Detection
|
app/app.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
app.py
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import numpy as np
|
| 9 |
+
import mediapipe as mp
|
| 10 |
+
import tensorflow as tf
|
| 11 |
+
import cv2
|
| 12 |
+
|
| 13 |
+
# Add the path to the model directory
|
| 14 |
+
path = Path(os.getcwd())
|
| 15 |
+
sys.path.insert(0,str(path.parent.absolute())+"/model/data")
|
| 16 |
+
from mp_process import process_mp_img
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
model = tf.keras.models.load_model(str(path.parent.absolute())+"/model/training/saved_models/en_model_v0.h5")
|
| 20 |
+
|
| 21 |
+
def preprocess_frame(frame):
|
| 22 |
+
"""
|
| 23 |
+
Preprocess the frame to be compatible with the model
|
| 24 |
+
"""
|
| 25 |
+
frame = cv2.resize(frame, (224,224), interpolation = cv2.INTER_AREA)
|
| 26 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 27 |
+
frame = frame / 255.0
|
| 28 |
+
return np.expand_dims(frame, axis=0)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def detect_drowsiness(frame):
|
| 32 |
+
"""
|
| 33 |
+
returns features and/or processed image
|
| 34 |
+
"""
|
| 35 |
+
annotated_img, eye_feature, mouth_feature, mp_drowsy = process_mp_img(frame)
|
| 36 |
+
# Preprocess the frame
|
| 37 |
+
preprocessed_frame = preprocess_frame(frame)
|
| 38 |
+
# Make predictions using the model
|
| 39 |
+
prediction = model.predict(preprocessed_frame)
|
| 40 |
+
# Threshold the prediction to classify drowsiness
|
| 41 |
+
model_drowsy = prediction[0][0] >= 0.5
|
| 42 |
+
|
| 43 |
+
# Return the result
|
| 44 |
+
return annotated_img, "Drowsy" if not model_drowsy else "Awake", "Drowsy" if mp_drowsy else "Awake",eye_feature, mouth_feature
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# Define the input component as an Image component
|
| 50 |
+
input_image = gr.inputs.Image(shape=(480, 640), source="webcam", label="live feed")
|
| 51 |
+
|
| 52 |
+
# Define the output components as an Image and a Label component
|
| 53 |
+
output_image = gr.components.Image(label="Drowsiness Detection")
|
| 54 |
+
output_model = gr.components.Label(label="Drowsiness Status - en_model_v0.h5")
|
| 55 |
+
output_mp = gr.components.Label(label="Drowsiness Status - MediaPipe")
|
| 56 |
+
output_eye = gr.components.Textbox(label="Eye Aspect Ratio")
|
| 57 |
+
output_mouth = gr.components.Textbox(label="Mouth Aspect Ratio")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
iface = gr.Interface(
|
| 61 |
+
fn=detect_drowsiness,
|
| 62 |
+
inputs=input_image,
|
| 63 |
+
title="antisomnus - driver drowsiness detection",
|
| 64 |
+
outputs=[output_image,output_model, output_mp, output_eye, output_mouth],
|
| 65 |
+
capture_session=True,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# Launch the Gradio interface
|
| 69 |
+
iface.launch(share=True)
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: driver_drowsiness_detection
|
| 2 |
+
version: '0.1'
|
| 3 |
+
|
| 4 |
+
services:
|
| 5 |
+
app:
|
| 6 |
+
image: driver_drowsiness_detection
|
| 7 |
+
build: .
|
| 8 |
+
ports:
|
| 9 |
+
- 127.0.0.1:7860:3000
|
| 10 |
+
volumes:
|
| 11 |
+
- ./app:/app/app
|
| 12 |
+
- ./model/data:/app/model/data
|
| 13 |
+
- ./model/training/saved_models:/app/model/training/saved_models
|
model/data/mp_process.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import math
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
import mediapipe as mp
|
| 6 |
+
|
| 7 |
+
from scipy.spatial.distance import euclidean as dist
|
| 8 |
+
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
|
| 9 |
+
|
| 10 |
+
# feature definitions
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
DIMS = (224,224,3) # dimensions of the image
|
| 14 |
+
RIGHT = [[33, 133], [160, 144], [159, 145], [158, 153]] # right eye landmark positions
|
| 15 |
+
LEFT = [[263, 362], [387, 373], [386, 374], [385, 380]] # left eye landmark positions
|
| 16 |
+
MOUTH = [[61, 291], [39, 181], [0, 17], [269, 405]] # mouth landmark coordinates
|
| 17 |
+
|
| 18 |
+
EYE_AR_THRESH = 0.45
|
| 19 |
+
PROB_THRESH = 0.3
|
| 20 |
+
EYE_AR_CONSEC_FRAMES = 15
|
| 21 |
+
|
| 22 |
+
MOUTH_AR_THRESH = 0.33
|
| 23 |
+
MOUTH_AR_CONSEC_FRAMES = 20
|
| 24 |
+
|
| 25 |
+
MP_FACE_DETECTION = mp.solutions.face_detection
|
| 26 |
+
MP_DRAWING = mp.solutions.drawing_utils
|
| 27 |
+
MP_DRAWING_STYLES = mp.solutions.drawing_styles
|
| 28 |
+
MP_FACE_MESH = mp.solutions.face_mesh
|
| 29 |
+
DRAWING_SPEC = MP_DRAWING.DrawingSpec(thickness=1, circle_radius=1)
|
| 30 |
+
|
| 31 |
+
def get_ear(landmarks,eye):
|
| 32 |
+
''' Calculate the ratio of the eye length to eye width.
|
| 33 |
+
:param landmarks: Face Landmarks returned from FaceMesh MediaPipe model
|
| 34 |
+
:param eye: List containing positions which correspond to the eye
|
| 35 |
+
:return: Eye aspect ratio value
|
| 36 |
+
'''
|
| 37 |
+
N1 = dist(landmarks[eye[1][0]], landmarks[eye[1][1]])
|
| 38 |
+
N2 = dist(landmarks[eye[2][0]], landmarks[eye[2][1]])
|
| 39 |
+
N3 = dist(landmarks[eye[3][0]], landmarks[eye[3][1]])
|
| 40 |
+
D = dist(landmarks[eye[0][0]], landmarks[eye[0][1]])
|
| 41 |
+
return (N1 + N2 + N3) / (3 * D)
|
| 42 |
+
|
| 43 |
+
def get_eye_feature(landmarks):
|
| 44 |
+
''' Calculate the eye feature as the average of the eye aspect ratio for the two eyes
|
| 45 |
+
:param landmarks: Face Landmarks returned from FaceMesh MediaPipe model
|
| 46 |
+
:return: Eye feature value
|
| 47 |
+
'''
|
| 48 |
+
return (get_ear(landmarks,LEFT) + get_ear(landmarks,RIGHT))
|
| 49 |
+
|
| 50 |
+
def get_mouth_feature(landmarks):
|
| 51 |
+
''' Calculate mouth feature as the ratio of the mouth length to mouth width
|
| 52 |
+
:param landmarks: Face Landmarks returned from FaceMesh MediaPipe model
|
| 53 |
+
:return: Mouth feature value
|
| 54 |
+
'''
|
| 55 |
+
n_1 = dist(landmarks[MOUTH[1][0]], landmarks[MOUTH[1][1]])
|
| 56 |
+
n_2 = dist(landmarks[MOUTH[2][0]], landmarks[MOUTH[2][1]])
|
| 57 |
+
n_3 = dist(landmarks[MOUTH[3][0]], landmarks[MOUTH[3][1]])
|
| 58 |
+
dst = dist(landmarks[MOUTH[0][0]], landmarks[MOUTH[0][1]])
|
| 59 |
+
return (n_1 + n_2 + n_3)/(3*dst)
|
| 60 |
+
|
| 61 |
+
# image processing
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def process_mp_img(frame):
|
| 65 |
+
"""
|
| 66 |
+
returns features and/or processed image
|
| 67 |
+
"""
|
| 68 |
+
with MP_FACE_MESH.FaceMesh(
|
| 69 |
+
min_detection_confidence=0.3,
|
| 70 |
+
min_tracking_confidence=0.8) as face_mesh:
|
| 71 |
+
# convert the img to RGB and process it with MediaPipe Face Detection
|
| 72 |
+
results = face_mesh.process(cv2.cvtColor(frame,cv2.COLOR_BGR2RGB))
|
| 73 |
+
|
| 74 |
+
if results.multi_face_landmarks is not None:
|
| 75 |
+
landmark_pos = []
|
| 76 |
+
for i, data in enumerate(results.multi_face_landmarks[0].landmark):
|
| 77 |
+
landmark_pos.append([data.x, data.y, data.z])
|
| 78 |
+
landmark_pos = np.array(landmark_pos)
|
| 79 |
+
|
| 80 |
+
# draw face detections of each face
|
| 81 |
+
annotated_img = frame.copy()
|
| 82 |
+
for face_landmarks in results.multi_face_landmarks:
|
| 83 |
+
# Calculate eye and mouth features
|
| 84 |
+
eye_feature = get_eye_feature(landmark_pos)
|
| 85 |
+
mouth_feature = get_mouth_feature(landmark_pos)
|
| 86 |
+
|
| 87 |
+
# Binary classification: drowsy (1) or non-drowsy (0)
|
| 88 |
+
drowsy = (eye_feature <= EYE_AR_THRESH) or (mouth_feature > MOUTH_AR_THRESH)
|
| 89 |
+
# face mesh
|
| 90 |
+
MP_DRAWING.draw_landmarks(
|
| 91 |
+
image=annotated_img,
|
| 92 |
+
landmark_list=face_landmarks,
|
| 93 |
+
connections=MP_FACE_MESH.FACEMESH_TESSELATION,
|
| 94 |
+
landmark_drawing_spec=None,
|
| 95 |
+
connection_drawing_spec=MP_DRAWING_STYLES
|
| 96 |
+
.get_default_face_mesh_tesselation_style()
|
| 97 |
+
)
|
| 98 |
+
# eyes and mouth regions
|
| 99 |
+
MP_DRAWING.draw_landmarks(
|
| 100 |
+
image=annotated_img,
|
| 101 |
+
landmark_list=face_landmarks,
|
| 102 |
+
connections=MP_FACE_MESH.FACEMESH_CONTOURS,
|
| 103 |
+
landmark_drawing_spec=None,
|
| 104 |
+
connection_drawing_spec=MP_DRAWING_STYLES
|
| 105 |
+
.get_default_face_mesh_contours_style()
|
| 106 |
+
)
|
| 107 |
+
return annotated_img, eye_feature, mouth_feature, drowsy
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def mediapipe_process(frames):
|
| 112 |
+
"""
|
| 113 |
+
Process all videos using MediaPipe and returns a
|
| 114 |
+
dictionary with the eye and mouth features in
|
| 115 |
+
the format {frame_number: {"eye_feature":0, "mouth_feature":0, "drowsy":0}}
|
| 116 |
+
"""
|
| 117 |
+
mp_features = {}
|
| 118 |
+
eye_features_all = []
|
| 119 |
+
mouth_features_all = []
|
| 120 |
+
# Extract eye and mouth features for all videos
|
| 121 |
+
for frame in frames:
|
| 122 |
+
mp_features[frame] = {"eye_feature": 0, "mouth_feature": 0, "drowsy": 0}
|
| 123 |
+
_,eye_feature,mouth_feature,drowsy = process_mp_img(frame)
|
| 124 |
+
mp_features[frame]["eye_feature"] = eye_feature
|
| 125 |
+
mp_features[frame]["mouth_feature"] = mouth_feature
|
| 126 |
+
mp_features[frame]["drowsy"] = drowsy
|
| 127 |
+
eye_features_all.append(eye_feature)
|
| 128 |
+
mouth_features_all.append(mouth_feature)
|
| 129 |
+
|
| 130 |
+
# Calculate mean and standard deviation for normalization
|
| 131 |
+
eye_mean, eye_std = np.mean(eye_features_all), np.std(eye_features_all)
|
| 132 |
+
mouth_mean, mouth_std = np.mean(mouth_features_all), np.std(mouth_features_all)
|
| 133 |
+
|
| 134 |
+
# Normalize eye and mouth features for all videos
|
| 135 |
+
for frame,features in mp_features.items():
|
| 136 |
+
features["eye_feature"] = (features["eye_feature"] - eye_mean) / eye_std
|
| 137 |
+
features[frame]["mouth_feature"] = (features["mouth_feature"] - mouth_mean) / mouth_std
|
| 138 |
+
|
| 139 |
+
return mp_features
|
model/data/process_data.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Module: example_module.py
|
| 3 |
+
|
| 4 |
+
This module provides functionality for processing video files and extracting
|
| 5 |
+
frame images. The primary function, `process_video_files`, is responsible for
|
| 6 |
+
downloading video files, converting them to frame images, and uploading the
|
| 7 |
+
frames back to the specified storage location.
|
| 8 |
+
|
| 9 |
+
Functions:
|
| 10 |
+
- process_video_files(bucket_name: str) -> None
|
| 11 |
+
- splice_video_to_frames(bucket_name: str, video_blob: Blob) -> None
|
| 12 |
+
|
| 13 |
+
Author: Rohit Nair
|
| 14 |
+
License: MIT License
|
| 15 |
+
Date: 2023-03-22
|
| 16 |
+
Version: 1.0.0
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import tempfile
|
| 21 |
+
import pickle
|
| 22 |
+
import numpy as np
|
| 23 |
+
import cv2
|
| 24 |
+
from google.cloud import storage
|
| 25 |
+
|
| 26 |
+
# Initialize Google Cloud Storage client
|
| 27 |
+
storage_client = storage.Client()
|
| 28 |
+
|
| 29 |
+
# Set the bucket name
|
| 30 |
+
BUCKET_NAME = "antisomnus-bucket"
|
| 31 |
+
bucket = storage_client.get_bucket(BUCKET_NAME)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Image:
|
| 35 |
+
def __init__(self,frame,dimensions:tuple):
|
| 36 |
+
self.frame = frame
|
| 37 |
+
self.height, self.width, self.depth = dimensions
|
| 38 |
+
|
| 39 |
+
def load_and_prep_image(self,scale=False):
|
| 40 |
+
frame_rgb = cv2.Color(self.frame,cv2.COLOR_BGR2RGB)
|
| 41 |
+
_, encoded_frame = cv2.imencode('.png',frame_rgb)
|
| 42 |
+
encoded_frame_bytes = encoded_frame.tobytes()
|
| 43 |
+
tensor_frame = tf.io.decode_image(encoded_frame_bytes)
|
| 44 |
+
tensor_frame = tf.image.resize(tensor_frame,(self.height,self.width))
|
| 45 |
+
if scale:
|
| 46 |
+
return tensor_frame/255.
|
| 47 |
+
else:
|
| 48 |
+
return tensor_frame
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class DriverDrowsinessDataset:
|
| 52 |
+
"""
|
| 53 |
+
DriverDrowsinessDataset
|
| 54 |
+
"""
|
| 55 |
+
def __init__(self, _data_dir, _label_dir):
|
| 56 |
+
self.data_dir = _data_dir
|
| 57 |
+
self.label_dir = _label_dir
|
| 58 |
+
|
| 59 |
+
def get_labels(self,vid_name):
|
| 60 |
+
"""
|
| 61 |
+
retrieves the labels for a video file
|
| 62 |
+
"""
|
| 63 |
+
vid_name = vid_name.split("/")[-1].split(".")[0]
|
| 64 |
+
label_file_name = self.label_dir + "/" + vid_name + "_drowsiness.txt"
|
| 65 |
+
|
| 66 |
+
# get the blob
|
| 67 |
+
label_blob = bucket.blob(label_file_name)
|
| 68 |
+
|
| 69 |
+
# download the blob to a temporary file
|
| 70 |
+
label_file = tempfile.NamedTemporaryFile(delete=False)
|
| 71 |
+
label_blob.download_to_filename(label_file.name)
|
| 72 |
+
|
| 73 |
+
# read the label file
|
| 74 |
+
labels = np.genfromtxt(label_file.name,delimiter=1,dtype=int)
|
| 75 |
+
|
| 76 |
+
# clean up
|
| 77 |
+
label_file.close()
|
| 78 |
+
os.unlink(label_file.name)
|
| 79 |
+
|
| 80 |
+
return labels
|
| 81 |
+
|
| 82 |
+
def unpkl_data(self):
|
| 83 |
+
"""get the pickled file with the data from the storage bucket and return the unpickled data"""
|
| 84 |
+
# get the blob
|
| 85 |
+
try:
|
| 86 |
+
blob = bucket.blob("training_data/training_data.pkl")
|
| 87 |
+
blob.download_to_filename("data.pkl")
|
| 88 |
+
except Exception as download_error:
|
| 89 |
+
print(download_error)
|
| 90 |
+
return False
|
| 91 |
+
|
| 92 |
+
return True
|
| 93 |
+
|
| 94 |
+
def show_data(self,file):
|
| 95 |
+
"""
|
| 96 |
+
shows data
|
| 97 |
+
"""
|
| 98 |
+
with open(file, 'rb') as pkl:
|
| 99 |
+
data_dict = pickle.load(pkl)
|
| 100 |
+
return data_dict
|
| 101 |
+
|
| 102 |
+
def get_all_data(self) -> bool:
|
| 103 |
+
"""
|
| 104 |
+
retrieves all the data in the form of a dictionary mapping image names to
|
| 105 |
+
their corresponding labels
|
| 106 |
+
format: {image_name: (image, label)}
|
| 107 |
+
"""
|
| 108 |
+
img_label_data = {}
|
| 109 |
+
# get a list of all files in the folder that ends with .avi
|
| 110 |
+
blobs = [blob for blob in
|
| 111 |
+
storage_client.list_blobs(BUCKET_NAME, prefix=self.data_dir)
|
| 112 |
+
if blob.name.endswith(".avi")]
|
| 113 |
+
|
| 114 |
+
blob_count = len(blobs)
|
| 115 |
+
|
| 116 |
+
if blob_count == 0:
|
| 117 |
+
print("No video files found in the bucket.")
|
| 118 |
+
return False
|
| 119 |
+
else:
|
| 120 |
+
print(f"Found {blob_count} video files in the bucket.")
|
| 121 |
+
|
| 122 |
+
for blob in blobs:
|
| 123 |
+
print(f"Processing video file {blob.name}...{blob_count} more to go")
|
| 124 |
+
# Download the video to a temporary file
|
| 125 |
+
video_file = tempfile.NamedTemporaryFile(delete=False)
|
| 126 |
+
blob.download_to_filename(video_file.name)
|
| 127 |
+
labels = self.get_labels(blob.name)
|
| 128 |
+
|
| 129 |
+
# Read the video and split it into frames
|
| 130 |
+
cap = cv2.VideoCapture(video_file.name)
|
| 131 |
+
frame_number = 0
|
| 132 |
+
while frame_number < len(labels):
|
| 133 |
+
ret, frame = cap.read()
|
| 134 |
+
if not ret:
|
| 135 |
+
break
|
| 136 |
+
print(f"Processing frame {frame_number}...")
|
| 137 |
+
|
| 138 |
+
# Save the frame in a dictionary
|
| 139 |
+
img_label_data[frame_number] = (frame, labels[frame_number])
|
| 140 |
+
frame_number += 1
|
| 141 |
+
|
| 142 |
+
# Clean up
|
| 143 |
+
video_file.close()
|
| 144 |
+
os.unlink(video_file.name)
|
| 145 |
+
cap.release()
|
| 146 |
+
#cv2.destroyAllWindows()
|
| 147 |
+
blob_count -= 1
|
| 148 |
+
|
| 149 |
+
# Delete the video file from Google Cloud Storage
|
| 150 |
+
# print(f"Deleting video file {blob.name}...")
|
| 151 |
+
# blob.delete()
|
| 152 |
+
# blob_count -= 1
|
| 153 |
+
# save img_label_data as a pickle file to the bucket
|
| 154 |
+
|
| 155 |
+
with open('data.pkl', 'wb') as file:
|
| 156 |
+
pickle.dump(img_label_data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
| 157 |
+
img_label_data_blob = bucket.blob("training_data/training_data.pkl")
|
| 158 |
+
img_label_data_blob.upload_from_filename('data.pkl')
|
| 159 |
+
|
| 160 |
+
print("Done processing all video files.")
|
| 161 |
+
|
| 162 |
+
return True
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
if __name__ == "__main__":
|
| 166 |
+
data = DriverDrowsinessDataset('training_data','training_data/labels')
|
| 167 |
+
data.get_all_data()
|
model/training/model_training.ipynb
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [
|
| 8 |
+
{
|
| 9 |
+
"name": "stdout",
|
| 10 |
+
"output_type": "stream",
|
| 11 |
+
"text": [
|
| 12 |
+
"zsh:1: command not found: nvidia-smi\n"
|
| 13 |
+
]
|
| 14 |
+
}
|
| 15 |
+
],
|
| 16 |
+
"source": [
|
| 17 |
+
"!nvidia-smi"
|
| 18 |
+
]
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"cell_type": "code",
|
| 22 |
+
"execution_count": 6,
|
| 23 |
+
"metadata": {},
|
| 24 |
+
"outputs": [],
|
| 25 |
+
"source": [
|
| 26 |
+
"import datetime\n",
|
| 27 |
+
"import pickle\n",
|
| 28 |
+
"import numpy as np\n",
|
| 29 |
+
"import matplotlib.pyplot as plt\n",
|
| 30 |
+
"import sys\n",
|
| 31 |
+
"\n",
|
| 32 |
+
"import tensorflow as tf\n",
|
| 33 |
+
"from tensorflow import keras\n",
|
| 34 |
+
"from tensorflow.keras import layers\n",
|
| 35 |
+
"from tensorflow.keras.layers.experimental import preprocessing\n",
|
| 36 |
+
"\n",
|
| 37 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 38 |
+
"from sklearn.metrics import confusion_matrix\n",
|
| 39 |
+
"\n",
|
| 40 |
+
"sys.path.insert(0,\"/home/ubuntu/code/DDD/model/data\")"
|
| 41 |
+
]
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"cell_type": "code",
|
| 45 |
+
"execution_count": 5,
|
| 46 |
+
"metadata": {},
|
| 47 |
+
"outputs": [],
|
| 48 |
+
"source": [
|
| 49 |
+
"from process_data import Image, DriverDrowsinessDataset"
|
| 50 |
+
]
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"cell_type": "code",
|
| 54 |
+
"execution_count": null,
|
| 55 |
+
"metadata": {},
|
| 56 |
+
"outputs": [],
|
| 57 |
+
"source": [
|
| 58 |
+
"def plot_images(images, classes, class_true, class_pred=None,smooth=True):\n",
|
| 59 |
+
" assert len(images) == len(class_true)\n",
|
| 60 |
+
" fig, axes = plt.subplots(3, 3)\n",
|
| 61 |
+
" if class_pred is None:\n",
|
| 62 |
+
" hspace = 0.3\n",
|
| 63 |
+
" else:\n",
|
| 64 |
+
" hspace = 0.6\n",
|
| 65 |
+
" fig.subplots_adjust(hspace=hspace, wspace=0.3)\n",
|
| 66 |
+
" if smooth:\n",
|
| 67 |
+
" interpolation = 'spline16'\n",
|
| 68 |
+
" else:\n",
|
| 69 |
+
" interpolation = 'nearest'\n",
|
| 70 |
+
" for i, ax in enumerate(axes.flat):\n",
|
| 71 |
+
" if i < len(images):\n",
|
| 72 |
+
" ax.imshow(images[i], interpolation=interpolation)\n",
|
| 73 |
+
" cls_true_name = classes[class_true[i]]\n",
|
| 74 |
+
" if class_pred is None:\n",
|
| 75 |
+
" xlabel = \"True: {0}\".format(class_true[i])\n",
|
| 76 |
+
" else:\n",
|
| 77 |
+
" class_pred_name = classes[class_pred[i]]\n",
|
| 78 |
+
" xlabel = \"True: {0}\\nPred: {1}\".format(cls_true_name, class_pred_name)\n",
|
| 79 |
+
" ax.set_xlabel(xlabel)\n",
|
| 80 |
+
" ax.set_xticks([])\n",
|
| 81 |
+
" ax.set_yticks([])\n",
|
| 82 |
+
" plt.show()\n",
|
| 83 |
+
"\n",
|
| 84 |
+
"def print_confusion_matrix(classes,class_test,class_pred):\n",
|
| 85 |
+
" \"\"\"\n",
|
| 86 |
+
" prints the confusion matrix. class_pred is the array of all predicted classes of each image.\n",
|
| 87 |
+
" \"\"\"\n",
|
| 88 |
+
" cm = confusion_matrix(y_true=class_test, y_pred=class_pred)\n",
|
| 89 |
+
" print(\"Confusion matrix:\")\n",
|
| 90 |
+
" print(cm)\n",
|
| 91 |
+
" for i, class_name in enumerate(classes):\n",
|
| 92 |
+
" print(\"({0}) {1}\".format(i, class_name))\n",
|
| 93 |
+
"\n",
|
| 94 |
+
"def plot_training_history(model):\n",
|
| 95 |
+
" # Get the classification accuracy and loss-value\n",
|
| 96 |
+
" # for the training-set.\n",
|
| 97 |
+
" acc = model.history['accuracy']\n",
|
| 98 |
+
" loss = model.history['loss']\n",
|
| 99 |
+
"\n",
|
| 100 |
+
" # Get it for the validation-set (we only use the test-set).\n",
|
| 101 |
+
" val_acc = model.history['val_accuracy']\n",
|
| 102 |
+
" val_loss = model.history['val_loss']\n",
|
| 103 |
+
"\n",
|
| 104 |
+
" # Plot the accuracy and loss-values for the training-set.\n",
|
| 105 |
+
" plt.plot(acc, linestyle='-', color='b', label='Training Acc.')\n",
|
| 106 |
+
" plt.plot(loss, 'o', color='b', label='Training Loss')\n",
|
| 107 |
+
" \n",
|
| 108 |
+
" # Plot it for the test-set.\n",
|
| 109 |
+
" plt.plot(val_acc, linestyle='--', color='r', label='Test Acc.')\n",
|
| 110 |
+
" plt.plot(val_loss, 'o', color='r', label='Test Loss')\n",
|
| 111 |
+
"\n",
|
| 112 |
+
" # Plot title and legend.\n",
|
| 113 |
+
" plt.title('Training and Test Accuracy')\n",
|
| 114 |
+
" plt.legend()\n",
|
| 115 |
+
"\n",
|
| 116 |
+
" # Ensure the plot shows correctly.\n",
|
| 117 |
+
" plt.show()\n"
|
| 118 |
+
]
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"cell_type": "code",
|
| 122 |
+
"execution_count": 10,
|
| 123 |
+
"metadata": {},
|
| 124 |
+
"outputs": [
|
| 125 |
+
{
|
| 126 |
+
"data": {
|
| 127 |
+
"text/plain": [
|
| 128 |
+
"(0, 0, 0, 0, 0)"
|
| 129 |
+
]
|
| 130 |
+
},
|
| 131 |
+
"execution_count": 10,
|
| 132 |
+
"metadata": {},
|
| 133 |
+
"output_type": "execute_result"
|
| 134 |
+
}
|
| 135 |
+
],
|
| 136 |
+
"source": [
|
| 137 |
+
"DIMS = (224,224,3)\n",
|
| 138 |
+
"BATCH_SIZE = 32\n",
|
| 139 |
+
"\n",
|
| 140 |
+
"with open(\"../data/data.pkl\",\"rb\") as f:\n",
|
| 141 |
+
" dataset = pickle.load(f)\n",
|
| 142 |
+
"\n",
|
| 143 |
+
"frames,labels = zip(*dataset.values())\n",
|
| 144 |
+
"\n",
|
| 145 |
+
"# preprocess all frames\n",
|
| 146 |
+
"frames = np.array([Image.load_and_prep_image(frame,dimensions=(DIMS[0],DIMS[1])) for frame in frames])\n",
|
| 147 |
+
"\n",
|
| 148 |
+
"X_train, X_val, y_train, y_val = train_test_split(frames,labels,test_size=0.2, random_state=42)\n",
|
| 149 |
+
"\n",
|
| 150 |
+
"X_train, X_val = np.array(X_train), np.array(X_val)\n",
|
| 151 |
+
"y_train, y_val = np.array(y_train), np.array(y_val)\n",
|
| 152 |
+
"\n",
|
| 153 |
+
"def data_generator(X,y,batch_size):\n",
|
| 154 |
+
" num_samples = len(X)\n",
|
| 155 |
+
" while True:\n",
|
| 156 |
+
" indices = np.arange(num_samples)\n",
|
| 157 |
+
" np.random.shuffle(indices)\n",
|
| 158 |
+
"\n",
|
| 159 |
+
" for start_idx in range(0, num_samples, batch_size):\n",
|
| 160 |
+
" batch_indices = indices[start_idx:start_idx + batch_size]\n",
|
| 161 |
+
" X_batch, y_batch = X[batch_indices], y[batch_indices]\n",
|
| 162 |
+
" yield X_batch, y_batch\n",
|
| 163 |
+
"\n",
|
| 164 |
+
"def create_tensorboard_callback(dir_name, experiment_name):\n",
|
| 165 |
+
" log_dir = dir_name + \"/\" + experiment_name + \"/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
|
| 166 |
+
" tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)\n",
|
| 167 |
+
" print(f\"Saving TensorBoard log files to: {log_dir}\")\n",
|
| 168 |
+
" return tensorboard_callback"
|
| 169 |
+
]
|
| 170 |
+
},
|
| 171 |
+
{
|
| 172 |
+
"cell_type": "code",
|
| 173 |
+
"execution_count": null,
|
| 174 |
+
"metadata": {},
|
| 175 |
+
"outputs": [],
|
| 176 |
+
"source": [
|
| 177 |
+
"train_generator = data_generator(X_train, y_train, BATCH_SIZE)\n",
|
| 178 |
+
"val_generator = data_generator(X_val, y_val, BATCH_SIZE)\n",
|
| 179 |
+
"\n",
|
| 180 |
+
"data_augmentation = keras.Sequential([\n",
|
| 181 |
+
" layers.experimental.preprocessing.RandomFlip(\"horizontal\", input_shape=DIMS),\n",
|
| 182 |
+
" layers.experimental.preprocessing.RandomRotation(0.2),\n",
|
| 183 |
+
" layers.experimental.preprocessing.RandomZoom(0.2),\n",
|
| 184 |
+
" layers.experimental.preprocessing.RandomHeight(0.2),\n",
|
| 185 |
+
" layers.experimental.preprocessing.RandomWidth(0.2),\n",
|
| 186 |
+
"], name=\"data_augmentation\")"
|
| 187 |
+
]
|
| 188 |
+
},
|
| 189 |
+
{
|
| 190 |
+
"cell_type": "code",
|
| 191 |
+
"execution_count": null,
|
| 192 |
+
"metadata": {},
|
| 193 |
+
"outputs": [],
|
| 194 |
+
"source": [
|
| 195 |
+
"BASE_MODEL = tf.keras.applications.EfficientNetV2B0(\n",
|
| 196 |
+
" input_shape=DIMS,\n",
|
| 197 |
+
" include_top=False,\n",
|
| 198 |
+
" weights=\"imagenet\",\n",
|
| 199 |
+
")\n",
|
| 200 |
+
"\n",
|
| 201 |
+
"def EfficientNet(input_shape=DIMS,base_model=BASE_MODEL, num_classes=2):\n",
|
| 202 |
+
" # freeze the base model layers\n",
|
| 203 |
+
" for layer in base_model.layers:\n",
|
| 204 |
+
" layer.trainable = False\n",
|
| 205 |
+
" \n",
|
| 206 |
+
" # create input layer\n",
|
| 207 |
+
" inputs = keras.Input(shape=input_shape, name=\"input_layer\")\n",
|
| 208 |
+
" x = data_augmentation(inputs)\n",
|
| 209 |
+
" x = base_model(x, training=False)\n",
|
| 210 |
+
" x = layers.GlobalAveragePooling2D(name=\"global_average_pooling_layer\")(x)\n",
|
| 211 |
+
" x = layers.Dense(1024, activation=\"relu\", name=\"dense_layer\")(x)\n",
|
| 212 |
+
" x = layers.Dropout(0.7, name=\"dropout_layer\")(x)\n",
|
| 213 |
+
" outputs = layers.Dense(num_classes, activation=\"sigmoid\", name=\"output_layer\")(x)\n",
|
| 214 |
+
"\n",
|
| 215 |
+
"\n",
|
| 216 |
+
" # create a new model with the EfficientNetV2B0 base model and a GlobalAveragePooling2D layer. assess model performance with metrics such as accuracy, loss, and f1 score\n",
|
| 217 |
+
" model = keras.Model(inputs, outputs, name=\"EfficientNet\")\n",
|
| 218 |
+
" model.compile(\n",
|
| 219 |
+
" optimizer=keras.optimizers.Adam(),\n",
|
| 220 |
+
" loss=\"binary_crossentropy\",\n",
|
| 221 |
+
" metrics=[\"accuracy\",keras.metrics.Precision(),keras.metrics.Recall(),keras.metrics.AUC()]\n",
|
| 222 |
+
" )\n",
|
| 223 |
+
" return model\n",
|
| 224 |
+
"\n"
|
| 225 |
+
]
|
| 226 |
+
},
|
| 227 |
+
{
|
| 228 |
+
"cell_type": "code",
|
| 229 |
+
"execution_count": null,
|
| 230 |
+
"metadata": {},
|
| 231 |
+
"outputs": [],
|
| 232 |
+
"source": [
|
| 233 |
+
"model = EfficientNet()\n",
|
| 234 |
+
"model.summary()"
|
| 235 |
+
]
|
| 236 |
+
},
|
| 237 |
+
{
|
| 238 |
+
"cell_type": "code",
|
| 239 |
+
"execution_count": null,
|
| 240 |
+
"metadata": {},
|
| 241 |
+
"outputs": [],
|
| 242 |
+
"source": [
|
| 243 |
+
"print(model.layers[3].dtype)\n",
|
| 244 |
+
"print(model.layers[3].dtype_policy)"
|
| 245 |
+
]
|
| 246 |
+
},
|
| 247 |
+
{
|
| 248 |
+
"cell_type": "code",
|
| 249 |
+
"execution_count": null,
|
| 250 |
+
"metadata": {},
|
| 251 |
+
"outputs": [],
|
| 252 |
+
"source": [
|
| 253 |
+
"model = model.fit(train_generator,\n",
|
| 254 |
+
" epochs=10,\n",
|
| 255 |
+
" steps_per_epoch=len(X_train)//BATCH_SIZE,validation_data=val_generator,\n",
|
| 256 |
+
" validation_steps=len(X_val)//BATCH_SIZE,\n",
|
| 257 |
+
" callbacks=[create_tensorboard_callback(\"logs\",\"EfficientNetV2B0\")])"
|
| 258 |
+
]
|
| 259 |
+
},
|
| 260 |
+
{
|
| 261 |
+
"cell_type": "code",
|
| 262 |
+
"execution_count": null,
|
| 263 |
+
"metadata": {},
|
| 264 |
+
"outputs": [],
|
| 265 |
+
"source": [
|
| 266 |
+
"model.save(\"en_model_v1.h5\")"
|
| 267 |
+
]
|
| 268 |
+
}
|
| 269 |
+
],
|
| 270 |
+
"metadata": {
|
| 271 |
+
"kernelspec": {
|
| 272 |
+
"display_name": "notebook",
|
| 273 |
+
"language": "python",
|
| 274 |
+
"name": "python3"
|
| 275 |
+
},
|
| 276 |
+
"language_info": {
|
| 277 |
+
"codemirror_mode": {
|
| 278 |
+
"name": "ipython",
|
| 279 |
+
"version": 3
|
| 280 |
+
},
|
| 281 |
+
"file_extension": ".py",
|
| 282 |
+
"mimetype": "text/x-python",
|
| 283 |
+
"name": "python",
|
| 284 |
+
"nbconvert_exporter": "python",
|
| 285 |
+
"pygments_lexer": "ipython3",
|
| 286 |
+
"version": "3.10.6"
|
| 287 |
+
},
|
| 288 |
+
"orig_nbformat": 4
|
| 289 |
+
},
|
| 290 |
+
"nbformat": 4,
|
| 291 |
+
"nbformat_minor": 2
|
| 292 |
+
}
|
model/training/saved_models/README.md
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Index of Models
|
| 2 |
+
1. [en_model_v0](en_model_v0.h5) --> 04:14:2023:03:53:45
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
opencv-python-headless
|
| 3 |
+
numpy
|
| 4 |
+
polars
|
| 5 |
+
seaborn
|
| 6 |
+
matplotlib
|
| 7 |
+
scikit-learn
|
| 8 |
+
scipy
|
| 9 |
+
tensorflow
|
| 10 |
+
mediapipe
|