""" infrastructure/processing/scipy_cardiogan_preprocessor.py ────────────────────────────────────────────────────────── SciPy implementation of CardioGANSignalPreprocessor. """ from __future__ import annotations import numpy as np from scipy.interpolate import interp1d from scipy.signal import butter, filtfilt from src.domain.exceptions.pipeline_exceptions import PreprocessingError from src.domain.interfaces.services.cardiogan_preprocessor import CardioGANSignalPreprocessor from src.shared.constants import ( CARDIOGAN_ORIG_FS, CARDIOGAN_TARGET_FS, CARDIOGAN_WINDOW_SAMPLES, CARDIOGAN_OVERLAP, VGTLNET_WINDOW_SIZE, ) from src.shared.logger import get_logger logger = get_logger(__name__) class SciPyCardioGANPreprocessor(CardioGANSignalPreprocessor): """ CardioGAN signal preprocessor using SciPy for resampling, filtering, and windowing. """ def preprocess_ppg(self, ppg_raw: np.ndarray) -> np.ndarray: """ Full CardioGAN preprocessing pipeline: 1. Resample: 125 Hz -> 128 Hz 2. Filter: Butterworth bandpass 1-8 Hz (order=4) 3. Normalize: Z-score (per-subject) 4. Segment: 512-sample sliding windows with 10% overlap 5. Normalize: Min-max normalize per window to [-1, 1] """ try: ppg_flat = ppg_raw.flatten().astype(np.float32) if len(ppg_flat) == 0: raise PreprocessingError("preprocess_ppg", "PPG signal array is empty") # 1. Resample 125 Hz -> 128 Hz ppg_128 = self._resample_signal( ppg_flat, CARDIOGAN_ORIG_FS, CARDIOGAN_TARGET_FS ) # 2. Filter: Butterworth 1-8 Hz ppg_filt = self._bandpass_butter( ppg_128, CARDIOGAN_TARGET_FS, low=1.0, high=8.0 ) # 3. Z-score normalization ppg_norm = self._zscore_normalize(ppg_filt) # 4. Segment into windows ppg_wins = self._segment_windows( ppg_norm, CARDIOGAN_WINDOW_SAMPLES, CARDIOGAN_OVERLAP ) if len(ppg_wins) == 0: raise PreprocessingError( "preprocess_ppg", f"PPG signal length {len(ppg_flat)} is too short to form any segment of size {CARDIOGAN_WINDOW_SAMPLES}" ) # 5. Min-max normalize per window to [-1, 1] ppg_wins = self._minmax_normalize(ppg_wins, -1.0, 1.0) return ppg_wins except Exception as e: if isinstance(e, PreprocessingError): raise e raise PreprocessingError("preprocess_ppg", f"Unexpected error: {e}") from e def postprocess_ecg(self, ecg_windows_128: np.ndarray) -> np.ndarray: """ Downsamples ECG signals from 128 Hz back to 125 Hz and trims/pads to 224 samples. """ try: if len(ecg_windows_128) == 0: raise PreprocessingError("postprocess_ecg", "ECG windows batch is empty") ecg_segments_out = [] for win in ecg_windows_128: # Downsample from 128 -> 125 Hz n_orig = len(win) n_target = int(n_orig * CARDIOGAN_ORIG_FS / CARDIOGAN_TARGET_FS) t_orig = np.linspace(0, 1, n_orig, endpoint=False) t_target = np.linspace(0, 1, n_target, endpoint=False) interp_fn = interp1d(t_orig, win, kind="linear", fill_value="extrapolate") ecg_win_125 = interp_fn(t_target).astype(np.float32) # Trim or pad to VGTLNET_WINDOW_SIZE (224 samples) if len(ecg_win_125) >= VGTLNET_WINDOW_SIZE: ecg_segments_out.append(ecg_win_125[:VGTLNET_WINDOW_SIZE]) else: padded = np.zeros(VGTLNET_WINDOW_SIZE, dtype=np.float32) padded[:len(ecg_win_125)] = ecg_win_125 ecg_segments_out.append(padded) return np.array(ecg_segments_out, dtype=np.float32) except Exception as e: if isinstance(e, PreprocessingError): raise e raise PreprocessingError("postprocess_ecg", f"Unexpected error: {e}") from e # ── Helper Processing Methods ─────────────────────────────────────────────── @staticmethod def _resample_signal(sig: np.ndarray, orig_fs: int, target_fs: int) -> np.ndarray: n_orig = len(sig) duration = n_orig / orig_fs n_target = int(duration * target_fs) t_orig = np.linspace(0, duration, n_orig, endpoint=False) t_target = np.linspace(0, duration, n_target, endpoint=False) interp_fn = interp1d(t_orig, sig, kind="linear", fill_value="extrapolate") return interp_fn(t_target).astype(np.float32) @staticmethod def _bandpass_butter(sig: np.ndarray, fs: int, low: float, high: float) -> np.ndarray: nyq = fs / 2.0 b, a = butter(4, [low / nyq, high / nyq], btype="band") return filtfilt(b, a, sig).astype(np.float32) @staticmethod def _zscore_normalize(sig: np.ndarray) -> np.ndarray: mu = np.mean(sig) std = np.std(sig) if std < 1e-8: return (sig - mu).astype(np.float32) return ((sig - mu) / std).astype(np.float32) @staticmethod def _segment_windows(sig: np.ndarray, win_len: int, overlap: float) -> np.ndarray: step = int(win_len * (1.0 - overlap)) n_windows = max(0, (len(sig) - win_len) // step + 1) if n_windows == 0: return np.empty((0, win_len), dtype=np.float32) return np.stack([sig[i * step : i * step + win_len] for i in range(n_windows)]).astype(np.float32) @staticmethod def _minmax_normalize(windows: np.ndarray, low: float, high: float) -> np.ndarray: mins = windows.min(axis=-1, keepdims=True) maxs = windows.max(axis=-1, keepdims=True) rng = maxs - mins rng[rng < 1e-8] = 1.0 normalized = (windows - mins) / rng return (normalized * (high - low) + low).astype(np.float32)