nina-m-m commited on
Commit
0ffeb19
·
verified ·
1 Parent(s): 7a8997c

Upload source files

Browse files
Files changed (9) hide show
  1. __init__.py +0 -0
  2. configs.py +41 -0
  3. conversion.py +85 -0
  4. ecg_feature_extraction.py +37 -0
  5. ecg_preprocessing.py +93 -0
  6. ecg_processing.py +54 -0
  7. logger.py +29 -0
  8. pydantic_models.py +167 -0
  9. utils.py +80 -0
__init__.py ADDED
File without changes
configs.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module contains all the configurations and statics for the project."""
2
+ from enum import Enum
3
+
4
+
5
+ class SignalEnum(str, Enum):
6
+ chest = 'chest'
7
+ wrest = 'wrest'
8
+
9
+
10
+ class WindowSlicingMethodEnum(str, Enum):
11
+ time_related = 'time_related'
12
+ label_related_before = 'label_related_before'
13
+ label_related_after = 'label_related_after'
14
+ label_related_middle = 'label_related_centered'
15
+
16
+
17
+ class NormalizationMethodEnum(str, Enum):
18
+ baseline_difference = 'baseline_difference'
19
+ baseline_relative = 'baseline_relative'
20
+ separate = 'separate'
21
+
22
+
23
+ class BColors(str, Enum):
24
+ HEADER = '\033[95m'
25
+ OKBLUE = '\033[94m'
26
+ OKCYAN = '\033[96m'
27
+ INFO = '\033[92m'
28
+ WARNING = '\033[93m'
29
+ FAIL = '\033[91m'
30
+ ENDC = '\033[0m'
31
+ BOLD = '\033[1m'
32
+ UNDERLINE = '\033[4m'
33
+
34
+
35
+ class OutputFormats(str, Enum):
36
+ JSON = 'json'
37
+ CSV = 'csv'
38
+ EXCEL_SPREADSHEET = 'excel_spreadsheet'
39
+
40
+
41
+ selected_features = ["HRV_MeanNN", "HRV_SDNN", "HRV_RMSSD", "HRV_pNN50"]
conversion.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import h5py
2
+ import pandas as pd
3
+
4
+ def h5_to_pandas(h5_file, ecg_channel_name='channel_1') -> pd.DataFrame:
5
+ """
6
+ Converts a h5 file to a pandas DataFrame. It must contain the following attributes: sampling rate, date, time,
7
+ nsamples, device, device name, duration, and raw. The raw attribute must contain the ecg data in a 2D numpy array.
8
+ The DataFrame will contain the following columns: timestamp_idx, ecg, record_date, configs.frequency,
9
+ configs.device_name.
10
+
11
+ h5 formats supported are of the company Bioplux (https://www.pluxbiosignals.com/) with its Recording Software
12
+ OpenSignals Revolution (https://support.pluxbiosignals.com/knowledge-base/introducing-opensignals-revolution/).
13
+
14
+ :param path_to_h5_file: Path to the h5 file.
15
+ :type path_to_h5_file: str
16
+ :param ecg_channel_name: The name of the ecg channel in the h5 file.
17
+ :type ecg_channel_name: str
18
+
19
+ :return: The pandas DataFrame.
20
+ :rtype: pd.DataFrame
21
+ """
22
+ with h5py.File(h5_file, 'r') as file:
23
+ # Get the first key as the group key
24
+ group_key = next(iter(file.keys()))
25
+ h5_group = file[group_key]
26
+
27
+ # Convert ECG data to a flattened numpy array
28
+ ecg = h5_group['raw'][ecg_channel_name][:].astype(float).flatten()
29
+
30
+ # Extract metadata
31
+ attrs = h5_group.attrs
32
+ sampling_rate = attrs['sampling rate']
33
+ date = attrs['date']
34
+ time = attrs['time']
35
+ num_samples = attrs['nsamples']
36
+ device = attrs['device']
37
+ device_name = attrs['device name']
38
+ duration = attrs['duration']
39
+
40
+ # Create the timestamp column
41
+ start = pd.to_datetime(date + ' ' + time)
42
+ freq = f'{1 / sampling_rate}S'
43
+ timestamps = pd.date_range(start=start, periods=num_samples, freq=freq)
44
+
45
+ # Check if the overall time range of the calculated timestamps fit the given duration attribute of the h5 file
46
+ end = start + pd.Timedelta(duration)
47
+ assert abs((end - timestamps[-1]).total_seconds()) < 1
48
+
49
+ # Create the DataFrame
50
+ df = pd.DataFrame({
51
+ 'record_date': date,
52
+ 'frequency': sampling_rate,
53
+ 'device_name': f'{device}_{device_name}',
54
+ 'timestamp_idx': timestamps,
55
+ 'ecg': ecg,
56
+ })
57
+
58
+ return df
59
+
60
+ def csv_to_pandas(path: str) -> pd.DataFrame:
61
+ """ Converts a CSV file in a pandas dataframe fitted to the ECG-HRV pipeline pydantic models.
62
+
63
+ :param path: Path to the csv file.
64
+ :type path: str
65
+
66
+ :return: The pandas DataFrame.
67
+ :rtype: pd.DataFrame
68
+ """
69
+ # Get metadata of csv file
70
+ with open(path, 'r') as file:
71
+ metadata = file.readline()
72
+ metadata = metadata.replace('# ', '')
73
+ metadata = eval(metadata)
74
+
75
+ configs = {key: value for key, value in metadata.items() if key.startswith('config')}
76
+ batch = {key: value for key, value in metadata.items() if key.startswith('batch')}
77
+
78
+ # Get samples from csv file
79
+ df = pd.read_csv(path, comment='#')
80
+
81
+ # Add metadata to samples
82
+ df = df.assign(**batch)
83
+ df = df.assign(**configs)
84
+
85
+ return df
ecg_feature_extraction.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file contains the ecg feature extraction pipelines and functions used for calculating the features."""
2
+
3
+ import neurokit2 as nk
4
+ from .configs import selected_features
5
+
6
+
7
+ def get_hrv_features(ecg_signal, fs):
8
+ # Find peaks
9
+ peaks, info = nk.ecg_peaks(ecg_signal, sampling_rate=fs, method="pantompkins1985")
10
+
11
+ # Compute time domain features
12
+ hrv_time_features = nk.hrv_time(peaks, sampling_rate=fs)
13
+
14
+ # Compute frequency domain features
15
+ #hrv_frequency_features = nk.hrv_frequency(peaks, sampling_rate=fs, method="welch", show=False)
16
+
17
+ # Concat features
18
+ #hrv_features = pd.concat([hrv_time_features, hrv_frequency_features], axis=1)
19
+ hrv_features = hrv_time_features
20
+
21
+ # to dict
22
+ hrv_features = hrv_features[selected_features].to_dict(orient="records")
23
+
24
+ return hrv_features
25
+
26
+
27
+ def normalize_features(features_df, normalization_method):
28
+ if normalization_method == "difference":
29
+ baseline_features = features_df[features_df['baseline'] == True].iloc[0]
30
+ features_df.loc[features_df['baseline'] == False, features_df.columns.isin(selected_features)] -= baseline_features
31
+ elif normalization_method == "relative":
32
+ baseline_features = features_df[features_df['baseline'] == True].iloc[0]
33
+ features_df.loc[features_df['baseline'] == False, features_df.columns.isin(selected_features)] /= baseline_features
34
+ elif (normalization_method == "separate") or (normalization_method is None):
35
+ pass
36
+
37
+ return features_df
ecg_preprocessing.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scipy.signal import filtfilt, butter, resample
2
+ from sklearn.preprocessing import StandardScaler
3
+
4
+
5
+ def remove_basline_wander(data, sampling_rate=360, cutoff_freq=0.05):
6
+ """
7
+ Remove baseline wander from ECG data using a high-pass filter. The high-pass filter will remove all frequencies
8
+ below the cutoff frequency. The cutoff frequency should be set to the lowest frequency that is still considered
9
+ baseline wander and not part of the ECG signal. For example, baseline wander is typically between 0.05 Hz and
10
+ 0.5 Hz. Therefore, a cutoff frequency of 0.05 Hz is a good starting point. However, if the ECG signal contains
11
+ low-frequency components of interest, such as the T wave or P wave, then a higher cutoff frequency may be necessary
12
+ to avoid over-filtering and loss of important ECG signal components.
13
+ See https://en.wikipedia.org/wiki/High-pass_filter for more information on high-pass filters.
14
+
15
+ :param data: ECG data as a 1-dimensional numpy array.
16
+ :type data: numpy array
17
+ :param sampling_rate: Sampling rate of ECG data (Hz), defaults to 360.
18
+ :type sampling_rate: int, optional
19
+ :param cutoff_freq: cutoff frequency of high-pass filter (Hz), defaults to 0.05.
20
+ :type cutoff_freq: float, optional
21
+
22
+ :return: ECG data with baseline wander removed.
23
+ :rtype: numpy array
24
+ """
25
+ # Define filter parameters Nyquist frequency - The highest frequency that can be represented given the sampling
26
+ # frequency. Nyquist Frequency is half the sampling rate (in Hz).
27
+ nyquist_freq = 0.5 * sampling_rate
28
+ # Filter order - The higher the order, the steeper the filter roll-off (i.e. the more aggressive the filter is at
29
+ # removing frequencies outside the passband).
30
+ filter_order = 3
31
+ # Apply high-pass filter
32
+ b, a = butter(filter_order, cutoff_freq / nyquist_freq, 'highpass')
33
+ filtered_data = filtfilt(b, a, data)
34
+
35
+ return filtered_data
36
+
37
+
38
+ def remove_noise(data, sampling_rate=360, lowcut=0.5, highcut=45):
39
+ """
40
+ Remove noise from ECG data using a band-pass filter. The band-pass filter will remove all frequencies below the
41
+ lowcut frequency and above the highcut frequency. The lowcut frequency should be set to the lowest frequency that
42
+ is still considered noise and not part of the ECG signal. For example, noise is typically between 0.5 Hz and 45
43
+ Hz. Therefore, a lowcut frequency of 0.5 Hz is a good starting point. However, if the ECG signal contains
44
+ low-frequency components of interest, such as the T wave or P wave, then a higher lowcut frequency may be
45
+ necessary to avoid over-filtering and loss of important ECG signal components. For this reason,
46
+ a lowcut frequency of 5 Hz is also a good starting point. The lowcut frequency can be adjusted as needed. The
47
+ highcut frequency should be set to the highest frequency that is still considered noise and not part of the ECG
48
+ signal. For example, noise is typically between 0.5 Hz and 45 Hz. Therefore, a highcut frequency of 45 Hz is a
49
+ good starting point. However, if the ECG signal contains high-frequency components of interest, such as the QRS
50
+ complex, then a lower highcut frequency may be necessary to avoid over-filtering and loss of important ECG signal
51
+ components. For this reason, a highcut frequency of 15 Hz is also a good starting point. The highcut frequency
52
+ can be adjusted as needed. See https://en.wikipedia.org/wiki/Band-pass_filter for more information on band-pass
53
+ filters.
54
+
55
+ :param data: ECG data as a 1-dimensional numpy array.
56
+ :type data: numpy array
57
+ :param sampling_rate: The sampling rate of ECG data (Hz), defaults to 360.
58
+ :type sampling_rate: int, optional
59
+ :param lowcut: The lowcut frequency of band-pass filter (Hz), defaults to 0.5.
60
+ :type lowcut: float, optional
61
+ :param highcut: The highcut frequency of band-pass filter (Hz), defaults to 45.
62
+ :type highcut: float, optional
63
+
64
+ :return: ECG data with noise removed
65
+ :rtype: numpy array
66
+ """
67
+ # Define filter parameters
68
+ nyquist_freq = 0.5 * sampling_rate
69
+ # Define cutoff frequencies (remove all frequencies below lowcut and above highcut)
70
+ low = lowcut / nyquist_freq
71
+ high = highcut / nyquist_freq
72
+ # Initialize filter
73
+ b, a = butter(4, [low, high], btype='band')
74
+ # Apply filter twice (combined filter) to remove forward and reverse phase shift. See
75
+ # https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.filtfilt.html for more information on filtfilt.
76
+ filtered_data = filtfilt(b, a, data)
77
+
78
+ return filtered_data
79
+
80
+
81
+ def preprocess_ecg(data, sampling_rate=1000, new_sampling_rate=360):
82
+ # Remove basline wander using highpass filter
83
+ filtered_data = remove_basline_wander(data=data, sampling_rate=sampling_rate)
84
+ # Remove noise from ECG data using bandpass filter
85
+ filtered_data = remove_noise(data=filtered_data, sampling_rate=sampling_rate)
86
+ # Resample ECG data to a new sampling rate
87
+ if new_sampling_rate is not None and new_sampling_rate != sampling_rate:
88
+ filtered_data = resample(filtered_data, int(len(filtered_data) * new_sampling_rate / sampling_rate))
89
+ # Normalize ECG data to have zero mean and unit variance
90
+ scaler = StandardScaler()
91
+ normalized_data = scaler.fit_transform(filtered_data.reshape(-1, 1)).reshape(-1)
92
+
93
+ return normalized_data
ecg_processing.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file contains the ecg processing pipelines."""
2
+ import pandas as pd
3
+
4
+ import neurokit2 as nk
5
+
6
+ from .ecg_feature_extraction import get_hrv_features, normalize_features
7
+ from .utils import cut_out_window, create_windows
8
+
9
+ pd.set_option('display.float_format', '{:.6f}'.format)
10
+
11
+ from .logger import setup_logger
12
+ logger = setup_logger(__name__)
13
+
14
+
15
+ def process_window(window, window_id, frequency):
16
+ features = get_hrv_features(window['ecg'].values, frequency)
17
+ tmp = pd.DataFrame(features, index=[0])
18
+ tmp['subject_id'] = window['subject_id'].unique()[0]
19
+ tmp['sample_id'] = str(window['sample_id'].unique()[0])
20
+ tmp['window_id'] = window_id
21
+ tmp['w_start_time'] = window['timestamp_idx'].min().strftime('%Y-%m-%d %H:%M:%S')
22
+ tmp['w_end_time'] = window['timestamp_idx'].max().strftime('%Y-%m-%d %H:%M:%S')
23
+ tmp['baseline'] = window_id == 0
24
+ tmp['frequency'] = frequency
25
+ return tmp
26
+
27
+
28
+ def process_batch(samples, configs):
29
+ features_list = []
30
+ for i, sample in enumerate(samples):
31
+ logger.info(f"Processing sample ({i + 1}/{len(samples)})...")
32
+ sample_df = pd.DataFrame.from_dict(sample.dict())
33
+ # Preprocess the ecg signal
34
+ logger.info("Preprocess ECG signals...")
35
+ sample_df['ecg'] = nk.ecg_clean(sample_df['ecg'], sampling_rate=sample.frequency, method="pantompkins1985")
36
+ # Cut out the windows and process them
37
+ if configs.baseline_start:
38
+ logger.info("Cut out baseline window...")
39
+ baseline_window = cut_out_window(sample_df, 'timestamp_idx', start=configs.baseline_start,
40
+ end=configs.baseline_end)
41
+ sample_df = sample_df[sample_df['timestamp_idx'] > baseline_window['timestamp_idx'].max()]
42
+ logger.info("Processing baseline window...")
43
+ features_list.append(process_window(baseline_window, 0, sample.frequency))
44
+ logger.info("Cut out windows...")
45
+ windows = create_windows(df=sample_df, time_column='timestamp_idx', window_size=configs.window_size,
46
+ window_slicing_method=configs.window_slicing_method)
47
+ logger.info(f"Processing windows (Total: {len(windows)})...")
48
+ features_list.extend(process_window(window, i, sample.frequency) for i, window in enumerate(windows, start=1))
49
+ features_df = pd.concat(features_list, ignore_index=True)
50
+ # Normalize the features via baseline subtraction
51
+ if configs.baseline_start:
52
+ features_df = normalize_features(features_df, configs.normalization_method)
53
+
54
+ return features_df
logger.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+ import colorlog
4
+
5
+
6
+ formatter = colorlog.ColoredFormatter(
7
+ "%(asctime)s [%(blue)s%(name)s:%(lineno)s%(reset)s] [%(log_color)s%(levelname)s%(reset)s] >>>> %(message)s",
8
+ log_colors={ # 'DEBUG': cyan',
9
+ 'INFO': 'green',
10
+ 'WARNING': 'yellow',
11
+ 'ERROR': 'red',
12
+ 'CRITICAL': 'red,bg_white',
13
+ }
14
+ )
15
+ stream_handler = colorlog.StreamHandler(stream=sys.stdout)
16
+ stream_handler.setFormatter(formatter)
17
+
18
+
19
+ def setup_logger(name, level=logging.DEBUG):
20
+ """Setup a logger with the given name and level.
21
+
22
+ :param name: The name of the logger.
23
+ :param level: The level of the logger.
24
+ :return: The logger.
25
+ """
26
+ logger = logging.getLogger(name)
27
+ logger.setLevel(level)
28
+ logger.addHandler(stream_handler)
29
+ return logger
pydantic_models.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Pydantic models for use in the API. """
2
+ import json
3
+ from datetime import datetime, timedelta, date
4
+ from typing import Union, Dict, Any
5
+ from uuid import UUID, uuid4
6
+
7
+ from pydantic import BaseModel, Field, model_validator
8
+
9
+ from .configs import SignalEnum, WindowSlicingMethodEnum, NormalizationMethodEnum
10
+
11
+ # Try opening json file samples
12
+ try:
13
+ with open('data/examples/example0_input.json') as json_file:
14
+ example0 = json.load(json_file)
15
+ with open('data/examples/example1_input.json') as json_file:
16
+ example1 = json.load(json_file)
17
+ except FileNotFoundError:
18
+ print(
19
+ "Example Files for interface not found. Please run the Jupyter Notebook in notebooks/1_Data_Formatting_and_transformation.py first.")
20
+ example0 = {}
21
+ example1 = {}
22
+
23
+
24
+ class ECGSample(BaseModel):
25
+ """ Model of the results of a single subject of an experiment with ECG biosignals. """
26
+ sample_id: UUID = Field(example="f70c1033-36ae-4b8b-8b89-099a96dccca5", default_factory=uuid4)
27
+ subject_id: str = Field(..., example="participant_1")
28
+ frequency: int = Field(..., example=1000)
29
+ device_name: str = Field(example="bioplux", default=None)
30
+ # pydantic will process either an int or float (unix timestamp) (e.g. 1496498400),
31
+ # an int or float as a string (assumed as Unix timestamp), or
32
+ # o string representing the date (e.g. "YYYY - MM - DD[T]HH: MM[:SS[.ffffff]][Z or [±]HH[:]MM]")
33
+ timestamp_idx: list[datetime] = Field(..., min_items=2, example=[1679709871, 1679713471, 1679720671])
34
+ ecg: list[float] = Field(..., min_items=2, example=[1.0, -1.100878, -3.996840])
35
+ label: list[str] = Field(min_items=2, example=["undefined", "stress", "undefined"], default=None)
36
+
37
+ class Config:
38
+ json_schema_extra = {
39
+ "example": {
40
+ "sample_id": "f70c1033-36ae-4b8b-8b89-099a96dccca5",
41
+ "subject_id": "participant_1",
42
+ "frequency": 1000,
43
+ "device_name": "bioplux",
44
+ "timestamp_idx": [1679709871, 1679713471, 1679720671],
45
+ "ecg": [1.0, -1.100878, -3.996840],
46
+ "label": ["undefined", "stress", "undefined"]
47
+ }
48
+ }
49
+
50
+ @model_validator(mode='before')
51
+ @classmethod
52
+ def set_label_default(cls, values: Any) -> Any:
53
+ """
54
+ Set default for list parameter "label" if list has empty values.
55
+ """
56
+ if isinstance(values, dict):
57
+ max_len = max(len(values['timestamp_idx']), len(values['ecg']))
58
+ if values['label'] is None:
59
+ values['label'] = ['undefined'] * max_len
60
+ elif len(values['label']) < max_len:
61
+ values['label'] += ['undefined'] * (max_len - len(values['label']))
62
+ return values
63
+
64
+ @model_validator(mode='after')
65
+ def check_length(self) -> 'ECGSample':
66
+ """
67
+ Validates that given lists have the same length.
68
+ """
69
+ lengths = [len(self.timestamp_idx), len(self.ecg)]
70
+ if len(set(lengths)) != 1:
71
+ raise ValueError('Given timestamp and ecg list must have the same length!')
72
+ return self
73
+
74
+
75
+ class ECGConfig(BaseModel):
76
+ """ Model of the configuration of an experiment with ECG biosignals. """
77
+ signal: SignalEnum = Field(example=SignalEnum.chest, default=None)
78
+ window_slicing_method: WindowSlicingMethodEnum = Field(example=WindowSlicingMethodEnum.time_related,
79
+ default=WindowSlicingMethodEnum.time_related)
80
+ window_size: float = Field(example=1.0, default=5.0)
81
+ # pydantic will process either an int or float (unix timestamp) (e.g. 1496498400),
82
+ # an int or float as a string (assumed as Unix timestamp), or
83
+ # o string representing the date (e.g. "YYYY - MM - DD[T]HH: MM[:SS[.ffffff]][Z or [±]HH[:]MM]")
84
+ baseline_start: datetime = Field(example="2034-01-16T00:00:00", default=None)
85
+ baseline_end: datetime = Field(example="2034-01-16T00:01:00", default=None)
86
+ baseline_duration: int = Field(example=60, default=None) # in seconds
87
+ normalization_method: Union[NormalizationMethodEnum | None] = Field(
88
+ example=NormalizationMethodEnum.baseline_difference,
89
+ default=NormalizationMethodEnum.baseline_difference)
90
+ extra: Dict[str, Any] = Field(default=None)
91
+
92
+ class Config:
93
+ json_schema_extra = {
94
+ "example": {
95
+ "signal": "chest",
96
+ "window_slicing_method": "time_related",
97
+ "window_size": 60,
98
+ "baseline_start": "2023-05-23 22:58:01.335",
99
+ "baseline_duration": 60,
100
+ "test": "test"
101
+ }
102
+ }
103
+
104
+ @model_validator(mode='before')
105
+ @classmethod
106
+ def build_extra(cls, values: Any) -> Any:
107
+ required_fields = {field.alias for field in cls.model_fields.values() if field.alias != 'extra'}
108
+ extra: Dict[str, Any] = {}
109
+ for field_name in list(values):
110
+ if field_name not in required_fields:
111
+ extra[field_name] = values.pop(field_name)
112
+ values['extra'] = extra
113
+ return values
114
+
115
+ @model_validator(mode='after')
116
+ def check_baseline_start(self) -> 'ECGConfig':
117
+ """
118
+ Validates that baseline_start and either baseline_duration or baseline_end are given if baseline is True.
119
+ If baseline_end is not provided, it is calculated as baseline_start + baseline_duration.
120
+ """
121
+ if self.baseline_start:
122
+ if self.baseline_duration is None and self.baseline_end is None:
123
+ raise ValueError(
124
+ 'If baseline_start id given, either baseline_duration or baseline_end must be provided.')
125
+ if self.baseline_end is None:
126
+ if self.baseline_duration is None:
127
+ raise ValueError(
128
+ 'If baseline is True, baseline_duration must be provided when baseline_end is not provided.')
129
+ self.baseline_end = self.baseline_start + timedelta(seconds=self.baseline_duration)
130
+
131
+ elif self.baseline_start is None and (self.baseline_duration or self.baseline_end) is not None:
132
+ raise ValueError(
133
+ 'If basleine_duration or baseline_end is given, baseline_start must be provided in order. Delete the '
134
+ 'baseline Parameters if the baseline is not needed.')
135
+ return self
136
+
137
+ @classmethod
138
+ def __get_validators__(cls):
139
+ yield cls.validate_to_json
140
+
141
+ @classmethod
142
+ def validate_to_json(cls, value):
143
+ if isinstance(value, str):
144
+ return cls.model_validate(json.loads(value.encode()))
145
+ return cls.model_validate(value)
146
+
147
+
148
+ class ECGBatch(BaseModel):
149
+ """ Input Modle for Data Validation. The Input being the results of an experiment with ECG biosignals,
150
+ including a batch of ecg data of different subjects. """
151
+ supervisor: str = Field(..., example="Lieschen Mueller")
152
+ # pydantic will process either an int or float (unix timestamp) (e.g. 1496498400),
153
+ # an int or float as a string (assumed as Unix timestamp), or
154
+ # o string representing the date (e.g. "YYYY-MM-DD")
155
+ record_date: date = Field(example="2034-01-16", default_factory=datetime.utcnow)
156
+ configs: ECGConfig = Field(..., example=ECGConfig.Config.json_schema_extra)
157
+ samples: list[ECGSample] = Field(..., min_items=1,
158
+ example=[ECGSample.Config.json_schema_extra, ECGSample.Config.json_schema_extra])
159
+
160
+ class Config:
161
+ json_schema_extra = {
162
+ "example": example1,
163
+ "examples": [
164
+ example0,
165
+ example1
166
+ ]
167
+ }
utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file contains a collection of utility functions that can be used for common tasks in the ecg processing."""
2
+ from datetime import datetime, timedelta
3
+ from typing import Union
4
+
5
+ import pandas as pd
6
+
7
+
8
+ def cut_out_window(df: pd.DataFrame,
9
+ time_column: str,
10
+ start: Union[datetime, pd.Timestamp],
11
+ end: Union[datetime, pd.Timestamp, None] = None,
12
+ duration: Union[timedelta, int, None] = None) -> pd.DataFrame:
13
+ """
14
+ Cuts out a window from a DataFrame based on the given start and end timestamps or duration. The dataframe must have
15
+ a time column containing timestamps.
16
+
17
+ :param df: The dataframe to cut out the window from.
18
+ :type df: pandas.DataFrame
19
+ :param time_column: The name of the column containing the timestamps.
20
+ :type time_column: str
21
+ :param start: The start timestamp of the window.
22
+ :type start: datetime.datetime or pandas.Timestamp
23
+ :param end: The end timestamp of the window.
24
+ :type end: datetime.datetime or pandas.Timestamp or None
25
+ :param duration: The duration of the window in seconds.
26
+ :type duration: datetime.timedelta or int or None
27
+
28
+ :return: The window as a dataframe.
29
+ :rtype: pandas.DataFrame
30
+ """
31
+ # Convert the timestamp column to datetime if it's not already
32
+ if not pd.api.types.is_datetime64_ns_dtype(df[time_column]):
33
+ df[time_column] = pd.to_datetime(df[time_column])
34
+
35
+ # Cut out the window
36
+ if end is None and duration is None:
37
+ raise ValueError('Either end or duration must be given!')
38
+ if end is None and duration is not None:
39
+ end = start + pd.Timedelta(seconds=duration)
40
+
41
+ window = df[(df[time_column] >= start) & (df[time_column] <= end)]
42
+ return window
43
+
44
+
45
+ def create_windows(df, time_column, label_column=None, window_size=5.0, window_slicing_method='time_related'):
46
+ """
47
+ Slices a dataframe into windows of a given size. The windows can be sliced in different ways. The windows are
48
+ returned as a generator of dataframes. The dataframe must have a column containing timestamps and be indexed by it.
49
+
50
+ :param df: The dataframe to slice.
51
+ :type df: pandas.DataFrame
52
+ :param time_column: The name of the column containing the timestamps.
53
+ :type time_column: str
54
+ :param label_column: The name of the column containing the labels.
55
+ :type label_column: str
56
+ :param window_size: The size of the windows in seconds.
57
+ :type window_size: int
58
+ :param window_slicing_method: The method used to slice the windows.
59
+ :type window_slicing_method: str
60
+
61
+ :return: A generator of dataframes containing the windows.
62
+ :rtype: generator
63
+ """
64
+ # Convert the timestamp column to datetime if it's not already
65
+ if not pd.api.types.is_datetime64_ns_dtype(df[time_column]):
66
+ df[time_column] = pd.to_datetime(df[time_column])
67
+
68
+ # Slice the dataframe into windows
69
+ if window_slicing_method == 'time_related':
70
+ # Resample the dataframe every x seconds
71
+ result_dfs = [group for _, group in df.groupby(pd.Grouper(key=time_column, freq=f'{window_size}S'))]
72
+ return result_dfs
73
+ elif window_slicing_method == 'label_related_before':
74
+ pass
75
+ elif window_slicing_method == 'label_related_after':
76
+ pass
77
+ elif window_slicing_method == 'label_related_centered':
78
+ pass
79
+ else:
80
+ raise ValueError(f'window_slicing_method {window_slicing_method} not supported')