File size: 10,709 Bytes
11e9a40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
"""
ECG Feature Extraction for XGBoost model.
Extracts 26 signal features from single-lead ECG using NeuroKit2.
This module is shared between ml/ training and backend/ inference.
"""

import numpy as np
import neurokit2 as nk
from scipy.stats import kurtosis, skew


def extract_ecg_features(ecg_signal: np.ndarray, sample_rate: int = 100,
                         heart_rate_sensor: float = None,
                         spo2: float = None) -> dict:
    """
    Extract 26 features from single-lead ECG signal.

    Args:
        ecg_signal: 1D numpy array of ECG samples
        sample_rate: Sampling rate in Hz (100 for ESP32, 500 for PTB-XL)
        heart_rate_sensor: HR from MAX30100 (optional, for device features)
        spo2: SpO2 from MAX30100 (optional, for device features)

    Returns:
        dict of 26 features (keys match XGBoost training feature names)
    """
    features = {}

    try:
        # Clean the ECG signal
        ecg_cleaned = nk.ecg_clean(ecg_signal, sampling_rate=sample_rate)

        # Detect R-peaks
        _, rpeaks = nk.ecg_peaks(ecg_cleaned, sampling_rate=sample_rate)
        r_peak_indices = rpeaks.get("ECG_R_Peaks", np.array([]))

        if len(r_peak_indices) < 3:
            return _fallback_features(ecg_signal, heart_rate_sensor, spo2)

        # --- HRV Time-Domain Features (7) ---
        rr_intervals = np.diff(r_peak_indices) / sample_rate * 1000  # ms

        features["mean_rr"] = float(np.mean(rr_intervals))
        features["sdnn"] = float(np.std(rr_intervals, ddof=1)) if len(rr_intervals) > 1 else 0.0
        features["rmssd"] = float(np.sqrt(np.mean(np.diff(rr_intervals) ** 2))) if len(rr_intervals) > 1 else 0.0

        nn_diff = np.abs(np.diff(rr_intervals))
        features["pnn50"] = float(np.sum(nn_diff > 50) / len(nn_diff) * 100) if len(nn_diff) > 0 else 0.0

        hr_from_rr = 60000.0 / rr_intervals
        features["mean_hr_ecg"] = float(np.mean(hr_from_rr))
        features["hr_std"] = float(np.std(hr_from_rr))
        features["rr_range"] = float(np.max(rr_intervals) - np.min(rr_intervals))

        # --- ECG Morphology Features (9) ---
        try:
            # Delineate ECG waves
            _, waves = nk.ecg_delineate(ecg_cleaned, rpeaks, sampling_rate=sample_rate, method="dwt")

            # QRS duration
            qrs_onsets = [x for x in waves.get("ECG_Q_Peaks", []) if isinstance(x, (int, float)) and not np.isnan(x)]
            qrs_offsets = [x for x in waves.get("ECG_S_Peaks", []) if isinstance(x, (int, float)) and not np.isnan(x)]
            if qrs_onsets and qrs_offsets:
                qrs_durations = []
                for q, s in zip(qrs_onsets[:len(qrs_offsets)], qrs_offsets[:len(qrs_onsets)]):
                    qrs_durations.append(abs(s - q) / sample_rate * 1000)
                features["qrs_duration"] = float(np.mean(qrs_durations)) if qrs_durations else 100.0
            else:
                features["qrs_duration"] = 100.0

            # R amplitude
            r_amplitudes = ecg_cleaned[r_peak_indices.astype(int)]
            features["r_amplitude"] = float(np.mean(r_amplitudes))
            features["r_amplitude_std"] = float(np.std(r_amplitudes))

            # QT interval
            t_offsets = [x for x in waves.get("ECG_T_Offsets", []) if isinstance(x, (int, float)) and not np.isnan(x)]
            if qrs_onsets and t_offsets:
                qt_intervals = []
                for q, t in zip(qrs_onsets[:len(t_offsets)], t_offsets[:len(qrs_onsets)]):
                    qt_intervals.append(abs(t - q) / sample_rate * 1000)
                features["qt_interval"] = float(np.mean(qt_intervals)) if qt_intervals else 400.0
                # Bazett's QTc
                mean_rr_sec = features["mean_rr"] / 1000
                features["qtc"] = float(features["qt_interval"] / np.sqrt(mean_rr_sec)) if mean_rr_sec > 0 else 440.0
            else:
                features["qt_interval"] = 400.0
                features["qtc"] = 440.0

            # ST level (amplitude at J-point, ~40ms after R-peak)
            j_offset = int(0.04 * sample_rate)
            st_levels = []
            for rp in r_peak_indices.astype(int):
                j_idx = rp + j_offset
                if j_idx < len(ecg_cleaned):
                    st_levels.append(ecg_cleaned[j_idx])
            features["st_level"] = float(np.mean(st_levels)) if st_levels else 0.0

            # T-wave amplitude
            t_peaks = [x for x in waves.get("ECG_T_Peaks", []) if isinstance(x, (int, float)) and not np.isnan(x)]
            if t_peaks:
                t_amps = [ecg_cleaned[int(t)] for t in t_peaks if int(t) < len(ecg_cleaned)]
                features["t_amplitude"] = float(np.mean(t_amps)) if t_amps else 0.0
            else:
                features["t_amplitude"] = 0.0

            # P-wave ratio (P amplitude / R amplitude)
            p_peaks = [x for x in waves.get("ECG_P_Peaks", []) if isinstance(x, (int, float)) and not np.isnan(x)]
            if p_peaks and features["r_amplitude"] != 0:
                p_amps = [ecg_cleaned[int(p)] for p in p_peaks if int(p) < len(ecg_cleaned)]
                features["p_wave_ratio"] = float(np.mean(p_amps) / features["r_amplitude"]) if p_amps else 0.1
            else:
                features["p_wave_ratio"] = 0.1

        except Exception:
            features.setdefault("qrs_duration", 100.0)
            features.setdefault("r_amplitude", float(np.max(ecg_cleaned) - np.min(ecg_cleaned)))
            features.setdefault("r_amplitude_std", 0.0)
            features.setdefault("qt_interval", 400.0)
            features.setdefault("qtc", 440.0)
            features.setdefault("st_level", 0.0)
            features.setdefault("t_amplitude", 0.0)
            features.setdefault("p_wave_ratio", 0.1)

        # --- Signal Statistics (6) ---
        features["rms"] = float(np.sqrt(np.mean(ecg_cleaned ** 2)))
        features["entropy"] = float(_sample_entropy(ecg_cleaned))
        features["zero_crossing_rate"] = float(
            np.sum(np.diff(np.sign(ecg_cleaned - np.mean(ecg_cleaned))) != 0) / len(ecg_cleaned)
        )
        features["kurtosis"] = float(kurtosis(ecg_cleaned))
        features["skewness"] = float(skew(ecg_cleaned))
        features["snr"] = float(_estimate_snr(ecg_cleaned, sample_rate))

        # --- Device Sensor Features (4) ---
        features["heart_rate_sensor"] = float(heart_rate_sensor) if heart_rate_sensor else features["mean_hr_ecg"]
        features["spo2"] = float(spo2) if spo2 else 97.0

        hr_diff = abs(features["heart_rate_sensor"] - features["mean_hr_ecg"])
        features["hr_sensor_ecg_diff"] = float(hr_diff)

        # ECG quality score (based on peak regularity)
        if len(rr_intervals) > 1:
            cv = np.std(rr_intervals) / np.mean(rr_intervals)
            features["ecg_quality"] = float(max(0, 1 - cv))
        else:
            features["ecg_quality"] = 0.5

    except Exception:
        return _fallback_features(ecg_signal, heart_rate_sensor, spo2)

    return features


def _sample_entropy(signal, m=2, r_factor=0.2):
    """Approximate sample entropy."""
    try:
        r = r_factor * np.std(signal)
        N = len(signal)
        if N < m + 2 or r == 0:
            return 0.0

        # Use simplified approach for speed
        templates_m = np.array([signal[i:i + m] for i in range(N - m)])
        templates_m1 = np.array([signal[i:i + m + 1] for i in range(N - m - 1)])

        count_m = 0
        count_m1 = 0

        # Sample subset for speed
        n_check = min(200, len(templates_m))
        indices = np.random.choice(len(templates_m), n_check, replace=False) if len(templates_m) > n_check else range(len(templates_m))

        for i in indices:
            dist_m = np.max(np.abs(templates_m - templates_m[i]), axis=1)
            count_m += np.sum(dist_m < r) - 1

            if i < len(templates_m1):
                dist_m1 = np.max(np.abs(templates_m1 - templates_m1[i]), axis=1)
                count_m1 += np.sum(dist_m1 < r) - 1

        if count_m == 0 or count_m1 == 0:
            return 0.0

        return -np.log(count_m1 / count_m)
    except Exception:
        return 0.0


def _estimate_snr(signal, sample_rate):
    """Estimate signal-to-noise ratio."""
    try:
        cleaned = nk.ecg_clean(signal, sampling_rate=sample_rate)
        noise = signal - cleaned
        signal_power = np.mean(cleaned ** 2)
        noise_power = np.mean(noise ** 2)
        if noise_power == 0:
            return 30.0
        return float(10 * np.log10(signal_power / noise_power))
    except Exception:
        return 10.0


def _fallback_features(ecg_signal, heart_rate_sensor=None, spo2=None) -> dict:
    """Return default features when ECG processing fails."""
    return {
        "mean_rr": 800.0, "sdnn": 50.0, "rmssd": 30.0, "pnn50": 10.0,
        "mean_hr_ecg": 75.0, "hr_std": 5.0, "rr_range": 200.0,
        "qrs_duration": 100.0, "r_amplitude": 1.0, "r_amplitude_std": 0.1,
        "qt_interval": 400.0, "qtc": 440.0, "st_level": 0.0,
        "t_amplitude": 0.3, "p_wave_ratio": 0.1,
        "rms": float(np.sqrt(np.mean(ecg_signal ** 2))) if len(ecg_signal) > 0 else 0.5,
        "entropy": 0.5, "zero_crossing_rate": 0.1,
        "kurtosis": 0.0, "skewness": 0.0, "snr": 10.0,
        "heart_rate_sensor": float(heart_rate_sensor) if heart_rate_sensor else 75.0,
        "spo2": float(spo2) if spo2 else 97.0,
        "hr_sensor_ecg_diff": 0.0, "ecg_quality": 0.5,
    }


# Ordered feature names for XGBoost (must match training order)
FEATURE_NAMES = [
    "mean_rr", "sdnn", "rmssd", "pnn50", "mean_hr_ecg", "hr_std", "rr_range",
    "qrs_duration", "r_amplitude", "r_amplitude_std", "qt_interval", "qtc",
    "st_level", "t_amplitude", "p_wave_ratio",
    "rms", "entropy", "zero_crossing_rate", "kurtosis", "skewness", "snr",
    "heart_rate_sensor", "spo2", "hr_sensor_ecg_diff", "ecg_quality",
]

# User profile feature names (appended after ECG features)
PROFILE_FEATURE_NAMES = [
    "age", "sex", "bmi", "is_diabetic", "is_hypertensive",
    "is_smoker", "family_history",
]

# Historical baseline feature names
HISTORY_FEATURE_NAMES = [
    "hr_baseline_24h", "hr_baseline_7d", "spo2_baseline_24h",
    "hr_deviation", "spo2_deviation", "resting_hr_trend", "readings_count_24h",
]

ALL_FEATURE_NAMES = FEATURE_NAMES + PROFILE_FEATURE_NAMES + HISTORY_FEATURE_NAMES


def features_to_array(features: dict, include_profile: bool = False) -> np.ndarray:
    """Convert feature dict to numpy array in correct order for XGBoost."""
    names = ALL_FEATURE_NAMES if include_profile else FEATURE_NAMES
    return np.array([features.get(name, 0.0) for name in names], dtype=np.float32)