nina-m-m commited on
Commit ·
44cdef4
1
Parent(s): bf0f149
Update model and implement abstract class
Browse files
ECG2HRV.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3669bf4b201fd873f1fb1b48083c41f26c42be00f58d14d234af6b9aac1f6433
|
| 3 |
+
size 39
|
notebooks/01_Model_Deployment_Research.ipynb
CHANGED
|
@@ -399,21 +399,18 @@
|
|
| 399 |
},
|
| 400 |
{
|
| 401 |
"cell_type": "code",
|
| 402 |
-
"execution_count":
|
| 403 |
"outputs": [],
|
| 404 |
"source": [
|
| 405 |
"from huggingface_hub import hf_hub_download\n",
|
| 406 |
"import joblib\n",
|
| 407 |
"import torch\n",
|
|
|
|
| 408 |
"\n",
|
| 409 |
-
"from src.model import
|
| 410 |
],
|
| 411 |
"metadata": {
|
| 412 |
-
"collapsed": false
|
| 413 |
-
"ExecuteTime": {
|
| 414 |
-
"end_time": "2024-02-21T11:39:25.775871100Z",
|
| 415 |
-
"start_time": "2024-02-21T11:39:25.755838Z"
|
| 416 |
-
}
|
| 417 |
}
|
| 418 |
},
|
| 419 |
{
|
|
@@ -427,11 +424,11 @@
|
|
| 427 |
},
|
| 428 |
{
|
| 429 |
"cell_type": "code",
|
| 430 |
-
"execution_count":
|
| 431 |
"outputs": [],
|
| 432 |
"source": [
|
| 433 |
"# Instantiate model\n",
|
| 434 |
-
"model =
|
| 435 |
"# Save\n",
|
| 436 |
"joblib.dump(model, \"..\\ECG2HRV.joblib\")\n",
|
| 437 |
"# Load in notebook\n",
|
|
@@ -440,15 +437,15 @@
|
|
| 440 |
"metadata": {
|
| 441 |
"collapsed": false,
|
| 442 |
"ExecuteTime": {
|
| 443 |
-
"end_time": "2024-02-
|
| 444 |
-
"start_time": "2024-02-
|
| 445 |
}
|
| 446 |
}
|
| 447 |
},
|
| 448 |
{
|
| 449 |
"cell_type": "markdown",
|
| 450 |
"source": [
|
| 451 |
-
"**Test
|
| 452 |
],
|
| 453 |
"metadata": {
|
| 454 |
"collapsed": false
|
|
@@ -456,38 +453,84 @@
|
|
| 456 |
},
|
| 457 |
{
|
| 458 |
"cell_type": "code",
|
| 459 |
-
"execution_count":
|
| 460 |
"outputs": [],
|
| 461 |
"source": [
|
| 462 |
-
"#
|
| 463 |
-
"
|
| 464 |
-
"
|
| 465 |
"\n",
|
| 466 |
-
"
|
| 467 |
-
"
|
| 468 |
-
")"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
],
|
| 470 |
"metadata": {
|
| 471 |
"collapsed": false,
|
| 472 |
"ExecuteTime": {
|
| 473 |
-
"end_time": "2024-02-
|
| 474 |
-
"start_time": "2024-02-
|
| 475 |
}
|
| 476 |
}
|
| 477 |
},
|
| 478 |
{
|
| 479 |
"cell_type": "code",
|
| 480 |
-
"execution_count":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
"outputs": [],
|
| 482 |
"source": [
|
| 483 |
-
"#
|
| 484 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
],
|
| 486 |
"metadata": {
|
| 487 |
"collapsed": false,
|
| 488 |
"ExecuteTime": {
|
| 489 |
-
"end_time": "2024-02-21T11:36:
|
| 490 |
-
"start_time": "2024-02-21T11:36:
|
| 491 |
}
|
| 492 |
}
|
| 493 |
},
|
|
@@ -504,9 +547,8 @@
|
|
| 504 |
}
|
| 505 |
],
|
| 506 |
"source": [
|
| 507 |
-
"# Run
|
| 508 |
-
"
|
| 509 |
-
"print(output)\n"
|
| 510 |
],
|
| 511 |
"metadata": {
|
| 512 |
"collapsed": false,
|
|
|
|
| 399 |
},
|
| 400 |
{
|
| 401 |
"cell_type": "code",
|
| 402 |
+
"execution_count": 1,
|
| 403 |
"outputs": [],
|
| 404 |
"source": [
|
| 405 |
"from huggingface_hub import hf_hub_download\n",
|
| 406 |
"import joblib\n",
|
| 407 |
"import torch\n",
|
| 408 |
+
"import numpy as np\n",
|
| 409 |
"\n",
|
| 410 |
+
"from src.model import ECG2HRV"
|
| 411 |
],
|
| 412 |
"metadata": {
|
| 413 |
+
"collapsed": false
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
}
|
| 415 |
},
|
| 416 |
{
|
|
|
|
| 424 |
},
|
| 425 |
{
|
| 426 |
"cell_type": "code",
|
| 427 |
+
"execution_count": 2,
|
| 428 |
"outputs": [],
|
| 429 |
"source": [
|
| 430 |
"# Instantiate model\n",
|
| 431 |
+
"model = ECG2HRV()\n",
|
| 432 |
"# Save\n",
|
| 433 |
"joblib.dump(model, \"..\\ECG2HRV.joblib\")\n",
|
| 434 |
"# Load in notebook\n",
|
|
|
|
| 437 |
"metadata": {
|
| 438 |
"collapsed": false,
|
| 439 |
"ExecuteTime": {
|
| 440 |
+
"end_time": "2024-02-21T16:08:51.659030Z",
|
| 441 |
+
"start_time": "2024-02-21T16:08:51.605730100Z"
|
| 442 |
}
|
| 443 |
}
|
| 444 |
},
|
| 445 |
{
|
| 446 |
"cell_type": "markdown",
|
| 447 |
"source": [
|
| 448 |
+
"**Test the model locally with random ecg**"
|
| 449 |
],
|
| 450 |
"metadata": {
|
| 451 |
"collapsed": false
|
|
|
|
| 453 |
},
|
| 454 |
{
|
| 455 |
"cell_type": "code",
|
| 456 |
+
"execution_count": 3,
|
| 457 |
"outputs": [],
|
| 458 |
"source": [
|
| 459 |
+
"duration_seconds = 10 # Time duration for ECG signal (in seconds)\n",
|
| 460 |
+
"sample_rate = 100 # Sample rate (samples per second)\n",
|
| 461 |
+
"num_samples = duration_seconds * sample_rate # Number of samples\n",
|
| 462 |
"\n",
|
| 463 |
+
"t = np.linspace(0, duration_seconds, num_samples) # Time array\n",
|
| 464 |
+
"\n",
|
| 465 |
+
"# Generate ECG signal (example synthetic data)\n",
|
| 466 |
+
"ecg_signal = (\n",
|
| 467 |
+
" 0.2 * np.sin(2 * np.pi * 1 * t) +\n",
|
| 468 |
+
" 0.5 * np.sin(2 * np.pi * 0.5 * t) -\n",
|
| 469 |
+
" 0.1 * np.sin(2 * np.pi * 2.5 * t)\n",
|
| 470 |
+
")\n",
|
| 471 |
+
"\n",
|
| 472 |
+
"# Add some random noise\n",
|
| 473 |
+
"ecg_signal += np.random.normal(scale=0.1, size=num_samples)"
|
| 474 |
],
|
| 475 |
"metadata": {
|
| 476 |
"collapsed": false,
|
| 477 |
"ExecuteTime": {
|
| 478 |
+
"end_time": "2024-02-21T16:08:51.669938Z",
|
| 479 |
+
"start_time": "2024-02-21T16:08:51.635032600Z"
|
| 480 |
}
|
| 481 |
}
|
| 482 |
},
|
| 483 |
{
|
| 484 |
"cell_type": "code",
|
| 485 |
+
"execution_count": 4,
|
| 486 |
+
"outputs": [
|
| 487 |
+
{
|
| 488 |
+
"data": {
|
| 489 |
+
"text/plain": "[{'HRV_MeanNN': 413.4782608695652,\n 'HRV_SDNN': 100.97743652790477,\n 'HRV_SDANN1': nan,\n 'HRV_SDNNI1': nan,\n 'HRV_SDANN2': nan,\n 'HRV_SDNNI2': nan,\n 'HRV_SDANN5': nan,\n 'HRV_SDNNI5': nan,\n 'HRV_RMSSD': 92.78518690551262,\n 'HRV_SDSD': 94.96410805236795,\n 'HRV_CVNN': 0.24421462041449105,\n 'HRV_CVSD': 0.22440160870944167,\n 'HRV_MedianNN': 400.0,\n 'HRV_MadNN': 118.60799999999999,\n 'HRV_MCVNN': 0.29651999999999995,\n 'HRV_IQRNN': 150.0,\n 'HRV_SDRMSSD': 1.0882926455785953,\n 'HRV_Prc20NN': 320.0,\n 'HRV_Prc80NN': 490.0,\n 'HRV_pNN50': 52.17391304347826,\n 'HRV_pNN20': 69.56521739130434,\n 'HRV_MinNN': 310.0,\n 'HRV_MaxNN': 640.0,\n 'HRV_HTI': 5.75,\n 'HRV_TINN': 0.0}]"
|
| 490 |
+
},
|
| 491 |
+
"execution_count": 4,
|
| 492 |
+
"metadata": {},
|
| 493 |
+
"output_type": "execute_result"
|
| 494 |
+
}
|
| 495 |
+
],
|
| 496 |
+
"source": [
|
| 497 |
+
"model(input_data=ecg_signal, frequency=100.0)"
|
| 498 |
+
],
|
| 499 |
+
"metadata": {
|
| 500 |
+
"collapsed": false,
|
| 501 |
+
"ExecuteTime": {
|
| 502 |
+
"end_time": "2024-02-21T16:08:51.755181400Z",
|
| 503 |
+
"start_time": "2024-02-21T16:08:51.671014900Z"
|
| 504 |
+
}
|
| 505 |
+
}
|
| 506 |
+
},
|
| 507 |
+
{
|
| 508 |
+
"cell_type": "markdown",
|
| 509 |
+
"source": [
|
| 510 |
+
"**Test if the model can be loaded from the hub and used**"
|
| 511 |
+
],
|
| 512 |
+
"metadata": {
|
| 513 |
+
"collapsed": false
|
| 514 |
+
}
|
| 515 |
+
},
|
| 516 |
+
{
|
| 517 |
+
"cell_type": "code",
|
| 518 |
+
"execution_count": 19,
|
| 519 |
"outputs": [],
|
| 520 |
"source": [
|
| 521 |
+
"# Load from hub\n",
|
| 522 |
+
"REPO_ID = \"HUBII-Platform/ECG2HRV\"\n",
|
| 523 |
+
"FILENAME = \"feature-extractor.joblib\"\n",
|
| 524 |
+
"\n",
|
| 525 |
+
"model = joblib.load(\n",
|
| 526 |
+
" hf_hub_download(repo_id=REPO_ID, filename=FILENAME)\n",
|
| 527 |
+
")"
|
| 528 |
],
|
| 529 |
"metadata": {
|
| 530 |
"collapsed": false,
|
| 531 |
"ExecuteTime": {
|
| 532 |
+
"end_time": "2024-02-21T11:36:52.302912800Z",
|
| 533 |
+
"start_time": "2024-02-21T11:36:52.145834500Z"
|
| 534 |
}
|
| 535 |
}
|
| 536 |
},
|
|
|
|
| 547 |
}
|
| 548 |
],
|
| 549 |
"source": [
|
| 550 |
+
"# Run model\n",
|
| 551 |
+
"model(input_data=ecg_signal, frequency=100.0)"
|
|
|
|
| 552 |
],
|
| 553 |
"metadata": {
|
| 554 |
"collapsed": false,
|
feature-extractor.joblib → notebooks/feature-extractor.joblib
RENAMED
|
File without changes
|
src/ecg2hrv.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import neurokit2 as nk
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from src.feature_extractor import FeatureExtractor
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ECG2HRV(FeatureExtractor):
|
| 10 |
+
def __init__(self):
|
| 11 |
+
super().__init__()
|
| 12 |
+
|
| 13 |
+
def extract_features(self, ecg, frequency, baseline=None, normalization_method=None):
|
| 14 |
+
# Ensure the numpy has at least one dimension (i.e. is not a scalar)
|
| 15 |
+
if ecg.ndim < 1:
|
| 16 |
+
raise ValueError("Array must have at least one dimension")
|
| 17 |
+
|
| 18 |
+
# Preprocess the ecg signal
|
| 19 |
+
ecg = nk.ecg_clean(ecg_signal=ecg, sampling_rate=frequency, method="pantompkins1985")
|
| 20 |
+
|
| 21 |
+
# Compute the HRV features
|
| 22 |
+
features = self.get_hrv_features(ecg, frequency)
|
| 23 |
+
|
| 24 |
+
# Normalize if baseline is available and method is set - TBD
|
| 25 |
+
if baseline is not None and normalization_method is not None:
|
| 26 |
+
baseline_features = self.get_hrv_features(baseline)
|
| 27 |
+
features = self.normalize_features(features, baseline_features)
|
| 28 |
+
|
| 29 |
+
return features
|
| 30 |
+
|
| 31 |
+
def get_hrv_features(self, ecg, frequency):
|
| 32 |
+
# Find peaks
|
| 33 |
+
peaks, info = nk.ecg_peaks(ecg, sampling_rate=frequency, method="pantompkins1985")
|
| 34 |
+
|
| 35 |
+
# Compute time domain features
|
| 36 |
+
hrv_time_features = nk.hrv_time(peaks, sampling_rate=frequency)
|
| 37 |
+
# Compute frequency domain features
|
| 38 |
+
# hrv_frequency_features = nk.hrv_frequency(peaks, sampling_rate=fs, method="welch", show=False)
|
| 39 |
+
|
| 40 |
+
# Concat features
|
| 41 |
+
# hrv_features = pd.concat([hrv_time_features, hrv_frequency_features], axis=1)
|
| 42 |
+
hrv_features = hrv_time_features
|
| 43 |
+
|
| 44 |
+
return hrv_features.to_dict(orient="records")
|
| 45 |
+
|
| 46 |
+
def normalize_features(self, features, baseline_features, normalization_method=None):
|
| 47 |
+
if normalization_method == "difference":
|
| 48 |
+
features = features - baseline_features
|
| 49 |
+
elif normalization_method == "relative":
|
| 50 |
+
features = features / baseline_features
|
| 51 |
+
else:
|
| 52 |
+
raise ValueError(f"Normalization method {normalization_method} not supported")
|
| 53 |
+
|
| 54 |
+
return features
|
src/feature_extractor.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class FeatureExtractor(ABC):
|
| 5 |
+
def __init__(self):
|
| 6 |
+
pass
|
| 7 |
+
|
| 8 |
+
def __call__(self, input_data, frequency, baseline_data=None, normalization_method=None):
|
| 9 |
+
return self.extract_features(input_data, frequency, baseline_data, normalization_method)
|
| 10 |
+
|
| 11 |
+
@abstractmethod
|
| 12 |
+
def extract_features(self, input_data, baseline_data, frequency, normalization_method):
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
@abstractmethod
|
| 16 |
+
def normalize_features(self, features, baseline_features=None):
|
| 17 |
+
pass
|