|
|
import torch |
|
|
from math import gcd |
|
|
from typing import Optional, Union |
|
|
import joblib |
|
|
|
|
|
import numpy as np |
|
|
from scipy import signal |
|
|
|
|
|
def load_scaler_joblib(path: str) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Load ecg_scaler.pkl and return center and scale as torch tensors. |
|
|
Args: |
|
|
path: Path to the joblib file. |
|
|
Returns: |
|
|
center: torch.Tensor |
|
|
scale: torch.Tensor |
|
|
""" |
|
|
sc = joblib.load(path) |
|
|
center = torch.from_numpy(sc.mean_.astype(np.float32)) |
|
|
scale = torch.from_numpy(sc.scale_.astype(np.float32)).clamp_min(1e-8) |
|
|
return center, scale |
|
|
|
|
|
class ECGTransform: |
|
|
""" |
|
|
Unified ECG preprocessing: downsampling and scaling. |
|
|
Usage: |
|
|
transform = ECGTransform(center, scale, src_fs=512, target_fs=100) |
|
|
ecg_out = transform(ecg_in) |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
center: Union[np.ndarray, torch.Tensor], |
|
|
scale: Union[np.ndarray, torch.Tensor], |
|
|
src_fs: int = 100, |
|
|
target_fs: int = 100, |
|
|
band: Optional[tuple[float, float]] = (0.5, 40.0), |
|
|
bp_order: int = 4, |
|
|
axis: int = -1, |
|
|
) -> None: |
|
|
self.center = torch.as_tensor(center, dtype=torch.float32) |
|
|
self.scale = torch.as_tensor(scale, dtype=torch.float32).clamp_min(1e-8) |
|
|
self.src_fs = src_fs |
|
|
self.target_fs = target_fs |
|
|
self.band = band |
|
|
self.bp_order = bp_order |
|
|
self.axis = axis |
|
|
|
|
|
def downsample(self, x: np.ndarray) -> np.ndarray: |
|
|
x = np.asarray(x) |
|
|
if self.band is not None: |
|
|
lowcut, highcut = self.band |
|
|
max_high = 0.45 * self.target_fs |
|
|
highcut = min(highcut, max_high) |
|
|
nyq = self.src_fs / 2.0 |
|
|
if lowcut <= 0: |
|
|
wn = highcut / nyq |
|
|
sos = signal.butter(self.bp_order, wn, btype="low", output="sos") |
|
|
else: |
|
|
wn = (lowcut / nyq, highcut / nyq) |
|
|
sos = signal.butter(self.bp_order, wn, btype="band", output="sos") |
|
|
x = signal.sosfiltfilt(sos, x, axis=self.axis) |
|
|
g = gcd(self.src_fs, self.target_fs) |
|
|
up = self.target_fs // g |
|
|
down = self.src_fs // g |
|
|
y = signal.resample_poly(x, up, down, axis=self.axis, window=("kaiser", 5.0), padtype="median") |
|
|
return y |
|
|
|
|
|
def scale(self, ecg: torch.Tensor) -> torch.Tensor: |
|
|
ecg = ecg.to(torch.float32) |
|
|
ecg = (ecg - self.center[:, None]) / self.scale[:, None] |
|
|
return ecg |
|
|
|
|
|
def __call__(self, x: np.ndarray) -> torch.Tensor: |
|
|
""" |
|
|
Downsample and scale ECG data. |
|
|
Args: |
|
|
x: np.ndarray, shape (leads, time) |
|
|
Returns: |
|
|
torch.Tensor, shape (leads, time) |
|
|
""" |
|
|
if self.src_fs != self.target_fs: |
|
|
x = self.downsample(x) |
|
|
if not isinstance(x, torch.Tensor): |
|
|
x = torch.from_numpy(x) |
|
|
x = self.scale(x) |
|
|
return x |