File size: 6,335 Bytes
e391a84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
infrastructure/processing/scipy_cardiogan_preprocessor.py
──────────────────────────────────────────────────────────
SciPy implementation of CardioGANSignalPreprocessor.
"""
from __future__ import annotations

import numpy as np
from scipy.interpolate import interp1d
from scipy.signal import butter, filtfilt

from src.domain.exceptions.pipeline_exceptions import PreprocessingError
from src.domain.interfaces.services.cardiogan_preprocessor import CardioGANSignalPreprocessor
from src.shared.constants import (
    CARDIOGAN_ORIG_FS,
    CARDIOGAN_TARGET_FS,
    CARDIOGAN_WINDOW_SAMPLES,
    CARDIOGAN_OVERLAP,
    VGTLNET_WINDOW_SIZE,
)
from src.shared.logger import get_logger

logger = get_logger(__name__)


class SciPyCardioGANPreprocessor(CardioGANSignalPreprocessor):
    """
    CardioGAN signal preprocessor using SciPy for resampling, filtering, and windowing.
    """

    def preprocess_ppg(self, ppg_raw: np.ndarray) -> np.ndarray:
        """
        Full CardioGAN preprocessing pipeline:
            1. Resample: 125 Hz -> 128 Hz
            2. Filter:   Butterworth bandpass 1-8 Hz (order=4)
            3. Normalize: Z-score (per-subject)
            4. Segment:  512-sample sliding windows with 10% overlap
            5. Normalize: Min-max normalize per window to [-1, 1]
        """
        try:
            ppg_flat = ppg_raw.flatten().astype(np.float32)
            if len(ppg_flat) == 0:
                raise PreprocessingError("preprocess_ppg", "PPG signal array is empty")

            # 1. Resample 125 Hz -> 128 Hz
            ppg_128 = self._resample_signal(
                ppg_flat, CARDIOGAN_ORIG_FS, CARDIOGAN_TARGET_FS
            )

            # 2. Filter: Butterworth 1-8 Hz
            ppg_filt = self._bandpass_butter(
                ppg_128, CARDIOGAN_TARGET_FS, low=1.0, high=8.0
            )

            # 3. Z-score normalization
            ppg_norm = self._zscore_normalize(ppg_filt)

            # 4. Segment into windows
            ppg_wins = self._segment_windows(
                ppg_norm, CARDIOGAN_WINDOW_SAMPLES, CARDIOGAN_OVERLAP
            )

            if len(ppg_wins) == 0:
                raise PreprocessingError(
                    "preprocess_ppg",
                    f"PPG signal length {len(ppg_flat)} is too short to form any segment of size {CARDIOGAN_WINDOW_SAMPLES}"
                )

            # 5. Min-max normalize per window to [-1, 1]
            ppg_wins = self._minmax_normalize(ppg_wins, -1.0, 1.0)
            return ppg_wins

        except Exception as e:
            if isinstance(e, PreprocessingError):
                raise e
            raise PreprocessingError("preprocess_ppg", f"Unexpected error: {e}") from e

    def postprocess_ecg(self, ecg_windows_128: np.ndarray) -> np.ndarray:
        """
        Downsamples ECG signals from 128 Hz back to 125 Hz and trims/pads to 224 samples.
        """
        try:
            if len(ecg_windows_128) == 0:
                raise PreprocessingError("postprocess_ecg", "ECG windows batch is empty")

            ecg_segments_out = []
            for win in ecg_windows_128:
                # Downsample from 128 -> 125 Hz
                n_orig = len(win)
                n_target = int(n_orig * CARDIOGAN_ORIG_FS / CARDIOGAN_TARGET_FS)
                t_orig = np.linspace(0, 1, n_orig, endpoint=False)
                t_target = np.linspace(0, 1, n_target, endpoint=False)
                interp_fn = interp1d(t_orig, win, kind="linear", fill_value="extrapolate")
                ecg_win_125 = interp_fn(t_target).astype(np.float32)

                # Trim or pad to VGTLNET_WINDOW_SIZE (224 samples)
                if len(ecg_win_125) >= VGTLNET_WINDOW_SIZE:
                    ecg_segments_out.append(ecg_win_125[:VGTLNET_WINDOW_SIZE])
                else:
                    padded = np.zeros(VGTLNET_WINDOW_SIZE, dtype=np.float32)
                    padded[:len(ecg_win_125)] = ecg_win_125
                    ecg_segments_out.append(padded)

            return np.array(ecg_segments_out, dtype=np.float32)

        except Exception as e:
            if isinstance(e, PreprocessingError):
                raise e
            raise PreprocessingError("postprocess_ecg", f"Unexpected error: {e}") from e

    # ── Helper Processing Methods ───────────────────────────────────────────────

    @staticmethod
    def _resample_signal(sig: np.ndarray, orig_fs: int, target_fs: int) -> np.ndarray:
        n_orig = len(sig)
        duration = n_orig / orig_fs
        n_target = int(duration * target_fs)
        t_orig = np.linspace(0, duration, n_orig, endpoint=False)
        t_target = np.linspace(0, duration, n_target, endpoint=False)
        interp_fn = interp1d(t_orig, sig, kind="linear", fill_value="extrapolate")
        return interp_fn(t_target).astype(np.float32)

    @staticmethod
    def _bandpass_butter(sig: np.ndarray, fs: int, low: float, high: float) -> np.ndarray:
        nyq = fs / 2.0
        b, a = butter(4, [low / nyq, high / nyq], btype="band")
        return filtfilt(b, a, sig).astype(np.float32)

    @staticmethod
    def _zscore_normalize(sig: np.ndarray) -> np.ndarray:
        mu = np.mean(sig)
        std = np.std(sig)
        if std < 1e-8:
            return (sig - mu).astype(np.float32)
        return ((sig - mu) / std).astype(np.float32)

    @staticmethod
    def _segment_windows(sig: np.ndarray, win_len: int, overlap: float) -> np.ndarray:
        step = int(win_len * (1.0 - overlap))
        n_windows = max(0, (len(sig) - win_len) // step + 1)
        if n_windows == 0:
            return np.empty((0, win_len), dtype=np.float32)
        return np.stack([sig[i * step : i * step + win_len] for i in range(n_windows)]).astype(np.float32)

    @staticmethod
    def _minmax_normalize(windows: np.ndarray, low: float, high: float) -> np.ndarray:
        mins = windows.min(axis=-1, keepdims=True)
        maxs = windows.max(axis=-1, keepdims=True)
        rng = maxs - mins
        rng[rng < 1e-8] = 1.0
        normalized = (windows - mins) / rng
        return (normalized * (high - low) + low).astype(np.float32)