tbdavid2019 commited on
Commit
96f2d7d
·
1 Parent(s): a5b98b3
Files changed (2) hide show
  1. app.py +468 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ rPPG Heart Rate Estimation using OpenCV and POS algorithm
4
+ """
5
+ import os
6
+ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
7
+
8
+ import gradio as gr
9
+ import cv2
10
+ import numpy as np
11
+ import matplotlib.pyplot as plt
12
+ from scipy import signal
13
+ from scipy.fft import fft, fftfreq
14
+ import tempfile
15
+ import time
16
+ from tqdm import tqdm
17
+
18
+ class SimpleRPPG:
19
+ def __init__(self, min_bpm=45, max_bpm=180):
20
+ self.min_bpm = min_bpm
21
+ self.max_bpm = max_bpm
22
+ self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
23
+
24
+ def detect_faces(self, frame):
25
+ """Detect faces using OpenCV Haar cascades"""
26
+ gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
27
+
28
+ # Try multiple parameter sets for better detection
29
+ param_sets = [
30
+ {"scaleFactor": 1.1, "minNeighbors": 5, "minSize": (50, 50)},
31
+ {"scaleFactor": 1.05, "minNeighbors": 3, "minSize": (30, 30)},
32
+ {"scaleFactor": 1.2, "minNeighbors": 6, "minSize": (60, 60)},
33
+ ]
34
+
35
+ for params in param_sets:
36
+ faces = self.face_cascade.detectMultiScale(gray, **params)
37
+ if len(faces) > 0:
38
+ return faces
39
+
40
+ return []
41
+
42
+ def extract_roi_signal(self, frame, face_box):
43
+ """Extract ROI and compute mean RGB values"""
44
+ x, y, w, h = face_box
45
+
46
+ # Define ROI (forehead and cheek areas)
47
+ roi_y1 = y + int(0.2 * h)
48
+ roi_y2 = y + int(0.7 * h)
49
+ roi_x1 = x + int(0.15 * w)
50
+ roi_x2 = x + int(0.85 * w)
51
+
52
+ roi = frame[roi_y1:roi_y2, roi_x1:roi_x2]
53
+
54
+ if roi.size == 0:
55
+ return None
56
+
57
+ # Calculate mean RGB values
58
+ mean_rgb = np.mean(roi, axis=(0, 1))
59
+ return mean_rgb
60
+
61
+ def pos_algorithm(self, rgb_signals, fps):
62
+ """POS (Plane-Orthogonal-to-Skin) algorithm"""
63
+ if len(rgb_signals) < 30: # Need at least 1 second of data at 30fps
64
+ return None, None
65
+
66
+ rgb_signals = np.array(rgb_signals)
67
+
68
+ # Normalize RGB signals
69
+ mean_rgb = np.mean(rgb_signals, axis=0)
70
+ normalized_rgb = rgb_signals / mean_rgb
71
+
72
+ # POS algorithm
73
+ X1 = normalized_rgb[:, 0] - normalized_rgb[:, 1] # R - G
74
+ X2 = normalized_rgb[:, 0] + normalized_rgb[:, 1] - 2 * normalized_rgb[:, 2] # R + G - 2B
75
+
76
+ # Temporal filtering (bandpass)
77
+ low_freq = self.min_bpm / 60.0
78
+ high_freq = self.max_bpm / 60.0
79
+
80
+ sos = signal.butter(4, [low_freq, high_freq], btype='band', fs=fps, output='sos')
81
+ X1_filtered = signal.sosfilt(sos, X1)
82
+ X2_filtered = signal.sosfilt(sos, X2)
83
+
84
+ # POS combination
85
+ alpha = np.std(X1_filtered) / np.std(X2_filtered)
86
+ pulse_signal = X1_filtered - alpha * X2_filtered
87
+
88
+ return pulse_signal, self.estimate_heart_rate(pulse_signal, fps)
89
+
90
+ def estimate_heart_rate(self, pulse_signal, fps):
91
+ """Estimate heart rate using FFT"""
92
+ if len(pulse_signal) < fps: # Need at least 1 second
93
+ return None
94
+
95
+ # Apply window function
96
+ windowed_signal = pulse_signal * signal.windows.hann(len(pulse_signal))
97
+
98
+ # FFT
99
+ freqs = fftfreq(len(windowed_signal), 1/fps)
100
+ fft_values = np.abs(fft(windowed_signal))
101
+
102
+ # Find frequency range corresponding to heart rate
103
+ min_freq = self.min_bpm / 60.0
104
+ max_freq = self.max_bpm / 60.0
105
+
106
+ valid_indices = (freqs >= min_freq) & (freqs <= max_freq)
107
+ if not np.any(valid_indices):
108
+ return None
109
+
110
+ valid_freqs = freqs[valid_indices]
111
+ valid_fft = fft_values[valid_indices]
112
+
113
+ # Find peak frequency
114
+ peak_idx = np.argmax(valid_fft)
115
+ peak_freq = valid_freqs[peak_idx]
116
+
117
+ heart_rate = peak_freq * 60.0
118
+
119
+ # Confidence based on peak prominence
120
+ confidence = np.max(valid_fft) / np.mean(valid_fft)
121
+ confidence = min(confidence / 10.0, 1.0) # Normalize to 0-1
122
+
123
+ return {"hr": heart_rate, "confidence": confidence}
124
+
125
+ def process_video(self, video_path, window_seconds=10.0, step_seconds=2.0, conf_threshold=0.3, progress_callback=None):
126
+ """Process video and extract heart rate"""
127
+ cap = cv2.VideoCapture(video_path)
128
+ fps = cap.get(cv2.CAP_PROP_FPS)
129
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
130
+
131
+ if fps <= 0 or total_frames <= 0:
132
+ return [], [], []
133
+
134
+ window_frames = int(window_seconds * fps)
135
+ step_frames = int(step_seconds * fps)
136
+
137
+ results_time = []
138
+ results_hr = []
139
+ results_conf = []
140
+
141
+ frame_buffer = []
142
+ rgb_buffer = []
143
+
144
+ frame_idx = 0
145
+ processed_chunks = 0
146
+
147
+ # Console progress bar
148
+ pbar = tqdm(total=total_frames, desc="Processing video", unit="frames")
149
+
150
+ # First check for face detection
151
+ if progress_callback:
152
+ progress_callback(0.1, "🔍 檢測人臉中...")
153
+
154
+ face_found = False
155
+ for i in range(0, min(300, total_frames), 30): # Check first 10 seconds
156
+ cap.set(cv2.CAP_PROP_POS_FRAMES, i)
157
+ ret, frame = cap.read()
158
+ if ret:
159
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
160
+ faces = self.detect_faces(rgb_frame)
161
+ if len(faces) > 0:
162
+ face_found = True
163
+ if progress_callback:
164
+ progress_callback(0.15, f"✅ 在第 {i} 幀 ({i/fps:.1f}秒) 檢測到人臉!")
165
+ break
166
+
167
+ if not face_found:
168
+ if progress_callback:
169
+ progress_callback(0.15, "⚠️ 未檢測到人臉,繼續處理...")
170
+
171
+ # Reset to beginning and process in chunks
172
+ cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
173
+
174
+ estimated_chunks = max(1, (total_frames - window_frames) // step_frames + 1)
175
+ pbar.reset(total=estimated_chunks)
176
+ pbar.set_description("Processing chunks")
177
+
178
+ processed_chunks = 0
179
+
180
+ # Process video in chunks (much more efficient)
181
+ for chunk_start in range(0, total_frames - window_frames + 1, step_frames):
182
+ chunk_frames = []
183
+
184
+ # Read a batch of frames for this chunk
185
+ cap.set(cv2.CAP_PROP_POS_FRAMES, chunk_start)
186
+ batch_frames = []
187
+
188
+ # Read all frames for this window at once
189
+ for i in range(window_frames):
190
+ ret, frame = cap.read()
191
+ if not ret:
192
+ break
193
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
194
+ batch_frames.append(rgb_frame)
195
+
196
+ # Detect face only in the first frame of the batch
197
+ if len(batch_frames) > 0:
198
+ faces = self.detect_faces(batch_frames[0])
199
+ if len(faces) > 0:
200
+ current_face_box = max(faces, key=lambda x: x[2] * x[3])
201
+
202
+ # Extract signals from all frames using the same face box
203
+ for rgb_frame in batch_frames:
204
+ rgb_signal = self.extract_roi_signal(rgb_frame, current_face_box)
205
+ if rgb_signal is not None:
206
+ chunk_frames.append(rgb_signal)
207
+
208
+ # Process this chunk if we have enough data
209
+ if len(chunk_frames) >= fps: # Need at least 1 second of data
210
+ pulse_signal, hr_result = self.pos_algorithm(chunk_frames, fps)
211
+
212
+ if hr_result is not None and hr_result["hr"] > 0 and hr_result["confidence"] >= conf_threshold:
213
+ t_sec = (chunk_start + window_frames // 2) / fps # Center time of window
214
+ results_time.append(t_sec)
215
+ results_hr.append(hr_result["hr"])
216
+ results_conf.append(hr_result["confidence"])
217
+
218
+ print(f"✅ Chunk {processed_chunks + 1}: HR = {hr_result['hr']:.1f} BPM at {t_sec:.1f}s")
219
+
220
+ processed_chunks += 1
221
+ pbar.update(1)
222
+
223
+ # Update Gradio progress
224
+ if progress_callback:
225
+ progress_val = 0.15 + (processed_chunks / estimated_chunks) * 0.7
226
+ if len(results_hr) > 0:
227
+ progress_callback(progress_val, f"💓 找到 {len(results_hr)} 個心率測量值")
228
+ else:
229
+ progress_callback(progress_val, f"處理第 {processed_chunks}/{estimated_chunks} 段...")
230
+
231
+ # Early termination if we have enough successful measurements
232
+ if len(results_hr) >= 10: # Stop if we have 10 good measurements
233
+ print(f"✅ Early termination: Found {len(results_hr)} measurements")
234
+ break
235
+
236
+ cap.release()
237
+ pbar.close() # Close console progress bar
238
+
239
+ if progress_callback:
240
+ progress_callback(1.0, f"完成!找到 {len(results_hr)} 個心率測量值")
241
+
242
+ return results_time, results_hr, results_conf
243
+
244
+ def quick_face_check(video_path, progress=None):
245
+ """Quick face detection check"""
246
+ if not video_path:
247
+ return "請先上傳影片檔案"
248
+
249
+ cap = cv2.VideoCapture(video_path)
250
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
251
+ fps = cap.get(cv2.CAP_PROP_FPS)
252
+
253
+ if progress:
254
+ progress(0.1, "🎬 開始檢查影片...")
255
+
256
+ # 載入 OpenCV 人臉檢測器
257
+ face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
258
+
259
+ # Console progress bar for face detection
260
+ face_pbar = tqdm(total=total_frames//15, desc="Face detection", unit="frames")
261
+
262
+ face_detected = False
263
+ face_found_at_frame = None
264
+
265
+ for i in range(0, total_frames, 15): # 每隔15幀檢查一次
266
+ cap.set(cv2.CAP_PROP_POS_FRAMES, i)
267
+ ret, frame = cap.read()
268
+ if ret:
269
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
270
+
271
+ # 嘗試多種參數組合
272
+ param_sets = [
273
+ {"scaleFactor": 1.1, "minNeighbors": 5, "minSize": (30, 30)},
274
+ {"scaleFactor": 1.05, "minNeighbors": 3, "minSize": (20, 20)},
275
+ {"scaleFactor": 1.2, "minNeighbors": 6, "minSize": (40, 40)},
276
+ ]
277
+
278
+ faces_found = False
279
+ for params in param_sets:
280
+ faces = face_cascade.detectMultiScale(gray, **params)
281
+ if len(faces) > 0:
282
+ faces_found = True
283
+ face_detected = True
284
+ face_found_at_frame = i
285
+ time_stamp = i / fps
286
+
287
+ if progress:
288
+ progress(0.8, f"✅ 在第 {i} 幀 ({time_stamp:.1f}秒) 檢測到 {len(faces)} 個人臉!")
289
+ break
290
+
291
+ if faces_found:
292
+ break
293
+
294
+ face_pbar.update(1) # Update console progress bar
295
+
296
+ # 更新檢測進度
297
+ if progress and i % 150 == 0:
298
+ detection_progress = 0.1 + min((i / total_frames) * 0.7, 0.7)
299
+ current_time = i / fps
300
+ progress(detection_progress, f"🔍 檢測人臉中... 已檢查到 {current_time:.1f}秒")
301
+
302
+ cap.release()
303
+ face_pbar.close() # Close console progress bar
304
+
305
+ if face_detected:
306
+ success_msg = f"✅ 成功!在第 {face_found_at_frame} 幀 ({face_found_at_frame/fps:.1f}秒) 檢測到人臉"
307
+ if progress:
308
+ progress(1.0, success_msg)
309
+ return success_msg + "\n\n💡 這個影片適合進行心率分析!"
310
+ else:
311
+ fail_msg = "❌ 整個影片中未檢測到人臉"
312
+ if progress:
313
+ progress(1.0, fail_msg)
314
+ return fail_msg + "\n\n📋 建議:\n• 確保影片中有清晰的正面人臉\n• 檢查光線是否充足\n• 避免過度的頭部移動"
315
+
316
+ def process_video(video_path, method, window, step, min_bpm, max_bpm, conf, progress=gr.Progress()):
317
+ """Process video and extract heart rate"""
318
+ if not video_path:
319
+ return "請上傳影片檔案", None, None
320
+
321
+ start_time = time.time()
322
+ print(f"🚀 開始處理影片: {video_path}")
323
+
324
+ # Initialize rPPG processor
325
+ rppg = SimpleRPPG(min_bpm=min_bpm, max_bpm=max_bpm)
326
+
327
+ # Process video
328
+ ts, hr, cf = rppg.process_video(
329
+ video_path,
330
+ window_seconds=window,
331
+ step_seconds=step,
332
+ conf_threshold=conf,
333
+ progress_callback=progress
334
+ )
335
+
336
+ processing_time = time.time() - start_time
337
+ print(f"⏱️ 處理完成!耗時: {processing_time:.1f} 秒,找到 {len(hr)} 個心率測量值")
338
+
339
+ if not hr:
340
+ return f"未檢測到心率數據。處理時間: {processing_time:.1f}秒", None, None
341
+
342
+ # Create CSV
343
+ csv_content = "time_sec,hr_bpm,confidence\n"
344
+ for a, b, c in zip(ts, hr, cf):
345
+ csv_content += f"{a:.2f},{b:.2f},{c:.3f}\n"
346
+
347
+ # Create plot
348
+ plt.figure(figsize=(10, 4))
349
+ plt.plot(ts, hr, 'b-', linewidth=2)
350
+ plt.xlabel("Time (s)")
351
+ plt.ylabel("Heart Rate (bpm)")
352
+ plt.title(f"Heart Rate Estimation (Avg: {np.mean(hr):.1f} BPM)")
353
+ plt.grid(True)
354
+ plt.tight_layout()
355
+
356
+ # Save plot to temp file
357
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
358
+ plt.savefig(tmp.name, dpi=150, bbox_inches='tight')
359
+ plot_path = tmp.name
360
+
361
+ plt.close()
362
+
363
+ # Save CSV to temp file
364
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp:
365
+ tmp.write(csv_content)
366
+ csv_path = tmp.name
367
+
368
+ result_msg = f"✅ 成功分析!\n平均心率: {np.mean(hr):.1f} BPM\n測量點數: {len(hr)}\n處理時間: {processing_time:.1f} 秒"
369
+
370
+ return result_msg, plot_path, csv_path
371
+
372
+ # Gradio interface
373
+ with gr.Blocks(title="rPPG Heart Rate Analysis") as demo:
374
+ gr.Markdown("# rPPG Heart Rate Analysis")
375
+ gr.Markdown("Upload a video to estimate heart rate using computer vision.")
376
+
377
+ with gr.Tabs():
378
+ with gr.Tab("Heart Rate Analysis"):
379
+ with gr.Row():
380
+ with gr.Column():
381
+ video_input = gr.Video(label="Upload Video")
382
+
383
+ with gr.Row():
384
+ method_select = gr.Dropdown(
385
+ choices=["POS"],
386
+ value="POS",
387
+ label="Method"
388
+ )
389
+
390
+ conf_slider = gr.Slider(
391
+ minimum=0.0,
392
+ maximum=1.0,
393
+ value=0.3,
394
+ step=0.1,
395
+ label="Confidence Threshold"
396
+ )
397
+
398
+ with gr.Row():
399
+ window_slider = gr.Slider(
400
+ minimum=5.0,
401
+ maximum=30.0,
402
+ value=10.0,
403
+ step=1.0,
404
+ label="Window (sec)"
405
+ )
406
+
407
+ step_slider = gr.Slider(
408
+ minimum=0.5,
409
+ maximum=5.0,
410
+ value=2.0,
411
+ step=0.5,
412
+ label="Step (sec)"
413
+ )
414
+
415
+ with gr.Row():
416
+ min_bpm = gr.Slider(
417
+ minimum=30,
418
+ maximum=100,
419
+ value=45,
420
+ step=5,
421
+ label="Min BPM"
422
+ )
423
+
424
+ max_bpm = gr.Slider(
425
+ minimum=100,
426
+ maximum=200,
427
+ value=180,
428
+ step=5,
429
+ label="Max BPM"
430
+ )
431
+
432
+ process_btn = gr.Button("Process Video", variant="primary", size="lg")
433
+
434
+ with gr.Column():
435
+ result_text = gr.Textbox(label="Results", lines=4)
436
+ plot_output = gr.Image(label="Heart Rate Plot")
437
+ csv_output = gr.File(label="Download CSV Data")
438
+
439
+ with gr.Tab("Face Detection Test"):
440
+ with gr.Row():
441
+ with gr.Column():
442
+ test_video_input = gr.Video(label="Upload Video for Face Test")
443
+ check_btn = gr.Button("Test Face Detection", variant="secondary", size="lg")
444
+
445
+ with gr.Column():
446
+ check_result = gr.Textbox(label="Face Detection Results", lines=8)
447
+
448
+ # Connect functions
449
+ process_btn.click(
450
+ fn=process_video,
451
+ inputs=[video_input, method_select, window_slider, step_slider, min_bpm, max_bpm, conf_slider],
452
+ outputs=[result_text, plot_output, csv_output],
453
+ show_progress=True
454
+ )
455
+
456
+ check_btn.click(
457
+ fn=quick_face_check,
458
+ inputs=[test_video_input],
459
+ outputs=[check_result],
460
+ show_progress=True
461
+ )
462
+
463
+ if __name__ == "__main__":
464
+ demo.launch(
465
+ server_name="127.0.0.1",
466
+ server_port=7860,
467
+ share=False
468
+ )
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ opencv-python
3
+ numpy
4
+ matplotlib
5
+ scipy
6
+ tqdm