EchoingECG / preprocessor.py
gaoyua19's picture
modelweights (#2)
38dfcf0 verified
import torch
from math import gcd
from typing import Optional, Union
import joblib
import numpy as np
from scipy import signal
def load_scaler_joblib(path: str) -> tuple[torch.Tensor, torch.Tensor]:
"""
Load ecg_scaler.pkl and return center and scale as torch tensors.
Args:
path: Path to the joblib file.
Returns:
center: torch.Tensor
scale: torch.Tensor
"""
sc = joblib.load(path)
center = torch.from_numpy(sc.mean_.astype(np.float32))
scale = torch.from_numpy(sc.scale_.astype(np.float32)).clamp_min(1e-8)
return center, scale
class ECGTransform:
"""
Unified ECG preprocessing: downsampling and scaling.
Usage:
transform = ECGTransform(center, scale, src_fs=512, target_fs=100)
ecg_out = transform(ecg_in)
"""
def __init__(
self,
center: Union[np.ndarray, torch.Tensor],
scale: Union[np.ndarray, torch.Tensor],
src_fs: int = 100, #we assume the input ECG is already at 100Hz
target_fs: int = 100,
band: Optional[tuple[float, float]] = (0.5, 40.0),
bp_order: int = 4,
axis: int = -1,
) -> None:
self.center = torch.as_tensor(center, dtype=torch.float32)
self.scale = torch.as_tensor(scale, dtype=torch.float32).clamp_min(1e-8)
self.src_fs = src_fs
self.target_fs = target_fs
self.band = band
self.bp_order = bp_order
self.axis = axis
def downsample(self, x: np.ndarray) -> np.ndarray:
x = np.asarray(x)
if self.band is not None:
lowcut, highcut = self.band
max_high = 0.45 * self.target_fs
highcut = min(highcut, max_high)
nyq = self.src_fs / 2.0
if lowcut <= 0:
wn = highcut / nyq
sos = signal.butter(self.bp_order, wn, btype="low", output="sos")
else:
wn = (lowcut / nyq, highcut / nyq)
sos = signal.butter(self.bp_order, wn, btype="band", output="sos")
x = signal.sosfiltfilt(sos, x, axis=self.axis)
g = gcd(self.src_fs, self.target_fs)
up = self.target_fs // g
down = self.src_fs // g
y = signal.resample_poly(x, up, down, axis=self.axis, window=("kaiser", 5.0), padtype="median")
return y
def scale(self, ecg: torch.Tensor) -> torch.Tensor:
ecg = ecg.to(torch.float32)
ecg = (ecg - self.center[:, None]) / self.scale[:, None]
return ecg
def __call__(self, x: np.ndarray) -> torch.Tensor:
"""
Downsample and scale ECG data.
Args:
x: np.ndarray, shape (leads, time)
Returns:
torch.Tensor, shape (leads, time)
"""
if self.src_fs != self.target_fs:
x = self.downsample(x)
if not isinstance(x, torch.Tensor):
x = torch.from_numpy(x)
x = self.scale(x)
return x