alex-schubert commited on
Commit
c5c661c
·
verified ·
1 Parent(s): bd73ec6

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +54 -1
README.md CHANGED
@@ -19,7 +19,60 @@ We are releasing the weights for the three best-performing models:
19
 
20
  ## Usage
21
 
 
22
 
 
23
 
24
- ## Citation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  ## Usage
21
 
22
+ Below are sample snippets for loading each model with its pretrained weights. Once loaded, you can run inference by passing your pre-processed ECG data to the model’s forward method. Please refer to the project’s GitHub repository for end-to-end examples demonstrating how to use these models.
23
 
24
+ ### S4-ECG
25
 
26
+ ```bash
27
+ import torch
28
+ import lightning.pytorch as pl
29
+ from src.lightning import S4Model
30
+
31
+ def load_from_checkpoint(pl_model, checkpoint_path):
32
+ """ load from checkpoint function that is compatible with S4
33
+ """
34
+ lightning_state_dict = torch.load(checkpoint_path)
35
+ state_dict = lightning_state_dict["state_dict"]
36
+
37
+ for name, param in pl_model.named_parameters():
38
+ param.data = state_dict[name].data
39
+ for name, param in pl_model.named_buffers():
40
+ param.data = state_dict[name].data
41
+
42
+ checkpoint_path = "path/to/your/benchmark_acs_state_v0/dmaxlwcg/checkpoints/epoch=49-step=1100.ckpt"
43
+
44
+ model = S4Model(init_lr=1e-4,
45
+ d_input=3,
46
+ d_output=1)
47
+
48
+ load_from_checkpoint(model, checkpoint_path)
49
+ ```
50
+
51
+ ### ResNet-18
52
 
53
+ ```bash
54
+ import torch
55
+ import lightning.pytorch as pl
56
+ from src.lightning import ResNet18_1D
57
+
58
+ checkpoint_path = "path/to/your/benchmark_acs_resnet18_1d_final_vf/2vud5fft/checkpoints/epoch=37-step=418.ckpt"
59
+ model = ResNet18_1D.load_from_checkpoint(checkpoint_path)
60
+ ```
61
+
62
+ ### HuBERT-ECG
63
+
64
+ ```bash
65
+ import torch
66
+ from hubert_ecg import HuBERTECG, HuBERTECGConfig
67
+ from hubert_ecg_classification import HuBERTForECGClassification
68
+
69
+ path = "path/to/your/hubert_3_iteration_300_finetuned_simdmsnv.pt"
70
+ checkpoint = torch.load(path, map_location='cpu')
71
+ config = checkpoint['model_config']
72
+ hubert_ecg = HuBERTECG(config)
73
+ hubert_ecg = HuBERTForECGClassification(hubert_ecg)
74
+ hubert_ecg.load_state_dict(checkpoint['model_state_dict'])
75
+ ```
76
+
77
+ ## Citation
78
+ *Paper currently under review*