|
|
--- |
|
|
license: apache-2.0 |
|
|
--- |
|
|
|
|
|
# Diagnosis of Acute Coronary Syndrome using ECG waveforms: A machine learning framework and benchmark dataset |
|
|
|
|
|
Code repository: https://github.com/alexmschubert/ACS-BenchWork |
|
|
|
|
|
## Project description |
|
|
In this project, we benchmarked various machine learning (ML) approaches for detecting acute coronary syndrome (ACS) - commonly known as 'heart attack' - from 12-lead ECG waveforms. Our findings reveal that ML models can successfully identify large groups of high-risk patients who show no classical ECG features (e.g., ST-elevation or depression) typically associated with ACS by cardiologists. We are releasing the weights of our best-performing models to facilitate further development in data-driven ACS detection. The training dataset, ACS-BenchWork, is available on [Nightingale Open Science](https://docs.ngsci.org/datasets/ed-bwh-ecg/), and we invite researchers to build on these resources to improve the accuracy and utility of ECG-based ACS screening tools. |
|
|
|
|
|
## Models |
|
|
|
|
|
We are releasing the weights for the three best-performing models: |
|
|
|
|
|
- **S4-ECG:** A Structured State Space model well-suited to ECG waveforms, capturing both local and long-range temporal features for robust ACS detection. Building upon work by [Strodthoff et al.](https://github.com/AI4HealthUOL/ECG-MIMIC). |
|
|
- **ResNet-18:** A 1D convolutional residual network adapted for ECG data, leveraging skip connections to learn complex patterns without vanishing gradients. |
|
|
- **HuBERT-ECG:** A transformer-based architecture originally designed for speech recognition, repurposed and fine-tuned on ECG signals for ACS prediction. Based on work by [Coppola et al.](https://github.com/Edoar-do/HuBERT-ECG/tree/master/code). |
|
|
|
|
|
## Usage |
|
|
|
|
|
Below are sample snippets for loading each model with its pretrained weights. Once loaded, you can run inference by passing your pre-processed ECG data to the model’s forward method. Please refer to the project’s GitHub repository for end-to-end examples demonstrating how to use these models. |
|
|
|
|
|
### S4-ECG |
|
|
|
|
|
```bash |
|
|
import torch |
|
|
import lightning.pytorch as pl |
|
|
from src.lightning import S4Model |
|
|
|
|
|
def load_from_checkpoint(pl_model, checkpoint_path): |
|
|
""" load from checkpoint function that is compatible with S4 |
|
|
""" |
|
|
lightning_state_dict = torch.load(checkpoint_path) |
|
|
state_dict = lightning_state_dict["state_dict"] |
|
|
|
|
|
for name, param in pl_model.named_parameters(): |
|
|
param.data = state_dict[name].data |
|
|
for name, param in pl_model.named_buffers(): |
|
|
param.data = state_dict[name].data |
|
|
|
|
|
checkpoint_path = "path/to/your/benchmark_acs_state_v0/dmaxlwcg/checkpoints/epoch=49-step=1100.ckpt" |
|
|
|
|
|
model = S4Model(init_lr=1e-4, |
|
|
d_input=3, |
|
|
d_output=1) |
|
|
|
|
|
load_from_checkpoint(model, checkpoint_path) |
|
|
``` |
|
|
|
|
|
### ResNet-18 |
|
|
|
|
|
```bash |
|
|
import torch |
|
|
import lightning.pytorch as pl |
|
|
from src.lightning import ResNet18_1D |
|
|
|
|
|
checkpoint_path = "path/to/your/benchmark_acs_resnet18_1d_final_vf/2vud5fft/checkpoints/epoch=37-step=418.ckpt" |
|
|
model = ResNet18_1D.load_from_checkpoint(checkpoint_path) |
|
|
``` |
|
|
|
|
|
### HuBERT-ECG |
|
|
|
|
|
```bash |
|
|
import torch |
|
|
from hubert_ecg import HuBERTECG, HuBERTECGConfig |
|
|
from hubert_ecg_classification import HuBERTForECGClassification |
|
|
|
|
|
path = "path/to/your/hubert_3_iteration_300_finetuned_simdmsnv.pt" |
|
|
checkpoint = torch.load(path, map_location='cpu') |
|
|
config = checkpoint['model_config'] |
|
|
hubert_ecg = HuBERTECG(config) |
|
|
hubert_ecg = HuBERTForECGClassification(hubert_ecg) |
|
|
hubert_ecg.load_state_dict(checkpoint['model_state_dict']) |
|
|
``` |
|
|
|
|
|
## Citation |
|
|
*Paper currently under review* |
|
|
|