riccardolunelli commited on
Commit
7730918
·
verified ·
1 Parent(s): cfc8bc1

Create xECG.py

Browse files
Files changed (1) hide show
  1. xECG.py +224 -0
xECG.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from xlstm import FeedForwardConfig, mLSTMLayerConfig, mLSTMBlockConfig, sLSTMLayerConfig, sLSTMBlockConfig, xLSTMBlockStackConfig, xLSTMBlockStack
5
+ import numpy as np
6
+ from huggingface_hub import PyTorchModelHubMixin
7
+
8
+
9
+ class xECG(
10
+ nn.Module,
11
+ PyTorchModelHubMixin,
12
+ repo_url="https://github.com/dlaskalab/bench-xecg/",
13
+ pipeline_tag="other",
14
+ license="mit"
15
+ ):
16
+
17
+ def __init__(
18
+ self,
19
+ cls_type,
20
+ config,
21
+ ):
22
+ super(xECG, self).__init__()
23
+
24
+ self.dropout = nn.Dropout(config['dropout'])
25
+ self.sampling_freq = config['sampling_freq']
26
+ self.patch_size = config['patch_size']
27
+ self.embedding_size = config['embedding_size']
28
+ self.cls_type = cls_type
29
+ assert self.cls_type in ['max', 'avg', 'mean', None], f"cls_type {self.cls_type} not supported"
30
+
31
+ self.patch_embedding = LinearPatchEmbedding(
32
+ patch_size=config['patch_size'],
33
+ num_hiddens=config['embedding_size'],
34
+ num_channels=12
35
+ )
36
+
37
+ self.core = get_xlstm(config)
38
+ self.mask_token = nn.Parameter(torch.zeros(config['embedding_size']))
39
+
40
+ def pooling(self, out, padding_mask=None):
41
+ cls= None
42
+ if self.cls_type == 'max':
43
+ if padding_mask is None:
44
+ cls = out.max(dim=1)[0]
45
+ else:
46
+ # do not consider padded value in max
47
+ cls = out.masked_fill(padding_mask, -torch.inf).max(dim=1)[0]
48
+ elif self.cls_type == 'mean' or self.cls_type == 'avg':
49
+ if padding_mask is None:
50
+ cls = out.mean(dim=1)
51
+ else:
52
+ # do not consider padded value in mean
53
+ cls = out.masked_fill(padding_mask, 0).sum(dim=1) / (out.shape[1] - padding_mask.sum(dim=1)).clamp(min=1)
54
+ return cls, out
55
+
56
+
57
+ def forward(self, x):
58
+ # find the padded part of the signal
59
+ padding_mask = self.get_padding_mask(x)
60
+
61
+ x = self.patch_embedding(x)
62
+
63
+ out = self.core(x) # [batch_size, embedding_dim]
64
+ cls, out = self.pooling(out, padding_mask)
65
+
66
+ return cls, out
67
+
68
+
69
+ def get_padding_mask(self, x):
70
+ padding_mask = (x.abs().sum(dim=-1) == 0).unsqueeze(-1)
71
+ num_patches = x.shape[1] // self.patch_size
72
+ padding_mask_patched = padding_mask.view(-1, num_patches, self.patch_size)[:, :, 0].unsqueeze(-1).expand(-1, -1, self.embedding_size)
73
+ return padding_mask_patched
74
+
75
+
76
+ def trainable_parameters(self):
77
+ return self.parameters()
78
+
79
+ def get_layers(self):
80
+ """
81
+ This function should return the layers of the model where to apply the layerwise decay
82
+ """
83
+ return self.core.model.blocks
84
+
85
+ def additional_params(self, lr, last_layer_lr, wd):
86
+ """
87
+ This fucntion should return additional parameters used by a model (like classification token and so on...)
88
+ """
89
+ params = []
90
+ params.append({"params": self.patch_embedding.parameters(), "lr": last_layer_lr, "name": "patch_embedding", "weight_decay": wd})
91
+
92
+ if hasattr(self.core, 'post_blocks_norm'):
93
+ params.append({'params': self.core.post_blocks_norm, 'lr': lr, 'name': 'post_block_norm', 'weight_decay': wd})
94
+
95
+ return params
96
+
97
+
98
+ def format_keys(self, key):
99
+ if key.startswith('model.'):
100
+ key = key[6:]
101
+
102
+ key = key.replace('xlstm.model', 'core.model') # Remove 'module.' prefix if present
103
+ return key
104
+
105
+ def load_checkpoint(self, checkpoint_path):
106
+ checkpoint = torch.load(checkpoint_path, weights_only=False)
107
+ new_state_dict = {self.format_keys(k): v for k, v in checkpoint['state_dict'].items()}
108
+
109
+ # for k, v in new_state_dict.items():
110
+ # if "slstm_cell._recurrent_kernel_" in k:
111
+ # new_state_dict[k] = v.permute(0, 2, 1)
112
+
113
+ # remove the fc layer
114
+ new_state_dict = {k: v for k, v in new_state_dict.items() if 'fc' not in k}
115
+ message = self.load_state_dict(new_state_dict, strict=False)
116
+ print(message)
117
+
118
+
119
+ class LinearPatchEmbedding(nn.Module):
120
+ def __init__(self, patch_size=64, num_hiddens=256, num_channels=12):
121
+ super().__init__()
122
+ self.conv = nn.Conv1d(num_channels, num_hiddens, kernel_size=patch_size, stride=patch_size, bias=False)
123
+
124
+ def forward(self, x, permute=True):
125
+ if permute: x = x.permute(0, 2, 1) # put the channels in the middle
126
+ x = self.conv(x).flatten(2).transpose(1, 2)
127
+ return x
128
+
129
+
130
+ class vanillaxLSTMWrapper(nn.Module):
131
+ """ xlstm wrapper to allow bidirectionality and drop path """
132
+
133
+ def __init__(self, xlstm, dropout=0.2, bidirectional=False, drop_path=0.):
134
+ super(vanillaxLSTMWrapper, self).__init__()
135
+ self.model = xlstm
136
+ self.dropout = nn.Dropout(dropout)
137
+ self.bidirectional = bidirectional
138
+ self.drop_path = DropPath()
139
+ self.dropout_rates = [x.item() for x in torch.linspace(0, drop_path, len(self.model.blocks))]
140
+
141
+ def step(self, x, state=None):
142
+ return self.model.step(x, state=state)
143
+
144
+ def forward(self, x: torch.Tensor):
145
+
146
+ for i, block in enumerate(self.model.blocks):
147
+ if self.bidirectional:
148
+ # flip the sequence
149
+ if i > 0:
150
+ x = x.flip(1)
151
+
152
+ if self.dropout_rates[i] == 0. or not self.training:
153
+ x = block(x)
154
+ else:
155
+ x = self.drop_path(x, block, self.dropout_rates[i])
156
+
157
+ x = self.model.post_blocks_norm(x)
158
+ return x
159
+
160
+
161
+ class DropPath(nn.Module):
162
+ """Drop paths (Stochastic Depth) per sample (when applied in the main path of residual blocks)."""
163
+ def __init__(self, is_large_mlstm=False):
164
+ super(DropPath, self).__init__()
165
+ self.is_large_mlstm = is_large_mlstm
166
+
167
+ def forward(self, x, block, drop_path_prob, state = None):
168
+ if drop_path_prob == 0. or not self.training:
169
+ if self.is_large_mlstm:
170
+ return block(x, state)
171
+ else:
172
+ return block(x)
173
+
174
+ # indexes of the batch
175
+ idxs = torch.randperm(x.shape[0])
176
+ num_to_keep = int(np.ceil((1.0 - drop_path_prob) * x.shape[0]))
177
+ idxs_to_keep = idxs[:num_to_keep] # First N elements are kept
178
+
179
+ if self.is_large_mlstm:
180
+ out, _ = block(x[idxs_to_keep], None)
181
+ x[idxs_to_keep] = out
182
+ # dont need to have a state in training
183
+ return x, None
184
+ else:
185
+ x[idxs_to_keep] = block(x[idxs_to_keep])
186
+ return x
187
+
188
+
189
+ def get_xlstm(config):
190
+ cfg = xLSTMBlockStackConfig(
191
+ mlstm_block=mLSTMBlockConfig(
192
+ mlstm=mLSTMLayerConfig(
193
+ conv1d_kernel_size=4,
194
+ qkv_proj_blocksize=config['num_heads'],
195
+ num_heads=config['num_heads'],
196
+ proj_factor=config['proj_factor']
197
+ )
198
+ ),
199
+ slstm_block=sLSTMBlockConfig(
200
+ slstm=sLSTMLayerConfig(
201
+ num_heads=config['num_heads'],
202
+ backend=config['backend'] if 'backend' in config.keys() and config['backend'] else "cuda",
203
+ conv1d_kernel_size=4,
204
+ bias_init="powerlaw_blockdependent",
205
+ batch_size=config['batch_size'],
206
+ ),
207
+ feedforward=FeedForwardConfig(proj_factor=1.3, act_fn=config['activation_fn']),
208
+ ),
209
+ context_length=8000,
210
+ num_blocks=len(config['xlstm_config']),
211
+ embedding_dim=config['embedding_size'],
212
+ slstm_at=[idx for idx, b in enumerate(config['xlstm_config']) if b == 's'],
213
+ dropout=config['dropout'],
214
+
215
+ add_post_blocks_norm=config['use_final_layer_norm'] if 'use_final_layer_norm' in config.keys() else False
216
+ )
217
+ print('creating xlstm with slstm at: ', [idx for idx, b in enumerate(config['xlstm_config']) if b == 's'])
218
+
219
+ return vanillaxLSTMWrapper(
220
+ xLSTMBlockStack(cfg),
221
+ dropout=config['dropout'],
222
+ bidirectional=True,
223
+ drop_path=config['drop_path_prob']
224
+ )