Spaces:
Sleeping
Sleeping
| import spaces | |
| import time | |
| import numpy as np | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| from scipy.signal import butter, filtfilt, find_peaks, coherence | |
| from scipy.interpolate import interp1d | |
| from dataclasses import dataclass | |
| from typing import Optional, Tuple | |
| # GPU Support: Try to import CuPy for GPU acceleration | |
| try: | |
| import cupy as cp | |
| from cupyx.scipy.signal import filtfilt as cp_filtfilt | |
| GPU_AVAILABLE = True | |
| print("GPU acceleration enabled with CuPy") | |
| except ImportError: | |
| GPU_AVAILABLE = False | |
| cp = np # Fallback to NumPy | |
| print("CuPy not available, using CPU (NumPy)") | |
| def get_array_module(x): | |
| """Get the appropriate array module (cupy or numpy) based on input""" | |
| if GPU_AVAILABLE and isinstance(x, cp.ndarray): | |
| return cp | |
| return np | |
| def to_gpu(x: np.ndarray) -> "cp.ndarray": | |
| """Transfer numpy array to GPU if available""" | |
| if GPU_AVAILABLE: | |
| return cp.asarray(x) | |
| return x | |
| def to_cpu(x) -> np.ndarray: | |
| """Transfer array from GPU to CPU if needed""" | |
| if GPU_AVAILABLE and isinstance(x, cp.ndarray): | |
| return cp.asnumpy(x) | |
| return np.asarray(x) | |
| def bandpass_filter(x: np.ndarray, fs: float, low: float = 0.7, high: float = 2.5, order: int = 4) -> np.ndarray: | |
| """Bandpass filter with optional GPU acceleration""" | |
| if len(x) < max(15, order * 3): | |
| return x | |
| nyq = 0.5 * fs | |
| lowc = low / nyq | |
| highc = high / nyq | |
| if not (0 < lowc < highc < 1): | |
| return x | |
| b, a = butter(order, [lowc, highc], btype="band") | |
| if GPU_AVAILABLE: | |
| try: | |
| # GPU-accelerated filtering | |
| x_gpu = to_gpu(x) | |
| b_gpu = to_gpu(b) | |
| a_gpu = to_gpu(a) | |
| result_gpu = cp_filtfilt(b_gpu, a_gpu, x_gpu) | |
| return to_cpu(result_gpu) | |
| except Exception as e: | |
| # Fallback to CPU if GPU fails | |
| print(f"GPU filter failed, using CPU: {e}") | |
| return filtfilt(b, a, x) | |
| else: | |
| return filtfilt(b, a, x) | |
| def resample_uniform(t: np.ndarray, x: np.ndarray, fs: float) -> Tuple[np.ndarray, np.ndarray]: | |
| if len(t) < 5: | |
| return t, x | |
| idx = np.argsort(t) | |
| t = t[idx] | |
| x = x[idx] | |
| dt = np.diff(t) | |
| keep = np.ones_like(t, dtype=bool) | |
| keep[1:] = dt > 1e-6 | |
| t = t[keep] | |
| x = x[keep] | |
| if len(t) < 5: | |
| return t, x | |
| t0, t1 = t[0], t[-1] | |
| if t1 <= t0: | |
| return t, x | |
| n = int(np.floor((t1 - t0) * fs)) + 1 | |
| if n < 5: | |
| return t, x | |
| tu = t0 + np.arange(n) / fs | |
| f = interp1d(t, x, kind="linear", fill_value="extrapolate", assume_sorted=True) | |
| xu = f(tu) | |
| return tu, xu | |
| def detrend_and_normalize(x: np.ndarray) -> np.ndarray: | |
| """Detrend and normalize with GPU support""" | |
| if GPU_AVAILABLE: | |
| try: | |
| x_gpu = to_gpu(x.astype(np.float64)) | |
| x_gpu = x_gpu - cp.mean(x_gpu) | |
| s = cp.std(x_gpu) + 1e-9 | |
| result = x_gpu / s | |
| return to_cpu(result) | |
| except Exception: | |
| pass | |
| # CPU fallback | |
| x = x.astype(np.float64) | |
| x = x - np.mean(x) | |
| s = np.std(x) + 1e-9 | |
| return x / s | |
| def spectral_sqi(x: np.ndarray, fs: float, band=(0.7, 2.5)) -> Tuple[float, float]: | |
| """Spectral SQI with GPU acceleration""" | |
| if len(x) < int(fs * 5): | |
| return 0.0, 0.0 | |
| if GPU_AVAILABLE: | |
| try: | |
| x_gpu = to_gpu(x) | |
| x_gpu = x_gpu - cp.mean(x_gpu) | |
| n = len(x_gpu) | |
| freqs = cp.fft.rfftfreq(n, d=1/fs) | |
| spec = cp.abs(cp.fft.rfft(x_gpu)) + 1e-12 | |
| power = spec**2 | |
| valid = freqs >= 0.1 | |
| freqs_v = freqs[valid] | |
| spec_v = spec[valid] | |
| power_v = power[valid] | |
| if len(freqs_v) < 10: | |
| return 0.0, 0.0 | |
| lo, hi = band | |
| in_band = (freqs_v >= lo) & (freqs_v <= hi) | |
| if cp.sum(in_band) < 5: | |
| return 0.0, 0.0 | |
| band_spec = spec_v[in_band] | |
| band_power = cp.sum(power_v[in_band]) | |
| total_power = cp.sum(power_v) | |
| peak = float(cp.max(band_spec)) | |
| med = float(cp.median(band_spec)) | |
| peak_dominance = peak / (med + 1e-12) | |
| band_power_ratio = float(band_power / (total_power + 1e-12)) | |
| return peak_dominance, band_power_ratio | |
| except Exception as e: | |
| print(f"GPU spectral_sqi failed: {e}") | |
| # CPU fallback | |
| x = x - np.mean(x) | |
| n = len(x) | |
| freqs = np.fft.rfftfreq(n, d=1/fs) | |
| spec = np.abs(np.fft.rfft(x)) + 1e-12 | |
| power = spec**2 | |
| valid = freqs >= 0.1 | |
| freqs_v = freqs[valid] | |
| spec_v = spec[valid] | |
| power_v = power[valid] | |
| if len(freqs_v) < 10: | |
| return 0.0, 0.0 | |
| lo, hi = band | |
| in_band = (freqs_v >= lo) & (freqs_v <= hi) | |
| if np.sum(in_band) < 5: | |
| return 0.0, 0.0 | |
| band_spec = spec_v[in_band] | |
| band_power = np.sum(power_v[in_band]) | |
| total_power = np.sum(power_v) | |
| peak = float(np.max(band_spec)) | |
| med = float(np.median(band_spec)) | |
| peak_dominance = peak / (med + 1e-12) | |
| band_power_ratio = float(band_power / (total_power + 1e-12)) | |
| return peak_dominance, band_power_ratio | |
| def rr_clean(rr_ms: np.ndarray) -> np.ndarray: | |
| rr_ms = np.asarray(rr_ms, dtype=np.float64) | |
| rr_ms = rr_ms[(rr_ms > 300) & (rr_ms < 2000)] | |
| if len(rr_ms) < 3: | |
| return rr_ms | |
| med = np.median(rr_ms) | |
| rr_ms = rr_ms[np.abs(rr_ms - med) < 0.25 * med] | |
| return rr_ms | |
| def compute_hr_and_peaks(tu: np.ndarray, xu: np.ndarray, fs: float, min_samples_s: float = 3.0) -> Tuple[Optional[float], Optional[np.ndarray], Optional[np.ndarray]]: | |
| # Reduced minimum samples requirement for 5-second windows | |
| if len(xu) < int(fs * min_samples_s): | |
| return None, None, None | |
| sig = detrend_and_normalize(xu) | |
| sig = bandpass_filter(sig, fs=fs, low=0.7, high=2.5, order=4) | |
| sig = detrend_and_normalize(sig) | |
| min_dist = int(0.35 * fs) | |
| prom = 0.3 * np.std(sig) # Lowered prominence threshold for better peak detection | |
| peaks, _ = find_peaks(sig, distance=min_dist, prominence=prom) | |
| if len(peaks) < 3: # Reduced from 4 to 3 for 5-second windows | |
| return None, peaks, None | |
| peak_t = tu[peaks] | |
| rr_ms = np.diff(peak_t) * 1000.0 | |
| rr_ms = rr_clean(rr_ms) | |
| if len(rr_ms) < 2: # Reduced from 3 to 2 for 5-second windows | |
| return None, peaks, rr_ms | |
| hr = 60000.0 / np.median(rr_ms) # Fixed: should be 60000 for correct BPM calculation | |
| return float(hr), peaks, rr_ms | |
| def compute_hrv(rr_ms: np.ndarray) -> Tuple[float, float, float]: | |
| """HRV computation with GPU support""" | |
| if GPU_AVAILABLE and len(rr_ms) > 10: | |
| try: | |
| rr_gpu = to_gpu(rr_ms) | |
| sdnn = float(cp.std(rr_gpu, ddof=1)) if len(rr_ms) > 1 else 0.0 | |
| diff_rr = cp.diff(rr_gpu) | |
| rmssd = float(cp.sqrt(cp.mean(diff_rr**2))) if len(diff_rr) > 0 else 0.0 | |
| pnn50 = float(cp.mean(cp.abs(diff_rr) > 50.0) * 100.0) if len(diff_rr) > 0 else 0.0 | |
| return sdnn, rmssd, pnn50 | |
| except Exception: | |
| pass | |
| # CPU fallback | |
| sdnn = float(np.std(rr_ms, ddof=1)) if len(rr_ms) > 1 else 0.0 | |
| diff_rr = np.diff(rr_ms) | |
| rmssd = float(np.sqrt(np.mean(diff_rr**2))) if len(diff_rr) > 0 else 0.0 | |
| pnn50 = float(np.mean(np.abs(diff_rr) > 50.0) * 100.0) if len(diff_rr) > 0 else 0.0 | |
| return sdnn, rmssd, pnn50 | |
| def coherence_sqi(xg: np.ndarray, xr: np.ndarray, fs: float, band=(0.7, 2.5)) -> float: | |
| # Coherence uses scipy which doesn't have GPU support, keeping CPU | |
| if len(xg) < int(fs * 5) or len(xr) < int(fs * 5): | |
| return 0.0 | |
| f, cxy = coherence(xg, xr, fs=fs, nperseg=min(256, len(xg))) | |
| lo, hi = band | |
| m = (f >= lo) & (f <= hi) | |
| if np.sum(m) < 3: | |
| return 0.0 | |
| return float(np.mean(cxy[m])) | |
| class Metrics: | |
| spo2: str = "Measuring..." | |
| hr: str = "Measuring..." | |
| sdnn: str = "Measuring..." | |
| rmssd: str = "Measuring..." | |
| pnn50: str = "Measuring..." | |
| sqi: str = "Measuring..." | |
| class PPGProcessor: | |
| def __init__(self, roi_size=(250, 250), target_fs=30.0, hr_window_s=5.0, hrv_window_s=60.0, buffer_s=65.0, hr_update_interval_s=5.0): | |
| self.roi_size = roi_size | |
| self.fs = target_fs | |
| self.hr_window_s = hr_window_s # Use 5 seconds of data for HR calculation | |
| self.hrv_window_s = hrv_window_s | |
| self.buffer_s = buffer_s | |
| self.hr_update_interval_s = hr_update_interval_s # Update HR every 5 seconds | |
| self._t = [] | |
| self._red = [] | |
| self._green = [] | |
| self._start_time = time.time() | |
| self._last_hr_update_t = 0.0 | |
| self._last_plot_update_t = 0.0 | |
| self._latest_metrics = Metrics() | |
| self._plot_fig = None | |
| self._waiting_for_fresh_data = False # Flag to track if we need fresh 5-second data | |
| self._fresh_data_start_t = None # Timestamp when fresh data collection started | |
| # Log GPU status | |
| if GPU_AVAILABLE: | |
| print(f"PPGProcessor initialized with GPU support (CuPy)") | |
| else: | |
| print(f"PPGProcessor initialized with CPU only") | |
| def reset_buffers(self): | |
| """Reset signal buffers to start fresh 5-second data collection""" | |
| self._t = [] | |
| self._red = [] | |
| self._green = [] | |
| self._start_time = time.time() | |
| self._waiting_for_fresh_data = True | |
| self._fresh_data_start_t = time.time() | |
| def estimate_spo2(self, r: np.ndarray, g: np.ndarray) -> float: | |
| """SpO2 estimation with GPU support""" | |
| if len(r) < 10 or len(g) < 10: | |
| return float("nan") | |
| if GPU_AVAILABLE: | |
| try: | |
| r_gpu = to_gpu(r.astype(np.float64)) | |
| g_gpu = to_gpu(g.astype(np.float64)) | |
| r_ac, r_dc = float(cp.std(r_gpu)), float(cp.mean(r_gpu)) + 1e-9 | |
| g_ac, g_dc = float(cp.std(g_gpu)), float(cp.mean(g_gpu)) + 1e-9 | |
| ratio = (r_ac / r_dc) / (g_ac / g_dc) | |
| spo2 = -1.146 * ratio + 98.313 | |
| return float(np.clip(spo2, 80.0, 100.0)) | |
| except Exception: | |
| pass | |
| # CPU fallback | |
| r = np.asarray(r, dtype=np.float64) | |
| g = np.asarray(g, dtype=np.float64) | |
| r_ac, r_dc = np.std(r), np.mean(r) + 1e-9 | |
| g_ac, g_dc = np.std(g), np.mean(g) + 1e-9 | |
| ratio = (r_ac / r_dc) / (g_ac / g_dc) | |
| spo2 = -1.146 * ratio + 98.313 | |
| return float(np.clip(spo2, 80.0, 100.0)) | |
| def process_frame(self, frame_rgb: np.ndarray): | |
| if frame_rgb is None: | |
| return None, None, None, None, None, None, None, None | |
| frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR) | |
| now = time.time() | |
| t_rel = now - self._start_time | |
| h, w, _ = frame_bgr.shape | |
| rw, rh = self.roi_size | |
| x1 = max(0, w // 2 - rw // 2) | |
| y1 = max(0, h // 2 - rh // 2) | |
| x2 = min(w, w // 2 + rw // 2) | |
| y2 = min(h, h // 2 + rh // 2) | |
| roi = frame_bgr[y1:y2, x1:x2] | |
| mean_bgr = cv2.mean(roi)[:3] | |
| g = float(mean_bgr[1]) | |
| r = float(mean_bgr[2]) | |
| cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
| frame_rgb_out = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) | |
| self._t.append(t_rel) | |
| self._red.append(r) | |
| self._green.append(g) | |
| cutoff = t_rel - self.buffer_s | |
| while self._t and self._t[0] < cutoff: | |
| self._t.pop(0) | |
| self._red.pop(0) | |
| self._green.pop(0) | |
| self.compute_updates_if_due() | |
| return (frame_rgb_out, self._latest_metrics.hr, self._latest_metrics.spo2, | |
| self._latest_metrics.sdnn, self._latest_metrics.rmssd, self._latest_metrics.pnn50, | |
| self._latest_metrics.sqi, self._plot_fig) | |
| def compute_updates_if_due(self): | |
| now = time.time() | |
| t = np.array(self._t) | |
| r = np.array(self._red) | |
| g = np.array(self._green) | |
| # Need minimum samples for processing | |
| if len(t) < 20: | |
| return | |
| # Update plots more frequently (every 0.5 seconds) | |
| if now - self._last_plot_update_t >= 0.5: | |
| self._plot_fig = self._make_plots(t, r, g) | |
| self._last_plot_update_t = now | |
| # If waiting for fresh data, check if we have 5 seconds of new data | |
| if self._waiting_for_fresh_data: | |
| if self._fresh_data_start_t is None: | |
| self._fresh_data_start_t = now | |
| elapsed = now - self._fresh_data_start_t | |
| if elapsed < self.hr_window_s: | |
| # Still collecting fresh data, don't update HR yet | |
| return | |
| else: | |
| # We have fresh 5-second data, reset flag and proceed | |
| self._waiting_for_fresh_data = False | |
| self._fresh_data_start_t = None | |
| # Check if it's time to update HR (every hr_update_interval_s seconds) | |
| if now - self._last_hr_update_t < self.hr_update_interval_s: | |
| return | |
| self._last_hr_update_t = now | |
| t_end = t[-1] | |
| # Use last hr_window_s seconds of data for HR calculation | |
| hr_mask = t >= (t_end - self.hr_window_s) | |
| th, rh, gh = t[hr_mask], r[hr_mask], g[hr_mask] | |
| thu, ghu = resample_uniform(th, gh, self.fs) | |
| _, rhu = resample_uniform(th, rh, self.fs) | |
| ghu_n = detrend_and_normalize(ghu) | |
| rhu_n = detrend_and_normalize(rhu) | |
| ghu_f = bandpass_filter(ghu_n, fs=self.fs) | |
| rhu_f = bandpass_filter(rhu_n, fs=self.fs) | |
| peak_dom, band_ratio = spectral_sqi(ghu_f, fs=self.fs) | |
| coh = coherence_sqi(ghu_f, rhu_f, fs=self.fs) | |
| hr_bpm, peaks_idx, rr_ms = compute_hr_and_peaks(thu, ghu, fs=self.fs, min_samples_s=3.0) | |
| # Relaxed SQI thresholds for better detection | |
| sqi_pass = (peak_dom >= 1.5 and band_ratio >= 0.20 and coh >= 0.15 and | |
| (hr_bpm is not None) and (40.0 <= hr_bpm <= 200.0)) | |
| m = Metrics() | |
| gpu_status = " [GPU]" if GPU_AVAILABLE else " [CPU]" | |
| m.sqi = f"PeakDom={peak_dom:.2f}, BandPow={band_ratio:.2f}, Coh={coh:.2f}{gpu_status}" | |
| if not sqi_pass: | |
| m.hr = "Measuring... (low signal quality)" | |
| m.spo2 = "Measuring..." | |
| m.sdnn = self._latest_metrics.sdnn | |
| m.rmssd = self._latest_metrics.rmssd | |
| m.pnn50 = self._latest_metrics.pnn50 | |
| self._latest_metrics = m | |
| # Reset buffers to start fresh 5-second data collection | |
| self.reset_buffers() | |
| return | |
| spo2 = self.estimate_spo2(rhu_f, ghu_f) | |
| m.spo2 = f"{spo2:.2f}%" if np.isfinite(spo2) else "Measuring..." | |
| m.hr = f"{hr_bpm:.2f} bpm" | |
| hrv_mask = t >= (t_end - self.hrv_window_s) | |
| tv, rv, gv = t[hrv_mask], r[hrv_mask], g[hrv_mask] | |
| if (tv[-1] - tv[0]) < (self.hrv_window_s - 1.0): | |
| m.sdnn = "Measuring..." | |
| m.rmssd = "Measuring..." | |
| m.pnn50 = "Measuring..." | |
| else: | |
| tvu, gvu = resample_uniform(tv, gv, self.fs) | |
| _, rvu = resample_uniform(tv, rv, self.fs) | |
| gvu_n = detrend_and_normalize(gvu) | |
| rvu_n = detrend_and_normalize(rvu) | |
| gvu_f = bandpass_filter(gvu_n, fs=self.fs) | |
| rvu_f = bandpass_filter(rvu_n, fs=self.fs) | |
| peak_dom_v, band_ratio_v = spectral_sqi(gvu_f, fs=self.fs) | |
| coh_v = coherence_sqi(gvu_f, rvu_f, fs=self.fs) | |
| hr_v, peaks_v, rr_v = compute_hr_and_peaks(tvu, gvu, fs=self.fs) | |
| sqi_pass_v = (peak_dom_v >= 3.0 and band_ratio_v >= 0.35 and coh_v >= 0.35 and | |
| (rr_v is not None) and (len(rr_v) >= 10)) | |
| if not sqi_pass_v: | |
| m.sdnn = "Measuring..." | |
| m.rmssd = "Measuring..." | |
| m.pnn50 = "Measuring..." | |
| else: | |
| sdnn, rmssd, pnn50 = compute_hrv(rr_v) | |
| m.sdnn = f"{sdnn:.2f} ms" | |
| m.rmssd = f"{rmssd:.2f} ms" | |
| m.pnn50 = f"{pnn50:.2f} %" | |
| self._latest_metrics = m | |
| def _make_plots(self, t: np.ndarray, r: np.ndarray, g: np.ndarray): | |
| if len(t) < 20: | |
| return None | |
| t_end = t[-1] | |
| mask = t >= (t_end - self.hr_window_s) | |
| th, rh, gh = t[mask], r[mask], g[mask] | |
| if len(th) < 10: | |
| return None | |
| thu, ghu = resample_uniform(th, gh, self.fs) | |
| _, rhu = resample_uniform(th, rh, self.fs) | |
| ghu_f = bandpass_filter(detrend_and_normalize(ghu), fs=self.fs) | |
| rhu_f = bandpass_filter(detrend_and_normalize(rhu), fs=self.fs) | |
| hr_bpm, peaks_idx, _ = compute_hr_and_peaks(thu, ghu, fs=self.fs) | |
| fig = plt.figure(figsize=(12, 4)) | |
| ax1 = fig.add_subplot(1, 2, 1) | |
| ax2 = fig.add_subplot(1, 2, 2) | |
| ax1.plot(thu, rhu_f, label="Red (filtered)") | |
| ax1.plot(thu, ghu_f, label="Green (filtered)") | |
| ax1.set_xlabel("Time (s)") | |
| ax1.set_ylabel("Normalized intensity") | |
| ax1.set_title("Filtered PPG (last 5s)") | |
| ax1.legend(loc="upper right") | |
| ax2.plot(thu, ghu_f, label="Green (filtered)") | |
| if peaks_idx is not None and len(peaks_idx) > 0: | |
| ax2.scatter(thu[peaks_idx], ghu_f[peaks_idx], label="Peaks") | |
| ax2.set_xlabel("Time (s)") | |
| ax2.set_ylabel("Normalized intensity") | |
| ax2.set_title(f"Peaks (HR={hr_bpm:.1f} bpm)" if hr_bpm is not None else "Peaks") | |
| ax2.legend(loc="upper right") | |
| fig.tight_layout() | |
| return fig | |
| # Use 5-second windows for HR calculation, update every 5 seconds | |
| processor = PPGProcessor(roi_size=(250, 250), target_fs=30.0, hr_window_s=5.0, hrv_window_s=60.0, buffer_s=65.0, hr_update_interval_s=5.0) | |
| def process_webcam_frame(frame): | |
| if frame is None: | |
| return None, "Waiting for frame...", "Measuring...", "Measuring...", "Measuring...", "Measuring...", "Measuring...", None | |
| return processor.process_frame(frame) | |
| with gr.Blocks(title="Real-Time PPG Monitor") as demo: | |
| gr.Markdown("## Real-Time Heart Rate & HRV from Finger PPG\n**Position your finger on the camera and stay still for best results**") | |
| # GPU Status indicator | |
| gpu_status_text = "🟢 GPU Acceleration: Enabled (CuPy)" if GPU_AVAILABLE else "🔴 GPU: Not Available (Using CPU)" | |
| gr.Markdown(f"**{gpu_status_text}**") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| webcam = gr.Image(sources=["webcam"], type="numpy", label="Webcam Feed", streaming=True) | |
| with gr.Column(scale=1): | |
| plot_output = gr.Plot(label="Signals & Peaks") | |
| with gr.Row(): | |
| hr_output = gr.Textbox(label="Heart Rate (10s)", value="Measuring...", interactive=False) | |
| spo2_output = gr.Textbox(label="SpO2 (indicative)", value="Measuring...", interactive=False) | |
| with gr.Row(): | |
| sdnn_output = gr.Textbox(label="SDNN (60s)", value="Measuring...", interactive=False) | |
| rmssd_output = gr.Textbox(label="RMSSD (60s)", value="Measuring...", interactive=False) | |
| pnn50_output = gr.Textbox(label="pNN50 (60s)", value="Measuring...", interactive=False) | |
| with gr.Row(): | |
| sqi_output = gr.Textbox(label="Signal Quality", value="Measuring...", interactive=False) | |
| webcam.stream( | |
| fn=process_webcam_frame, | |
| inputs=webcam, | |
| outputs=[webcam, hr_output, spo2_output, sdnn_output, rmssd_output, pnn50_output, sqi_output, plot_output], | |
| time_limit=None | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() |