File size: 2,948 Bytes
38dfcf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
from math import gcd
from typing import Optional, Union
import joblib

import numpy as np
from scipy import signal

def load_scaler_joblib(path: str) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Load ecg_scaler.pkl and return center and scale as torch tensors.
    Args:
        path: Path to the joblib file.
    Returns:
        center: torch.Tensor
        scale: torch.Tensor
    """
    sc = joblib.load(path)
    center = torch.from_numpy(sc.mean_.astype(np.float32))
    scale = torch.from_numpy(sc.scale_.astype(np.float32)).clamp_min(1e-8)
    return center, scale

class ECGTransform:
    """
    Unified ECG preprocessing: downsampling and scaling.
    Usage:
        transform = ECGTransform(center, scale, src_fs=512, target_fs=100)
        ecg_out = transform(ecg_in)
    """
    def __init__(
        self,
        center: Union[np.ndarray, torch.Tensor],
        scale: Union[np.ndarray, torch.Tensor],
        src_fs: int = 100, #we assume the input ECG is already at 100Hz
        target_fs: int = 100,
        band: Optional[tuple[float, float]] = (0.5, 40.0),
        bp_order: int = 4,
        axis: int = -1,
    ) -> None:
        self.center = torch.as_tensor(center, dtype=torch.float32)
        self.scale = torch.as_tensor(scale, dtype=torch.float32).clamp_min(1e-8)
        self.src_fs = src_fs
        self.target_fs = target_fs
        self.band = band
        self.bp_order = bp_order
        self.axis = axis

    def downsample(self, x: np.ndarray) -> np.ndarray:
        x = np.asarray(x)
        if self.band is not None:
            lowcut, highcut = self.band
            max_high = 0.45 * self.target_fs
            highcut = min(highcut, max_high)
            nyq = self.src_fs / 2.0
            if lowcut <= 0:
                wn = highcut / nyq
                sos = signal.butter(self.bp_order, wn, btype="low", output="sos")
            else:
                wn = (lowcut / nyq, highcut / nyq)
                sos = signal.butter(self.bp_order, wn, btype="band", output="sos")
            x = signal.sosfiltfilt(sos, x, axis=self.axis)
        g = gcd(self.src_fs, self.target_fs)
        up = self.target_fs // g
        down = self.src_fs // g
        y = signal.resample_poly(x, up, down, axis=self.axis, window=("kaiser", 5.0), padtype="median")
        return y

    def scale(self, ecg: torch.Tensor) -> torch.Tensor:
        ecg = ecg.to(torch.float32)
        ecg = (ecg - self.center[:, None]) / self.scale[:, None]
        return ecg

    def __call__(self, x: np.ndarray) -> torch.Tensor:
        """
        Downsample and scale ECG data.
        Args:
            x: np.ndarray, shape (leads, time) 
        Returns:
            torch.Tensor, shape (leads, time)
        """
        if self.src_fs != self.target_fs:
            x = self.downsample(x)
        if not isinstance(x, torch.Tensor):
            x = torch.from_numpy(x)
        x = self.scale(x)
        return x