| """ |
| infrastructure/processing/scipy_cardiogan_preprocessor.py |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| SciPy implementation of CardioGANSignalPreprocessor. |
| """ |
| from __future__ import annotations |
|
|
| import numpy as np |
| from scipy.interpolate import interp1d |
| from scipy.signal import butter, filtfilt |
|
|
| from src.domain.exceptions.pipeline_exceptions import PreprocessingError |
| from src.domain.interfaces.services.cardiogan_preprocessor import CardioGANSignalPreprocessor |
| from src.shared.constants import ( |
| CARDIOGAN_ORIG_FS, |
| CARDIOGAN_TARGET_FS, |
| CARDIOGAN_WINDOW_SAMPLES, |
| CARDIOGAN_OVERLAP, |
| VGTLNET_WINDOW_SIZE, |
| ) |
| from src.shared.logger import get_logger |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| class SciPyCardioGANPreprocessor(CardioGANSignalPreprocessor): |
| """ |
| CardioGAN signal preprocessor using SciPy for resampling, filtering, and windowing. |
| """ |
|
|
| def preprocess_ppg(self, ppg_raw: np.ndarray) -> np.ndarray: |
| """ |
| Full CardioGAN preprocessing pipeline: |
| 1. Resample: 125 Hz -> 128 Hz |
| 2. Filter: Butterworth bandpass 1-8 Hz (order=4) |
| 3. Normalize: Z-score (per-subject) |
| 4. Segment: 512-sample sliding windows with 10% overlap |
| 5. Normalize: Min-max normalize per window to [-1, 1] |
| """ |
| try: |
| ppg_flat = ppg_raw.flatten().astype(np.float32) |
| if len(ppg_flat) == 0: |
| raise PreprocessingError("preprocess_ppg", "PPG signal array is empty") |
|
|
| |
| ppg_128 = self._resample_signal( |
| ppg_flat, CARDIOGAN_ORIG_FS, CARDIOGAN_TARGET_FS |
| ) |
|
|
| |
| ppg_filt = self._bandpass_butter( |
| ppg_128, CARDIOGAN_TARGET_FS, low=1.0, high=8.0 |
| ) |
|
|
| |
| ppg_norm = self._zscore_normalize(ppg_filt) |
|
|
| |
| ppg_wins = self._segment_windows( |
| ppg_norm, CARDIOGAN_WINDOW_SAMPLES, CARDIOGAN_OVERLAP |
| ) |
|
|
| if len(ppg_wins) == 0: |
| raise PreprocessingError( |
| "preprocess_ppg", |
| f"PPG signal length {len(ppg_flat)} is too short to form any segment of size {CARDIOGAN_WINDOW_SAMPLES}" |
| ) |
|
|
| |
| ppg_wins = self._minmax_normalize(ppg_wins, -1.0, 1.0) |
| return ppg_wins |
|
|
| except Exception as e: |
| if isinstance(e, PreprocessingError): |
| raise e |
| raise PreprocessingError("preprocess_ppg", f"Unexpected error: {e}") from e |
|
|
| def postprocess_ecg(self, ecg_windows_128: np.ndarray) -> np.ndarray: |
| """ |
| Downsamples ECG signals from 128 Hz back to 125 Hz and trims/pads to 224 samples. |
| """ |
| try: |
| if len(ecg_windows_128) == 0: |
| raise PreprocessingError("postprocess_ecg", "ECG windows batch is empty") |
|
|
| ecg_segments_out = [] |
| for win in ecg_windows_128: |
| |
| n_orig = len(win) |
| n_target = int(n_orig * CARDIOGAN_ORIG_FS / CARDIOGAN_TARGET_FS) |
| t_orig = np.linspace(0, 1, n_orig, endpoint=False) |
| t_target = np.linspace(0, 1, n_target, endpoint=False) |
| interp_fn = interp1d(t_orig, win, kind="linear", fill_value="extrapolate") |
| ecg_win_125 = interp_fn(t_target).astype(np.float32) |
|
|
| |
| if len(ecg_win_125) >= VGTLNET_WINDOW_SIZE: |
| ecg_segments_out.append(ecg_win_125[:VGTLNET_WINDOW_SIZE]) |
| else: |
| padded = np.zeros(VGTLNET_WINDOW_SIZE, dtype=np.float32) |
| padded[:len(ecg_win_125)] = ecg_win_125 |
| ecg_segments_out.append(padded) |
|
|
| return np.array(ecg_segments_out, dtype=np.float32) |
|
|
| except Exception as e: |
| if isinstance(e, PreprocessingError): |
| raise e |
| raise PreprocessingError("postprocess_ecg", f"Unexpected error: {e}") from e |
|
|
| |
|
|
| @staticmethod |
| def _resample_signal(sig: np.ndarray, orig_fs: int, target_fs: int) -> np.ndarray: |
| n_orig = len(sig) |
| duration = n_orig / orig_fs |
| n_target = int(duration * target_fs) |
| t_orig = np.linspace(0, duration, n_orig, endpoint=False) |
| t_target = np.linspace(0, duration, n_target, endpoint=False) |
| interp_fn = interp1d(t_orig, sig, kind="linear", fill_value="extrapolate") |
| return interp_fn(t_target).astype(np.float32) |
|
|
| @staticmethod |
| def _bandpass_butter(sig: np.ndarray, fs: int, low: float, high: float) -> np.ndarray: |
| nyq = fs / 2.0 |
| b, a = butter(4, [low / nyq, high / nyq], btype="band") |
| return filtfilt(b, a, sig).astype(np.float32) |
|
|
| @staticmethod |
| def _zscore_normalize(sig: np.ndarray) -> np.ndarray: |
| mu = np.mean(sig) |
| std = np.std(sig) |
| if std < 1e-8: |
| return (sig - mu).astype(np.float32) |
| return ((sig - mu) / std).astype(np.float32) |
|
|
| @staticmethod |
| def _segment_windows(sig: np.ndarray, win_len: int, overlap: float) -> np.ndarray: |
| step = int(win_len * (1.0 - overlap)) |
| n_windows = max(0, (len(sig) - win_len) // step + 1) |
| if n_windows == 0: |
| return np.empty((0, win_len), dtype=np.float32) |
| return np.stack([sig[i * step : i * step + win_len] for i in range(n_windows)]).astype(np.float32) |
|
|
| @staticmethod |
| def _minmax_normalize(windows: np.ndarray, low: float, high: float) -> np.ndarray: |
| mins = windows.min(axis=-1, keepdims=True) |
| maxs = windows.max(axis=-1, keepdims=True) |
| rng = maxs - mins |
| rng[rng < 1e-8] = 1.0 |
| normalized = (windows - mins) / rng |
| return (normalized * (high - low) + low).astype(np.float32) |
|
|