LIBRE / src /infrastructure /processing /scipy_cardiogan_preprocessor.py
RyZ
feat: adding full working local ETL Pipeline
e391a84
Raw
History Blame Contribute Delete
6.34 kB
"""
infrastructure/processing/scipy_cardiogan_preprocessor.py
──────────────────────────────────────────────────────────
SciPy implementation of CardioGANSignalPreprocessor.
"""
from __future__ import annotations
import numpy as np
from scipy.interpolate import interp1d
from scipy.signal import butter, filtfilt
from src.domain.exceptions.pipeline_exceptions import PreprocessingError
from src.domain.interfaces.services.cardiogan_preprocessor import CardioGANSignalPreprocessor
from src.shared.constants import (
CARDIOGAN_ORIG_FS,
CARDIOGAN_TARGET_FS,
CARDIOGAN_WINDOW_SAMPLES,
CARDIOGAN_OVERLAP,
VGTLNET_WINDOW_SIZE,
)
from src.shared.logger import get_logger
logger = get_logger(__name__)
class SciPyCardioGANPreprocessor(CardioGANSignalPreprocessor):
"""
CardioGAN signal preprocessor using SciPy for resampling, filtering, and windowing.
"""
def preprocess_ppg(self, ppg_raw: np.ndarray) -> np.ndarray:
"""
Full CardioGAN preprocessing pipeline:
1. Resample: 125 Hz -> 128 Hz
2. Filter: Butterworth bandpass 1-8 Hz (order=4)
3. Normalize: Z-score (per-subject)
4. Segment: 512-sample sliding windows with 10% overlap
5. Normalize: Min-max normalize per window to [-1, 1]
"""
try:
ppg_flat = ppg_raw.flatten().astype(np.float32)
if len(ppg_flat) == 0:
raise PreprocessingError("preprocess_ppg", "PPG signal array is empty")
# 1. Resample 125 Hz -> 128 Hz
ppg_128 = self._resample_signal(
ppg_flat, CARDIOGAN_ORIG_FS, CARDIOGAN_TARGET_FS
)
# 2. Filter: Butterworth 1-8 Hz
ppg_filt = self._bandpass_butter(
ppg_128, CARDIOGAN_TARGET_FS, low=1.0, high=8.0
)
# 3. Z-score normalization
ppg_norm = self._zscore_normalize(ppg_filt)
# 4. Segment into windows
ppg_wins = self._segment_windows(
ppg_norm, CARDIOGAN_WINDOW_SAMPLES, CARDIOGAN_OVERLAP
)
if len(ppg_wins) == 0:
raise PreprocessingError(
"preprocess_ppg",
f"PPG signal length {len(ppg_flat)} is too short to form any segment of size {CARDIOGAN_WINDOW_SAMPLES}"
)
# 5. Min-max normalize per window to [-1, 1]
ppg_wins = self._minmax_normalize(ppg_wins, -1.0, 1.0)
return ppg_wins
except Exception as e:
if isinstance(e, PreprocessingError):
raise e
raise PreprocessingError("preprocess_ppg", f"Unexpected error: {e}") from e
def postprocess_ecg(self, ecg_windows_128: np.ndarray) -> np.ndarray:
"""
Downsamples ECG signals from 128 Hz back to 125 Hz and trims/pads to 224 samples.
"""
try:
if len(ecg_windows_128) == 0:
raise PreprocessingError("postprocess_ecg", "ECG windows batch is empty")
ecg_segments_out = []
for win in ecg_windows_128:
# Downsample from 128 -> 125 Hz
n_orig = len(win)
n_target = int(n_orig * CARDIOGAN_ORIG_FS / CARDIOGAN_TARGET_FS)
t_orig = np.linspace(0, 1, n_orig, endpoint=False)
t_target = np.linspace(0, 1, n_target, endpoint=False)
interp_fn = interp1d(t_orig, win, kind="linear", fill_value="extrapolate")
ecg_win_125 = interp_fn(t_target).astype(np.float32)
# Trim or pad to VGTLNET_WINDOW_SIZE (224 samples)
if len(ecg_win_125) >= VGTLNET_WINDOW_SIZE:
ecg_segments_out.append(ecg_win_125[:VGTLNET_WINDOW_SIZE])
else:
padded = np.zeros(VGTLNET_WINDOW_SIZE, dtype=np.float32)
padded[:len(ecg_win_125)] = ecg_win_125
ecg_segments_out.append(padded)
return np.array(ecg_segments_out, dtype=np.float32)
except Exception as e:
if isinstance(e, PreprocessingError):
raise e
raise PreprocessingError("postprocess_ecg", f"Unexpected error: {e}") from e
# ── Helper Processing Methods ───────────────────────────────────────────────
@staticmethod
def _resample_signal(sig: np.ndarray, orig_fs: int, target_fs: int) -> np.ndarray:
n_orig = len(sig)
duration = n_orig / orig_fs
n_target = int(duration * target_fs)
t_orig = np.linspace(0, duration, n_orig, endpoint=False)
t_target = np.linspace(0, duration, n_target, endpoint=False)
interp_fn = interp1d(t_orig, sig, kind="linear", fill_value="extrapolate")
return interp_fn(t_target).astype(np.float32)
@staticmethod
def _bandpass_butter(sig: np.ndarray, fs: int, low: float, high: float) -> np.ndarray:
nyq = fs / 2.0
b, a = butter(4, [low / nyq, high / nyq], btype="band")
return filtfilt(b, a, sig).astype(np.float32)
@staticmethod
def _zscore_normalize(sig: np.ndarray) -> np.ndarray:
mu = np.mean(sig)
std = np.std(sig)
if std < 1e-8:
return (sig - mu).astype(np.float32)
return ((sig - mu) / std).astype(np.float32)
@staticmethod
def _segment_windows(sig: np.ndarray, win_len: int, overlap: float) -> np.ndarray:
step = int(win_len * (1.0 - overlap))
n_windows = max(0, (len(sig) - win_len) // step + 1)
if n_windows == 0:
return np.empty((0, win_len), dtype=np.float32)
return np.stack([sig[i * step : i * step + win_len] for i in range(n_windows)]).astype(np.float32)
@staticmethod
def _minmax_normalize(windows: np.ndarray, low: float, high: float) -> np.ndarray:
mins = windows.min(axis=-1, keepdims=True)
maxs = windows.max(axis=-1, keepdims=True)
rng = maxs - mins
rng[rng < 1e-8] = 1.0
normalized = (windows - mins) / rng
return (normalized * (high - low) + low).astype(np.float32)