EchoingECG / preprocessor.py
Yuan Gao
preprocessing code
38b8a8c
raw
history blame
2.95 kB
import torch
from math import gcd
from typing import Optional, Union
import joblib
import numpy as np
from scipy import signal
def load_scaler_joblib(path: str) -> tuple[torch.Tensor, torch.Tensor]:
"""
Load ecg_scaler.pkl and return center and scale as torch tensors.
Args:
path: Path to the joblib file.
Returns:
center: torch.Tensor
scale: torch.Tensor
"""
sc = joblib.load(path)
center = torch.from_numpy(sc.mean_.astype(np.float32))
scale = torch.from_numpy(sc.scale_.astype(np.float32)).clamp_min(1e-8)
return center, scale
class ECGTransform:
"""
Unified ECG preprocessing: downsampling and scaling.
Usage:
transform = ECGTransform(center, scale, src_fs=512, target_fs=100)
ecg_out = transform(ecg_in)
"""
def __init__(
self,
center: Union[np.ndarray, torch.Tensor],
scale: Union[np.ndarray, torch.Tensor],
src_fs: int = 100, #we assume the input ECG is already at 100Hz
target_fs: int = 100,
band: Optional[tuple[float, float]] = (0.5, 40.0),
bp_order: int = 4,
axis: int = -1,
) -> None:
self.center = torch.as_tensor(center, dtype=torch.float32)
self.scale = torch.as_tensor(scale, dtype=torch.float32).clamp_min(1e-8)
self.src_fs = src_fs
self.target_fs = target_fs
self.band = band
self.bp_order = bp_order
self.axis = axis
def downsample(self, x: np.ndarray) -> np.ndarray:
x = np.asarray(x)
if self.band is not None:
lowcut, highcut = self.band
max_high = 0.45 * self.target_fs
highcut = min(highcut, max_high)
nyq = self.src_fs / 2.0
if lowcut <= 0:
wn = highcut / nyq
sos = signal.butter(self.bp_order, wn, btype="low", output="sos")
else:
wn = (lowcut / nyq, highcut / nyq)
sos = signal.butter(self.bp_order, wn, btype="band", output="sos")
x = signal.sosfiltfilt(sos, x, axis=self.axis)
g = gcd(self.src_fs, self.target_fs)
up = self.target_fs // g
down = self.src_fs // g
y = signal.resample_poly(x, up, down, axis=self.axis, window=("kaiser", 5.0), padtype="median")
return y
def scale(self, ecg: torch.Tensor) -> torch.Tensor:
ecg = ecg.to(torch.float32)
ecg = (ecg - self.center[:, None]) / self.scale[:, None]
return ecg
def __call__(self, x: np.ndarray) -> torch.Tensor:
"""
Downsample and scale ECG data.
Args:
x: np.ndarray, shape (leads, time)
Returns:
torch.Tensor, shape (leads, time)
"""
if self.src_fs != self.target_fs:
x = self.downsample(x)
if not isinstance(x, torch.Tensor):
x = torch.from_numpy(x)
x = self.scale(x)
return x