rodolphethinks1 commited on
Commit
e34b3e2
·
verified ·
1 Parent(s): 92f9678

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +287 -37
src/streamlit_app.py CHANGED
@@ -1,40 +1,290 @@
1
- import altair as alt
 
 
2
  import numpy as np
 
3
  import pandas as pd
4
- import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import cv2
2
+ import mediapipe as mp
3
+ import streamlit as st
4
  import numpy as np
5
+ import time
6
  import pandas as pd
7
+ import altair as alt
8
+ from scipy.signal import butter, filtfilt, find_peaks
9
+
10
+ # -------------------------------
11
+ # Constants
12
+ # -------------------------------
13
+ LEFT_EYE = [33, 160, 158, 133, 153, 144]
14
+ RIGHT_EYE = [362, 385, 387, 263, 373, 380]
15
+ FOREHEAD_ROI = [10, 109, 67, 103, 54, 21, 162, 127, 234, 93, 132, 58, 172, 136, 150, 149, 176, 148]
16
+ BUFFER_SIZE = 300 # ~10s if 30 FPS
17
+ EAR_THRESHOLD = 0.22
18
+ DROWSINESS_TIME_THRESHOLD = 2.0 # seconds
19
+
20
+ # -------------------------------
21
+ # Mediapipe face mesh
22
+ # -------------------------------
23
+ mp_face_mesh = mp.solutions.face_mesh
24
+
25
+ # -------------------------------
26
+ # Helper Functions
27
+ # -------------------------------
28
+ def get_eye_aspect_ratio(landmarks, eye_indices):
29
+ """Calculates the Eye Aspect Ratio (EAR) for a single eye."""
30
+ pts = np.array([(landmarks[i].x, landmarks[i].y) for i in eye_indices])
31
+ A = np.linalg.norm(pts[1] - pts[5])
32
+ B = np.linalg.norm(pts[2] - pts[4])
33
+ C = np.linalg.norm(pts[0] - pts[3])
34
+ ear = (A + B) / (2.0 * C)
35
+ return ear
36
+
37
+ def bandpass_filter(data, low=0.8, high=2.5, fs=30):
38
+ """Applies a bandpass filter to the signal."""
39
+ nyq = 0.5 * fs
40
+ b, a = butter(1, [low/nyq, high/nyq], btype="band")
41
+ return filtfilt(b, a, data)
42
+
43
+ def compute_hr_hrv(signal, times, fs=30):
44
+ """Computes Heart Rate (HR) and Heart Rate Variability (HRV) from the rPPG signal."""
45
+ if len(signal) < fs * 3: # need at least 3s of data
46
+ return None, None
47
+
48
+ try:
49
+ # Detrending the signal to remove baseline wander
50
+ signal_detrended = signal - np.mean(signal)
51
+
52
+ filtered = bandpass_filter(signal_detrended, fs=fs)
53
+
54
+ # Using a more robust peak finding
55
+ peaks, properties = find_peaks(filtered, distance=fs*0.7, prominence=np.std(filtered)*0.3)
56
+
57
+ if len(peaks) < 3: # Need at least 3 peaks for a more stable HR
58
+ return None, None
59
+
60
+ peak_times = np.array(times)[peaks]
61
+ rr_intervals = np.diff(peak_times) # in seconds
62
+
63
+ # Basic outlier removal for RR intervals
64
+ median_rr = np.median(rr_intervals)
65
+ valid_rr = rr_intervals[np.abs(rr_intervals - median_rr) < 0.3 * median_rr]
66
+
67
+ if len(valid_rr) < 2:
68
+ return None, None
69
+
70
+ hr = 60.0 / np.mean(valid_rr)
71
+ hrv = np.std(valid_rr) * 1000 # RMSSD is a better HRV metric, but this is a start
72
+
73
+ # Plausible HR range
74
+ if not (40 < hr < 160):
75
+ return None, None
76
+
77
+ return hr, hrv
78
+ except (np.linalg.LinAlgError, ValueError):
79
+ return None, None
80
+
81
+ def initialize_session_state():
82
+ """Initializes Streamlit session state variables."""
83
+ if "history" not in st.session_state:
84
+ st.session_state.history = {
85
+ "time": [], "blink_rate": [], "hr": [], "hrv": []
86
+ }
87
+ if "blink_count" not in st.session_state:
88
+ st.session_state.blink_count = 0
89
+ if "last_eye_state" not in st.session_state:
90
+ st.session_state.last_eye_state = "open"
91
+ if "start_time" not in st.session_state:
92
+ st.session_state.start_time = time.time()
93
+ if "signal_buffer" not in st.session_state:
94
+ st.session_state.signal_buffer = []
95
+ if "time_buffer" not in st.session_state:
96
+ st.session_state.time_buffer = []
97
+ if "drowsy_start_time" not in st.session_state:
98
+ st.session_state.drowsy_start_time = None
99
+
100
+ def update_dashboard(placeholders, data):
101
+ """Updates the Streamlit dashboard with new data."""
102
+ placeholders["stframe"].image(data["frame"], channels="BGR")
103
+
104
+ # --- Metrics ---
105
+ placeholders["metrics"]["blink"].metric("Blink Rate (per min)", f"{data['blink_rate']:.2f}")
106
+ placeholders["metrics"]["hr"].metric("Heart Rate (bpm)", f"{data['hr']:.1f}" if data['hr'] is not None else "N/A")
107
+ placeholders["metrics"]["hrv"].metric("HRV (ms)", f"{data['hrv']:.1f}" if data['hrv'] is not None else "N/A")
108
+
109
+ # --- Drowsiness Alert ---
110
+ if data["drowsy_alert"]:
111
+ placeholders["alert"].warning("🚨 Drowsiness Detected!")
112
+ else:
113
+ placeholders["alert"].empty()
114
+
115
+ # --- Charts ---
116
+ df = pd.DataFrame(st.session_state.history)
117
+ df["time"] = pd.to_datetime(df["time"], unit="s")
118
+
119
+ with placeholders["charts"]["blink_tab"]:
120
+ chart = alt.Chart(df).mark_line().encode(
121
+ x=alt.X('time:T', title='Time'),
122
+ y=alt.Y('blink_rate:Q', title='Blink Rate (per min)')
123
+ ).properties(title="Blink Rate Over Time")
124
+ placeholders["charts"]["blink_chart"].altair_chart(chart, use_container_width=True)
125
+
126
+ with placeholders["charts"]["hr_tab"]:
127
+ chart = alt.Chart(df).mark_line().encode(
128
+ x=alt.X('time:T', title='Time'),
129
+ y=alt.Y('hr:Q', title='Heart Rate (bpm)')
130
+ ).properties(title="Heart Rate Over Time")
131
+ placeholders["charts"]["hr_chart"].altair_chart(chart, use_container_width=True)
132
+
133
+ with placeholders["charts"]["hrv_tab"]:
134
+ chart = alt.Chart(df).mark_line().encode(
135
+ x=alt.X('time:T', title='Time'),
136
+ y=alt.Y('hrv:Q', title='HRV (ms)')
137
+ ).properties(title="HRV Over Time")
138
+ placeholders["charts"]["hrv_chart"].altair_chart(chart, use_container_width=True)
139
+
140
+ def main():
141
+ """Main function to run the Streamlit application."""
142
+ st.set_page_config(page_title="DriFit - Driver Monitoring", layout="wide")
143
+ st.title("DriFit: In-Car Driver Health & Fatigue Monitoring")
144
+ st.info("This application uses your webcam to monitor driver fatigue and health metrics in real-time.")
145
+
146
+ initialize_session_state()
147
+
148
+ # --- UI Placeholders ---
149
+ col1, col2 = st.columns([2, 1])
150
+ with col1:
151
+ stframe = st.empty()
152
+ alert_placeholder = st.empty()
153
+ with col2:
154
+ m_col1, m_col2, m_col3 = st.columns(3)
155
+ st.subheader("Metrics")
156
+ blink_metric_placeholder = m_col1.empty()
157
+ hr_metric_placeholder = m_col2.empty()
158
+ hrv_metric_placeholder = m_col3.empty()
159
+
160
+ st.subheader("Metrics Over Time")
161
+ blink_tab, hr_tab, hrv_tab = st.tabs(["Blink Rate", "Heart Rate", "HRV"])
162
+ with blink_tab:
163
+ blink_chart_placeholder = st.empty()
164
+ with hr_tab:
165
+ hr_chart_placeholder = st.empty()
166
+ with hrv_tab:
167
+ hrv_chart_placeholder = st.empty()
168
+
169
+ placeholders = {
170
+ "stframe": stframe,
171
+ "alert": alert_placeholder,
172
+ "metrics": {"blink": blink_metric_placeholder, "hr": hr_metric_placeholder, "hrv": hrv_metric_placeholder},
173
+ "charts": {
174
+ "blink_tab": blink_tab, "hr_tab": hr_tab, "hrv_tab": hrv_tab,
175
+ "blink_chart": blink_chart_placeholder,
176
+ "hr_chart": hr_chart_placeholder,
177
+ "hrv_chart": hrv_chart_placeholder
178
+ }
179
+ }
180
+
181
+ # --- Webcam and Face Mesh ---
182
+ if 'cap' not in st.session_state:
183
+ st.session_state.cap = cv2.VideoCapture(0)
184
+ if 'face_mesh' not in st.session_state:
185
+ st.session_state.face_mesh = mp_face_mesh.FaceMesh(refine_landmarks=True)
186
+
187
+ cap = st.session_state.cap
188
+ face_mesh = st.session_state.face_mesh
189
+
190
+ run = st.checkbox('Run')
191
+
192
+ if not cap.isOpened():
193
+ st.error("Could not open webcam. Please grant access and refresh.")
194
+ return
195
+
196
+ while run:
197
+ ret, frame = cap.read()
198
+ if not ret:
199
+ st.warning("Could not read frame from webcam. Stopping.")
200
+ run = False
201
+ break
202
+
203
+ frame = cv2.flip(frame, 1)
204
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
205
+ results = face_mesh.process(rgb_frame)
206
+
207
+ eye_state = "open"
208
+ hr, hrv = None, None
209
+ drowsy_alert = False
210
+
211
+ if results.multi_face_landmarks:
212
+ face_landmarks = results.multi_face_landmarks[0]
213
+
214
+ # --- Eye tracking (fatigue) ---
215
+ left_ear = get_eye_aspect_ratio(face_landmarks.landmark, LEFT_EYE)
216
+ right_ear = get_eye_aspect_ratio(face_landmarks.landmark, RIGHT_EYE)
217
+ avg_ear = (left_ear + right_ear) / 2.0
218
+
219
+ if avg_ear < EAR_THRESHOLD:
220
+ eye_state = "closed"
221
+ if st.session_state.drowsy_start_time is None:
222
+ st.session_state.drowsy_start_time = time.time()
223
+ elif time.time() - st.session_state.drowsy_start_time > DROWSINESS_TIME_THRESHOLD:
224
+ drowsy_alert = True
225
+ else:
226
+ eye_state = "open"
227
+ st.session_state.drowsy_start_time = None
228
+
229
+ if st.session_state.last_eye_state == "closed" and eye_state == "open":
230
+ st.session_state.blink_count += 1
231
+ st.session_state.last_eye_state = eye_state
232
+
233
+ # --- rPPG HR & HRV (forehead ROI) ---
234
+ h, w, _ = frame.shape
235
+ forehead_pts = np.array([(face_landmarks.landmark[i].x * w, face_landmarks.landmark[i].y * h) for i in FOREHEAD_ROI], dtype=np.int32)
236
+
237
+ mask = np.zeros(frame.shape[:2], dtype=np.uint8)
238
+ cv2.fillConvexPoly(mask, forehead_pts, 255)
239
+
240
+ roi = cv2.bitwise_and(frame, frame, mask=mask)
241
+
242
+ x, y, w_roi, h_roi = cv2.boundingRect(forehead_pts)
243
+
244
+ if w_roi > 0 and h_roi > 0:
245
+ roi_cropped = roi[y:y+h_roi, x:x+w_roi]
246
+ if roi_cropped.size > 0:
247
+ green_mean = np.mean(roi_cropped[:, :, 1])
248
+ st.session_state.signal_buffer.append(green_mean)
249
+ st.session_state.time_buffer.append(time.time())
250
+
251
+ if len(st.session_state.signal_buffer) > BUFFER_SIZE:
252
+ st.session_state.signal_buffer.pop(0)
253
+ st.session_state.time_buffer.pop(0)
254
+
255
+ hr, hrv = compute_hr_hrv(st.session_state.signal_buffer, st.session_state.time_buffer)
256
+
257
+ cv2.polylines(frame, [forehead_pts], isClosed=True, color=(0, 255, 0), thickness=1)
258
+
259
+ # --- Data Update ---
260
+ elapsed = time.time() - st.session_state.start_time
261
+ blink_rate = (st.session_state.blink_count / (elapsed / 60)) if elapsed > 5 else 0.0
262
+
263
+ # Update history
264
+ st.session_state.history["time"].append(time.time())
265
+ st.session_state.history["blink_rate"].append(blink_rate)
266
+ st.session_state.history["hr"].append(hr)
267
+ st.session_state.history["hrv"].append(hrv)
268
+
269
+ for key in st.session_state.history:
270
+ st.session_state.history[key] = st.session_state.history[key][-100:]
271
+
272
+ # --- Dashboard Update ---
273
+ update_data = {
274
+ "frame": frame,
275
+ "blink_rate": blink_rate,
276
+ "hr": hr,
277
+ "hrv": hrv,
278
+ "drowsy_alert": drowsy_alert
279
+ }
280
+ update_dashboard(placeholders, update_data)
281
+
282
+ else:
283
+ if 'cap' in st.session_state:
284
+ st.session_state.cap.release()
285
+ del st.session_state.cap
286
+
287
 
288
+ if __name__ == "__main__":
289
+ main()
290
+