ahczhg commited on
Commit
69a8706
·
verified ·
1 Parent(s): 143c2f7

Upload models_mamba_ecg.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models_mamba_ecg.py +795 -0
models_mamba_ecg.py ADDED
@@ -0,0 +1,795 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2015-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as functional
6
+ from functools import partial
7
+ from torch import Tensor
8
+ from typing import Optional
9
+
10
+ from timm.models.vision_transformer import VisionTransformer, _cfg
11
+ # from timm.models.registry import register_model
12
+ # from timm.models.layers import trunc_normal_, lecun_normal_
13
+
14
+ from timm.models import register_model
15
+ from timm.layers import trunc_normal_, lecun_normal_
16
+
17
+ from timm.layers import DropPath, to_2tuple
18
+
19
+ # from timm.models.layers import DropPath, to_2tuple
20
+ from timm.models.vision_transformer import _load_weights
21
+
22
+ import math
23
+
24
+ from collections import namedtuple
25
+
26
+ from mamba_ssm.modules.mamba_simple import Mamba
27
+ from mamba_ssm.utils.generation import GenerationMixin
28
+ from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
29
+
30
+ from rope import *
31
+ import random
32
+ import sys
33
+
34
+ try:
35
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
36
+ except ImportError:
37
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
38
+
39
+
40
+ # layer_norm_fn and rms_norm_fn both are normalization method
41
+
42
+ __all__ = [
43
+ 'vim_tiny_patch16_224', 'vim_small_patch16_224', 'vim_base_patch16_224',
44
+ 'vim_tiny_patch16_384', 'vim_small_patch16_384', 'vim_base_patch16_384',
45
+ ]
46
+
47
+
48
+ '''
49
+ in the original script(ft-vim-s.sh)
50
+ img_size = 224,
51
+ patch_size = 16,
52
+ stride = 8,
53
+ in_chans = 3,
54
+ embed_dim = 768
55
+ ------------------------------------
56
+ self.img_size: (224, 224)
57
+ self.patch_size: (16, 16)
58
+ self.grid_size: (27, 27)
59
+ self.num_patches: 729
60
+ self.flatten: True
61
+ self.norm: nn.Identity()
62
+ '''
63
+ class PatchEmbed(nn.Module):
64
+ """ 2D Image to Patch Embedding
65
+ """
66
+ def __init__(self, img_size=224, patch_size=16, stride=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
67
+ super().__init__()
68
+ img_size = to_2tuple(img_size)
69
+ patch_size = to_2tuple(patch_size)
70
+ self.img_size = img_size
71
+ self.patch_size = patch_size
72
+ self.grid_size = ((img_size[0] - patch_size[0]) // stride + 1, (img_size[1] - patch_size[1]) // stride + 1)
73
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
74
+ self.flatten = flatten
75
+
76
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride)
77
+ # if the norm_layer is not none or null, the self.norm = norm_layer(embed_dim)
78
+ # otherwise, self.norm = nn.Identity()
79
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
80
+
81
+
82
+ def forward(self, x):
83
+ B, C, H, W = x.shape
84
+ assert H == self.img_size[0] and W == self.img_size[1], \
85
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
86
+
87
+ x = self.proj(x)
88
+ print("This is the shape after the CNN", x.shape)
89
+
90
+ if self.flatten:
91
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
92
+ print("This is the shape after the flatten:", x.shape)
93
+
94
+ x = self.norm(x)
95
+ return x
96
+
97
+ class PatchEmbed_spectrogram(nn.Module):
98
+ """ 2D spectrogram to Patch Embedding
99
+ """
100
+ def __init__(self, img_size_f = 128, img_size_t = 64, patch_size=6, stride=3, in_chans=12, embed_dim=432, flatten=True):
101
+ super().__init__()
102
+ # img_size = to_2tuple(img_size)
103
+ patch_size = to_2tuple(patch_size)
104
+
105
+ self.img_size_f = img_size_f
106
+ self.img_size_t = img_size_t
107
+
108
+ self.patch_size = patch_size
109
+ self.grid_size = ((img_size_f - patch_size[0]) // stride + 1, (img_size_t - patch_size[1]) // stride + 1)
110
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
111
+ self.flatten = flatten
112
+
113
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride)
114
+ # if the norm_layer is not none or null, the self.norm = norm_layer(embed_dim)
115
+ # otherwise, self.norm = nn.Identity()
116
+ # self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
117
+
118
+
119
+ def forward(self, x):
120
+ B, C, H, W = x.shape
121
+ assert H == self.img_size_f and W == self.img_size_t, \
122
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size_f}*{self.img_size_t})."
123
+
124
+ x = self.proj(x)
125
+
126
+ # This is the shape after the CNN torch.Size([1, 432, 41, 20])
127
+ # print("This is the shape after the CNN", x.shape)
128
+ if self.flatten:
129
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
130
+ # This is the shape after the flatten: torch.Size([1, 820, 432])
131
+ # print("This is the shape after the flatten:", x.shape)
132
+ return x
133
+
134
+
135
+ class CNN_layers(nn.Module):
136
+
137
+ def __init__(self, embed_size = 384):
138
+ super().__init__()
139
+ self.multiple_cnn = nn.Sequential(
140
+ nn.Conv1d(12, 128, kernel_size=14, stride=3, padding=2, bias=False),
141
+ nn.BatchNorm1d(128),
142
+ nn.ReLU(inplace=True),
143
+
144
+ nn.Conv1d(128, embed_size, kernel_size=15, stride=4, padding=100, bias=False),
145
+ nn.BatchNorm1d(embed_size),
146
+ nn.ReLU(inplace=True))
147
+
148
+ def forward(self, x):
149
+ # print("This is the shape of enconder:(before)", x.shape)
150
+ # This is the shape of enconder:(before) torch.Size([44, 12, 8192])
151
+ x = self.multiple_cnn(x)
152
+ x = x.transpose(1, 2)
153
+ # print("This is the shape of enconder:(after)", x.shape)
154
+ # This is the shape of enconder:(after) torch.Size([44, 729, 384])
155
+
156
+ # sys.exit()
157
+ return x
158
+
159
+ '''
160
+ dim: 384
161
+ mixer_cls: mixer_cla is an instance of mamba.
162
+ drop_path = 0.
163
+ norm_cls = nn.LayerNorm
164
+ fused_add_norm = True,
165
+ residual_in_fp32 = True
166
+ '''
167
+ class Block(nn.Module):
168
+ def __init__(
169
+ self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False, drop_path=0.,
170
+ ):
171
+ """
172
+ Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
173
+
174
+ This Block has a slightly different structure compared to a regular
175
+ prenorm Transformer block.
176
+ The standard block is: LN -> MHA/MLP -> Add.
177
+ [Ref: https://arxiv.org/abs/2002.04745]
178
+ Here we have: Add -> LN -> Mixer, returning both
179
+ the hidden_states (output of the mixer) and the residual.
180
+ This is purely for performance reasons, as we can fuse add and LayerNorm.
181
+ The residual needs to be provided (except for the very first block).
182
+ """
183
+ super().__init__()
184
+
185
+ self.residual_in_fp32 = residual_in_fp32
186
+ self.fused_add_norm = fused_add_norm
187
+ self.mixer = mixer_cls(dim)
188
+ self.norm = norm_cls(dim)
189
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
190
+
191
+ # fused_add_norm true
192
+ if self.fused_add_norm:
193
+ assert RMSNorm is not None, "RMSNorm import fails"
194
+ assert isinstance(self.norm, (nn.LayerNorm, RMSNorm)), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
195
+
196
+ def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None):
197
+
198
+ r"""Pass the input through the encoder layer.
199
+
200
+ Args:
201
+ hidden_states: the sequence to the encoder layer (required).
202
+ residual: hidden_states = Mixer(LN(residual))
203
+ """
204
+
205
+ if not self.fused_add_norm:
206
+ if residual is None:
207
+ residual = hidden_states
208
+ else:
209
+ residual = residual + self.drop_path(hidden_states)
210
+
211
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
212
+ if self.residual_in_fp32:
213
+ residual = residual.to(torch.float32)
214
+
215
+ # since the self.fused_add_norm is true, the code is going below
216
+ # fused_add_norm_fn = layer_norm_fn
217
+ ###########
218
+ # hidden_states: Tensor
219
+ # self.norm.weight = torch.nn.LayerNorm.weight
220
+ # self.norm.bias = torch.nn.LayerNorm.bias
221
+ # residual: Optional[Tensor] = None
222
+ # prenorm=True
223
+ # residual_in_fp32 = True
224
+ # eps = torch.nn.LayerNorm.eps
225
+
226
+ else:
227
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
228
+ if residual is None:
229
+ hidden_states, residual = fused_add_norm_fn(
230
+ hidden_states,
231
+ self.norm.weight,
232
+ self.norm.bias,
233
+ residual=residual,
234
+ prenorm=True,
235
+ residual_in_fp32=self.residual_in_fp32,
236
+ eps=self.norm.eps,
237
+ )
238
+ else:
239
+ hidden_states, residual = fused_add_norm_fn(
240
+ self.drop_path(hidden_states),
241
+ self.norm.weight,
242
+ self.norm.bias,
243
+ residual=residual,
244
+ prenorm=True,
245
+ residual_in_fp32=self.residual_in_fp32,
246
+ eps=self.norm.eps,
247
+ )
248
+
249
+ # inference_params=None
250
+ hidden_states = self.mixer(hidden_states, inference_params=inference_params)
251
+ return hidden_states, residual
252
+
253
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
254
+ print("the code is going through allocate_inference_cache in the block")
255
+
256
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
257
+
258
+ '''
259
+ from torch.nn.ModuleList()
260
+
261
+ embed_dim=384, (embedding dimension)
262
+ device: None
263
+ dtype: None
264
+ ssm_cfg: None
265
+ norm_epsilon: float = 1e-5
266
+ rms_norm: bool = False
267
+ residual_in_fp32 = True
268
+ fused_add_norm = True
269
+ if_bimamba = False
270
+ bimamba_type = "v2"
271
+ inter_dpr: [0.0, 0.0, 0.004347826354205608, ...,0.09565217792987823, 0.10000000149011612]
272
+ if_devide_out = True
273
+ init_layer_scale = None
274
+ layer_idx = i
275
+ '''
276
+
277
+ def create_block(
278
+ d_model,
279
+ ssm_cfg=None,
280
+ norm_epsilon=1e-5,
281
+ drop_path=0.,
282
+ rms_norm=False,
283
+ residual_in_fp32=False,
284
+ fused_add_norm=False,
285
+ layer_idx=None,
286
+ device=None,
287
+ dtype=None,
288
+ if_bimamba=False,
289
+ bimamba_type="none",
290
+ if_devide_out=False,
291
+ init_layer_scale=None,
292
+ block_name = "default_value"
293
+ ):
294
+ if if_bimamba:
295
+ bimamba_type = "v1"
296
+
297
+ if ssm_cfg is None:
298
+ ssm_cfg = {}
299
+
300
+ factory_kwargs = {"device": device, "dtype": dtype}
301
+
302
+ if block_name == "VisionMamba":
303
+ mixer_cls = partial(Mamba, layer_idx=layer_idx, bimamba_type=bimamba_type, if_devide_out=if_devide_out, init_layer_scale=init_layer_scale, **ssm_cfg, **factory_kwargs)
304
+ elif block_name == "OriginalMamba":
305
+ mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
306
+ else:
307
+ raise ValueError(f"No matching condition for value: {block_name}")
308
+
309
+
310
+
311
+ # rms_norm = False
312
+ norm_cls = partial(nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs)
313
+
314
+ block = Block(
315
+ d_model,
316
+ mixer_cls,
317
+ norm_cls=norm_cls,
318
+ drop_path=drop_path,
319
+ fused_add_norm=fused_add_norm,
320
+ residual_in_fp32=residual_in_fp32,
321
+ )
322
+ block.layer_idx = layer_idx
323
+ return block
324
+
325
+
326
+ # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
327
+ def _init_weights(
328
+ module,
329
+ n_layer,
330
+ initializer_range=0.02, # Now only used for embedding layer.
331
+ rescale_prenorm_residual=True,
332
+ n_residuals_per_layer=1, # Change to 2 if we have MLP
333
+ ):
334
+ if isinstance(module, nn.Linear):
335
+ if module.bias is not None:
336
+ if not getattr(module.bias, "_no_reinit", False):
337
+ nn.init.zeros_(module.bias)
338
+ elif isinstance(module, nn.Embedding):
339
+ nn.init.normal_(module.weight, std=initializer_range)
340
+
341
+ if rescale_prenorm_residual:
342
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
343
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
344
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
345
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
346
+ #
347
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
348
+ for name, p in module.named_parameters():
349
+ if name in ["out_proj.weight", "fc2.weight"]:
350
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
351
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
352
+ # We need to reinit p since this code could be called multiple times
353
+ # Having just p *= scale would repeatedly scale it down
354
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
355
+ with torch.no_grad():
356
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
357
+
358
+
359
+ def segm_init_weights(m):
360
+ if isinstance(m, nn.Linear):
361
+ trunc_normal_(m.weight, std=0.02)
362
+ if isinstance(m, nn.Linear) and m.bias is not None:
363
+ nn.init.constant_(m.bias, 0)
364
+ elif isinstance(m, nn.Conv2d):
365
+ # NOTE conv was left to pytorch default in my original init
366
+ lecun_normal_(m.weight)
367
+ if m.bias is not None:
368
+ nn.init.zeros_(m.bias)
369
+ elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
370
+ nn.init.zeros_(m.bias)
371
+ nn.init.ones_(m.weight)
372
+
373
+ '''
374
+ below is the 'ft-vim-s.sh':
375
+
376
+ patch_size=16, (just patch size)
377
+ stride=8, (just stride)
378
+ embed_dim=384, (embedding dimension)
379
+ depth=24, (?) the number of block
380
+ rms_norm = True, (?)
381
+ residual_in_fp32 = True, (?)
382
+ fused_add_norm = True, (?)
383
+ final_pool_type = 'mean', (?)
384
+ if_abs_pos_embed = True, (?)
385
+ if_rope = False, (?)
386
+ if_rope_residual = False, (?)
387
+ bimamba_type = "v2", (?)
388
+ if_cls_token = True, (?)
389
+ if_devide_out = True, (?)
390
+ use_middle_cls_token = True,
391
+ **kwargs
392
+ '''
393
+
394
+ class VisionMamba(nn.Module):
395
+ def __init__(self,
396
+ img_size=224,
397
+ patch_size=16,
398
+ stride=16,
399
+ depth=24,
400
+ embed_dim=192,
401
+ channels=3,
402
+ num_classes=26,
403
+ ssm_cfg=None,
404
+ drop_rate=0.,
405
+ drop_path_rate=0,
406
+ norm_epsilon: float = 1e-5,
407
+ rms_norm: bool = False,
408
+ initializer_cfg=None,
409
+ fused_add_norm=False,
410
+ residual_in_fp32=False,
411
+ device=None,
412
+ dtype=None,
413
+ ft_seq_len=None,
414
+ pt_hw_seq_len=14,
415
+ if_bidirectional=False,
416
+ final_pool_type='none',
417
+ if_abs_pos_embed=False,
418
+ if_rope=False,
419
+ if_rope_residual=False,
420
+ flip_img_sequences_ratio=-1.,
421
+ if_bimamba=False,
422
+ bimamba_type="none",
423
+ if_cls_token=False,
424
+ if_devide_out=False,
425
+ init_layer_scale=None,
426
+ use_double_cls_token=False,
427
+ use_middle_cls_token=False,
428
+ **kwargs):
429
+ # print("The program is coming the init")
430
+ factory_kwargs = {"device": device, "dtype": dtype}
431
+ # factory_kwargs: {'device': None, 'dtype': None}
432
+ # add factory_kwargs into kwargs
433
+ block_name = kwargs.get('block', 'default_value')
434
+
435
+ kwargs.update(factory_kwargs)
436
+
437
+ super().__init__()
438
+ self.residual_in_fp32 = residual_in_fp32
439
+ self.fused_add_norm = fused_add_norm
440
+ self.if_bidirectional = if_bidirectional
441
+ self.final_pool_type = final_pool_type
442
+ self.if_abs_pos_embed = if_abs_pos_embed
443
+ self.if_rope = if_rope
444
+ self.if_rope_residual = if_rope_residual
445
+ self.flip_img_sequences_ratio = flip_img_sequences_ratio
446
+ self.if_cls_token = if_cls_token
447
+ self.use_double_cls_token = use_double_cls_token
448
+ self.use_middle_cls_token = use_middle_cls_token
449
+ self.num_tokens = 1 if if_cls_token else 0
450
+
451
+ # pretrain parameters
452
+ self.num_classes = num_classes
453
+ self.d_model = self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
454
+
455
+ # self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, stride=stride, in_chans=channels, embed_dim=embed_dim)
456
+ # num_patches = self.patch_embed.num_patches
457
+
458
+ self.CNN_layers = CNN_layers()
459
+ num_patches = 729
460
+
461
+ # self.ECG_patch_embedding = ECG_Patch_embedding()
462
+ # num_patches = 1023
463
+ # self.LC = LeadCombiner(lead=6, out_ch=8)
464
+
465
+ # self.CNN_layers = D2_CNN_layers()
466
+ # num_patches = 828
467
+
468
+ # self.CNN_layers = CNN_layers()
469
+ # num_patches = 775
470
+
471
+ # if_cls_token: True (in the original script)
472
+ if if_cls_token:
473
+ # use_double_cls_token: False (in the original script)
474
+ if use_double_cls_token:
475
+ self.cls_token_head = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
476
+ self.cls_token_tail = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
477
+ self.num_tokens = 2
478
+
479
+ else:
480
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
481
+ # self.num_tokens = 1
482
+
483
+ # if_abs_pos_embed: True (in the original script)
484
+ if if_abs_pos_embed:
485
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, self.embed_dim))
486
+ self.pos_drop = nn.Dropout(p=drop_rate)
487
+
488
+ # if_rope: False (in the original script)
489
+
490
+ if if_rope:
491
+ pass
492
+ # half_head_dim = embed_dim // 2
493
+ # self.rope = TimeSeriesRotaryEmbeddingFast(dim=embed_dim, seq_len= num_patches + self.num_tokens)
494
+
495
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
496
+ # self.head_LC = nn.Linear(16, num_classes)
497
+ # depth: 24; drop_path_rate: 0.1
498
+ # TODO: release this comment
499
+ # dpr: [0.0, 0.004347826354205608, 0.008695652708411217, ..., 0.09130434691905975, 0.09565217792987823, 0.10000000149011612]
500
+
501
+ if drop_path_rate == 0:
502
+ print("This is the drop_path_rate:", drop_path_rate)
503
+ dpr = [x.item() for x in torch.full((depth,), drop_path_rate)]
504
+ else:
505
+ print("This is the drop_path_rate:", drop_path_rate)
506
+ print("follow the stochastic depth decay rule")
507
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
508
+
509
+
510
+ # import ipdb;ipdb.set_trace()
511
+ # inter_dpr: [0.0, 0.0, 0.004347826354205608, ...,0.09565217792987823, 0.10000000149011612]
512
+ inter_dpr = [0.0] + dpr
513
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
514
+
515
+ # transformer blocks
516
+ # depth:
517
+ self.layers = nn.ModuleList(
518
+ [
519
+ create_block(
520
+ embed_dim,
521
+ ssm_cfg=ssm_cfg,
522
+ norm_epsilon=norm_epsilon,
523
+ rms_norm=rms_norm,
524
+ residual_in_fp32=residual_in_fp32,
525
+ fused_add_norm=fused_add_norm,
526
+ layer_idx=i,
527
+ if_bimamba=if_bimamba,
528
+ bimamba_type=bimamba_type,
529
+ drop_path=inter_dpr[i],
530
+ if_devide_out=if_devide_out,
531
+ init_layer_scale=init_layer_scale,
532
+ block_name = block_name,
533
+ **factory_kwargs,
534
+ )
535
+ for i in range(depth)
536
+ ]
537
+ )
538
+
539
+ # output head
540
+ self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(embed_dim, eps=norm_epsilon, **factory_kwargs)
541
+
542
+ # self.pre_logits = nn.Identity()
543
+
544
+ # original init
545
+ # self.patch_embed.apply(segm_init_weights)
546
+
547
+ self.CNN_layers.apply(segm_init_weights)
548
+
549
+ # self.ECG_patch_embedding.apply(segm_init_weights)
550
+
551
+ self.head.apply(segm_init_weights)
552
+
553
+ # self.head_LC.apply(segm_init_weights)
554
+ # self.LC.apply(segm_init_weights)
555
+
556
+
557
+ # if_abs_pos_embed: True (in the original script)
558
+ if if_abs_pos_embed:
559
+ trunc_normal_(self.pos_embed, std=.02)
560
+ # if_cls_token: True (in the original script)
561
+ if if_cls_token:
562
+ if use_double_cls_token:
563
+ trunc_normal_(self.cls_token_head, std=.02)
564
+ trunc_normal_(self.cls_token_tail, std=.02)
565
+ # the code is coming here
566
+ else:
567
+ trunc_normal_(self.cls_token, std=.02)
568
+
569
+ # mamba init
570
+ self.apply(partial(_init_weights, n_layer=depth, **(initializer_cfg if initializer_cfg is not None else {}),))
571
+
572
+
573
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
574
+ print("the code is going through allocate_inference_cache in the vision mamba")
575
+ return {
576
+ i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
577
+ for i, layer in enumerate(self.layers)
578
+ }
579
+
580
+ @torch.jit.ignore
581
+ def no_weight_decay(self):
582
+ return {"pos_embed", "cls_token", "dist_token", "cls_token_head", "cls_token_tail"}
583
+
584
+ @torch.jit.ignore()
585
+ def load_pretrained(self, checkpoint_path, prefix=""):
586
+ _load_weights(self, checkpoint_path, prefix)
587
+
588
+ # x: this is the input 224*224(tensor)
589
+ # inference_params: None
590
+ # if_random_cls_token_position:False
591
+ # if_random_token_rank: False
592
+
593
+ def forward_features(self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
594
+ # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
595
+ # with slight modifications to add the dist_token
596
+
597
+ # x = self.patch_embed(x)
598
+
599
+ # x = self.CNN_layers(x)
600
+
601
+ x = self.CNN_layers(x)
602
+
603
+ # B: batch size 16
604
+ # M: 729 the number of patch,(27*27)
605
+ # D: the hidden state dimension, 384 (small-size variant), this is set by author of Vim, it is 768 in the convential Vit
606
+ # N: SSM dimension, SSM dimension N to 16.
607
+ # L: the number of blocks, we set the number of blocks L to 24
608
+
609
+ B, M, _ = x.shape
610
+
611
+ # if_cls_token: True (in the original script)
612
+ if self.if_cls_token:
613
+
614
+ # self.use_double_cls_token: False (in the original script)
615
+ if self.use_double_cls_token:
616
+ cls_token_head = self.cls_token_head.expand(B, -1, -1)
617
+ cls_token_tail = self.cls_token_tail.expand(B, -1, -1)
618
+ token_position = [0, M + 1]
619
+ x = torch.cat((cls_token_head, x, cls_token_tail), dim=1)
620
+ M = x.shape[1]
621
+ else:
622
+ # self.use_middle_cls_token: True(in the original script)
623
+ if self.use_middle_cls_token:
624
+ cls_token = self.cls_token.expand(B, -1, -1)
625
+ token_position = M // 2
626
+ # add cls token in the middle
627
+ x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
628
+ elif if_random_cls_token_position:
629
+ cls_token = self.cls_token.expand(B, -1, -1)
630
+ token_position = random.randint(0, M)
631
+ x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
632
+ print("token_position: ", token_position)
633
+ else:
634
+ cls_token = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
635
+ token_position = 0
636
+ x = torch.cat((cls_token, x), dim=1)
637
+ M = x.shape[1]
638
+
639
+ # # if_abs_pos_embed: True (in the original script)
640
+ if self.if_abs_pos_embed:
641
+ # if new_grid_size[0] == self.patch_embed.grid_size[0] and new_grid_size[1] == self.patch_embed.grid_size[1]:
642
+ # x = x + self.pos_embed
643
+ # else:
644
+ # pos_embed = interpolate_pos_embed_online(
645
+ # self.pos_embed, self.patch_embed.grid_size, new_grid_size,0
646
+ # )
647
+ x = x + self.pos_embed
648
+ x = self.pos_drop(x)
649
+
650
+
651
+
652
+ if_flip_img_sequences = False
653
+ if self.flip_img_sequences_ratio > 0 and (self.flip_img_sequences_ratio - random.random()) > 1e-5:
654
+ x = x.flip([1])
655
+ if_flip_img_sequences = True
656
+
657
+ # mamba impl
658
+ # if_bidirectional: false
659
+ # inference_params: None
660
+ residual = None
661
+ hidden_states = x
662
+ if not self.if_bidirectional:
663
+ for layer in self.layers:
664
+
665
+ # here is false in the original script
666
+ if if_flip_img_sequences and self.if_rope:
667
+ hidden_states = hidden_states.flip([1])
668
+ if residual is not None:
669
+ residual = residual.flip([1])
670
+
671
+ # rope about, defaule is false
672
+ if self.if_rope:
673
+ hidden_states = self.rope(hidden_states)
674
+ if residual is not None and self.if_rope_residual:
675
+ residual = self.rope(residual)
676
+
677
+ # here is false in the original script
678
+ if if_flip_img_sequences and self.if_rope:
679
+ hidden_states = hidden_states.flip([1])
680
+ if residual is not None:
681
+ residual = residual.flip([1])
682
+
683
+ hidden_states, residual = layer(hidden_states, residual, inference_params=inference_params)
684
+ # sys.exit()
685
+
686
+ else:
687
+ # get two layers in a single for-loop
688
+ for i in range(len(self.layers) // 2):
689
+ if self.if_rope:
690
+ hidden_states = self.rope(hidden_states)
691
+ if residual is not None and self.if_rope_residual:
692
+ residual = self.rope(residual)
693
+
694
+ hidden_states_f, residual_f = self.layers[i * 2](
695
+ hidden_states, residual, inference_params=inference_params
696
+ )
697
+ hidden_states_b, residual_b = self.layers[i * 2 + 1](
698
+ hidden_states.flip([1]), None if residual == None else residual.flip([1]), inference_params=inference_params
699
+ )
700
+ hidden_states = hidden_states_f + hidden_states_b.flip([1])
701
+ residual = residual_f + residual_b.flip([1])
702
+
703
+ # fused_add_norm: True
704
+
705
+ if not self.fused_add_norm:
706
+ if residual is None:
707
+ residual = hidden_states
708
+ else:
709
+ residual = residual + self.drop_path(hidden_states)
710
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
711
+ else:
712
+ # Set prenorm = False here since we don't need the residual
713
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
714
+
715
+ hidden_states = fused_add_norm_fn(
716
+ self.drop_path(hidden_states),
717
+ self.norm_f.weight,
718
+ self.norm_f.bias,
719
+ eps=self.norm_f.eps,
720
+ residual=residual,
721
+ prenorm=False,
722
+ residual_in_fp32=self.residual_in_fp32,
723
+ )
724
+
725
+ # return only cls token if it exists
726
+ # if_cls_token: True (in the original script)
727
+ # self.use_middle_cls_token: True
728
+ if self.if_cls_token:
729
+ if self.use_double_cls_token:
730
+ return (hidden_states[:, token_position[0], :] + hidden_states[:, token_position[1], :]) / 2
731
+ else:
732
+ if self.use_middle_cls_token:
733
+ return hidden_states[:, token_position, :]
734
+ elif if_random_cls_token_position:
735
+ return hidden_states[:, token_position, :]
736
+ else:
737
+ return hidden_states[:, token_position, :]
738
+
739
+ # self.final_pol_type = 'mean'
740
+ if self.final_pool_type == 'none':
741
+ return hidden_states[:, -1, :]
742
+ elif self.final_pool_type == 'mean':
743
+ return hidden_states.mean(dim=1)
744
+ elif self.final_pool_type == 'max':
745
+ return hidden_states
746
+ elif self.final_pool_type == 'all':
747
+ return hidden_states
748
+ else:
749
+ raise NotImplementedError
750
+
751
+ def forward(self, x, return_features=False, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
752
+ x = self.forward_features(x, inference_params, if_random_cls_token_position=if_random_cls_token_position, if_random_token_rank=if_random_token_rank)
753
+ # batch_number = x.shape[0]
754
+ # x = x.view(batch_number, 8, 6, 8)
755
+ # x = self.LC(x)
756
+ # x = self.head_LC(x)
757
+ # print("This is the shape of X (after the fully connected layer):", x.shape)
758
+ # return_features = False
759
+ # print("This is the return feature:", return_features)
760
+ # if return_features:
761
+ # return x
762
+
763
+ # print("This is the shape of X (Before the fully connected layer):", x.shape)
764
+ x = self.head(x)
765
+
766
+ # final_pool_type = 'mean' in original script
767
+ if self.final_pool_type == 'max':
768
+ x = x.max(dim=1)[0]
769
+ # sys.exit()
770
+ return x
771
+
772
+
773
+
774
+ # below is for the vision in mamba
775
+ @register_model
776
+ def ecg_vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, depth=5, fused_add_norm = True, drop_path_rate = 0.1, if_divide_out = True, use_middle_cls_token = True, **kwargs):
777
+ model = VisionMamba(patch_size=16, stride=8, embed_dim=384, depth=depth, rms_norm=True, residual_in_fp32=True, drop_path_rate = drop_path_rate, fused_add_norm = fused_add_norm, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=if_divide_out, use_middle_cls_token=use_middle_cls_token, **kwargs)
778
+
779
+ # As a reminder:
780
+ print("This is whether the fused_add_norm:", fused_add_norm)
781
+ print("This is whether the if_divide_out:", if_divide_out)
782
+ print("This is whether the use_middle_cls_token:", use_middle_cls_token)
783
+
784
+
785
+ model.default_cfg = _cfg()
786
+
787
+ return model
788
+
789
+ # below is for the original mamba and 24 blocks
790
+ # @register_model
791
+ # def ecg_vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, **kwargs):
792
+ # model = VisionMamba(patch_size=16, stride=8, embed_dim=384, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
793
+ # model.default_cfg = _cfg()
794
+
795
+ # return model