Spaces:
Sleeping
Sleeping
| from tslearn.utils import to_time_series_dataset | |
| from tslearn.preprocessing import TimeSeriesScalerMeanVariance | |
| import pickle | |
| from wfdb import processing | |
| from sklearn import preprocessing | |
| from scipy.signal import resample | |
| from io import StringIO | |
| import numpy as np | |
| import pandas as pd | |
| def preproc(X): | |
| # to be called in inference/api | |
| in_shape = X.shape | |
| if X.shape[1] != 180: | |
| print('File shape is not (n, 180) but ', in_shape) | |
| X = to_time_series_dataset(X) | |
| X = X.reshape(in_shape[0], -1) | |
| scaler = TimeSeriesScalerMeanVariance() | |
| X = scaler.fit_transform(X) | |
| return X.reshape(in_shape) | |
| def apple_csv_to_data(file_content): | |
| # extract sampling rate | |
| for il,line in enumerate(file_content.decode('utf-8').splitlines()): | |
| if line.startswith("Sample Rate"): | |
| # Extract the sample rate | |
| sample_rate = int(line.split(",")[1].split()[0]) # Split and get the numerical part | |
| print(f"Sample Rate: {sample_rate}") | |
| break | |
| if il > 30: | |
| print("Could not find sample rate in first 30 lines") | |
| return None, None | |
| X = pd.read_csv(StringIO(file_content.decode('utf-8')), skiprows=14, header=None) | |
| return X, sample_rate | |
| def apple_trim_join(X, sample_rate=512, ns=2): | |
| # There should be a less horrible way of doing this | |
| # Ignore first two and last two seconds, that tend to be noisy --> 26 seconds ecg | |
| X[1] = X[1].fillna(0) | |
| X = X[0] + X[1] / (10 ** (X[1].astype(str).str.len() - 2)) # Ignoring the trailing ".0" | |
| print(f"Ignoring first and last {ns} seconds") | |
| X = X[ns*sample_rate:-ns*sample_rate].to_frame().T | |
| X = X.iloc[0].to_numpy() | |
| return X | |
| def apple_extract_beats(X, sample_rate=512): | |
| X = apple_trim_join(X, sample_rate=sample_rate, ns=3) | |
| # Scale and remove nans (should not happen anymore) | |
| X = preprocessing.scale(X[~np.isnan(X)]) | |
| # I tried to hack the detection to make it learn peaks and | |
| # not go with default, but it doesn't work | |
| # I have tried: | |
| # - Hardwiring n_calib_beats (not possible from user side) | |
| # to a lower number (5, 3). | |
| # - Setting qrs_width to lower and higher values | |
| # - Relax the correlation requirement to Rikers wavelet | |
| # Maybe explore correlation with more robust wavelets | |
| # wavelet = pywt.Wavelet('db4') | |
| # (lib/python3.10/site-packages/wfdb/processing/qrs.py) | |
| # Conf = processing.XQRS.Conf(qrs_width=0.1) | |
| # qrs = processing.XQRS(sig = X,fs = sample_rate, conf=Conf) | |
| # wfdb library doesn't allow to set n_calib_beats | |
| qrs = processing.XQRS(sig = X,fs = sample_rate) | |
| qrs.detect() | |
| peaks = qrs.qrs_inds | |
| print("Number of beats detected: ", len(peaks)) | |
| target_length = 180 | |
| beats = np.zeros((len(peaks), target_length)) | |
| for i, peak in enumerate(peaks[1:-1]): | |
| rr_interval = peaks[i + 1] - peaks[i] # Distance to the next peak | |
| window_size = int(rr_interval * 1.2) # Extend by 20% to capture full P-QRS-T cycle | |
| # Define the dynamic window around the R-peak | |
| start = max(0, peak - window_size // 2) | |
| end = min(len(X), peak + window_size // 2) | |
| beat = resample(X[start:end], target_length) | |
| beats[i] = beat | |
| return beats | |
| def save_beats_csv(beats, filepath_csv): | |
| pd.DataFrame(beats).to_csv(filepath_csv, index=False) | |
| def label_decoding(values, path): | |
| with open(path, "rb") as f: | |
| mapping = pickle.load(f) | |
| inverse_mapping = {v: k for k, v in mapping.items()} | |
| # return inverse_mapping[values] | |
| return [inverse_mapping[value] for value in values] |