File size: 6,335 Bytes
e391a84 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | """
infrastructure/processing/scipy_cardiogan_preprocessor.py
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
SciPy implementation of CardioGANSignalPreprocessor.
"""
from __future__ import annotations
import numpy as np
from scipy.interpolate import interp1d
from scipy.signal import butter, filtfilt
from src.domain.exceptions.pipeline_exceptions import PreprocessingError
from src.domain.interfaces.services.cardiogan_preprocessor import CardioGANSignalPreprocessor
from src.shared.constants import (
CARDIOGAN_ORIG_FS,
CARDIOGAN_TARGET_FS,
CARDIOGAN_WINDOW_SAMPLES,
CARDIOGAN_OVERLAP,
VGTLNET_WINDOW_SIZE,
)
from src.shared.logger import get_logger
logger = get_logger(__name__)
class SciPyCardioGANPreprocessor(CardioGANSignalPreprocessor):
"""
CardioGAN signal preprocessor using SciPy for resampling, filtering, and windowing.
"""
def preprocess_ppg(self, ppg_raw: np.ndarray) -> np.ndarray:
"""
Full CardioGAN preprocessing pipeline:
1. Resample: 125 Hz -> 128 Hz
2. Filter: Butterworth bandpass 1-8 Hz (order=4)
3. Normalize: Z-score (per-subject)
4. Segment: 512-sample sliding windows with 10% overlap
5. Normalize: Min-max normalize per window to [-1, 1]
"""
try:
ppg_flat = ppg_raw.flatten().astype(np.float32)
if len(ppg_flat) == 0:
raise PreprocessingError("preprocess_ppg", "PPG signal array is empty")
# 1. Resample 125 Hz -> 128 Hz
ppg_128 = self._resample_signal(
ppg_flat, CARDIOGAN_ORIG_FS, CARDIOGAN_TARGET_FS
)
# 2. Filter: Butterworth 1-8 Hz
ppg_filt = self._bandpass_butter(
ppg_128, CARDIOGAN_TARGET_FS, low=1.0, high=8.0
)
# 3. Z-score normalization
ppg_norm = self._zscore_normalize(ppg_filt)
# 4. Segment into windows
ppg_wins = self._segment_windows(
ppg_norm, CARDIOGAN_WINDOW_SAMPLES, CARDIOGAN_OVERLAP
)
if len(ppg_wins) == 0:
raise PreprocessingError(
"preprocess_ppg",
f"PPG signal length {len(ppg_flat)} is too short to form any segment of size {CARDIOGAN_WINDOW_SAMPLES}"
)
# 5. Min-max normalize per window to [-1, 1]
ppg_wins = self._minmax_normalize(ppg_wins, -1.0, 1.0)
return ppg_wins
except Exception as e:
if isinstance(e, PreprocessingError):
raise e
raise PreprocessingError("preprocess_ppg", f"Unexpected error: {e}") from e
def postprocess_ecg(self, ecg_windows_128: np.ndarray) -> np.ndarray:
"""
Downsamples ECG signals from 128 Hz back to 125 Hz and trims/pads to 224 samples.
"""
try:
if len(ecg_windows_128) == 0:
raise PreprocessingError("postprocess_ecg", "ECG windows batch is empty")
ecg_segments_out = []
for win in ecg_windows_128:
# Downsample from 128 -> 125 Hz
n_orig = len(win)
n_target = int(n_orig * CARDIOGAN_ORIG_FS / CARDIOGAN_TARGET_FS)
t_orig = np.linspace(0, 1, n_orig, endpoint=False)
t_target = np.linspace(0, 1, n_target, endpoint=False)
interp_fn = interp1d(t_orig, win, kind="linear", fill_value="extrapolate")
ecg_win_125 = interp_fn(t_target).astype(np.float32)
# Trim or pad to VGTLNET_WINDOW_SIZE (224 samples)
if len(ecg_win_125) >= VGTLNET_WINDOW_SIZE:
ecg_segments_out.append(ecg_win_125[:VGTLNET_WINDOW_SIZE])
else:
padded = np.zeros(VGTLNET_WINDOW_SIZE, dtype=np.float32)
padded[:len(ecg_win_125)] = ecg_win_125
ecg_segments_out.append(padded)
return np.array(ecg_segments_out, dtype=np.float32)
except Exception as e:
if isinstance(e, PreprocessingError):
raise e
raise PreprocessingError("postprocess_ecg", f"Unexpected error: {e}") from e
# ββ Helper Processing Methods βββββββββββββββββββββββββββββββββββββββββββββββ
@staticmethod
def _resample_signal(sig: np.ndarray, orig_fs: int, target_fs: int) -> np.ndarray:
n_orig = len(sig)
duration = n_orig / orig_fs
n_target = int(duration * target_fs)
t_orig = np.linspace(0, duration, n_orig, endpoint=False)
t_target = np.linspace(0, duration, n_target, endpoint=False)
interp_fn = interp1d(t_orig, sig, kind="linear", fill_value="extrapolate")
return interp_fn(t_target).astype(np.float32)
@staticmethod
def _bandpass_butter(sig: np.ndarray, fs: int, low: float, high: float) -> np.ndarray:
nyq = fs / 2.0
b, a = butter(4, [low / nyq, high / nyq], btype="band")
return filtfilt(b, a, sig).astype(np.float32)
@staticmethod
def _zscore_normalize(sig: np.ndarray) -> np.ndarray:
mu = np.mean(sig)
std = np.std(sig)
if std < 1e-8:
return (sig - mu).astype(np.float32)
return ((sig - mu) / std).astype(np.float32)
@staticmethod
def _segment_windows(sig: np.ndarray, win_len: int, overlap: float) -> np.ndarray:
step = int(win_len * (1.0 - overlap))
n_windows = max(0, (len(sig) - win_len) // step + 1)
if n_windows == 0:
return np.empty((0, win_len), dtype=np.float32)
return np.stack([sig[i * step : i * step + win_len] for i in range(n_windows)]).astype(np.float32)
@staticmethod
def _minmax_normalize(windows: np.ndarray, low: float, high: float) -> np.ndarray:
mins = windows.min(axis=-1, keepdims=True)
maxs = windows.max(axis=-1, keepdims=True)
rng = maxs - mins
rng[rng < 1e-8] = 1.0
normalized = (windows - mins) / rng
return (normalized * (high - low) + low).astype(np.float32)
|