File size: 46,645 Bytes
981e922
e5ef7e3
2c58d17
bc9be1b
0453ba8
 
21d112f
bc9be1b
2c58d17
 
c487f43
fe7f1d6
cfacf2f
cb18875
2c58d17
98be869
2c58d17
 
 
21d112f
 
 
 
 
 
 
 
 
c487f43
21d112f
c487f43
5cd0edf
c487f43
4d7914e
fe7f1d6
72d014e
2c58d17
 
fe7f1d6
72d014e
4d7914e
72d014e
 
2c58d17
bc9be1b
fe7f1d6
72d014e
 
bc9be1b
72d014e
 
2c58d17
c487f43
4d7914e
ee0f844
72d014e
 
2c58d17
4d7914e
ee0f844
0453ba8
 
 
 
72d014e
 
2c58d17
72d014e
4d7914e
 
2c58d17
72d014e
 
 
2c58d17
4d7914e
ee0f844
21d112f
0453ba8
21d112f
0453ba8
 
cfacf2f
7b785c9
342950f
 
 
 
21d112f
 
 
 
0453ba8
2c58d17
4d7914e
ee0f844
21d112f
bc9be1b
 
 
fe7f1d6
21d112f
 
fe7f1d6
72d014e
 
2c58d17
0453ba8
ee0f844
4903c34
 
21d112f
72d014e
 
0453ba8
 
ee0f844
4903c34
21d112f
72d014e
 
0453ba8
 
ee0f844
21d112f
 
 
 
 
fe7f1d6
72d014e
0453ba8
ee0f844
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342950f
ee0f844
 
342950f
ee0f844
 
b8cc483
ee0f844
 
 
b8cc483
c487f43
4234d4a
c487f43
ee0f844
2c58d17
4234d4a
556a8b5
21d112f
 
 
 
4903c34
21d112f
 
c487f43
4234d4a
556a8b5
21d112f
4234d4a
 
 
 
21d112f
 
 
4234d4a
c487f43
f56e9dc
556a8b5
f56e9dc
 
 
 
 
 
 
 
 
 
 
 
517b944
f56e9dc
556a8b5
4234d4a
f56e9dc
 
4234d4a
21d112f
4234d4a
517b944
f56e9dc
 
21d112f
fe7f1d6
21d112f
fe7f1d6
21d112f
 
fe7f1d6
342950f
 
bc9be1b
 
4d7914e
 
21d112f
fe7f1d6
21d112f
c487f43
 
342950f
 
4234d4a
2c58d17
c487f43
b8cc483
c487f43
c4b0e29
1b002e6
1ab48fb
556a8b5
1ab48fb
 
40e1ccf
1ab48fb
cfacf2f
1ab48fb
 
 
 
 
 
cfacf2f
 
 
 
1ab48fb
1b002e6
556a8b5
1ab48fb
00354fc
fe7f1d6
c4b0e29
fe7f1d6
1b002e6
c4b0e29
40e1ccf
 
cfacf2f
40e1ccf
cfacf2f
1ab48fb
1b002e6
fe7f1d6
c4b0e29
c487f43
 
c4b0e29
40e1ccf
 
1ab48fb
40e1ccf
 
fe7f1d6
c487f43
841455a
c4b0e29
cfacf2f
 
 
fe7f1d6
1ab48fb
841455a
c487f43
21d112f
 
 
556a8b5
c4b0e29
 
21d112f
 
556a8b5
c4b0e29
21d112f
c4b0e29
21d112f
 
2de4b30
21d112f
c4b0e29
21d112f
2de4b30
556a8b5
c4b0e29
a1eeef9
2de4b30
a1eeef9
2de4b30
a1eeef9
 
 
 
 
2de4b30
a1eeef9
 
2de4b30
a1eeef9
21d112f
c4b0e29
21d112f
2de4b30
a1eeef9
2de4b30
 
a1eeef9
 
342950f
 
a1eeef9
 
 
 
 
 
 
 
 
 
2de4b30
a1eeef9
2de4b30
21d112f
a1eeef9
342950f
 
 
a1eeef9
 
 
 
0244d26
fe7f1d6
a1eeef9
 
 
 
0244d26
a1eeef9
 
 
0244d26
 
 
 
 
 
 
 
 
 
517b944
a1eeef9
 
 
c4b0e29
 
 
a1eeef9
 
 
0244d26
21d112f
0244d26
fe7f1d6
f56e9dc
 
 
 
 
 
fe7f1d6
0244d26
 
 
 
 
 
 
 
 
 
 
f56e9dc
0244d26
f56e9dc
 
 
 
 
0244d26
f56e9dc
21d112f
c9d5e10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403e359
 
c9d5e10
 
 
 
 
 
 
 
a0d32bb
c9d5e10
 
 
 
 
 
 
 
 
a0d32bb
 
 
c9d5e10
a0d32bb
 
 
 
 
 
 
 
c9d5e10
403e359
c9d5e10
a0d32bb
403e359
 
 
 
 
 
 
 
 
 
a0d32bb
 
 
 
 
 
 
 
c9d5e10
a0d32bb
 
 
 
 
 
 
 
 
c9d5e10
403e359
a0d32bb
 
 
403e359
a0d32bb
 
 
 
 
4905f0b
c9d5e10
 
 
 
 
403e359
4905f0b
c9d5e10
 
a0d32bb
 
c9d5e10
 
 
c487f43
fe7f1d6
81a9038
1ab48fb
c487f43
a0d32bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03bc83d
47f3610
 
21d112f
517b944
f56e9dc
4234d4a
21d112f
517b944
f56e9dc
38fd139
4234d4a
47f3610
03bc83d
 
4234d4a
517b944
 
fe7f1d6
517b944
f56e9dc
21d112f
03bc83d
21d112f
 
a1eeef9
 
38fd139
2de4b30
21d112f
2de4b30
21d112f
a1eeef9
03bc83d
a1eeef9
21d112f
2de4b30
a1eeef9
d64b9f3
a0d32bb
 
 
 
81a9038
a0d32bb
 
 
 
 
 
 
 
 
 
 
81a9038
 
a0d32bb
30cc698
81a9038
 
30cc698
 
 
81a9038
30cc698
81a9038
 
a0d32bb
d64b9f3
03bc83d
f56e9dc
 
 
 
 
a1eeef9
 
0244d26
 
f56e9dc
 
 
 
0244d26
 
f56e9dc
 
 
 
 
 
 
 
0244d26
 
a0d32bb
fe7f1d6
d64b9f3
d095c33
fe7f1d6
9fa3181
 
 
 
 
 
 
 
4c8f956
1ab48fb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
import os
import gradio as gr
import numpy as np
import math
import cv2
import tempfile
from scipy.io import loadmat, savemat
from scipy import signal, sparse, linalg
from sklearn.decomposition import FastICA
import matplotlib.pyplot as plt
import time
import mat73  # For reading SCAMPS v7.3 mat files
import pandas as pd # For LinePlot
from typing import Tuple, Dict, List # For type hinting in the new section


# --- Constants ---
MIN_BPM, MAX_BPM = 45, 180

# Fitzpatrick skin type mean RGB values (approximated from research papers)
FITZPATRICK_RGB = {
    "Type I": [239, 207, 186], "Type II": [232, 194, 163],
    "Type III": [216, 172, 134], "Type IV": [193, 142, 107],
    "Type V": [151, 103, 70], "Type VI": [82, 57, 43]
}
FITZPATRICK_TYPES = list(FITZPATRICK_RGB.keys())
SKIN_COLOR_MAP = {1: "Type I", 2: "Type II", 3: "Type III", 4: "Type IV", 5: "Type V", 6: "Type VI"}

# =================================================================================
# SECTION 1: CORE RPPG LOGIC (Used by multiple tabs)
# =================================================================================

# --- 1a. Helper Functions for Signal Processing ---
def bandpass_filter(data, fs, min_hz, max_hz):
    """Applies a Butterworth bandpass filter to the signal."""
    nyquist = 0.5 * fs; b, a = signal.butter(4, [min_hz/nyquist, max_hz/nyquist], btype='band'); return signal.filtfilt(b, a, data)

def calculate_bpm(bvp_signal, fs):
    """Calculates BPM from a BVP signal using FFT."""
    min_hz, max_hz = MIN_BPM / 60.0, MAX_BPM / 60.0; fft_data = np.abs(np.fft.rfft(bvp_signal)); freqs = np.fft.rfftfreq(len(bvp_signal), 1 / fs)
    valid_indices = np.where((freqs >= min_hz) & (freqs <= max_hz))
    if len(valid_indices[0]) == 0: return 0
    peak_freq_index = valid_indices[0][np.argmax(fft_data[valid_indices])]; peak_freq = freqs[peak_freq_index]; return peak_freq * 60

def detrend(input_signal, lambda_value):
    """Applies the smoothness priors-based detrending from the rPPG-Toolbox."""
    signal_length = input_signal.shape[0]; H = np.identity(signal_length); ones = np.ones(signal_length); minus_twos = -2 * np.ones(signal_length)
    diags_data = np.array([ones, minus_twos, ones]); diags_index = np.array([0, 1, 2])
    D = sparse.spdiags(diags_data, diags_index, (signal_length - 2), signal_length).toarray()
    if input_signal.ndim == 1: input_signal = input_signal[:, np.newaxis]
    filtered_signal = np.dot((H - linalg.inv(H + (lambda_value ** 2) * np.dot(D.T, D))), input_signal); return filtered_signal.flatten()

# --- 1b. Implementations of All Unsupervised rPPG Models ---
def rppg_green(raw_signal, fs):
    """Selects the Green channel as the BVP signal."""
    min_hz, max_hz = MIN_BPM / 60.0, MAX_BPM / 60.0; green_channel = raw_signal[:, 1]; detrended_green = signal.detrend(green_channel)
    bvp = bandpass_filter(detrended_green, fs, min_hz, max_hz); return bvp

def rppg_ica(raw_signal, fs):
    """Uses Independent Component Analysis (ICA) to separate the BVP signal."""
    normalized_signal = np.zeros_like(raw_signal)
    for i in range(raw_signal.shape[1]):
        channel_detrended = detrend(raw_signal[:, i], 100)
        normalized_signal[:, i] = (channel_detrended - np.mean(channel_detrended)) / (np.std(channel_detrended) + 1e-6)
    ica = FastICA(n_components=3, random_state=0, max_iter=1000); ica_sources = ica.fit_transform(normalized_signal)
    min_hz, max_hz = MIN_BPM / 60.0, MAX_BPM / 60.0; max_power = -np.inf; best_component = None
    for component in ica_sources.T:
        fft_data = np.abs(np.fft.rfft(component)); freqs = np.fft.rfftfreq(len(component), 1 / fs)
        valid_indices = np.where((freqs >= min_hz) & (freqs <= max_hz))
        if len(valid_indices[0]) == 0: continue
        power = np.sum(fft_data[valid_indices]**2)
        if power > max_power: max_power = power; best_component = component
    if best_component is None: raise ValueError("ICA could not find a suitable component.")
    bvp = bandpass_filter(best_component, fs, min_hz, max_hz); return bvp

def rppg_chrom(raw_signal, fs):
    """Applies the Chrominance-based method (CHROM) using a sliding window."""
    RGB = raw_signal; FN = RGB.shape[0]; WinSec = 1.6; WinL = math.ceil(WinSec * fs)
    if WinL % 2: WinL += 1
    NWin = math.floor((FN - WinL / 2) / (WinL / 2)); S = np.zeros(FN); LPF, HPF = 0.7, 2.5; NyquistF = 0.5 * fs
    B, A = signal.butter(3, [LPF / NyquistF, HPF / NyquistF], 'bandpass')
    for i in range(NWin):
        WinS = int(i * WinL / 2); WinE = int(WinS + WinL)
        if WinE > FN:
            WinE = FN
            WinL = WinE - WinS
            if WinL < 2:
                break
        RGB_win = RGB[WinS:WinE, :]; RGB_mean = np.mean(RGB_win, axis=0); RGB_norm = RGB_win / (RGB_mean + 1e-6)
        Xs = 3 * RGB_norm[:, 0] - 2 * RGB_norm[:, 1]; Ys = 1.5 * RGB_norm[:, 0] + RGB_norm[:, 1] - 1.5 * RGB_norm[:, 2]
        Xf = signal.filtfilt(B, A, Xs); Yf = signal.filtfilt(B, A, Ys); alpha = (np.std(Xf) / (np.std(Yf) + 1e-6))
        S_win = Xf - alpha * Yf; S_win = S_win * signal.windows.hann(len(S_win)); S[WinS:WinE] = S[WinS:WinE] + S_win
    return S

def rppg_pos(raw_signal, fs):
    """Applies the Plane-Orthogonal-to-Skin (POS) method using a sliding window."""
    RGB = raw_signal; N = RGB.shape[0]; H = np.zeros(N); WinSec = 1.6; l = math.ceil(WinSec * fs)
    for n in range(N):
        m = n - l + 1
        if m >= 0:
            Cn = RGB[m:n + 1, :]; mean_color = np.mean(Cn, axis=0); Cn = Cn / (mean_color + 1e-6)
            projection_matrix = np.array([[0, 1, -1], [-2, 1, 1]]); S = np.dot(Cn, projection_matrix.T)
            std_S0, std_S1 = np.std(S[:, 0]), np.std(S[:, 1]); h = S[:, 0] + (std_S0 / (std_S1 + 1e-6)) * S[:, 1]
            H[m:n + 1] = H[m:n + 1] + (h - np.mean(h))
    BVP = detrend(H, 100); min_hz, max_hz = MIN_BPM / 60.0, MAX_BPM / 60.0
    BVP = bandpass_filter(BVP, fs, min_hz, max_hz); return BVP

def rppg_lgi(raw_signal, fs):
    """Applies the Local Group Invariance (LGI) method using SVD."""
    processed_data = raw_signal.T[np.newaxis, :, :]; U, _, _ = np.linalg.svd(processed_data)
    S = U[:, :, 0]; S = np.expand_dims(S, 2); SST = np.matmul(S, np.swapaxes(S, 1, 2))
    p = np.tile(np.identity(3), (S.shape[0], 1, 1)); P = p - SST; Y = np.matmul(P, processed_data)
    bvp = Y[:, 1, :].flatten(); min_hz, max_hz = MIN_BPM / 60.0, MAX_BPM / 60.0
    bvp = bandpass_filter(bvp, fs, min_hz, max_hz); return bvp

def rppg_omit(raw_signal, fs):
    """Applies the Orthogonal Matrix Imaging Technique (OMIT) using QR decomposition."""
    processed_data = raw_signal.T; Q, R = np.linalg.qr(processed_data)
    S = Q[:, 0].reshape(1, -1); P = np.identity(3) - np.matmul(S.T, S); Y = np.dot(P, processed_data)
    bvp = Y[1, :].flatten(); min_hz, max_hz = MIN_BPM / 60.0, MAX_BPM / 60.0
    bvp = bandpass_filter(bvp, fs, min_hz, max_hz); return bvp

def rppg_pbv(raw_signal, fs):
    """Applies the Blood Volume Pulse Signature (PBV) method."""
    processed_data = raw_signal.T; sig_mean = np.mean(processed_data, axis=1, keepdims=True)
    normalized_signal = processed_data / (sig_mean + 1e-6); pbv_n = np.std(normalized_signal, axis=1)
    pbv_d = np.sqrt(np.sum(np.var(normalized_signal, axis=1))); pbv = pbv_n / (pbv_d + 1e-6)
    C = normalized_signal; Ct = C.T; Q = np.matmul(C, Ct)
    W, _, _, _ = np.linalg.lstsq(Q, pbv, rcond=None); bvp = np.matmul(Ct, W)
    min_hz, max_hz = MIN_BPM / 60.0, MAX_BPM / 60.0
    bvp = bandpass_filter(bvp, fs, min_hz, max_hz); return bvp

def rppg_additive_pos(raw_signal, fs):
    """Applies the Additive POS model based on the research paper's philosophy."""
    bvp_pos = rppg_pos(raw_signal, fs)
    normalized_signal = np.zeros_like(raw_signal)
    for i in range(raw_signal.shape[1]):
        channel_detrended = detrend(raw_signal[:, i], 100)
        normalized_signal[:, i] = (channel_detrended - np.mean(channel_detrended)) / (np.std(channel_detrended) + 1e-6)
    ica = FastICA(n_components=3, random_state=0, max_iter=1000)
    ica_sources = ica.fit_transform(normalized_signal)
    max_power = -np.inf; bvp_component_index = -1
    for i, component in enumerate(ica_sources.T):
        fft_data = np.abs(np.fft.rfft(component)); freqs = np.fft.rfftfreq(len(component), 1 / fs)
        valid_indices = np.where((freqs >= 0.75) & (freqs <= 2.5))
        if len(valid_indices[0]) == 0: continue
        power = np.sum(fft_data[valid_indices]**2)
        if power > max_power: max_power = power; bvp_component_index = i
    residual_indices = [i for i in range(3) if i != bvp_component_index]
    if len(residual_indices) < 2: return bvp_pos
    residual_signal = ica_sources[:, residual_indices[0]] + ica_sources[:, residual_indices[1]]
    peak_freq = calculate_bpm(bvp_pos, fs) / 60.0
    if peak_freq == 0: return bvp_pos
    corrected_residual = bandpass_filter(residual_signal, fs, peak_freq - 0.2, peak_freq + 0.2)
    alpha = 0.5
    norm_bvp_pos = (bvp_pos - np.mean(bvp_pos)) / (np.std(bvp_pos) + 1e-6)
    norm_corrected_residual = (corrected_residual - np.mean(corrected_residual)) / (np.std(corrected_residual) + 1e-6)
    final_bvp = norm_bvp_pos + (alpha * norm_corrected_residual)
    return final_bvp

# =================================================================================
# SECTION 2: LOGIC FOR FILE ANALYZER TAB
# =================================================================================
MODEL_DISPATCHER = {"Additive POS": rppg_additive_pos, "GREEN": rppg_green, "ICA": rppg_ica, "CHROM": rppg_chrom, "POS": rppg_pos, "LGI": rppg_lgi, "OMIT": rppg_omit, "PBV": rppg_pbv}

def load_data_from_mat(mat_file):
    """Loads data and creates a video preview from an MMPD .mat file."""
    FS = 30.0; mat_data = loadmat(mat_file.name); video_frames = mat_data['video']
    raw_signal = np.mean(video_frames, axis=(1, 2)); gt_signal = mat_data['GT_ppg'].flatten()
    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file: output_path = tmp_file.name
    _, height, width, _ = video_frames.shape; fourcc = cv2.VideoWriter_fourcc(*'mp4v'); out = cv2.VideoWriter(output_path, fourcc, FS, (width, height))
    for frame_float in video_frames:
        frame_uint8 = (frame_float * 255).astype(np.uint8); frame_bgr = cv2.cvtColor(frame_uint8, cv2.COLOR_RGB2BGR); out.write(frame_bgr)
    out.release(); return raw_signal, gt_signal, FS, output_path

def load_data_from_ubfc(video_file, gt_file):
    """Loads data from UBFC-format files (.avi and .txt)."""
    cap = cv2.VideoCapture(video_file.name); fs = cap.get(cv2.CAP_PROP_FPS); frames = []
    while True:
        ret, frame = cap.read()
        if not ret: break
        frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    cap.release(); video_frames = np.array(frames); raw_signal = np.mean(video_frames, axis=(1, 2))
    with open(gt_file.name, "r") as f: gt_str = f.read().strip()
    gt_lines = gt_str.split('\n'); gt_signal = np.array([float(x) for x in gt_lines[0].split()])
    return raw_signal, gt_signal, fs, video_file.name

def load_data_from_scamps_mat(mat_file):
    """Loads data and creates a video preview from a SCAMPS .mat file."""
    FS = 30.0; mat_data = mat73.loadmat(mat_file.name)
    if 'Xsub' in mat_data: video_frames = mat_data['Xsub']
    else: raise gr.Error("SCAMPS .mat file must contain an 'Xsub' key for video.")
    if 'd_ppg' in mat_data: gt_signal = mat_data['d_ppg'].flatten()
    else: raise gr.Error("SCAMPS .mat file must contain a 'd_ppg' key for ground truth.")
    if 'fs' in mat_data: FS = mat_data['fs']
    raw_signal = np.mean(video_frames, axis=(1, 2))
    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file: output_path = tmp_file.name
    _, height, width, _ = video_frames.shape; fourcc = cv2.VideoWriter_fourcc(*'mp4v'); out = cv2.VideoWriter(output_path, fourcc, FS, (width, height))
    for frame_float in video_frames:
        frame_uint8 = (frame_float * 255).astype(np.uint8); frame_bgr = cv2.cvtColor(frame_uint8, cv2.COLOR_RGB2BGR); out.write(frame_bgr)
    out.release(); return raw_signal, gt_signal, FS, output_path

def process_file_and_evaluate(dataset_type, mat_file_mmpd, ubfc_vid_file, ubfc_gt_file, mat_file_scamps, selected_model):
    """Main processing pipeline for all file-based analysis."""
    if dataset_type == "MMPD":
        if mat_file_mmpd is None: raise gr.Error("Please upload an MMPD .mat file.")
        raw_signal, gt_signal, fs, video_path = load_data_from_mat(mat_file_mmpd)
    elif dataset_type == "UBFC":
        if ubfc_vid_file is None or ubfc_gt_file is None: raise gr.Error("Please upload both a video and a ground truth file for UBFC.")
        raw_signal, gt_signal, fs, video_path = load_data_from_ubfc(ubfc_vid_file, ubfc_gt_file)
    elif dataset_type == "SCAMPS":
        if mat_file_scamps is None: raise gr.Error("Please upload a SCAMPS .mat file.")
        raw_signal, gt_signal, fs, video_path = load_data_from_scamps_mat(mat_file_scamps)
    else: raise gr.Error("Invalid dataset type selected.")
    
    if len(gt_signal) != len(raw_signal): gt_signal = signal.resample(gt_signal, len(raw_signal))
    
    model_function = MODEL_DISPATCHER[selected_model]; predicted_bvp = model_function(raw_signal, fs)
    predicted_bpm = calculate_bpm(predicted_bvp, fs); gt_bpm = calculate_bpm(gt_signal, fs)
    
    fig = plt.figure(figsize=(12, 6))
    time_values = np.arange(len(gt_signal)) / fs
    norm_predicted_bvp = (predicted_bvp - np.mean(predicted_bvp)) / (np.std(predicted_bvp) + 1e-6)
    norm_gt_signal = (gt_signal - np.mean(gt_signal)) / (np.std(gt_signal) + 1e-6)
    plt.plot(time_values, norm_gt_signal, label=f'Ground Truth BVP (BPM: {gt_bpm:.2f})', alpha=0.8)
    plt.plot(time_values, norm_predicted_bvp, label=f'Predicted BVP ({selected_model}) (BPM: {predicted_bpm:.2f})', linestyle='--')
    plt.title("Ground Truth vs. Predicted BVP Signal"); plt.xlabel("Time (seconds)"), plt.ylabel("Normalized Amplitude"), plt.grid(True), plt.legend(), plt.tight_layout()
    
    mae = abs(predicted_bpm - gt_bpm); mape = (mae / gt_bpm) * 100 if gt_bpm != 0 else float('inf')
    eval_results = (f"Ground Truth BPM: {gt_bpm:.2f}\nPredicted BPM: {predicted_bpm:.2f}\n"
                    f"Mean Absolute Error (MAE): {mae:.2f}\nMean Absolute Percentage Error (MAPE): {mape:.2f} %")
    
    plt.close(fig) # Fix for memory leak
    return f"{predicted_bpm:.2f} BPM", fig, eval_results, video_path

# =================================================================================
# SECTION 3: LOGIC FOR LIVE WEBCAM PREDICTION TAB
# =================================================================================
BUFFER_SIZE = 300; FS_WEBCAM = 30; face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'); DETECTION_INTERVAL = 5

def analyze_bvp_signal(bvp_signal, fs):
    """Takes a BVP signal and returns HR, HRV, estimated BP, and plot data."""
    hr = calculate_bpm(bvp_signal, fs)
    try:
        peaks, _ = signal.find_peaks(bvp_signal, height=np.mean(bvp_signal), distance=fs * 0.5)
        if len(peaks) > 1:
            rr_intervals = np.diff(peaks) / fs * 1000
            successive_diffs = np.diff(rr_intervals)
            hrv_rmssd = np.sqrt(np.mean(successive_diffs ** 2))
        else:
            hrv_rmssd = 0
    except Exception:
        hrv_rmssd = 0
    systolic_bp = (0.5 * hr) + 90; diastolic_bp = (0.3 * hr) + 60
    time_values = np.arange(len(bvp_signal)) / fs
    plot_df = pd.DataFrame({"time": time_values, "BVP": bvp_signal})
    return hr, hrv_rmssd, systolic_bp, diastolic_bp, plot_df

def process_webcam_frame(frame, signal_buffer, last_update_time, selected_model, frame_counter, last_face_box):
    """Processes each frame from the webcam for live prediction."""
    if frame is None:
        return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), None, signal_buffer, last_update_time, frame_counter, last_face_box
    
    frame_with_feedback = frame.copy(); frame_counter += 1; face_box = None
    
    if frame_counter % DETECTION_INTERVAL == 0:
        gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY); faces = face_cascade.detectMultiScale(gray, 1.1, 5)
        if len(faces) > 0:
            face_box = faces[0]
            last_face_box = face_box
    else:
        face_box = last_face_box
    status_text = "Detecting face..."
    if face_box is not None:
        x, y, w, h = face_box; cv2.rectangle(frame_with_feedback, (x, y), (x + w, y + h), (0, 255, 0), 2)
        forehead_y, forehead_h = y + int(0.08 * h), int(0.18 * h); forehead_x, forehead_w = x + int(0.25 * w), int(0.5 * w)
        roi = frame[forehead_y:forehead_y + forehead_h, forehead_x:forehead_x + forehead_w]
        if roi.size > 0:
            avg_rgb = np.mean(roi, axis=(0, 1)); signal_buffer.append(avg_rgb)
            if len(signal_buffer) > BUFFER_SIZE:
                signal_buffer.pop(0)
            progress = int((len(signal_buffer) / BUFFER_SIZE) * 100); status_text = f"Collecting data... ({progress}%)"
    else:
        status_text = "Face not detected..."
    
    current_time = time.time()
    if len(signal_buffer) == BUFFER_SIZE and (current_time - last_update_time) > 1:
        model_function = MODEL_DISPATCHER[selected_model]; bvp_signal = model_function(np.array(signal_buffer), FS_WEBCAM)
        hr, hrv, sbp, dbp, plot_data = analyze_bvp_signal(bvp_signal, FS_WEBCAM)
        hr_text = f"{hr:.2f} BPM"; hrv_text = f"{hrv:.2f} ms (RMSSD)"; bp_text = f"{sbp:.1f} / {dbp:.1f} mmHg (est.)"
        return status_text, hr_text, hrv_text, bp_text, plot_data, frame_with_feedback, signal_buffer, current_time, frame_counter, last_face_box
    
    return status_text, gr.update(), gr.update(), gr.update(), gr.update(), frame_with_feedback, signal_buffer, last_update_time, frame_counter, last_face_box

# =================================================================================
# SECTION 4: LOGIC FOR SKIN COLOR MANIPULATION TAB
# =================================================================================
def get_skin_mask(frame_rgb):
    """Creates a binary mask to identify skin pixels."""
    frame_hsv = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2HSV); lower_hsv = np.array([0, 48, 80], dtype="uint8")
    upper_hsv = np.array([20, 255, 255], dtype="uint8"); skin_mask = cv2.inRange(frame_hsv, lower_hsv, upper_hsv); return skin_mask

def get_skin_info_and_preview(mat_file):
    """Loads an MMPD file, identifies its skin type, and creates a video preview."""
    if mat_file is None: return "N/A", gr.Dropdown(choices=FITZPATRICK_TYPES), None
    try:
        mat_data = loadmat(mat_file.name); original_skin_type_id = mat_data['skin_color'].item()
        original_skin_type_name = SKIN_COLOR_MAP.get(original_skin_type_id, "Unknown")
        dropdown_choices = [ftype for ftype in FITZPATRICK_TYPES if ftype != original_skin_type_name]
        _, _, _, video_path = load_data_from_mat(mat_file)
        return original_skin_type_name, gr.Dropdown(choices=dropdown_choices, value=dropdown_choices[0]), video_path
    except Exception as e: raise gr.Error(f"Failed to process .mat file: {e}")

def manipulate_and_compare(mat_file, target_skin_type_name, selected_model):
    """Manipulates skin color and performs a comparative analysis."""
    if mat_file is None or not target_skin_type_name or not selected_model: raise gr.Error("Please upload a file and select all options.")
    mat_data = loadmat(mat_file.name); original_video_frames_float = mat_data['video']
    original_video_frames_uint8 = (original_video_frames_float * 255).astype(np.uint8)
    gt_signal = mat_data['GT_ppg'].flatten(); fs = 30.0
    original_raw_signal = np.mean(original_video_frames_float, axis=(1, 2))
    if len(gt_signal) != len(original_raw_signal): gt_signal = signal.resample(gt_signal, len(original_raw_signal))
    gt_bpm = calculate_bpm(gt_signal, fs); model_function = MODEL_DISPATCHER[selected_model]
    original_bvp = model_function(original_raw_signal, fs); original_bpm = calculate_bpm(original_bvp, fs); original_mae = abs(original_bpm - gt_bpm)
    target_color = np.array(FITZPATRICK_RGB[target_skin_type_name]); first_frame = original_video_frames_uint8[0]
    gray = cv2.cvtColor(first_frame, cv2.COLOR_RGB2GRAY); faces = face_cascade.detectMultiScale(gray, 1.1, 5)
    if len(faces) == 0: raise gr.Error("Could not detect a face in the first frame.")
    x, y, w, h = faces[0]; forehead_roi = first_frame[y + int(0.08*h) : y + int(0.26*h), x + int(0.25*w) : x + int(0.75*w)]
    source_color = np.mean(forehead_roi, axis=(0, 1)); manipulated_frames = []; color_offset = target_color - source_color
    for frame in original_video_frames_uint8:
        skin_mask = get_skin_mask(frame); skin_mask_3ch = cv2.cvtColor(skin_mask, cv2.COLOR_GRAY2RGB)
        offset_image = (skin_mask_3ch / 255 * color_offset).astype('int16')
        manipulated_frame = np.clip(frame.astype('int16') + offset_image, 0, 255).astype('uint8'); manipulated_frames.append(manipulated_frame)
    manipulated_video_float = np.array(manipulated_frames) / 255.0
    manipulated_raw_signal = np.mean(manipulated_video_float, axis=(1, 2))
    manipulated_bvp = model_function(manipulated_raw_signal, fs); manipulated_bpm = calculate_bpm(manipulated_bvp, fs); manipulated_mae = abs(manipulated_bpm - gt_bpm)
    comparison_text = (f"--- Analysis using {selected_model} ---\n\n"
                       f"Ground Truth BPM: {gt_bpm:.2f}\n\n"
                       f"Original Video:\n  - Predicted BPM: {original_bpm:.2f}\n  - MAE: {original_mae:.2f}\n\n"
                       f"Manipulated Video ({target_skin_type_name}):\n  - Predicted BPM: {manipulated_bpm:.2f}\n  - MAE: {manipulated_mae:.2f}")
    fig = plt.figure(figsize=(12, 6))
    time_values = np.arange(len(gt_signal)) / fs
    norm_gt = (gt_signal - np.mean(gt_signal)) / (np.std(gt_signal) + 1e-6)
    norm_orig = (original_bvp - np.mean(original_bvp)) / (np.std(original_bvp) + 1e-6)
    norm_manip = (manipulated_bvp - np.mean(manipulated_bvp)) / (np.std(manipulated_bvp) + 1e-6)
    plt.plot(time_values, norm_gt, label=f'Ground Truth (BPM: {gt_bpm:.2f})', alpha=0.9)
    plt.plot(time_values, norm_orig, label=f'Original Pred. (BPM: {original_bpm:.2f})', linestyle='--')
    plt.plot(time_values, norm_manip, label=f'Manipulated Pred. (BPM: {manipulated_bpm:.2f})', linestyle=':')
    plt.title("BVP Signal Comparison"); plt.xlabel("Time (seconds)"); plt.ylabel("Normalized Amplitude"); plt.grid(True); plt.legend(); plt.tight_layout()
    new_mat_data = mat_data.copy(); new_mat_data['video'] = manipulated_video_float
    new_skin_type_id = FITZPATRICK_TYPES.index(target_skin_type_name) + 1; new_mat_data['skin_color'] = np.array([[new_skin_type_id]])
    with tempfile.NamedTemporaryFile(suffix=".mat", delete=False) as tmp_file: new_mat_path = tmp_file.name
    savemat(new_mat_path, new_mat_data, do_compression=True)
    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file: video_out_path = tmp_file.name
    _, height, width, _ = original_video_frames_uint8.shape
    fourcc = cv2.VideoWriter_fourcc(*'mp4v'); out = cv2.VideoWriter(video_out_path, fourcc, fs, (width, height))
    for frame in manipulated_frames: out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
    out.release()
    plt.close(fig) # Fix for memory leak
    return video_out_path, new_mat_path, comparison_text, fig

# =================================================================================
# SECTION 5: LOGIC FOR DATASET CONVERTER TAB
# =================================================================================
def convert_ubfc_to_mmpd(video_file, gt_file, resize_option):
    """Converts UBFC files to a single MMPD-compatible .mat file."""
    if video_file is None or gt_file is None:
        raise gr.Error("Please upload both UBFC video and ground truth files.")
    cap = cv2.VideoCapture(video_file.name); frames = []
    fs = cap.get(cv2.CAP_PROP_FPS)
    message = "Conversion successful!"
    while True:
        ret, frame = cap.read()
        if not ret: break
        if resize_option == "Downsample to 320p (Recommended)":
            target_width = 320
            original_width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
            original_height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
            target_height = int(target_width * (original_height / original_width))
            processed_frame = cv2.resize(frame, (target_width, target_height))
            message = f"Conversion successful! Video resized to {target_width}x{target_height}."
        else:
            processed_frame = frame
        frames.append(cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB) / 255.0)
    cap.release(); video_frames_float = np.array(frames, dtype=np.float32)
    with open(gt_file.name, "r") as f: gt_str = f.read().strip()
    gt_lines = gt_str.split('\n'); gt_signal = np.array([float(x) for x in gt_lines[0].split()])
    mmpd_data = {
        'video': video_frames_float, 'GT_ppg': gt_signal, 'skin_color': np.array([[3]]), 'gender': np.array([['male']]),
        'light': np.array([['LED-low']]), 'motion': np.array([['Stationary']]), 'exercise': np.array([['False']]),
        'glasser': np.array([['False']]), 'hair_cover': np.array([['False']]), 'makeup': np.array([['False']]), 'fps': np.array([[fs]])
    }
    with tempfile.NamedTemporaryFile(suffix=".mat", delete=False) as tmp_file: new_mat_path = tmp_file.name
    savemat(new_mat_path, mmpd_data, do_compression=True)
    return new_mat_path, message

def convert_scamps_to_mmpd(mat_file, resize_option):
    """Converts a SCAMPS .mat file to an MMPD-compatible .mat file."""
    if mat_file is None: raise gr.Error("Please upload a SCAMPS .mat file.")
    mat_data = mat73.loadmat(mat_file.name)
    if 'Xsub' in mat_data: video_frames = mat_data['Xsub']
    else: raise gr.Error("SCAMPS .mat file must contain an 'Xsub' key.")
    if 'd_ppg' in mat_data: gt_signal = mat_data['d_ppg'].flatten()
    else: raise gr.Error("SCAMPS .mat file must contain a 'd_ppg' key.")
    fs = mat_data.get('fs', 30.0)
    message = "Conversion successful!"
    if resize_option == "Downsample to 320p (Recommended)":
        num_frames, h, w, _ = video_frames.shape
        target_width = 320; target_height = int(target_width * (h / w))
        resized_frames = np.zeros((num_frames, target_height, target_width, 3), dtype=np.float32)
        for i, frame in enumerate(video_frames):
            resized_frames[i] = cv2.resize(frame, (target_width, target_height), interpolation=cv2.INTER_AREA)
        video_frames_float = resized_frames
        message = f"Conversion successful! Video resized to {target_width}x{target_height}."
    else:
        video_frames_float = np.array(video_frames, dtype=np.float32)
    mmpd_data = {
        'video': video_frames_float, 'GT_ppg': gt_signal, 'skin_color': np.array([[3]]), 'gender': np.array([['male']]),
        'light': np.array([['LED-low']]), 'motion': np.array([['Stationary']]), 'exercise': np.array([['False']]),
        'glasser': np.array([['False']]), 'hair_cover': np.array([['False']]), 'makeup': np.array([['False']]), 'fps': np.array([[fs]])
    }
    with tempfile.NamedTemporaryFile(suffix=".mat", delete=False) as tmp_file: new_mat_path = tmp_file.name
    savemat(new_mat_path, mmpd_data, do_compression=True)
    return new_mat_path, message

# =================================================================================
# SECTION 6: LOGIC FOR SYNTHETIC DATASET GENERATOR TAB
# =================================================================================
class CameraConditionSimulator:
    def __init__(self, seed: int = 42):
        np.random.seed(seed)
    def add_sensor_noise(self, frame: np.ndarray, noise_level: str = 'medium') -> np.ndarray:
        noise_params = {'low': {'shot_sigma': 2, 'read_sigma': 1}, 'medium': {'shot_sigma': 5, 'read_sigma': 3}, 'high': {'shot_sigma': 10, 'read_sigma': 5}}
        params = noise_params[noise_level]
        shot_noise = np.random.normal(0, params['shot_sigma'], frame.shape) * np.sqrt(frame / 255.0)
        read_noise = np.random.normal(0, params['read_sigma'], frame.shape)
        noisy_frame = frame + shot_noise + read_noise
        return np.clip(noisy_frame, 0, 255).astype(np.uint8)
    def add_motion_blur(self, frame: np.ndarray, intensity: float = 0.5) -> np.ndarray:
        kernel_size = int(5 * intensity) * 2 + 1
        if kernel_size < 3: return frame
        kernel = np.zeros((kernel_size, kernel_size)); kernel[kernel_size // 2, :] = np.ones(kernel_size); kernel = kernel / kernel_size
        return cv2.filter2D(frame, -1, kernel).astype(np.uint8)
    def simulate_compression(self, frame: np.ndarray, quality: int = 25) -> np.ndarray:
        frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]; _, encoded = cv2.imencode('.jpg', frame_bgr, encode_param)
        return cv2.imdecode(encoded, cv2.IMREAD_COLOR)
    def adjust_lighting(self, frame: np.ndarray, illumination: int = 300) -> np.ndarray:
        scale_factor = illumination / 300; gamma = 1.0 if scale_factor >= 1.0 else 0.8
        adjusted = np.power(frame.astype(np.float32) * scale_factor / 255.0, gamma) * 255.0
        if illumination < 200: adjusted = self.add_sensor_noise(adjusted, 'high')
        elif illumination < 300: adjusted = self.add_sensor_noise(adjusted, 'medium')
        return np.clip(adjusted, 0, 255).astype(np.uint8)

def generate_synthetic_dataset(mat_file, resolution_key, environment_key, selected_model):
    if mat_file is None: raise gr.Error("Please upload an MMPD .mat file first.")
    
    resolution_configs = {'480p': {'resolution': (640, 480), 'fps': 30, 'noise_level': 'medium', 'compression_quality': 28}, '720p': {'resolution': (1280, 720), 'fps': 30, 'noise_level': 'medium', 'compression_quality': 25}, '1080p30': {'resolution': (1920, 1080), 'fps': 30, 'noise_level': 'low', 'compression_quality': 23}}
    environment_presets = {'optimal': {'illumination': 300, 'motion_blur': 0.0}, 'low_light': {'illumination': 150, 'motion_blur': 0.0}, 'motion': {'illumination': 300, 'motion_blur': 0.5}}

    res_config = resolution_configs[resolution_key]
    env_config = environment_presets[environment_key]
    
    mat_data = loadmat(mat_file.name)
    original_video_frames_float = mat_data['video']
    gt_signal = mat_data['GT_ppg'].flatten()
    fs = mat_data.get('fps', np.array([[30.0]])).item()
    
    original_raw_signal = np.mean(original_video_frames_float, axis=(1, 2))
    if len(gt_signal) != len(original_raw_signal): gt_signal = signal.resample(gt_signal, len(original_raw_signal))
    gt_bpm = calculate_bpm(gt_signal, fs)
    model_function = MODEL_DISPATCHER[selected_model]
    original_bvp = model_function(original_raw_signal, fs)
    original_bpm = calculate_bpm(original_bvp, fs)
    original_mae = abs(original_bpm - gt_bpm)

    simulator = CameraConditionSimulator()
    new_frames_uint8_bgr = []
    
    for frame_float in original_video_frames_float:
        frame_uint8_rgb = (frame_float * 255).astype(np.uint8)
        resized_rgb = cv2.resize(frame_uint8_rgb, res_config['resolution'], interpolation=cv2.INTER_AREA)
        processed_rgb = simulator.adjust_lighting(resized_rgb, env_config['illumination'])
        if env_config['motion_blur'] > 0: processed_rgb = simulator.add_motion_blur(processed_rgb, env_config['motion_blur'])
        processed_rgb = simulator.add_sensor_noise(processed_rgb, res_config['noise_level'])
        compressed_bgr = simulator.simulate_compression(processed_rgb, res_config['compression_quality'])
        new_frames_uint8_bgr.append(compressed_bgr)

    new_video_float_rgb = np.array([cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in new_frames_uint8_bgr], dtype=np.float32) / 255.0
    synthetic_raw_signal = np.mean(new_video_float_rgb, axis=(1, 2))
    synthetic_bvp = model_function(synthetic_raw_signal, fs)
    synthetic_bpm = calculate_bpm(synthetic_bvp, fs)
    synthetic_mae = abs(synthetic_bpm - gt_bpm)

    comparison_text = (f"--- Analysis using {selected_model} ---\n\n"
                       f"Ground Truth BPM: {gt_bpm:.2f}\n\n"
                       f"Original Video:\n  - Predicted BPM: {original_bpm:.2f}\n  - MAE: {original_mae:.2f}\n\n"
                       f"Synthetic Video ({resolution_key}, {environment_key}):\n  - Predicted BPM: {synthetic_bpm:.2f}\n  - MAE: {synthetic_mae:.2f}")

    fig = plt.figure(figsize=(12, 6))
    time_values = np.arange(len(gt_signal)) / fs
    norm_gt = (gt_signal - np.mean(gt_signal)) / (np.std(gt_signal) + 1e-6)
    norm_orig = (original_bvp - np.mean(original_bvp)) / (np.std(original_bvp) + 1e-6)
    norm_synth = (synthetic_bvp - np.mean(synthetic_bvp)) / (np.std(synthetic_bvp) + 1e-6)
    plt.plot(time_values, norm_gt, label=f'Ground Truth (BPM: {gt_bpm:.2f})', alpha=0.9)
    plt.plot(time_values, norm_orig, label=f'Original Pred. (BPM: {original_bpm:.2f})', linestyle='--')
    plt.plot(time_values, norm_synth, label=f'Synthetic Pred. (BPM: {synthetic_bpm:.2f})', linestyle=':')
    plt.title("BVP Signal Comparison"); plt.xlabel("Time (seconds)"); plt.ylabel("Normalized Amplitude"); plt.grid(True); plt.legend(); plt.tight_layout()
    
    num_frames, h, w, _ = new_video_float_rgb.shape
    target_width = 320
    target_height = int(target_width * (h / w))
    resized_frames_for_mat = np.zeros((num_frames, target_height, target_width, 3), dtype=np.float32)
    for i, frame in enumerate(new_video_float_rgb):
        frame_uint8 = (frame * 255).astype(np.uint8)
        resized_frame = cv2.resize(frame_uint8, (target_width, target_height), interpolation=cv2.INTER_AREA)
        resized_frames_for_mat[i] = resized_frame.astype(np.float32) / 255.0

    new_mat_data = mat_data.copy(); new_mat_data['video'] = resized_frames_for_mat
    new_mat_data['fps'] = np.array([[res_config['fps']]])
    with tempfile.NamedTemporaryFile(suffix=".mat", delete=False) as tmp_file: new_mat_path = tmp_file.name
    savemat(new_mat_path, new_mat_data, do_compression=True)
    
    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file: video_out_path = tmp_file.name
    fourcc = cv2.VideoWriter_fourcc(*'mp4v'); out = cv2.VideoWriter(video_out_path, fourcc, res_config['fps'], res_config['resolution'])
    for frame in new_frames_uint8_bgr:
        out.write(frame)
    out.release()

    plt.close(fig)
    return video_out_path, new_mat_path, comparison_text, fig

# =================================================================================
# SECTION 7: GRADIO USER INTERFACE
# =================================================================================
with gr.Blocks(theme=gr.themes.Soft(), title="rPPG Analysis and Prediction Toolbox") as demo:
    gr.Markdown("# rPPG Analysis and Prediction Toolbox\n*Idea taken from the [rPPG-Toolbox](https://github.com/ubicomplab/rPPG-Toolbox)*")
    
    with gr.Tabs():
        with gr.TabItem("Live Webcam Prediction"):
            with gr.Row():
                with gr.Column():
                    gr.Markdown("## 1. Live Camera Feed"); webcam_component = gr.Image(sources=["webcam"], streaming=True, label="Webcam Feed with Face Detection")
                    model_selector_live = gr.Dropdown(choices=["Additive POS", "POS", "CHROM", "ICA", "GREEN", "LGI", "OMIT", "PBV"], value="Additive POS", label="Select rPPG Model for Live Analysis")
                with gr.Column():
                    gr.Markdown("## 2. Real-Time Results")
                    status_live_text = gr.Textbox(label="Status", interactive=False)
                    hr_live_output = gr.Textbox(label="Heart Rate (HR)")
                    hrv_live_output = gr.Textbox(label="Heart Rate Variability (HRV)")
                    bp_live_output = gr.Textbox(label="Blood Pressure (BP) Estimation")
                    gr.Markdown("*Disclaimer: Blood pressure estimation is for demonstration purposes only and is not medically accurate.*")
                    bvp_plot_live = gr.LinePlot(x="time", y="BVP", title="Live BVP Signal", show_label=False, width=400, height=300)
            signal_buffer_state = gr.State([]); last_update_time_state = gr.State(0)
            frame_counter_state = gr.State(0); last_face_box_state = gr.State(None)
            webcam_component.stream(fn=process_webcam_frame, inputs=[webcam_component, signal_buffer_state, last_update_time_state, model_selector_live, frame_counter_state, last_face_box_state], outputs=[status_live_text, hr_live_output, hrv_live_output, bp_live_output, bvp_plot_live, webcam_component, signal_buffer_state, last_update_time_state, frame_counter_state, last_face_box_state])

        with gr.TabItem("File Analyzer"):
            with gr.Row():
                with gr.Column(scale=1):
                    gr.Markdown("## 1. Inputs"); video_output_file = gr.Video(label="Video Preview", interactive=False)
                    dataset_type_selector = gr.Radio(choices=["MMPD", "UBFC", "SCAMPS"], value="MMPD", label="Select Dataset Format")
                    with gr.Group(visible=True) as mmpd_uploader: mat_input_mmpd = gr.File(label="Upload .mat File (MMPD Format)", file_types=[".mat"])
                    with gr.Group(visible=False) as ubfc_uploader:
                        ubfc_vid_input = gr.File(label="Upload video.avi File", file_types=["video"]); ubfc_gt_input = gr.File(label="Upload ground_truth.txt", file_types=[".txt"])
                    with gr.Group(visible=False) as scamps_uploader:
                        mat_input_scamps = gr.File(label="Upload .mat File (SCAMPS Format)", file_types=[".mat"])
                    model_selector_file = gr.Dropdown(choices=["Additive POS", "POS", "CHROM", "ICA", "GREEN", "LGI", "OMIT", "PBV"], value="Additive POS", label="Select rPPG Model")
                    run_button_file = gr.Button("Analyze & Evaluate", variant="primary")
                with gr.Column(scale=2):
                    gr.Markdown("## 2. Results"); bpm_output_file = gr.Textbox(label="Predicted Heart Rate (BPM)")
                    bvp_plot_file = gr.Plot(label="BVP Signal Comparison"); eval_output_file = gr.Textbox(label="Evaluation Metrics", lines=4)
            def switch_uploader(dataset_type):
                if dataset_type == "MMPD": return gr.Group(visible=True), gr.Group(visible=False), gr.Group(visible=False)
                elif dataset_type == "UBFC": return gr.Group(visible=False), gr.Group(visible=True), gr.Group(visible=False)
                else: return gr.Group(visible=False), gr.Group(visible=False), gr.Group(visible=True)
            dataset_type_selector.change(fn=switch_uploader, inputs=dataset_type_selector, outputs=[mmpd_uploader, ubfc_uploader, scamps_uploader])
            run_button_file.click(fn=process_file_and_evaluate, inputs=[dataset_type_selector, mat_input_mmpd, ubfc_vid_input, ubfc_gt_input, mat_input_scamps, model_selector_file], outputs=[bpm_output_file, bvp_plot_file, eval_output_file, video_output_file])

        with gr.TabItem("Skin Color Manipulation"):
            with gr.Row():
                with gr.Column():
                    gr.Markdown("## 1. Upload & Select"); mat_input_skin = gr.File(label="Upload .mat File (MMPD Format)", file_types=[".mat"])
                    original_skin_type_text = gr.Textbox(label="Original Skin Type", interactive=False); target_skin_type_dropdown = gr.Dropdown(label="Select Target Skin Type")
                    model_selector_skin = gr.Dropdown(choices=["Additive POS", "POS", "CHROM", "ICA", "GREEN", "LGI", "OMIT", "PBV"], value="Additive POS", label="Select rPPG Model for Comparison")
                    manipulate_button = gr.Button("Manipulate & Compare", variant="primary")
                with gr.Column():
                    gr.Markdown("## 2. Preview, Compare & Download")
                    with gr.Row():
                        video_preview_original = gr.Video(label="Original Video", interactive=False); video_preview_manipulated = gr.Video(label="Manipulated Video", interactive=False)
                    comparison_results_text = gr.Textbox(label="Comparison Results", lines=8)
                    comparison_plot = gr.Plot(label="BVP Signal Comparison Plot")
                    download_button = gr.File(label="Download Manipulated .mat File", interactive=True)
            mat_input_skin.upload(fn=get_skin_info_and_preview, inputs=mat_input_skin, outputs=[original_skin_type_text, target_skin_type_dropdown, video_preview_original])
            manipulate_button.click(fn=manipulate_and_compare, inputs=[mat_input_skin, target_skin_type_dropdown, model_selector_skin], outputs=[video_preview_manipulated, download_button, comparison_results_text, comparison_plot])

        with gr.TabItem("Resolution and Environment Converter"):
            with gr.Row():
                with gr.Column():
                    gr.Markdown("## 1. Upload & Configure"); synth_mat_input = gr.File(label="Upload .mat File (MMPD Format)", file_types=[".mat"])
                    original_resolution_text = gr.Textbox(label="Original Video Resolution", interactive=False)
                    synth_resolution_dd = gr.Dropdown(label="Target Resolution", choices=['480p', '720p', '1080p30'], value='720p')
                    synth_env_dd = gr.Dropdown(label="Target Environment", choices=['optimal', 'low_light', 'motion'], value='optimal')
                    model_selector_synth = gr.Dropdown(choices=["Additive POS", "POS", "CHROM", "ICA", "GREEN", "LGI", "OMIT", "PBV"], value="Additive POS", label="Select rPPG Model for Comparison")
                    synth_generate_btn = gr.Button("Generate & Compare", variant="primary")
                with gr.Column():
                    gr.Markdown("## 2. Preview, Compare & Download")
                    with gr.Row():
                        synth_video_original = gr.Video(label="Original Video", interactive=False); synth_video_generated = gr.Video(label="Generated Synthetic Video", interactive=False)
                    synth_comparison_text = gr.Textbox(label="Comparison Results", lines=8)
                    synth_comparison_plot = gr.Plot(label="BVP Signal Comparison Plot")
                    synth_download_btn = gr.File(label="Download Generated .mat File", interactive=True)
            def synth_get_info_and_preview(mat_file):
                if mat_file is None: return None, "N/A", gr.Dropdown(choices=['480p', '720p', '1080p30'])
                _, _, _, video_path = load_data_from_mat(mat_file)
                mat_data = loadmat(mat_file.name)
                h, w = mat_data['video'].shape[1], mat_data['video'].shape[2]
                original_res_str = f"{w}x{h}"
                all_resolutions = {'480p': 480, '720p': 720, '1080p30': 1080}
                original_res_key = None
                for key, val in all_resolutions.items():
                    if abs(h - val) < 50: original_res_key = key; break
                new_choices = [res for res in all_resolutions.keys() if res != original_res_key]
                return video_path, original_res_str, gr.Dropdown(choices=new_choices, value=new_choices[0] if new_choices else None)
            synth_mat_input.upload(fn=synth_get_info_and_preview, inputs=synth_mat_input, outputs=[synth_video_original, original_resolution_text, synth_resolution_dd])
            synth_generate_btn.click(fn=generate_synthetic_dataset, inputs=[synth_mat_input, synth_resolution_dd, synth_env_dd, model_selector_synth], outputs=[synth_video_generated, synth_download_btn, synth_comparison_text, synth_comparison_plot])
        
        with gr.TabItem("Dataset Converter"):
            with gr.Column():
                gr.Markdown("## 1. Select Conversion Type")
                conversion_type_selector = gr.Radio(choices=["UBFC to MMPD", "SCAMPS to MMPD"], value="UBFC to MMPD", label="Select Conversion")
                with gr.Group(visible=True) as ubfc_converter_group:
                    gr.Markdown("### Upload UBFC Files")
                    ubfc_conv_vid = gr.File(label="Upload video.avi File", file_types=["video"])
                    ubfc_conv_gt = gr.File(label="Upload ground_truth.txt", file_types=[".txt"])
                    ubfc_resize_option = gr.Radio(choices=["Downsample to 320p (Recommended)", "Keep Original Size (May Fail for Large Files)"], value="Downsample to 320p (Recommended)", label="Video Size Option")
                    gr.Markdown("*Warning: Keeping original size may cause an OverflowError if the video file is very large.*")
                    convert_button_ubfc = gr.Button("Convert UBFC to MMPD", variant="primary")
                with gr.Group(visible=False) as scamps_converter_group:
                    gr.Markdown("### Upload SCAMPS File")
                    scamps_conv_mat = gr.File(label="Upload .mat File (SCAMPS Format)", file_types=[".mat"])
                    scamps_resize_option = gr.Radio(choices=["Downsample to 320p (Recommended)", "Keep Original Size (May Fail for Large Files)"], value="Downsample to 320p (Recommended)", label="Video Size Option")
                    gr.Markdown("*Warning: Keeping original size may cause an OverflowError if the video file is very large.*")
                    convert_button_scamps = gr.Button("Convert SCAMPS to MMPD", variant="primary")
                gr.Markdown("## 2. Download Converted File")
                status_text_conv = gr.Textbox(label="Status", interactive=False)
                download_button_conv = gr.File(label="Download Converted .mat File", interactive=True)
            def switch_converter(conv_type):
                if conv_type == "UBFC to MMPD": return gr.Group(visible=True), gr.Group(visible=False)
                else: return gr.Group(visible=False), gr.Group(visible=True)
            conversion_type_selector.change(fn=switch_converter, inputs=conversion_type_selector, outputs=[ubfc_converter_group, scamps_converter_group])
            convert_button_ubfc.click(fn=convert_ubfc_to_mmpd, inputs=[ubfc_conv_vid, ubfc_conv_gt, ubfc_resize_option], outputs=[download_button_conv, status_text_conv])
            convert_button_scamps.click(fn=convert_scamps_to_mmpd, inputs=[scamps_conv_mat, scamps_resize_option], outputs=[download_button_conv, status_text_conv])
    
    gr.Markdown("---")
    gr.Markdown("This program was made by Indra Dewaji for the purpose of research in the Doctor of Professional Practice at Universiti Teknologi Malaysia")

demo.queue()
demo.launch(
    server_name="0.0.0.0",
    server_port=int(os.environ.get("PORT", 7860)),
    show_error=True,
    ssr_mode=False,
)