File size: 19,682 Bytes
20dff18
13541be
20dff18
 
f2b80f5
20dff18
 
 
 
 
 
1cab498
20dff18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bfb42f
20dff18
 
 
 
 
 
 
 
 
 
 
 
0bfb42f
20dff18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bfb42f
20dff18
 
 
 
 
 
 
 
 
 
 
0bfb42f
20dff18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bfb42f
20dff18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bfb42f
20dff18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13541be
0bfb42f
20dff18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c1a91c
20dff18
 
 
3c1a91c
13541be
20dff18
 
 
 
b2dc4cd
20dff18
 
b2dc4cd
20dff18
 
 
 
 
 
 
 
 
 
13541be
 
 
20dff18
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
import spaces
import time
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
import cv2
from scipy.signal import butter, filtfilt, find_peaks, coherence
from scipy.interpolate import interp1d
from dataclasses import dataclass
from typing import Optional, Tuple
# GPU Support: Try to import CuPy for GPU acceleration
try:
    import cupy as cp
    from cupyx.scipy.signal import filtfilt as cp_filtfilt
    GPU_AVAILABLE = True
    print("GPU acceleration enabled with CuPy")
except ImportError:
    GPU_AVAILABLE = False
    cp = np  # Fallback to NumPy
    print("CuPy not available, using CPU (NumPy)")


def get_array_module(x):
    """Get the appropriate array module (cupy or numpy) based on input"""
    if GPU_AVAILABLE and isinstance(x, cp.ndarray):
        return cp
    return np


def to_gpu(x: np.ndarray) -> "cp.ndarray":
    """Transfer numpy array to GPU if available"""
    if GPU_AVAILABLE:
        return cp.asarray(x)
    return x


def to_cpu(x) -> np.ndarray:
    """Transfer array from GPU to CPU if needed"""
    if GPU_AVAILABLE and isinstance(x, cp.ndarray):
        return cp.asnumpy(x)
    return np.asarray(x)


def bandpass_filter(x: np.ndarray, fs: float, low: float = 0.7, high: float = 2.5, order: int = 4) -> np.ndarray:
    """Bandpass filter with optional GPU acceleration"""
    if len(x) < max(15, order * 3):
        return x
    nyq = 0.5 * fs
    lowc = low / nyq
    highc = high / nyq
    if not (0 < lowc < highc < 1):
        return x
    b, a = butter(order, [lowc, highc], btype="band")
    
    if GPU_AVAILABLE:
        try:
            # GPU-accelerated filtering
            x_gpu = to_gpu(x)
            b_gpu = to_gpu(b)
            a_gpu = to_gpu(a)
            result_gpu = cp_filtfilt(b_gpu, a_gpu, x_gpu)
            return to_cpu(result_gpu)
        except Exception as e:
            # Fallback to CPU if GPU fails
            print(f"GPU filter failed, using CPU: {e}")
            return filtfilt(b, a, x)
    else:
        return filtfilt(b, a, x)


def resample_uniform(t: np.ndarray, x: np.ndarray, fs: float) -> Tuple[np.ndarray, np.ndarray]:
    if len(t) < 5:
        return t, x
    idx = np.argsort(t)
    t = t[idx]
    x = x[idx]
    dt = np.diff(t)
    keep = np.ones_like(t, dtype=bool)
    keep[1:] = dt > 1e-6
    t = t[keep]
    x = x[keep]
    if len(t) < 5:
        return t, x
    t0, t1 = t[0], t[-1]
    if t1 <= t0:
        return t, x
    n = int(np.floor((t1 - t0) * fs)) + 1
    if n < 5:
        return t, x
    tu = t0 + np.arange(n) / fs
    f = interp1d(t, x, kind="linear", fill_value="extrapolate", assume_sorted=True)
    xu = f(tu)
    return tu, xu


def detrend_and_normalize(x: np.ndarray) -> np.ndarray:
    """Detrend and normalize with GPU support"""
    if GPU_AVAILABLE:
        try:
            x_gpu = to_gpu(x.astype(np.float64))
            x_gpu = x_gpu - cp.mean(x_gpu)
            s = cp.std(x_gpu) + 1e-9
            result = x_gpu / s
            return to_cpu(result)
        except Exception:
            pass
    
    # CPU fallback
    x = x.astype(np.float64)
    x = x - np.mean(x)
    s = np.std(x) + 1e-9
    return x / s


def spectral_sqi(x: np.ndarray, fs: float, band=(0.7, 2.5)) -> Tuple[float, float]:
    """Spectral SQI with GPU acceleration"""
    if len(x) < int(fs * 5):
        return 0.0, 0.0
    
    if GPU_AVAILABLE:
        try:
            x_gpu = to_gpu(x)
            x_gpu = x_gpu - cp.mean(x_gpu)
            n = len(x_gpu)
            freqs = cp.fft.rfftfreq(n, d=1/fs)
            spec = cp.abs(cp.fft.rfft(x_gpu)) + 1e-12
            power = spec**2
            valid = freqs >= 0.1
            freqs_v = freqs[valid]
            spec_v = spec[valid]
            power_v = power[valid]
            
            if len(freqs_v) < 10:
                return 0.0, 0.0
            
            lo, hi = band
            in_band = (freqs_v >= lo) & (freqs_v <= hi)
            if cp.sum(in_band) < 5:
                return 0.0, 0.0
            
            band_spec = spec_v[in_band]
            band_power = cp.sum(power_v[in_band])
            total_power = cp.sum(power_v)
            peak = float(cp.max(band_spec))
            med = float(cp.median(band_spec))
            peak_dominance = peak / (med + 1e-12)
            band_power_ratio = float(band_power / (total_power + 1e-12))
            return peak_dominance, band_power_ratio
        except Exception as e:
            print(f"GPU spectral_sqi failed: {e}")
    
    # CPU fallback
    x = x - np.mean(x)
    n = len(x)
    freqs = np.fft.rfftfreq(n, d=1/fs)
    spec = np.abs(np.fft.rfft(x)) + 1e-12
    power = spec**2
    valid = freqs >= 0.1
    freqs_v = freqs[valid]
    spec_v = spec[valid]
    power_v = power[valid]
    if len(freqs_v) < 10:
        return 0.0, 0.0
    lo, hi = band
    in_band = (freqs_v >= lo) & (freqs_v <= hi)
    if np.sum(in_band) < 5:
        return 0.0, 0.0
    band_spec = spec_v[in_band]
    band_power = np.sum(power_v[in_band])
    total_power = np.sum(power_v)
    peak = float(np.max(band_spec))
    med = float(np.median(band_spec))
    peak_dominance = peak / (med + 1e-12)
    band_power_ratio = float(band_power / (total_power + 1e-12))
    return peak_dominance, band_power_ratio


def rr_clean(rr_ms: np.ndarray) -> np.ndarray:
    rr_ms = np.asarray(rr_ms, dtype=np.float64)
    rr_ms = rr_ms[(rr_ms > 300) & (rr_ms < 2000)]
    if len(rr_ms) < 3:
        return rr_ms
    med = np.median(rr_ms)
    rr_ms = rr_ms[np.abs(rr_ms - med) < 0.25 * med]
    return rr_ms


def compute_hr_and_peaks(tu: np.ndarray, xu: np.ndarray, fs: float, min_samples_s: float = 3.0) -> Tuple[Optional[float], Optional[np.ndarray], Optional[np.ndarray]]:
    # Reduced minimum samples requirement for 5-second windows
    if len(xu) < int(fs * min_samples_s):
        return None, None, None
    sig = detrend_and_normalize(xu)
    sig = bandpass_filter(sig, fs=fs, low=0.7, high=2.5, order=4)
    sig = detrend_and_normalize(sig)
    min_dist = int(0.35 * fs)
    prom = 0.3 * np.std(sig)  # Lowered prominence threshold for better peak detection
    peaks, _ = find_peaks(sig, distance=min_dist, prominence=prom)
    if len(peaks) < 3:  # Reduced from 4 to 3 for 5-second windows
        return None, peaks, None
    peak_t = tu[peaks]
    rr_ms = np.diff(peak_t) * 1000.0
    rr_ms = rr_clean(rr_ms)
    if len(rr_ms) < 2:  # Reduced from 3 to 2 for 5-second windows
        return None, peaks, rr_ms
    hr = 60000.0 / np.median(rr_ms)  # Fixed: should be 60000 for correct BPM calculation
    return float(hr), peaks, rr_ms


def compute_hrv(rr_ms: np.ndarray) -> Tuple[float, float, float]:
    """HRV computation with GPU support"""
    if GPU_AVAILABLE and len(rr_ms) > 10:
        try:
            rr_gpu = to_gpu(rr_ms)
            sdnn = float(cp.std(rr_gpu, ddof=1)) if len(rr_ms) > 1 else 0.0
            diff_rr = cp.diff(rr_gpu)
            rmssd = float(cp.sqrt(cp.mean(diff_rr**2))) if len(diff_rr) > 0 else 0.0
            pnn50 = float(cp.mean(cp.abs(diff_rr) > 50.0) * 100.0) if len(diff_rr) > 0 else 0.0
            return sdnn, rmssd, pnn50
        except Exception:
            pass
    
    # CPU fallback
    sdnn = float(np.std(rr_ms, ddof=1)) if len(rr_ms) > 1 else 0.0
    diff_rr = np.diff(rr_ms)
    rmssd = float(np.sqrt(np.mean(diff_rr**2))) if len(diff_rr) > 0 else 0.0
    pnn50 = float(np.mean(np.abs(diff_rr) > 50.0) * 100.0) if len(diff_rr) > 0 else 0.0
    return sdnn, rmssd, pnn50


def coherence_sqi(xg: np.ndarray, xr: np.ndarray, fs: float, band=(0.7, 2.5)) -> float:
    # Coherence uses scipy which doesn't have GPU support, keeping CPU
    if len(xg) < int(fs * 5) or len(xr) < int(fs * 5):
        return 0.0
    f, cxy = coherence(xg, xr, fs=fs, nperseg=min(256, len(xg)))
    lo, hi = band
    m = (f >= lo) & (f <= hi)
    if np.sum(m) < 3:
        return 0.0
    return float(np.mean(cxy[m]))


@dataclass
class Metrics:
    spo2: str = "Measuring..."
    hr: str = "Measuring..."
    sdnn: str = "Measuring..."
    rmssd: str = "Measuring..."
    pnn50: str = "Measuring..."
    sqi: str = "Measuring..."


class PPGProcessor:
    def __init__(self, roi_size=(250, 250), target_fs=30.0, hr_window_s=5.0, hrv_window_s=60.0, buffer_s=65.0, hr_update_interval_s=5.0):
        self.roi_size = roi_size
        self.fs = target_fs
        self.hr_window_s = hr_window_s  # Use 5 seconds of data for HR calculation
        self.hrv_window_s = hrv_window_s
        self.buffer_s = buffer_s
        self.hr_update_interval_s = hr_update_interval_s  # Update HR every 5 seconds
        self._t = []
        self._red = []
        self._green = []
        self._start_time = time.time()
        self._last_hr_update_t = 0.0
        self._last_plot_update_t = 0.0
        self._latest_metrics = Metrics()
        self._plot_fig = None
        self._waiting_for_fresh_data = False  # Flag to track if we need fresh 5-second data
        self._fresh_data_start_t = None  # Timestamp when fresh data collection started
        
        # Log GPU status
        if GPU_AVAILABLE:
            print(f"PPGProcessor initialized with GPU support (CuPy)")
        else:
            print(f"PPGProcessor initialized with CPU only")

    def reset_buffers(self):
        """Reset signal buffers to start fresh 5-second data collection"""
        self._t = []
        self._red = []
        self._green = []
        self._start_time = time.time()
        self._waiting_for_fresh_data = True
        self._fresh_data_start_t = time.time()

    def estimate_spo2(self, r: np.ndarray, g: np.ndarray) -> float:
        """SpO2 estimation with GPU support"""
        if len(r) < 10 or len(g) < 10:
            return float("nan")
        
        if GPU_AVAILABLE:
            try:
                r_gpu = to_gpu(r.astype(np.float64))
                g_gpu = to_gpu(g.astype(np.float64))
                r_ac, r_dc = float(cp.std(r_gpu)), float(cp.mean(r_gpu)) + 1e-9
                g_ac, g_dc = float(cp.std(g_gpu)), float(cp.mean(g_gpu)) + 1e-9
                ratio = (r_ac / r_dc) / (g_ac / g_dc)
                spo2 = -1.146 * ratio + 98.313
                return float(np.clip(spo2, 80.0, 100.0))
            except Exception:
                pass
        
        # CPU fallback
        r = np.asarray(r, dtype=np.float64)
        g = np.asarray(g, dtype=np.float64)
        r_ac, r_dc = np.std(r), np.mean(r) + 1e-9
        g_ac, g_dc = np.std(g), np.mean(g) + 1e-9
        ratio = (r_ac / r_dc) / (g_ac / g_dc)
        spo2 = -1.146 * ratio + 98.313
        return float(np.clip(spo2, 80.0, 100.0))

    def process_frame(self, frame_rgb: np.ndarray):
        if frame_rgb is None:
            return None, None, None, None, None, None, None, None
        frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
        now = time.time()
        t_rel = now - self._start_time
        h, w, _ = frame_bgr.shape
        rw, rh = self.roi_size
        x1 = max(0, w // 2 - rw // 2)
        y1 = max(0, h // 2 - rh // 2)
        x2 = min(w, w // 2 + rw // 2)
        y2 = min(h, h // 2 + rh // 2)
        roi = frame_bgr[y1:y2, x1:x2]
        mean_bgr = cv2.mean(roi)[:3]
        g = float(mean_bgr[1])
        r = float(mean_bgr[2])
        cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2)
        frame_rgb_out = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
        self._t.append(t_rel)
        self._red.append(r)
        self._green.append(g)
        cutoff = t_rel - self.buffer_s
        while self._t and self._t[0] < cutoff:
            self._t.pop(0)
            self._red.pop(0)
            self._green.pop(0)
        self.compute_updates_if_due()
        return (frame_rgb_out, self._latest_metrics.hr, self._latest_metrics.spo2,
                self._latest_metrics.sdnn, self._latest_metrics.rmssd, self._latest_metrics.pnn50,
                self._latest_metrics.sqi, self._plot_fig)

    def compute_updates_if_due(self):
        now = time.time()
        t = np.array(self._t)
        r = np.array(self._red)
        g = np.array(self._green)
        
        # Need minimum samples for processing
        if len(t) < 20:
            return
        
        # Update plots more frequently (every 0.5 seconds)
        if now - self._last_plot_update_t >= 0.5:
            self._plot_fig = self._make_plots(t, r, g)
            self._last_plot_update_t = now
        
        # If waiting for fresh data, check if we have 5 seconds of new data
        if self._waiting_for_fresh_data:
            if self._fresh_data_start_t is None:
                self._fresh_data_start_t = now
            elapsed = now - self._fresh_data_start_t
            if elapsed < self.hr_window_s:
                # Still collecting fresh data, don't update HR yet
                return
            else:
                # We have fresh 5-second data, reset flag and proceed
                self._waiting_for_fresh_data = False
                self._fresh_data_start_t = None
        
        # Check if it's time to update HR (every hr_update_interval_s seconds)
        if now - self._last_hr_update_t < self.hr_update_interval_s:
            return
        
        self._last_hr_update_t = now
        t_end = t[-1]
        
        # Use last hr_window_s seconds of data for HR calculation
        hr_mask = t >= (t_end - self.hr_window_s)
        th, rh, gh = t[hr_mask], r[hr_mask], g[hr_mask]
        thu, ghu = resample_uniform(th, gh, self.fs)
        _, rhu = resample_uniform(th, rh, self.fs)
        ghu_n = detrend_and_normalize(ghu)
        rhu_n = detrend_and_normalize(rhu)
        ghu_f = bandpass_filter(ghu_n, fs=self.fs)
        rhu_f = bandpass_filter(rhu_n, fs=self.fs)
        peak_dom, band_ratio = spectral_sqi(ghu_f, fs=self.fs)
        coh = coherence_sqi(ghu_f, rhu_f, fs=self.fs)
        hr_bpm, peaks_idx, rr_ms = compute_hr_and_peaks(thu, ghu, fs=self.fs, min_samples_s=3.0)
        
        # Relaxed SQI thresholds for better detection
        sqi_pass = (peak_dom >= 1.5 and band_ratio >= 0.20 and coh >= 0.15 and
                   (hr_bpm is not None) and (40.0 <= hr_bpm <= 200.0))
        
        m = Metrics()
        gpu_status = " [GPU]" if GPU_AVAILABLE else " [CPU]"
        m.sqi = f"PeakDom={peak_dom:.2f}, BandPow={band_ratio:.2f}, Coh={coh:.2f}{gpu_status}"
        
        if not sqi_pass:
            m.hr = "Measuring... (low signal quality)"
            m.spo2 = "Measuring..."
            m.sdnn = self._latest_metrics.sdnn
            m.rmssd = self._latest_metrics.rmssd
            m.pnn50 = self._latest_metrics.pnn50
            self._latest_metrics = m
            # Reset buffers to start fresh 5-second data collection
            self.reset_buffers()
            return
        spo2 = self.estimate_spo2(rhu_f, ghu_f)
        m.spo2 = f"{spo2:.2f}%" if np.isfinite(spo2) else "Measuring..."
        m.hr = f"{hr_bpm:.2f} bpm"
        hrv_mask = t >= (t_end - self.hrv_window_s)
        tv, rv, gv = t[hrv_mask], r[hrv_mask], g[hrv_mask]
        if (tv[-1] - tv[0]) < (self.hrv_window_s - 1.0):
            m.sdnn = "Measuring..."
            m.rmssd = "Measuring..."
            m.pnn50 = "Measuring..."
        else:
            tvu, gvu = resample_uniform(tv, gv, self.fs)
            _, rvu = resample_uniform(tv, rv, self.fs)
            gvu_n = detrend_and_normalize(gvu)
            rvu_n = detrend_and_normalize(rvu)
            gvu_f = bandpass_filter(gvu_n, fs=self.fs)
            rvu_f = bandpass_filter(rvu_n, fs=self.fs)
            peak_dom_v, band_ratio_v = spectral_sqi(gvu_f, fs=self.fs)
            coh_v = coherence_sqi(gvu_f, rvu_f, fs=self.fs)
            hr_v, peaks_v, rr_v = compute_hr_and_peaks(tvu, gvu, fs=self.fs)
            sqi_pass_v = (peak_dom_v >= 3.0 and band_ratio_v >= 0.35 and coh_v >= 0.35 and
                         (rr_v is not None) and (len(rr_v) >= 10))
            if not sqi_pass_v:
                m.sdnn = "Measuring..."
                m.rmssd = "Measuring..."
                m.pnn50 = "Measuring..."
            else:
                sdnn, rmssd, pnn50 = compute_hrv(rr_v)
                m.sdnn = f"{sdnn:.2f} ms"
                m.rmssd = f"{rmssd:.2f} ms"
                m.pnn50 = f"{pnn50:.2f} %"
        self._latest_metrics = m

    def _make_plots(self, t: np.ndarray, r: np.ndarray, g: np.ndarray):
        if len(t) < 20:
            return None
        t_end = t[-1]
        mask = t >= (t_end - self.hr_window_s)
        th, rh, gh = t[mask], r[mask], g[mask]
        if len(th) < 10:
            return None
        thu, ghu = resample_uniform(th, gh, self.fs)
        _, rhu = resample_uniform(th, rh, self.fs)
        ghu_f = bandpass_filter(detrend_and_normalize(ghu), fs=self.fs)
        rhu_f = bandpass_filter(detrend_and_normalize(rhu), fs=self.fs)
        hr_bpm, peaks_idx, _ = compute_hr_and_peaks(thu, ghu, fs=self.fs)
        fig = plt.figure(figsize=(12, 4))
        ax1 = fig.add_subplot(1, 2, 1)
        ax2 = fig.add_subplot(1, 2, 2)
        ax1.plot(thu, rhu_f, label="Red (filtered)")
        ax1.plot(thu, ghu_f, label="Green (filtered)")
        ax1.set_xlabel("Time (s)")
        ax1.set_ylabel("Normalized intensity")
        ax1.set_title("Filtered PPG (last 5s)")
        ax1.legend(loc="upper right")
        ax2.plot(thu, ghu_f, label="Green (filtered)")
        if peaks_idx is not None and len(peaks_idx) > 0:
            ax2.scatter(thu[peaks_idx], ghu_f[peaks_idx], label="Peaks")
        ax2.set_xlabel("Time (s)")
        ax2.set_ylabel("Normalized intensity")
        ax2.set_title(f"Peaks (HR={hr_bpm:.1f} bpm)" if hr_bpm is not None else "Peaks")
        ax2.legend(loc="upper right")
        fig.tight_layout()
        return fig

# Use 5-second windows for HR calculation, update every 5 seconds
processor = PPGProcessor(roi_size=(250, 250), target_fs=30.0, hr_window_s=5.0, hrv_window_s=60.0, buffer_s=65.0, hr_update_interval_s=5.0)

@spaces.GPU
def process_webcam_frame(frame):
    if frame is None:
        return None, "Waiting for frame...", "Measuring...", "Measuring...", "Measuring...", "Measuring...", "Measuring...", None
    return processor.process_frame(frame)

with gr.Blocks(title="Real-Time PPG Monitor") as demo:
    gr.Markdown("## Real-Time Heart Rate & HRV from Finger PPG\n**Position your finger on the camera and stay still for best results**")
    
    # GPU Status indicator
    gpu_status_text = "🟢 GPU Acceleration: Enabled (CuPy)" if GPU_AVAILABLE else "🔴 GPU: Not Available (Using CPU)"
    gr.Markdown(f"**{gpu_status_text}**")
    
    with gr.Row():
        with gr.Column(scale=1):
            webcam = gr.Image(sources=["webcam"], type="numpy", label="Webcam Feed", streaming=True)
        with gr.Column(scale=1):
            plot_output = gr.Plot(label="Signals & Peaks")
    with gr.Row():
        hr_output = gr.Textbox(label="Heart Rate (10s)", value="Measuring...", interactive=False)
        spo2_output = gr.Textbox(label="SpO2 (indicative)", value="Measuring...", interactive=False)
    with gr.Row():
        sdnn_output = gr.Textbox(label="SDNN (60s)", value="Measuring...", interactive=False)
        rmssd_output = gr.Textbox(label="RMSSD (60s)", value="Measuring...", interactive=False)
        pnn50_output = gr.Textbox(label="pNN50 (60s)", value="Measuring...", interactive=False)
    with gr.Row():
        sqi_output = gr.Textbox(label="Signal Quality", value="Measuring...", interactive=False)
    webcam.stream(
        fn=process_webcam_frame,
        inputs=webcam,
        outputs=[webcam, hr_output, spo2_output, sdnn_output, rmssd_output, pnn50_output, sqi_output, plot_output],
        time_limit=None
    )

if __name__ == "__main__":
    demo.queue().launch()