ericzhang0328 commited on
Commit
0c1e054
·
verified ·
1 Parent(s): 269f7c8

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. models/MAE_SDT.py +639 -0
  3. models/__init__.py +0 -0
  4. models/__pycache__/MAE_SDT.cpython-311.pyc +0 -0
  5. models/__pycache__/MAE_SDT.cpython-312.pyc +0 -0
  6. models/__pycache__/__init__.cpython-311.pyc +0 -0
  7. models/__pycache__/__init__.cpython-312.pyc +0 -0
  8. models/__pycache__/__init__.cpython-39.pyc +0 -0
  9. models/__pycache__/encoder.cpython-311.pyc +0 -0
  10. models/__pycache__/encoder.cpython-312.pyc +0 -0
  11. models/__pycache__/metaformer.cpython-311.pyc +0 -0
  12. models/__pycache__/metaformer.cpython-312.pyc +0 -0
  13. models/__pycache__/neuron.cpython-311.pyc +0 -0
  14. models/__pycache__/neuron.cpython-312.pyc +0 -0
  15. models/__pycache__/qk_model_v1_1003.cpython-311.pyc +0 -0
  16. models/__pycache__/qkformer.cpython-311.pyc +0 -0
  17. models/__pycache__/qkformer.cpython-312.pyc +0 -0
  18. models/__pycache__/sd_former_v1.cpython-311.pyc +0 -0
  19. models/__pycache__/sd_former_v1.cpython-312.pyc +0 -0
  20. models/__pycache__/sdtv3.cpython-311.pyc +0 -0
  21. models/__pycache__/sdtv3.cpython-312.pyc +0 -0
  22. models/__pycache__/sdtv3.cpython-39.pyc +0 -0
  23. models/__pycache__/sdtv3_large.cpython-311.pyc +0 -0
  24. models/__pycache__/sdtv3_large.cpython-312.pyc +0 -0
  25. models/__pycache__/spikformer.cpython-311.pyc +0 -0
  26. models/__pycache__/spikformer.cpython-312.pyc +0 -0
  27. models/__pycache__/vit.cpython-311.pyc +3 -0
  28. models/__pycache__/vit.cpython-312.pyc +3 -0
  29. models/encoder.py +158 -0
  30. models/metaformer.py +1538 -0
  31. models/neuron.py +1587 -0
  32. models/q_vit/Quant.py +185 -0
  33. models/q_vit/__init__.py +0 -0
  34. models/q_vit/__pycache__/Quant.cpython-311.pyc +0 -0
  35. models/q_vit/__pycache__/Quant.cpython-312.pyc +0 -0
  36. models/q_vit/__pycache__/__init__.cpython-311.pyc +0 -0
  37. models/q_vit/__pycache__/__init__.cpython-312.pyc +0 -0
  38. models/q_vit/__pycache__/_quan_base.cpython-311.pyc +0 -0
  39. models/q_vit/__pycache__/_quan_base.cpython-312.pyc +0 -0
  40. models/q_vit/__pycache__/quant_vision_transformer.cpython-311.pyc +0 -0
  41. models/q_vit/__pycache__/quant_vision_transformer.cpython-312.pyc +0 -0
  42. models/q_vit/_quan_base.py +208 -0
  43. models/q_vit/quant_vision_transformer.py +527 -0
  44. models/qk_model_v1_1003.py +426 -0
  45. models/qk_model_with_delay/__init__.py +0 -0
  46. models/qk_model_with_delay/__pycache__/__init__.cpython-311.pyc +0 -0
  47. models/qk_model_with_delay/__pycache__/delay_synaptic_func_inter.cpython-311.pyc +0 -0
  48. models/qk_model_with_delay/__pycache__/delay_synaptic_inter_model.cpython-311.pyc +0 -0
  49. models/qk_model_with_delay/delay_synaptic_func_inter.py +169 -0
  50. models/qk_model_with_delay/delay_synaptic_inter_model.py +459 -0
.gitattributes CHANGED
@@ -91,3 +91,5 @@ visual-aids/vit-tiny-reluact-16-224/erf_vit_tiny_relu_16_224_w_pretrained_B8_att
91
  visual-aids/vit-tiny-reluact-16-224/erf_vit_tiny_relu_16_224_w_pretrained_B9_attn_proj.pdf filter=lfs diff=lfs merge=lfs -text
92
  visual-aids/vit-tiny-reluact-16-224/erf_vit_tiny_relu_16_224_w_pretrained_all_layers.pdf filter=lfs diff=lfs merge=lfs -text
93
  visual-aids/vit-tiny-reluact-16-224/erf_vit_tiny_relu_16_224_w_pretrained_average.pdf filter=lfs diff=lfs merge=lfs -text
 
 
 
91
  visual-aids/vit-tiny-reluact-16-224/erf_vit_tiny_relu_16_224_w_pretrained_B9_attn_proj.pdf filter=lfs diff=lfs merge=lfs -text
92
  visual-aids/vit-tiny-reluact-16-224/erf_vit_tiny_relu_16_224_w_pretrained_all_layers.pdf filter=lfs diff=lfs merge=lfs -text
93
  visual-aids/vit-tiny-reluact-16-224/erf_vit_tiny_relu_16_224_w_pretrained_average.pdf filter=lfs diff=lfs merge=lfs -text
94
+ models/__pycache__/vit.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
95
+ models/__pycache__/vit.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
models/MAE_SDT.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchinfo
5
+ from timm.models.layers import to_2tuple, trunc_normal_, DropPath
6
+ from timm.models.registry import register_model
7
+ from timm.models.vision_transformer import _cfg
8
+ from einops.layers.torch import Rearrange
9
+ import torch.nn.functional as F
10
+ from timm.models.vision_transformer import PatchEmbed, Block
11
+
12
+ from spikingjelly.clock_driven import layer
13
+ import copy
14
+ from torchvision import transforms
15
+ import matplotlib.pyplot as plt
16
+
17
+ import models.encoder as encoder
18
+ from .util.pos_embed import get_2d_sincos_pos_embed
19
+
20
+ import torch
21
+
22
+ #timestep
23
+ T=4
24
+
25
+
26
+ class multispike(torch.autograd.Function):
27
+ @staticmethod
28
+ def forward(ctx, input, lens=T):
29
+ ctx.save_for_backward(input)
30
+ ctx.lens = lens
31
+ return torch.floor(torch.clamp(input, 0, lens) + 0.5)
32
+
33
+ @staticmethod
34
+ def backward(ctx, grad_output):
35
+ input, = ctx.saved_tensors
36
+ grad_input = grad_output.clone()
37
+ temp1 = 0 < input
38
+ temp2 = input < ctx.lens
39
+ return grad_input * temp1.float() * temp2.float(), None
40
+
41
+
42
+ class Multispike(nn.Module):
43
+ def __init__(self, spike=multispike,norm=T):
44
+ super().__init__()
45
+ self.lens = norm
46
+ self.spike = spike
47
+ self.norm=norm
48
+
49
+ def forward(self, inputs):
50
+ return self.spike.apply(inputs)/self.norm
51
+
52
+
53
+
54
+
55
+ def MS_conv_unit(in_channels, out_channels,kernel_size=1,padding=0,groups=1):
56
+ return nn.Sequential(
57
+ layer.SeqToANNContainer(
58
+ encoder.SparseConv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, groups=groups,bias=True),
59
+ encoder.SparseBatchNorm2d(out_channels)
60
+ )
61
+ )
62
+ class MS_ConvBlock(nn.Module):
63
+ def __init__(self, dim,
64
+ mlp_ratio=4.0):
65
+ super().__init__()
66
+
67
+ self.neuron1 = Multispike()
68
+ self.conv1 = MS_conv_unit(dim, dim * mlp_ratio, 3, 1)
69
+
70
+ self.neuron2 = Multispike()
71
+ self.conv2 = MS_conv_unit(dim*mlp_ratio, dim, 3, 1)
72
+
73
+
74
+ def forward(self, x, mask=None):
75
+ short_cut = x
76
+ x = self.neuron1(x)
77
+ x = self.conv1(x)
78
+ x = self.neuron2(x)
79
+ x = self.conv2(x)
80
+ x = x +short_cut
81
+ return x
82
+
83
+ class MS_MLP(nn.Module):
84
+ def __init__(
85
+ self, in_features, hidden_features=None, out_features=None, drop=0.0, layer=0
86
+ ):
87
+ super().__init__()
88
+ out_features = out_features or in_features
89
+ hidden_features = hidden_features or in_features
90
+ self.fc1_conv = nn.Conv1d(in_features, hidden_features, kernel_size=1, stride=1)
91
+ self.fc1_bn = nn.BatchNorm1d(hidden_features)
92
+ self.fc1_lif = Multispike()
93
+
94
+
95
+ self.fc2_conv = nn.Conv1d(
96
+ hidden_features, out_features, kernel_size=1, stride=1
97
+ )
98
+ self.fc2_bn = nn.BatchNorm1d(out_features)
99
+ self.fc2_lif = Multispike()
100
+
101
+ self.c_hidden = hidden_features
102
+ self.c_output = out_features
103
+
104
+ def forward(self, x):
105
+ T, B, C, N= x.shape
106
+
107
+ x = self.fc1_lif(x)
108
+ x = self.fc1_conv(x.flatten(0, 1))
109
+ x = self.fc1_bn(x).reshape(T, B, self.c_hidden, N).contiguous()
110
+
111
+ x = self.fc2_lif(x)
112
+ x = self.fc2_conv(x.flatten(0, 1))
113
+ x = self.fc2_bn(x).reshape(T, B, C, N).contiguous()
114
+
115
+ return x
116
+
117
+ class RepConv(nn.Module):
118
+ def __init__(
119
+ self,
120
+ in_channel,
121
+ out_channel,
122
+ bias=False,
123
+ ):
124
+ super().__init__()
125
+ # TODO in_channel-> 2*in_channel->in_channel
126
+ self.conv1 = nn.Sequential(nn.Conv1d(in_channel, int(in_channel*1.5), kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(int(in_channel*1.5)))
127
+ self.conv2 = nn.Sequential(nn.Conv1d(int(in_channel*1.5), out_channel, kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(out_channel))
128
+ def forward(self, x):
129
+ return self.conv2(self.conv1(x))
130
+ class RepConv2(nn.Module):
131
+ def __init__(
132
+ self,
133
+ in_channel,
134
+ out_channel,
135
+ bias=False,
136
+ ):
137
+ super().__init__()
138
+ # TODO in_channel-> 2*in_channel->in_channel
139
+ self.conv1 = nn.Sequential(nn.Conv1d(in_channel, int(in_channel*1.5), kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(int(in_channel*1.5)))
140
+ self.conv2 = nn.Sequential(nn.Conv1d(int(in_channel*1.5), out_channel, kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(out_channel))
141
+ def forward(self, x):
142
+ return self.conv2(self.conv1(x))
143
+
144
+ class MS_Attention_Conv_qkv_id(nn.Module):
145
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
146
+ super().__init__()
147
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
148
+ self.dim = dim
149
+ self.num_heads = num_heads
150
+ self.scale = 0.125
151
+ self.sr_ratio=sr_ratio
152
+
153
+ self.head_lif = Multispike()
154
+
155
+ # track 1: split convs
156
+ self.q_conv = nn.Sequential(RepConv(dim,dim), nn.BatchNorm1d(dim))
157
+ self.k_conv = nn.Sequential(RepConv(dim,dim), nn.BatchNorm1d(dim))
158
+ self.v_conv = nn.Sequential(RepConv(dim,dim*sr_ratio), nn.BatchNorm1d(dim*sr_ratio))
159
+
160
+ # track 2: merge (prefer) NOTE: need `chunk` in forward
161
+ # self.qkv_conv = nn.Sequential(RepConv(dim,dim * 3), nn.BatchNorm2d(dim * 3))
162
+
163
+ self.q_lif = Multispike()
164
+
165
+ self.k_lif = Multispike()
166
+
167
+ self.v_lif = Multispike()
168
+
169
+ self.attn_lif = Multispike()
170
+
171
+ self.proj_conv = nn.Sequential(RepConv(sr_ratio*dim,dim), nn.BatchNorm1d(dim))
172
+
173
+ def forward(self, x):
174
+ T, B, C, N = x.shape
175
+
176
+ x = self.head_lif(x)
177
+
178
+ x_for_qkv = x.flatten(0, 1)
179
+ q_conv_out = self.q_conv(x_for_qkv).reshape(T, B, C, N)
180
+
181
+ q_conv_out = self.q_lif(q_conv_out)
182
+
183
+ q = q_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2,
184
+ 4)
185
+
186
+ k_conv_out = self.k_conv(x_for_qkv).reshape(T, B, C, N)
187
+
188
+ k_conv_out = self.k_lif(k_conv_out)
189
+
190
+ k = k_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2,
191
+ 4)
192
+
193
+ v_conv_out = self.v_conv(x_for_qkv).reshape(T, B, self.sr_ratio*C, N)
194
+
195
+ v_conv_out = self.v_lif(v_conv_out)
196
+
197
+ v = v_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, self.sr_ratio*C // self.num_heads).permute(0, 1, 3, 2,
198
+ 4)
199
+
200
+ x = k.transpose(-2, -1) @ v
201
+ x = (q @ x) * self.scale
202
+ x = x.transpose(3, 4).reshape(T, B, self.sr_ratio*C, N)
203
+ x = self.attn_lif(x)
204
+
205
+ x = self.proj_conv(x.flatten(0, 1)).reshape(T, B, C, N)
206
+ return x
207
+
208
+
209
+
210
+
211
+ class MS_DownSampling(nn.Module):
212
+ def __init__(
213
+ self,
214
+ in_channels=2,
215
+ embed_dims=256,
216
+ kernel_size=3,
217
+ stride=2,
218
+ padding=1,
219
+ first_layer=True,
220
+
221
+ ):
222
+ super().__init__()
223
+
224
+ self.encode_conv = encoder.SparseConv2d(
225
+ in_channels,
226
+ embed_dims,
227
+ kernel_size=kernel_size,
228
+ stride=stride,
229
+ padding=padding,
230
+ )
231
+
232
+ self.encode_bn = encoder.SparseBatchNorm2d(embed_dims)
233
+ self.first_layer = first_layer
234
+ if not first_layer:
235
+ self.encode_spike = Multispike()
236
+
237
+ def forward(self, x):
238
+ T, B, _, _, _ = x.shape
239
+ if hasattr(self, "encode_spike"):
240
+ x = self.encode_spike(x)
241
+ x = self.encode_conv(x.flatten(0, 1))
242
+ _, _, H, W = x.shape
243
+ x = self.encode_bn(x).reshape(T, B, -1, H, W)
244
+
245
+ return x
246
+
247
+ class MS_Block(nn.Module):
248
+ def __init__(
249
+ self,
250
+ dim,
251
+ choice,
252
+ num_heads,
253
+ mlp_ratio=4.0,
254
+ qkv_bias=False,
255
+ qk_scale=None,
256
+ drop=0.0,
257
+ attn_drop=0.0,
258
+ drop_path=0.0,
259
+ norm_layer=nn.LayerNorm,
260
+ sr_ratio=1,init_values=1e-6,finetune=False,
261
+ ):
262
+ super().__init__()
263
+ self.model=choice
264
+ if self.model=="base":
265
+ self.rep_conv=RepConv2(dim,dim) #if have param==83M
266
+ self.lif = Multispike()
267
+ self.attn = MS_Attention_Conv_qkv_id(
268
+ dim,
269
+ num_heads=num_heads,
270
+ qkv_bias=qkv_bias,
271
+ qk_scale=qk_scale,
272
+ attn_drop=attn_drop,
273
+ proj_drop=drop,
274
+ sr_ratio=sr_ratio,
275
+ )
276
+ self.finetune = finetune
277
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
278
+ mlp_hidden_dim = int(dim * mlp_ratio)
279
+ self.mlp = MS_MLP(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
280
+
281
+ if self.finetune:
282
+ self.layer_scale1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
283
+ self.layer_scale2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
284
+
285
+ def forward(self, x):
286
+ # T, B, C, N = x.shape
287
+ if self.model=="base":
288
+ x= x + self.rep_conv(self.lif(x).flatten(0, 1)).reshape(T, B, C, N)
289
+ # TODO: need channel-wise layer scale, init as 1e-6
290
+ if self.finetune:
291
+ x = x + self.drop_path(self.attn(x) * self.layer_scale1.unsqueeze(0).unsqueeze(0).unsqueeze(-1))
292
+ x = x + self.drop_path(self.mlp(x) * self.layer_scale2.unsqueeze(0).unsqueeze(0).unsqueeze(-1))
293
+ else:
294
+ x = x + self.attn(x)
295
+ x = x + self.mlp(x)
296
+ return x
297
+
298
+ class Spikmae(nn.Module):
299
+ def __init__(self, T=1,choice=None,
300
+ img_size_h=224,
301
+ img_size_w=224,
302
+ patch_size=16,
303
+ embed_dim=[128, 256, 512],
304
+ num_heads=8,
305
+ mlp_ratios=4,
306
+ in_channels=3,
307
+ qk_scale=None,
308
+ drop_rate=0.0,
309
+ attn_drop_rate=0.0,
310
+ drop_path_rate=0.0,
311
+ num_classes=1000,
312
+ qkv_bias=False,
313
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), #norm_layer=nn.LayerNorm shaokun
314
+ depths=8,
315
+ sr_ratios=1,
316
+ decoder_embed_dim=768,
317
+ decoder_depth=4,
318
+ decoder_num_heads=16,
319
+ mlp_ratio=4.,
320
+ norm_pix_loss=False, nb_classes=1000):
321
+ super().__init__()
322
+
323
+ self.num_classes = num_classes
324
+ self.depths = depths
325
+ self.T = 1
326
+
327
+ dpr = [
328
+ x.item() for x in torch.linspace(0, drop_path_rate, depths)
329
+ ] # stochastic depth decay rule
330
+
331
+ self.downsample1_1 = MS_DownSampling(
332
+ in_channels=in_channels,
333
+ embed_dims=embed_dim[0] // 2,
334
+ kernel_size=7,
335
+ stride=2,
336
+ padding=3,
337
+ first_layer=True,
338
+ )
339
+
340
+ self.ConvBlock1_1 = nn.ModuleList(
341
+ [MS_ConvBlock(dim=embed_dim[0] // 2, mlp_ratio=mlp_ratios)]
342
+ )
343
+
344
+ self.downsample1_2 = MS_DownSampling(
345
+ in_channels=embed_dim[0] // 2,
346
+ embed_dims=embed_dim[0],
347
+ kernel_size=3,
348
+ stride=2,
349
+ padding=1,
350
+ first_layer=False,
351
+
352
+ )
353
+
354
+ self.ConvBlock1_2 = nn.ModuleList(
355
+ [MS_ConvBlock(dim=embed_dim[0], mlp_ratio=mlp_ratios)]
356
+ )
357
+
358
+ self.downsample2 = MS_DownSampling(
359
+ in_channels=embed_dim[0],
360
+ embed_dims=embed_dim[1],
361
+ kernel_size=3,
362
+ stride=2,
363
+ padding=1,
364
+ first_layer=False,
365
+
366
+ )
367
+
368
+ self.ConvBlock2_1 = nn.ModuleList(
369
+ [MS_ConvBlock(dim=embed_dim[1], mlp_ratio=mlp_ratios)]
370
+ )
371
+
372
+ self.ConvBlock2_2 = nn.ModuleList(
373
+ [MS_ConvBlock(dim=embed_dim[1], mlp_ratio=mlp_ratios)]
374
+ )
375
+
376
+ self.downsample3 = MS_DownSampling(
377
+ in_channels=embed_dim[1],
378
+ embed_dims=embed_dim[2],
379
+ kernel_size=3,
380
+ stride=2,
381
+ padding=1,
382
+ first_layer=False,
383
+
384
+ )
385
+
386
+ self.block3 = nn.ModuleList(
387
+ [
388
+ MS_Block(
389
+ dim=embed_dim[2],
390
+ choice=choice,
391
+ num_heads=num_heads,
392
+ mlp_ratio=mlp_ratios,
393
+ qkv_bias=qkv_bias,
394
+ qk_scale=qk_scale,
395
+ drop=drop_rate,
396
+ attn_drop=attn_drop_rate,
397
+ drop_path=dpr[j],
398
+ norm_layer=norm_layer,
399
+ sr_ratio=sr_ratios,
400
+ finetune=False,
401
+ )
402
+ for j in range(depths)
403
+ ]
404
+ )
405
+
406
+ self.norm = nn.BatchNorm1d(embed_dim[-1])
407
+ self.downsample_raito =16
408
+
409
+ num_patches = 196
410
+
411
+ self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim[-1],num_patches), requires_grad=False)
412
+
413
+ ## MAE decoder vit
414
+ self.decoder_embed = nn.Linear(embed_dim[-1], decoder_embed_dim,bias=True)
415
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
416
+ # Try larned decoder
417
+ self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches, decoder_embed_dim), requires_grad=False)
418
+ self.decoder_blocks = nn.ModuleList([
419
+ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=False, norm_layer=norm_layer)
420
+ for i in range(decoder_depth)])
421
+ self.decoder_norm = norm_layer(decoder_embed_dim)
422
+ self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_channels,bias=True) # decoder to patch
423
+ self.initialize_weights()
424
+
425
+ def initialize_weights(self):
426
+ num_patches=196
427
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[1], int(num_patches ** .5),
428
+ cls_token=False)
429
+
430
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed.transpose(1,0)).float().unsqueeze(0))
431
+
432
+ decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1],
433
+ int(num_patches** .5), cls_token=False)
434
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
435
+
436
+ torch.nn.init.normal_(self.mask_token, std=.02)
437
+ self.apply(self._init_weights)
438
+
439
+ def _init_weights(self, m):
440
+ if isinstance(m, nn.Linear):
441
+ trunc_normal_(m.weight, std=0.02)
442
+ if isinstance(m, nn.Linear) and m.bias is not None:
443
+ nn.init.constant_(m.bias, 0)
444
+ elif isinstance(m, nn.LayerNorm):
445
+ nn.init.constant_(m.bias, 0)
446
+ nn.init.constant_(m.weight, 1.0)
447
+ def random_masking(self, x, mask_ratio):
448
+ """
449
+ Perform per-sample random masking by per-sample shuffling.
450
+ Per-sample shuffling is done by argsort random noise.
451
+ x: [N, L, D], sequence
452
+ """
453
+ num_patches=196
454
+ T, N, _, _, _ = x.shape # batch, length, dim
455
+ L = num_patches
456
+ len_keep = int(L * (1 - mask_ratio))
457
+
458
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
459
+
460
+ # sort noise for each sample
461
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
462
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
463
+
464
+ # keep the first subset
465
+ ids_keep = ids_shuffle[:, :len_keep]
466
+
467
+ # generate the binary mask: 0 is keep, 1 is remove
468
+ mask = torch.ones([N, L], device=x.device)
469
+ mask[:, :len_keep] = 0
470
+ # unshuffle to get the binary mask
471
+ mask = torch.gather(mask, dim=1, index=ids_restore)
472
+
473
+ # active is inverse mask
474
+ active = torch.ones([N, L], device=x.device)
475
+ active[:, len_keep:] = 0
476
+ active = torch.gather(active, dim=1, index=ids_restore)
477
+
478
+ return ids_keep, active, ids_restore
479
+
480
+ def forward_encoder(self, x , mask_ratio=1.0):
481
+ x = (x.unsqueeze(0)).repeat(self.T, 1, 1, 1, 1)
482
+ # step1. Mask
483
+ ids_keep, active, ids_restore = self.random_masking(x , mask_ratio)
484
+ B,N=active.shape
485
+ active_b1ff=active.reshape(B,1,14,14)
486
+
487
+
488
+ encoder._cur_active = active_b1ff
489
+ active_hw = active_b1ff.repeat_interleave(self.downsample_raito, 2).repeat_interleave(self.downsample_raito, 3)
490
+ active_hw = active_hw.unsqueeze(0)
491
+ masked_bchw = x * active_hw
492
+ x = masked_bchw
493
+ x = self.downsample1_1(x)
494
+ for blk in self.ConvBlock1_1:
495
+ x = blk(x)
496
+ x = self.downsample1_2(x)
497
+ for blk in self.ConvBlock1_2:
498
+ x = blk(x)
499
+
500
+ x = self.downsample2(x)
501
+ for blk in self.ConvBlock2_1:
502
+ x = blk(x)
503
+ for blk in self.ConvBlock2_2:
504
+ x = blk(x)
505
+
506
+ x = self.downsample3(x)
507
+ x = x.flatten(3)
508
+ for blk in self.block3:
509
+ x = blk(x)
510
+
511
+ x = x.mean(0)
512
+ x = self.norm(x).transpose(-1, -2).contiguous()
513
+ return x, active,ids_restore,active_hw
514
+
515
+ def forward_decoder(self, x, ids_restore):
516
+ # embed tokens
517
+ B, N, C = x.shape
518
+ x = self.decoder_embed(x) # B, N, C
519
+ # append mask tokens to sequence
520
+ # ids_restore#1,196
521
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
522
+ x_ = torch.cat([x[:, :, :], mask_tokens], dim=1) # no cls token
523
+ x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
524
+ x = x_
525
+ #
526
+ # add pos embed
527
+ x = x + self.decoder_pos_embed
528
+ # apply Transformer blocks
529
+ for blk in self.decoder_blocks:
530
+ x = blk(x)
531
+ x = self.decoder_norm(x)
532
+ x = self.decoder_pred(x)
533
+
534
+ return x
535
+
536
+ def patchify(self, imgs):
537
+ """
538
+ imgs: (N, 3, H, W)
539
+ x: (N, L, patch_size**2 *3)
540
+ """
541
+ p = 16
542
+ assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
543
+
544
+ h = w = imgs.shape[2] // p
545
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
546
+ x = torch.einsum('nchpwq->nhwpqc', x)
547
+ x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
548
+ return x
549
+
550
+ def unpatchify(self, x):
551
+ """
552
+ x: (N, L, patch_size**2 *3)
553
+ imgs: (N, 3, H, W)
554
+ """
555
+ p = 16
556
+ h = w = int(x.shape[1] ** .5)
557
+ assert h * w == x.shape[1]
558
+
559
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
560
+ x = torch.einsum('nhwpqc->nchpwq', x)
561
+ imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
562
+ return imgs
563
+ def forward_loss(self, imgs, pred, mask):
564
+ """
565
+ imgs: [N, 3, H, W]
566
+ pred: [N, L, p*p*3]
567
+ mask: [N, L], 0 is keep, 1 is remove,
568
+ """
569
+
570
+ inp, rec = self.patchify(imgs), pred # inp and rec: (B, L = f*f, N = C*downsample_raito**2)
571
+ mean = inp.mean(dim=-1, keepdim=True)
572
+ var = (inp.var(dim=-1, keepdim=True) + 1e-6) ** .5
573
+ inp = (inp - mean) / var
574
+ l2_loss = ((rec - inp) ** 2).mean(dim=2, keepdim=False) # (B, L, C) ==mean==> (B, L)
575
+ non_active = mask.logical_not().int().view(mask.shape[0], -1) # (B, 1, f, f) => (B, L)
576
+ recon_loss = l2_loss.mul_(non_active).sum() / (non_active.sum() + 1e-8) # loss only on masked (non-active) patches
577
+ return recon_loss,mean,var
578
+
579
+ def forward(self, imgs, mask_ratio=0.5,vis=False):
580
+
581
+ latent, active, ids_restore,active_hw = self.forward_encoder(imgs, mask_ratio)
582
+ rec = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
583
+ recon_loss,mean,var = self.forward_loss(imgs, rec, active)
584
+ if vis:
585
+ masked_bchw = imgs * active_hw.flatten(0,1)
586
+ rec_bchw = self.unpatchify(rec * var + mean)
587
+ rec_or_inp = torch.where(active_hw.flatten(0,1).bool(), imgs, rec_bchw)
588
+ return imgs, masked_bchw, rec_or_inp
589
+ else:
590
+ return recon_loss
591
+
592
+
593
+ def spikmae_12_512(**kwargs):
594
+ model = Spikmae(
595
+ T=1,
596
+ choice="base",
597
+ img_size_h=224,
598
+ img_size_w=224,
599
+ patch_size=16,
600
+ embed_dim=[128,256,512],
601
+ num_heads=8,
602
+ mlp_ratios=4,
603
+ in_channels=3,
604
+ num_classes=1000,
605
+ qkv_bias=False,
606
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
607
+ depths=12,
608
+ sr_ratios=1, decoder_embed_dim=256, decoder_depth=4, decoder_num_heads=4,
609
+ **kwargs)
610
+ return model
611
+ def spikmae_12_768(**kwargs):
612
+ model = Spikmae(
613
+ T=1,
614
+ choice="large",
615
+ img_size_h=224,
616
+ img_size_w=224,
617
+ patch_size=16,
618
+ embed_dim=[192,384,768],
619
+ num_heads=8,
620
+ mlp_ratios=4,
621
+ in_channels=3,
622
+ num_classes=1000,
623
+ qkv_bias=False,
624
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
625
+ depths=12,
626
+ sr_ratios=1, decoder_embed_dim=256, decoder_depth=4, decoder_num_heads=4,
627
+ **kwargs)
628
+ return model
629
+
630
+
631
+
632
+
633
+ if __name__ == "__main__":
634
+ model = spikmae_12_768()
635
+ x=torch.randn(1,3,224,224)
636
+ loss = model(x,mask_ratio=0.50)
637
+ print('loss',loss)
638
+ torchinfo.summary(model, (1, 3, 224, 224))
639
+ print(f"number of params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
models/__init__.py ADDED
File without changes
models/__pycache__/MAE_SDT.cpython-311.pyc ADDED
Binary file (36.2 kB). View file
 
models/__pycache__/MAE_SDT.cpython-312.pyc ADDED
Binary file (32.1 kB). View file
 
models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (201 Bytes). View file
 
models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (212 Bytes). View file
 
models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (153 Bytes). View file
 
models/__pycache__/encoder.cpython-311.pyc ADDED
Binary file (10.6 kB). View file
 
models/__pycache__/encoder.cpython-312.pyc ADDED
Binary file (9.31 kB). View file
 
models/__pycache__/metaformer.cpython-311.pyc ADDED
Binary file (72.1 kB). View file
 
models/__pycache__/metaformer.cpython-312.pyc ADDED
Binary file (63.8 kB). View file
 
models/__pycache__/neuron.cpython-311.pyc ADDED
Binary file (78.9 kB). View file
 
models/__pycache__/neuron.cpython-312.pyc ADDED
Binary file (75.7 kB). View file
 
models/__pycache__/qk_model_v1_1003.cpython-311.pyc ADDED
Binary file (30 kB). View file
 
models/__pycache__/qkformer.cpython-311.pyc ADDED
Binary file (31.5 kB). View file
 
models/__pycache__/qkformer.cpython-312.pyc ADDED
Binary file (27.1 kB). View file
 
models/__pycache__/sd_former_v1.cpython-311.pyc ADDED
Binary file (29.7 kB). View file
 
models/__pycache__/sd_former_v1.cpython-312.pyc ADDED
Binary file (25.6 kB). View file
 
models/__pycache__/sdtv3.cpython-311.pyc ADDED
Binary file (64.6 kB). View file
 
models/__pycache__/sdtv3.cpython-312.pyc ADDED
Binary file (55.6 kB). View file
 
models/__pycache__/sdtv3.cpython-39.pyc ADDED
Binary file (21 kB). View file
 
models/__pycache__/sdtv3_large.cpython-311.pyc ADDED
Binary file (27.3 kB). View file
 
models/__pycache__/sdtv3_large.cpython-312.pyc ADDED
Binary file (23.7 kB). View file
 
models/__pycache__/spikformer.cpython-311.pyc ADDED
Binary file (28 kB). View file
 
models/__pycache__/spikformer.cpython-312.pyc ADDED
Binary file (24.6 kB). View file
 
models/__pycache__/vit.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c206e15daa2f79c2abc87acf17b7e6263bb292fe86a0581f10e58b94da50c3d5
3
+ size 204918
models/__pycache__/vit.cpython-312.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:541616a16f1f3839624aff7ca6c0d0f168227ee19a642340a85e71e77d6ea63d
3
+ size 183274
models/encoder.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from timm.models.layers import DropPath
10
+
11
+
12
+ _cur_active: torch.Tensor = None # B1ff
13
+ # todo: try to use `gather` for speed?
14
+ def _get_active_ex_or_ii(H, W, returning_active_ex=True):
15
+ h_repeat, w_repeat = H // _cur_active.shape[-2], W // _cur_active.shape[-1]
16
+ active_ex = _cur_active.repeat_interleave(h_repeat, dim=2).repeat_interleave(w_repeat, dim=3)
17
+ return active_ex if returning_active_ex else active_ex.squeeze(1).nonzero(as_tuple=True) # ii: bi, hi, wi
18
+
19
+
20
+ def sp_conv_forward(self, x: torch.Tensor):
21
+ x = super(type(self), self).forward(x)
22
+ x *= _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=True) # (BCHW) *= (B1HW), mask the output of conv
23
+ return x
24
+
25
+
26
+ def sp_bn_forward(self, x: torch.Tensor):
27
+ ii = _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=False)
28
+
29
+ bhwc = x.permute(0, 2, 3, 1)
30
+ nc = bhwc[ii] # select the features on non-masked positions to form a flatten feature `nc`
31
+ nc = super(type(self), self).forward(nc) # use BN1d to normalize this flatten feature `nc`
32
+
33
+ bchw = torch.zeros_like(bhwc)
34
+ bchw[ii] = nc
35
+ bchw = bchw.permute(0, 3, 1, 2)
36
+ return bchw
37
+
38
+
39
+ class SparseConv2d(nn.Conv2d):
40
+ forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details
41
+
42
+
43
+ class SparseMaxPooling(nn.MaxPool2d):
44
+ forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details
45
+
46
+
47
+ class SparseAvgPooling(nn.AvgPool2d):
48
+ forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details
49
+
50
+
51
+ class SparseBatchNorm2d(nn.BatchNorm1d):
52
+ forward = sp_bn_forward # hack: override the forward function; see `sp_bn_forward` above for more details
53
+
54
+
55
+ class SparseSyncBatchNorm2d(nn.SyncBatchNorm):
56
+ forward = sp_bn_forward # hack: override the forward function; see `sp_bn_forward` above for more details
57
+
58
+
59
+ class SparseConvNeXtLayerNorm(nn.LayerNorm):
60
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
61
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
62
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
63
+ with shape (batch_size, channels, height, width).
64
+ """
65
+
66
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", sparse=True):
67
+ if data_format not in ["channels_last", "channels_first"]:
68
+ raise NotImplementedError
69
+ super().__init__(normalized_shape, eps, elementwise_affine=True)
70
+ self.data_format = data_format
71
+ self.sparse = sparse
72
+
73
+ def forward(self, x):
74
+ if x.ndim == 4: # BHWC or BCHW
75
+ if self.data_format == "channels_last": # BHWC
76
+ if self.sparse:
77
+ ii = _get_active_ex_or_ii(H=x.shape[1], W=x.shape[2], returning_active_ex=False)
78
+ nc = x[ii]
79
+ nc = super(SparseConvNeXtLayerNorm, self).forward(nc)
80
+
81
+ x = torch.zeros_like(x)
82
+ x[ii] = nc
83
+ return x
84
+ else:
85
+ return super(SparseConvNeXtLayerNorm, self).forward(x)
86
+ else: # channels_first, BCHW
87
+ if self.sparse:
88
+ ii = _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=False)
89
+ bhwc = x.permute(0, 2, 3, 1)
90
+ nc = bhwc[ii]
91
+ nc = super(SparseConvNeXtLayerNorm, self).forward(nc)
92
+
93
+ x = torch.zeros_like(bhwc)
94
+ x[ii] = nc
95
+ return x.permute(0, 3, 1, 2)
96
+ else:
97
+ u = x.mean(1, keepdim=True)
98
+ s = (x - u).pow(2).mean(1, keepdim=True)
99
+ x = (x - u) / torch.sqrt(s + self.eps)
100
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
101
+ return x
102
+ else: # BLC or BC
103
+ if self.sparse:
104
+ raise NotImplementedError
105
+ else:
106
+ return super(SparseConvNeXtLayerNorm, self).forward(x)
107
+
108
+ def __repr__(self):
109
+ return super(SparseConvNeXtLayerNorm, self).__repr__()[:-1] + f', ch={self.data_format.split("_")[-1]}, sp={self.sparse})'
110
+
111
+
112
+ class SparseConvNeXtBlock(nn.Module):
113
+ r""" ConvNeXt Block. There are two equivalent implementations:
114
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
115
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
116
+ We use (2) as we find it slightly faster in PyTorch
117
+
118
+ Args:
119
+ dim (int): Number of input channels.
120
+ drop_path (float): Stochastic depth rate. Default: 0.0
121
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
122
+ """
123
+
124
+ def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, sparse=True, ks=7):
125
+ super().__init__()
126
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=ks, padding=ks//2, groups=dim) # depthwise conv
127
+ self.norm = SparseConvNeXtLayerNorm(dim, eps=1e-6, sparse=sparse)
128
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
129
+ self.act = nn.GELU()
130
+ self.pwconv2 = nn.Linear(4 * dim, dim)
131
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
132
+ requires_grad=True) if layer_scale_init_value > 0 else None
133
+ self.drop_path: nn.Module = DropPath(drop_path) if drop_path > 0. else nn.Identity()
134
+ self.sparse = sparse
135
+
136
+ def forward(self, x):
137
+ input = x
138
+ x = self.dwconv(x)
139
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
140
+ x = self.norm(x)
141
+ x = self.pwconv1(x)
142
+ x = self.act(x) # GELU(0) == (0), so there is no need to mask x (no need to `x *= _get_active_ex_or_ii`)
143
+ x = self.pwconv2(x)
144
+ if self.gamma is not None:
145
+ x = self.gamma * x
146
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
147
+
148
+ if self.sparse:
149
+ x *= _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=True)
150
+
151
+ x = input + self.drop_path(x)
152
+ return x
153
+
154
+ def __repr__(self):
155
+ return super(SparseConvNeXtBlock, self).__repr__()[:-1] + f', sp={self.sparse})'
156
+
157
+
158
+
models/metaformer.py ADDED
@@ -0,0 +1,1538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Garena Online Private Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ MetaFormer baselines including IdentityFormer, RandFormer, PoolFormerV2,
17
+ ConvFormer and CAFormer.
18
+ Some implementations are modified from timm (https://github.com/rwightman/pytorch-image-models).
19
+ """
20
+ from functools import partial
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+
25
+ from timm.layers import trunc_normal_, DropPath
26
+ from timm.models.registry import register_model
27
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
28
+ from timm.layers.helpers import to_2tuple
29
+
30
+
31
+ def _cfg(url='', **kwargs):
32
+ return {
33
+ 'url': url,
34
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
35
+ 'crop_pct': 1.0, 'interpolation': 'bicubic',
36
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head',
37
+ **kwargs
38
+ }
39
+
40
+
41
+ default_cfgs = {
42
+ 'identityformer_s12': _cfg(
43
+ url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s12.pth'),
44
+ 'identityformer_s24': _cfg(
45
+ url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s24.pth'),
46
+ 'identityformer_s36': _cfg(
47
+ url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s36.pth'),
48
+ 'identityformer_m36': _cfg(
49
+ url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m36.pth'),
50
+ 'identityformer_m48': _cfg(
51
+ url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m48.pth'),
52
+
53
+
54
+ 'randformer_s12': _cfg(
55
+ url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s12.pth'),
56
+ 'randformer_s24': _cfg(
57
+ url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s24.pth'),
58
+ 'randformer_s36': _cfg(
59
+ url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s36.pth'),
60
+ 'randformer_m36': _cfg(
61
+ url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m36.pth'),
62
+ 'randformer_m48': _cfg(
63
+ url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m48.pth'),
64
+
65
+ 'poolformerv2_s12': _cfg(
66
+ url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s12.pth'),
67
+ 'poolformerv2_s24': _cfg(
68
+ url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s24.pth'),
69
+ 'poolformerv2_s36': _cfg(
70
+ url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s36.pth'),
71
+ 'poolformerv2_m36': _cfg(
72
+ url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m36.pth'),
73
+ 'poolformerv2_m48': _cfg(
74
+ url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m48.pth'),
75
+
76
+
77
+
78
+ 'convformer_s18': _cfg(
79
+ url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18.pth'),
80
+ 'convformer_s18_384': _cfg(
81
+ url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth',
82
+ input_size=(3, 384, 384)),
83
+ 'convformer_s18_in21ft1k': _cfg(
84
+ url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth'),
85
+ 'convformer_s18_384_in21ft1k': _cfg(
86
+ url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384_in21ft1k.pth',
87
+ input_size=(3, 384, 384)),
88
+ 'convformer_s18_in21k': _cfg(
89
+ url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21k.pth',
90
+ num_classes=21841),
91
+
92
+ 'convformer_s36': _cfg(
93
+ url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36.pth'),
94
+ 'convformer_s36_384': _cfg(
95
+ url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384.pth',
96
+ input_size=(3, 384, 384)),
97
+ 'convformer_s36_in21ft1k': _cfg(
98
+ url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21ft1k.pth'),
99
+ 'convformer_s36_384_in21ft1k': _cfg(
100
+ url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384_in21ft1k.pth',
101
+ input_size=(3, 384, 384)),
102
+ 'convformer_s36_in21k': _cfg(
103
+ url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21k.pth',
104
+ num_classes=21841),
105
+
106
+ 'convformer_m36': _cfg(
107
+ url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36.pth'),
108
+ 'convformer_m36_384': _cfg(
109
+ url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384.pth',
110
+ input_size=(3, 384, 384)),
111
+ 'convformer_m36_in21ft1k': _cfg(
112
+ url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21ft1k.pth'),
113
+ 'convformer_m36_384_in21ft1k': _cfg(
114
+ url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384_in21ft1k.pth',
115
+ input_size=(3, 384, 384)),
116
+ 'convformer_m36_in21k': _cfg(
117
+ url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21k.pth',
118
+ num_classes=21841),
119
+
120
+ 'convformer_b36': _cfg(
121
+ url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36.pth'),
122
+ 'convformer_b36_384': _cfg(
123
+ url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384.pth',
124
+ input_size=(3, 384, 384)),
125
+ 'convformer_b36_in21ft1k': _cfg(
126
+ url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21ft1k.pth'),
127
+ 'convformer_b36_384_in21ft1k': _cfg(
128
+ url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384_in21ft1k.pth',
129
+ input_size=(3, 384, 384)),
130
+ 'convformer_b36_in21k': _cfg(
131
+ url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21k.pth',
132
+ num_classes=21841),
133
+
134
+
135
+ 'caformer_s18': _cfg(
136
+ url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18.pth'),
137
+ 'caformer_s18_384': _cfg(
138
+ url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384.pth',
139
+ input_size=(3, 384, 384)),
140
+ 'caformer_s18_in21ft1k': _cfg(
141
+ url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21ft1k.pth'),
142
+ 'caformer_s18_384_in21ft1k': _cfg(
143
+ url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384_in21ft1k.pth',
144
+ input_size=(3, 384, 384)),
145
+ 'caformer_s18_in21k': _cfg(
146
+ url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21k.pth',
147
+ num_classes=21841),
148
+
149
+ 'caformer_s36': _cfg(
150
+ url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36.pth'),
151
+ 'caformer_s36_384': _cfg(
152
+ url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384.pth',
153
+ input_size=(3, 384, 384)),
154
+ 'caformer_s36_in21ft1k': _cfg(
155
+ url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21ft1k.pth'),
156
+ 'caformer_s36_384_in21ft1k': _cfg(
157
+ url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384_in21ft1k.pth',
158
+ input_size=(3, 384, 384)),
159
+ 'caformer_s36_in21k': _cfg(
160
+ url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21k.pth',
161
+ num_classes=21841),
162
+
163
+ 'caformer_m36': _cfg(
164
+ url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36.pth'),
165
+ 'caformer_m36_384': _cfg(
166
+ url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384.pth',
167
+ input_size=(3, 384, 384)),
168
+ 'caformer_m36_in21ft1k': _cfg(
169
+ url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21ft1k.pth'),
170
+ 'caformer_m36_384_in21ft1k': _cfg(
171
+ url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth',
172
+ input_size=(3, 384, 384)),
173
+ 'caformer_m36_in21k': _cfg(
174
+ url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21k.pth',
175
+ num_classes=21841),
176
+
177
+ 'caformer_b36': _cfg(
178
+ url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36.pth'),
179
+ 'caformer_b36_384': _cfg(
180
+ url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384.pth',
181
+ input_size=(3, 384, 384)),
182
+ 'caformer_b36_in21ft1k': _cfg(
183
+ url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21ft1k.pth'),
184
+ 'caformer_b36_384_in21ft1k': _cfg(
185
+ url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384_in21ft1k.pth',
186
+ input_size=(3, 384, 384)),
187
+ 'caformer_b36_in21k': _cfg(
188
+ url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21k.pth',
189
+ num_classes=21841),
190
+ }
191
+
192
+
193
+ class Downsampling(nn.Module):
194
+ """
195
+ Downsampling implemented by a layer of convolution.
196
+ """
197
+ def __init__(self, in_channels, out_channels,
198
+ kernel_size, stride=1, padding=0,
199
+ pre_norm=None, post_norm=None, pre_permute=False):
200
+ super().__init__()
201
+ self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity()
202
+ self.pre_permute = pre_permute
203
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
204
+ stride=stride, padding=padding)
205
+ self.post_norm = post_norm(out_channels) if post_norm else nn.Identity()
206
+
207
+ def forward(self, x):
208
+ x = self.pre_norm(x)
209
+ if self.pre_permute:
210
+ # if take [B, H, W, C] as input, permute it to [B, C, H, W]
211
+ x = x.permute(0, 3, 1, 2)
212
+ x = self.conv(x)
213
+ x = x.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C]
214
+ x = self.post_norm(x)
215
+ return x
216
+
217
+
218
+ class Scale(nn.Module):
219
+ """
220
+ Scale vector by element multiplications.
221
+ """
222
+ def __init__(self, dim, init_value=1.0, trainable=True):
223
+ super().__init__()
224
+ self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable)
225
+
226
+ def forward(self, x):
227
+ return x * self.scale
228
+
229
+
230
+ class SquaredReLU(nn.Module):
231
+ """
232
+ Squared ReLU: https://arxiv.org/abs/2109.08668
233
+ """
234
+ def __init__(self, inplace=False):
235
+ super().__init__()
236
+ self.relu = nn.ReLU(inplace=inplace)
237
+ def forward(self, x):
238
+ return torch.square(self.relu(x))
239
+
240
+
241
+ class StarReLU(nn.Module):
242
+ """
243
+ StarReLU: s * relu(x) ** 2 + b
244
+ """
245
+ def __init__(self, scale_value=1.0, bias_value=0.0,
246
+ scale_learnable=True, bias_learnable=True,
247
+ mode=None, inplace=False):
248
+ super().__init__()
249
+ self.inplace = inplace
250
+ self.relu = nn.ReLU(inplace=inplace)
251
+ self.scale = nn.Parameter(scale_value * torch.ones(1),
252
+ requires_grad=scale_learnable)
253
+ self.bias = nn.Parameter(bias_value * torch.ones(1),
254
+ requires_grad=bias_learnable)
255
+ def forward(self, x):
256
+ return self.scale * self.relu(x)**2 + self.bias
257
+
258
+
259
+ class Attention(nn.Module):
260
+ """
261
+ Vanilla self-attention from Transformer: https://arxiv.org/abs/1706.03762.
262
+ Modified from timm.
263
+ """
264
+ def __init__(self, dim, head_dim=32, num_heads=None, qkv_bias=False,
265
+ attn_drop=0., proj_drop=0., proj_bias=False, **kwargs):
266
+ super().__init__()
267
+
268
+ self.head_dim = head_dim
269
+ self.scale = head_dim ** -0.5
270
+
271
+ self.num_heads = num_heads if num_heads else dim // head_dim
272
+ if self.num_heads == 0:
273
+ self.num_heads = 1
274
+
275
+ self.attention_dim = self.num_heads * self.head_dim
276
+
277
+ self.qkv = nn.Linear(dim, self.attention_dim * 3, bias=qkv_bias)
278
+ self.attn_drop = nn.Dropout(attn_drop)
279
+ self.proj = nn.Linear(self.attention_dim, dim, bias=proj_bias)
280
+ self.proj_drop = nn.Dropout(proj_drop)
281
+
282
+
283
+ def forward(self, x):
284
+ B, H, W, C = x.shape
285
+ N = H * W
286
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
287
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
288
+
289
+ attn = (q @ k.transpose(-2, -1)) * self.scale
290
+ attn = attn.softmax(dim=-1)
291
+ attn = self.attn_drop(attn)
292
+
293
+ x = (attn @ v).transpose(1, 2).reshape(B, H, W, self.attention_dim)
294
+ x = self.proj(x)
295
+ x = self.proj_drop(x)
296
+ return x
297
+
298
+
299
+ class RandomMixing(nn.Module):
300
+ def __init__(self, num_tokens=196, **kwargs):
301
+ super().__init__()
302
+ self.random_matrix = nn.parameter.Parameter(
303
+ data=torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1),
304
+ requires_grad=False)
305
+ def forward(self, x):
306
+ B, H, W, C = x.shape
307
+ x = x.reshape(B, H*W, C)
308
+ x = torch.einsum('mn, bnc -> bmc', self.random_matrix, x)
309
+ x = x.reshape(B, H, W, C)
310
+ return x
311
+
312
+
313
+ class LayerNormGeneral(nn.Module):
314
+ r""" General LayerNorm for different situations.
315
+
316
+ Args:
317
+ affine_shape (int, list or tuple): The shape of affine weight and bias.
318
+ Usually the affine_shape=C, but in some implementation, like torch.nn.LayerNorm,
319
+ the affine_shape is the same as normalized_dim by default.
320
+ To adapt to different situations, we offer this argument here.
321
+ normalized_dim (tuple or list): Which dims to compute mean and variance.
322
+ scale (bool): Flag indicates whether to use scale or not.
323
+ bias (bool): Flag indicates whether to use scale or not.
324
+
325
+ We give several examples to show how to specify the arguments.
326
+
327
+ LayerNorm (https://arxiv.org/abs/1607.06450):
328
+ For input shape of (B, *, C) like (B, N, C) or (B, H, W, C),
329
+ affine_shape=C, normalized_dim=(-1, ), scale=True, bias=True;
330
+ For input shape of (B, C, H, W),
331
+ affine_shape=(C, 1, 1), normalized_dim=(1, ), scale=True, bias=True.
332
+
333
+ Modified LayerNorm (https://arxiv.org/abs/2111.11418)
334
+ that is idental to partial(torch.nn.GroupNorm, num_groups=1):
335
+ For input shape of (B, N, C),
336
+ affine_shape=C, normalized_dim=(1, 2), scale=True, bias=True;
337
+ For input shape of (B, H, W, C),
338
+ affine_shape=C, normalized_dim=(1, 2, 3), scale=True, bias=True;
339
+ For input shape of (B, C, H, W),
340
+ affine_shape=(C, 1, 1), normalized_dim=(1, 2, 3), scale=True, bias=True.
341
+
342
+ For the several metaformer baslines,
343
+ IdentityFormer, RandFormer and PoolFormerV2 utilize Modified LayerNorm without bias (bias=False);
344
+ ConvFormer and CAFormer utilizes LayerNorm without bias (bias=False).
345
+ """
346
+ def __init__(self, affine_shape=None, normalized_dim=(-1, ), scale=True,
347
+ bias=True, eps=1e-5):
348
+ super().__init__()
349
+ self.normalized_dim = normalized_dim
350
+ self.use_scale = scale
351
+ self.use_bias = bias
352
+ self.weight = nn.Parameter(torch.ones(affine_shape)) if scale else None
353
+ self.bias = nn.Parameter(torch.zeros(affine_shape)) if bias else None
354
+ self.eps = eps
355
+
356
+ def forward(self, x):
357
+ c = x - x.mean(self.normalized_dim, keepdim=True)
358
+ s = c.pow(2).mean(self.normalized_dim, keepdim=True)
359
+ x = c / torch.sqrt(s + self.eps)
360
+ if self.use_scale:
361
+ x = x * self.weight
362
+ if self.use_bias:
363
+ x = x + self.bias
364
+ return x
365
+
366
+
367
+ class LayerNormWithoutBias(nn.Module):
368
+ """
369
+ Equal to partial(LayerNormGeneral, bias=False) but faster,
370
+ because it directly utilizes otpimized F.layer_norm
371
+ """
372
+ def __init__(self, normalized_shape, eps=1e-5, **kwargs):
373
+ super().__init__()
374
+ self.eps = eps
375
+ self.bias = None
376
+ if isinstance(normalized_shape, int):
377
+ normalized_shape = (normalized_shape,)
378
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
379
+ self.normalized_shape = normalized_shape
380
+ def forward(self, x):
381
+ return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
382
+
383
+
384
+ class SepConv(nn.Module):
385
+ r"""
386
+ Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381.
387
+ """
388
+ def __init__(self, dim, expansion_ratio=2,
389
+ act1_layer=StarReLU, act2_layer=nn.Identity,
390
+ bias=False, kernel_size=7, padding=3,
391
+ **kwargs, ):
392
+ super().__init__()
393
+ med_channels = int(expansion_ratio * dim)
394
+ self.pwconv1 = nn.Linear(dim, med_channels, bias=bias)
395
+ self.act1 = act1_layer()
396
+ self.dwconv = nn.Conv2d(
397
+ med_channels, med_channels, kernel_size=kernel_size,
398
+ padding=padding, groups=med_channels, bias=bias) # depthwise conv
399
+ self.act2 = act2_layer()
400
+ self.pwconv2 = nn.Linear(med_channels, dim, bias=bias)
401
+
402
+ def forward(self, x):
403
+ x = self.pwconv1(x)
404
+ x = self.act1(x)
405
+ x = x.permute(0, 3, 1, 2)
406
+ x = self.dwconv(x)
407
+ x = x.permute(0, 2, 3, 1)
408
+ x = self.act2(x)
409
+ x = self.pwconv2(x)
410
+ return x
411
+
412
+
413
+ class Pooling(nn.Module):
414
+ """
415
+ Implementation of pooling for PoolFormer: https://arxiv.org/abs/2111.11418
416
+ Modfiled for [B, H, W, C] input
417
+ """
418
+ def __init__(self, pool_size=3, **kwargs):
419
+ super().__init__()
420
+ self.pool = nn.AvgPool2d(
421
+ pool_size, stride=1, padding=pool_size//2, count_include_pad=False)
422
+
423
+ def forward(self, x):
424
+ y = x.permute(0, 3, 1, 2)
425
+ y = self.pool(y)
426
+ y = y.permute(0, 2, 3, 1)
427
+ return y - x
428
+
429
+
430
+ class Mlp(nn.Module):
431
+ """ MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks.
432
+ Mostly copied from timm.
433
+ """
434
+ def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0., bias=False, **kwargs):
435
+ super().__init__()
436
+ in_features = dim
437
+ out_features = out_features or in_features
438
+ hidden_features = int(mlp_ratio * in_features)
439
+ drop_probs = to_2tuple(drop)
440
+
441
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
442
+ self.act = act_layer()
443
+ self.drop1 = nn.Dropout(drop_probs[0])
444
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
445
+ self.drop2 = nn.Dropout(drop_probs[1])
446
+
447
+ def forward(self, x):
448
+ x = self.fc1(x)
449
+ x = self.act(x)
450
+ x = self.drop1(x)
451
+ x = self.fc2(x)
452
+ x = self.drop2(x)
453
+ return x
454
+
455
+
456
+ class MlpHead(nn.Module):
457
+ """ MLP classification head
458
+ """
459
+ def __init__(self, dim, num_classes=1000, mlp_ratio=4, act_layer=SquaredReLU,
460
+ norm_layer=nn.LayerNorm, head_dropout=0., bias=True):
461
+ super().__init__()
462
+ hidden_features = int(mlp_ratio * dim)
463
+ self.fc1 = nn.Linear(dim, hidden_features, bias=bias)
464
+ self.act = act_layer()
465
+ self.norm = norm_layer(hidden_features)
466
+ self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias)
467
+ self.head_dropout = nn.Dropout(head_dropout)
468
+
469
+
470
+ def forward(self, x):
471
+ x = self.fc1(x)
472
+ x = self.act(x)
473
+ x = self.norm(x)
474
+ x = self.head_dropout(x)
475
+ x = self.fc2(x)
476
+ return x
477
+
478
+
479
+ class MetaFormerBlock(nn.Module):
480
+ """
481
+ Implementation of one MetaFormer block.
482
+ """
483
+ def __init__(self, dim,
484
+ token_mixer=nn.Identity, mlp=Mlp,
485
+ norm_layer=nn.LayerNorm,
486
+ drop=0., drop_path=0.,
487
+ layer_scale_init_value=None, res_scale_init_value=None
488
+ ):
489
+
490
+ super().__init__()
491
+
492
+ self.norm1 = norm_layer(dim)
493
+ self.token_mixer = token_mixer(dim=dim, drop=drop)
494
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
495
+ self.layer_scale1 = Scale(dim=dim, init_value=layer_scale_init_value) \
496
+ if layer_scale_init_value else nn.Identity()
497
+ self.res_scale1 = Scale(dim=dim, init_value=res_scale_init_value) \
498
+ if res_scale_init_value else nn.Identity()
499
+
500
+ self.norm2 = norm_layer(dim)
501
+ self.mlp = mlp(dim=dim, drop=drop)
502
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
503
+ self.layer_scale2 = Scale(dim=dim, init_value=layer_scale_init_value) \
504
+ if layer_scale_init_value else nn.Identity()
505
+ self.res_scale2 = Scale(dim=dim, init_value=res_scale_init_value) \
506
+ if res_scale_init_value else nn.Identity()
507
+
508
+ def forward(self, x):
509
+ x = self.res_scale1(x) + \
510
+ self.layer_scale1(
511
+ self.drop_path1(
512
+ self.token_mixer(self.norm1(x))
513
+ )
514
+ )
515
+ x = self.res_scale2(x) + \
516
+ self.layer_scale2(
517
+ self.drop_path2(
518
+ self.mlp(self.norm2(x))
519
+ )
520
+ )
521
+ return x
522
+
523
+
524
+ r"""
525
+ downsampling (stem) for the first stage is a layer of conv with k7, s4 and p2
526
+ downsamplings for the last 3 stages is a layer of conv with k3, s2 and p1
527
+ DOWNSAMPLE_LAYERS_FOUR_STAGES format: [Downsampling, Downsampling, Downsampling, Downsampling]
528
+ use `partial` to specify some arguments
529
+ """
530
+ DOWNSAMPLE_LAYERS_FOUR_STAGES = [partial(Downsampling,
531
+ kernel_size=7, stride=4, padding=2,
532
+ post_norm=partial(LayerNormGeneral, bias=False, eps=1e-6)
533
+ )] + \
534
+ [partial(Downsampling,
535
+ kernel_size=3, stride=2, padding=1,
536
+ pre_norm=partial(LayerNormGeneral, bias=False, eps=1e-6), pre_permute=True
537
+ )]*3
538
+
539
+
540
+ class MetaFormer(nn.Module):
541
+ r""" MetaFormer
542
+ A PyTorch impl of : `MetaFormer Baselines for Vision` -
543
+ https://arxiv.org/abs/2210.13452
544
+
545
+ Args:
546
+ in_chans (int): Number of input image channels. Default: 3.
547
+ num_classes (int): Number of classes for classification head. Default: 1000.
548
+ depths (list or tuple): Number of blocks at each stage. Default: [2, 2, 6, 2].
549
+ dims (int): Feature dimension at each stage. Default: [64, 128, 320, 512].
550
+ downsample_layers: (list or tuple): Downsampling layers before each stage.
551
+ token_mixers (list, tuple or token_fcn): Token mixer for each stage. Default: nn.Identity.
552
+ mlps (list, tuple or mlp_fcn): Mlp for each stage. Default: Mlp.
553
+ norm_layers (list, tuple or norm_fcn): Norm layers for each stage. Default: partial(LayerNormGeneral, eps=1e-6, bias=False).
554
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
555
+ head_dropout (float): dropout for MLP classifier. Default: 0.
556
+ layer_scale_init_values (list, tuple, float or None): Init value for Layer Scale. Default: None.
557
+ None means not use the layer scale. Form: https://arxiv.org/abs/2103.17239.
558
+ res_scale_init_values (list, tuple, float or None): Init value for Layer Scale. Default: [None, None, 1.0, 1.0].
559
+ None means not use the layer scale. From: https://arxiv.org/abs/2110.09456.
560
+ output_norm: norm before classifier head. Default: partial(nn.LayerNorm, eps=1e-6).
561
+ head_fn: classification head. Default: nn.Linear.
562
+ """
563
+ def __init__(self, in_chans=3, num_classes=1000,
564
+ depths=[2, 2, 6, 2],
565
+ dims=[64, 128, 320, 512],
566
+ downsample_layers=DOWNSAMPLE_LAYERS_FOUR_STAGES,
567
+ token_mixers=nn.Identity,
568
+ mlps=Mlp,
569
+ norm_layers=partial(LayerNormWithoutBias, eps=1e-6), # partial(LayerNormGeneral, eps=1e-6, bias=False),
570
+ drop_path_rate=0.,
571
+ head_dropout=0.0,
572
+ layer_scale_init_values=None,
573
+ res_scale_init_values=[None, None, 1.0, 1.0],
574
+ output_norm=partial(nn.LayerNorm, eps=1e-6),
575
+ head_fn=nn.Linear,
576
+ **kwargs,
577
+ ):
578
+ super().__init__()
579
+ self.num_classes = num_classes
580
+
581
+ if not isinstance(depths, (list, tuple)):
582
+ depths = [depths] # it means the model has only one stage
583
+ if not isinstance(dims, (list, tuple)):
584
+ dims = [dims]
585
+
586
+ num_stage = len(depths)
587
+ self.num_stage = num_stage
588
+
589
+ if not isinstance(downsample_layers, (list, tuple)):
590
+ downsample_layers = [downsample_layers] * num_stage
591
+ down_dims = [in_chans] + dims
592
+ self.downsample_layers = nn.ModuleList(
593
+ [downsample_layers[i](down_dims[i], down_dims[i+1]) for i in range(num_stage)]
594
+ )
595
+
596
+ if not isinstance(token_mixers, (list, tuple)):
597
+ token_mixers = [token_mixers] * num_stage
598
+
599
+ if not isinstance(mlps, (list, tuple)):
600
+ mlps = [mlps] * num_stage
601
+
602
+ if not isinstance(norm_layers, (list, tuple)):
603
+ norm_layers = [norm_layers] * num_stage
604
+
605
+ dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
606
+
607
+ if not isinstance(layer_scale_init_values, (list, tuple)):
608
+ layer_scale_init_values = [layer_scale_init_values] * num_stage
609
+ if not isinstance(res_scale_init_values, (list, tuple)):
610
+ res_scale_init_values = [res_scale_init_values] * num_stage
611
+
612
+ self.stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
613
+ cur = 0
614
+ for i in range(num_stage):
615
+ stage = nn.Sequential(
616
+ *[MetaFormerBlock(dim=dims[i],
617
+ token_mixer=token_mixers[i],
618
+ mlp=mlps[i],
619
+ norm_layer=norm_layers[i],
620
+ drop_path=dp_rates[cur + j],
621
+ layer_scale_init_value=layer_scale_init_values[i],
622
+ res_scale_init_value=res_scale_init_values[i],
623
+ ) for j in range(depths[i])]
624
+ )
625
+ self.stages.append(stage)
626
+ cur += depths[i]
627
+
628
+ self.norm = output_norm(dims[-1])
629
+
630
+ if head_dropout > 0.0:
631
+ self.head = head_fn(dims[-1], num_classes, head_dropout=head_dropout)
632
+ else:
633
+ self.head = head_fn(dims[-1], num_classes)
634
+
635
+ self.apply(self._init_weights)
636
+
637
+ def _init_weights(self, m):
638
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
639
+ trunc_normal_(m.weight, std=.02)
640
+ if m.bias is not None:
641
+ nn.init.constant_(m.bias, 0)
642
+
643
+ @torch.jit.ignore
644
+ def no_weight_decay(self):
645
+ return {'norm'}
646
+
647
+ def forward_features(self, x):
648
+ for i in range(self.num_stage):
649
+ x = self.downsample_layers[i](x)
650
+ x = self.stages[i](x)
651
+ return self.norm(x.mean([1, 2])) # (B, H, W, C) -> (B, C)
652
+
653
+ def forward(self, x):
654
+ x = self.forward_features(x)
655
+ x = self.head(x)
656
+ return x
657
+
658
+
659
+
660
+ @register_model
661
+ def identityformer_s12(pretrained=False, **kwargs):
662
+ model = MetaFormer(
663
+ depths=[2, 2, 6, 2],
664
+ dims=[64, 128, 320, 512],
665
+ token_mixers=nn.Identity,
666
+ norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
667
+ **kwargs)
668
+ model.default_cfg = default_cfgs['identityformer_s12']
669
+ if pretrained:
670
+ state_dict = torch.hub.load_state_dict_from_url(
671
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
672
+ model.load_state_dict(state_dict)
673
+ return model
674
+
675
+
676
+ @register_model
677
+ def identityformer_s24(pretrained=False, **kwargs):
678
+ model = MetaFormer(
679
+ depths=[4, 4, 12, 4],
680
+ dims=[64, 128, 320, 512],
681
+ token_mixers=nn.Identity,
682
+ norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
683
+ **kwargs)
684
+ model.default_cfg = default_cfgs['identityformer_s24']
685
+ if pretrained:
686
+ state_dict = torch.hub.load_state_dict_from_url(
687
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
688
+ model.load_state_dict(state_dict)
689
+ return model
690
+
691
+
692
+ @register_model
693
+ def identityformer_s36(pretrained=False, **kwargs):
694
+ model = MetaFormer(
695
+ depths=[6, 6, 18, 6],
696
+ dims=[64, 128, 320, 512],
697
+ token_mixers=nn.Identity,
698
+ norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
699
+ **kwargs)
700
+ model.default_cfg = default_cfgs['identityformer_s36']
701
+ if pretrained:
702
+ state_dict = torch.hub.load_state_dict_from_url(
703
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
704
+ model.load_state_dict(state_dict)
705
+ return model
706
+
707
+
708
+ @register_model
709
+ def identityformer_m36(pretrained=False, **kwargs):
710
+ model = MetaFormer(
711
+ depths=[6, 6, 18, 6],
712
+ dims=[96, 192, 384, 768],
713
+ token_mixers=nn.Identity,
714
+ norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
715
+ **kwargs)
716
+ model.default_cfg = default_cfgs['identityformer_m36']
717
+ if pretrained:
718
+ state_dict = torch.hub.load_state_dict_from_url(
719
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
720
+ model.load_state_dict(state_dict)
721
+ return model
722
+
723
+
724
+ @register_model
725
+ def identityformer_m48(pretrained=False, **kwargs):
726
+ model = MetaFormer(
727
+ depths=[8, 8, 24, 8],
728
+ dims=[96, 192, 384, 768],
729
+ token_mixers=nn.Identity,
730
+ norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
731
+ **kwargs)
732
+ model.default_cfg = default_cfgs['identityformer_m48']
733
+ if pretrained:
734
+ state_dict = torch.hub.load_state_dict_from_url(
735
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
736
+ model.load_state_dict(state_dict)
737
+ return model
738
+
739
+
740
+ @register_model
741
+ def randformer_s12(pretrained=False, **kwargs):
742
+ model = MetaFormer(
743
+ depths=[2, 2, 6, 2],
744
+ dims=[64, 128, 320, 512],
745
+ token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)],
746
+ norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
747
+ **kwargs)
748
+ model.default_cfg = default_cfgs['randformer_s12']
749
+ if pretrained:
750
+ state_dict = torch.hub.load_state_dict_from_url(
751
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
752
+ model.load_state_dict(state_dict)
753
+ return model
754
+
755
+
756
+ @register_model
757
+ def randformer_s24(pretrained=False, **kwargs):
758
+ model = MetaFormer(
759
+ depths=[4, 4, 12, 4],
760
+ dims=[64, 128, 320, 512],
761
+ token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)],
762
+ norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
763
+ **kwargs)
764
+ model.default_cfg = default_cfgs['randformer_s24']
765
+ if pretrained:
766
+ state_dict = torch.hub.load_state_dict_from_url(
767
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
768
+ model.load_state_dict(state_dict)
769
+ return model
770
+
771
+
772
+ @register_model
773
+ def randformer_s36(pretrained=False, **kwargs):
774
+ model = MetaFormer(
775
+ depths=[6, 6, 18, 6],
776
+ dims=[64, 128, 320, 512],
777
+ token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)],
778
+ norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
779
+ **kwargs)
780
+ model.default_cfg = default_cfgs['randformer_s36']
781
+ if pretrained:
782
+ state_dict = torch.hub.load_state_dict_from_url(
783
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
784
+ model.load_state_dict(state_dict)
785
+ return model
786
+
787
+
788
+ @register_model
789
+ def randformer_m36(pretrained=False, **kwargs):
790
+ model = MetaFormer(
791
+ depths=[6, 6, 18, 6],
792
+ dims=[96, 192, 384, 768],
793
+ token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)],
794
+ norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
795
+ **kwargs)
796
+ model.default_cfg = default_cfgs['randformer_m36']
797
+ if pretrained:
798
+ state_dict = torch.hub.load_state_dict_from_url(
799
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
800
+ model.load_state_dict(state_dict)
801
+ return model
802
+
803
+
804
+ @register_model
805
+ def randformer_m48(pretrained=False, **kwargs):
806
+ model = MetaFormer(
807
+ depths=[8, 8, 24, 8],
808
+ dims=[96, 192, 384, 768],
809
+ token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)],
810
+ norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
811
+ **kwargs)
812
+ model.default_cfg = default_cfgs['randformer_m48']
813
+ if pretrained:
814
+ state_dict = torch.hub.load_state_dict_from_url(
815
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
816
+ model.load_state_dict(state_dict)
817
+ return model
818
+
819
+
820
+
821
+ @register_model
822
+ def poolformerv2_s12(pretrained=False, **kwargs):
823
+ model = MetaFormer(
824
+ depths=[2, 2, 6, 2],
825
+ dims=[64, 128, 320, 512],
826
+ token_mixers=Pooling,
827
+ norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
828
+ **kwargs)
829
+ model.default_cfg = default_cfgs['poolformerv2_s12']
830
+ if pretrained:
831
+ state_dict = torch.hub.load_state_dict_from_url(
832
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
833
+ model.load_state_dict(state_dict)
834
+ return model
835
+
836
+
837
+ @register_model
838
+ def poolformerv2_s24(pretrained=False, **kwargs):
839
+ model = MetaFormer(
840
+ depths=[4, 4, 12, 4],
841
+ dims=[64, 128, 320, 512],
842
+ token_mixers=Pooling,
843
+ norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
844
+ **kwargs)
845
+ model.default_cfg = default_cfgs['poolformerv2_s24']
846
+ if pretrained:
847
+ state_dict = torch.hub.load_state_dict_from_url(
848
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
849
+ model.load_state_dict(state_dict)
850
+ return model
851
+
852
+
853
+ @register_model
854
+ def poolformerv2_s36(pretrained=False, **kwargs):
855
+ model = MetaFormer(
856
+ depths=[6, 6, 18, 6],
857
+ dims=[64, 128, 320, 512],
858
+ token_mixers=Pooling,
859
+ norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
860
+ **kwargs)
861
+ model.default_cfg = default_cfgs['poolformerv2_s36']
862
+ if pretrained:
863
+ state_dict = torch.hub.load_state_dict_from_url(
864
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
865
+ model.load_state_dict(state_dict)
866
+ return model
867
+
868
+
869
+ @register_model
870
+ def poolformerv2_m36(pretrained=False, **kwargs):
871
+ model = MetaFormer(
872
+ depths=[6, 6, 18, 6],
873
+ dims=[96, 192, 384, 768],
874
+ token_mixers=Pooling,
875
+ norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
876
+ **kwargs)
877
+ model.default_cfg = default_cfgs['poolformerv2_m36']
878
+ if pretrained:
879
+ state_dict = torch.hub.load_state_dict_from_url(
880
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
881
+ model.load_state_dict(state_dict)
882
+ return model
883
+
884
+
885
+ @register_model
886
+ def poolformerv2_m48(pretrained=False, **kwargs):
887
+ model = MetaFormer(
888
+ depths=[8, 8, 24, 8],
889
+ dims=[96, 192, 384, 768],
890
+ token_mixers=Pooling,
891
+ norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
892
+ **kwargs)
893
+ model.default_cfg = default_cfgs['poolformerv2_m48']
894
+ if pretrained:
895
+ state_dict = torch.hub.load_state_dict_from_url(
896
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
897
+ model.load_state_dict(state_dict)
898
+ return model
899
+
900
+
901
+ @register_model
902
+ def convformer_s18(pretrained=False, **kwargs):
903
+ model = MetaFormer(
904
+ depths=[3, 3, 9, 3],
905
+ dims=[64, 128, 320, 512],
906
+ token_mixers=SepConv,
907
+ head_fn=MlpHead,
908
+ **kwargs)
909
+ model.default_cfg = default_cfgs['convformer_s18']
910
+ if pretrained:
911
+ state_dict = torch.hub.load_state_dict_from_url(
912
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
913
+ model.load_state_dict(state_dict)
914
+ return model
915
+
916
+
917
+ @register_model
918
+ def convformer_s18_384(pretrained=False, **kwargs):
919
+ model = MetaFormer(
920
+ depths=[3, 3, 9, 3],
921
+ dims=[64, 128, 320, 512],
922
+ token_mixers=SepConv,
923
+ head_fn=MlpHead,
924
+ **kwargs)
925
+ model.default_cfg = default_cfgs['convformer_s18_384']
926
+ if pretrained:
927
+ state_dict = torch.hub.load_state_dict_from_url(
928
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
929
+ model.load_state_dict(state_dict)
930
+ return model
931
+
932
+
933
+ @register_model
934
+ def convformer_s18_in21ft1k(pretrained=False, **kwargs):
935
+ model = MetaFormer(
936
+ depths=[3, 3, 9, 3],
937
+ dims=[64, 128, 320, 512],
938
+ token_mixers=SepConv,
939
+ head_fn=MlpHead,
940
+ **kwargs)
941
+ model.default_cfg = default_cfgs['convformer_s18_in21ft1k']
942
+ if pretrained:
943
+ state_dict = torch.hub.load_state_dict_from_url(
944
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
945
+ model.load_state_dict(state_dict)
946
+ return model
947
+
948
+
949
+ @register_model
950
+ def convformer_s18_384_in21ft1k(pretrained=False, **kwargs):
951
+ model = MetaFormer(
952
+ depths=[3, 3, 9, 3],
953
+ dims=[64, 128, 320, 512],
954
+ token_mixers=SepConv,
955
+ head_fn=MlpHead,
956
+ **kwargs)
957
+ model.default_cfg = default_cfgs['convformer_s18_384_in21ft1k']
958
+ if pretrained:
959
+ state_dict = torch.hub.load_state_dict_from_url(
960
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
961
+ model.load_state_dict(state_dict)
962
+ return model
963
+
964
+
965
+ @register_model
966
+ def convformer_s18_in21k(pretrained=False, **kwargs):
967
+ model = MetaFormer(
968
+ depths=[3, 3, 9, 3],
969
+ dims=[64, 128, 320, 512],
970
+ token_mixers=SepConv,
971
+ head_fn=MlpHead,
972
+ **kwargs)
973
+ model.default_cfg = default_cfgs['convformer_s18_in21k']
974
+ if pretrained:
975
+ state_dict = torch.hub.load_state_dict_from_url(
976
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
977
+ model.load_state_dict(state_dict)
978
+ return model
979
+
980
+
981
+ @register_model
982
+ def convformer_s36(pretrained=False, **kwargs):
983
+ model = MetaFormer(
984
+ depths=[3, 12, 18, 3],
985
+ dims=[64, 128, 320, 512],
986
+ token_mixers=SepConv,
987
+ head_fn=MlpHead,
988
+ **kwargs)
989
+ model.default_cfg = default_cfgs['convformer_s36']
990
+ if pretrained:
991
+ state_dict = torch.hub.load_state_dict_from_url(
992
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
993
+ model.load_state_dict(state_dict)
994
+ return model
995
+
996
+
997
+ @register_model
998
+ def convformer_s36_384(pretrained=False, **kwargs):
999
+ model = MetaFormer(
1000
+ depths=[3, 12, 18, 3],
1001
+ dims=[64, 128, 320, 512],
1002
+ token_mixers=SepConv,
1003
+ head_fn=MlpHead,
1004
+ **kwargs)
1005
+ model.default_cfg = default_cfgs['convformer_s36_384']
1006
+ if pretrained:
1007
+ state_dict = torch.hub.load_state_dict_from_url(
1008
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1009
+ model.load_state_dict(state_dict)
1010
+ return model
1011
+
1012
+
1013
+ @register_model
1014
+ def convformer_s36_in21ft1k(pretrained=False, **kwargs):
1015
+ model = MetaFormer(
1016
+ depths=[3, 12, 18, 3],
1017
+ dims=[64, 128, 320, 512],
1018
+ token_mixers=SepConv,
1019
+ head_fn=MlpHead,
1020
+ **kwargs)
1021
+ model.default_cfg = default_cfgs['convformer_s36_in21ft1k']
1022
+ if pretrained:
1023
+ state_dict = torch.hub.load_state_dict_from_url(
1024
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1025
+ model.load_state_dict(state_dict)
1026
+ return model
1027
+
1028
+
1029
+ @register_model
1030
+ def convformer_s36_384_in21ft1k(pretrained=False, **kwargs):
1031
+ model = MetaFormer(
1032
+ depths=[3, 12, 18, 3],
1033
+ dims=[64, 128, 320, 512],
1034
+ token_mixers=SepConv,
1035
+ head_fn=MlpHead,
1036
+ **kwargs)
1037
+ model.default_cfg = default_cfgs['convformer_s36_384_in21ft1k']
1038
+ if pretrained:
1039
+ state_dict = torch.hub.load_state_dict_from_url(
1040
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1041
+ model.load_state_dict(state_dict)
1042
+ return model
1043
+
1044
+
1045
+ @register_model
1046
+ def convformer_s36_in21k(pretrained=False, **kwargs):
1047
+ model = MetaFormer(
1048
+ depths=[3, 12, 18, 3],
1049
+ dims=[64, 128, 320, 512],
1050
+ token_mixers=SepConv,
1051
+ head_fn=MlpHead,
1052
+ **kwargs)
1053
+ model.default_cfg = default_cfgs['convformer_s36_in21k']
1054
+ if pretrained:
1055
+ state_dict = torch.hub.load_state_dict_from_url(
1056
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1057
+ model.load_state_dict(state_dict)
1058
+ return model
1059
+
1060
+
1061
+ @register_model
1062
+ def convformer_m36(pretrained=False, **kwargs):
1063
+ model = MetaFormer(
1064
+ depths=[3, 12, 18, 3],
1065
+ dims=[96, 192, 384, 576],
1066
+ token_mixers=SepConv,
1067
+ head_fn=MlpHead,
1068
+ **kwargs)
1069
+ model.default_cfg = default_cfgs['convformer_m36']
1070
+ if pretrained:
1071
+ state_dict = torch.hub.load_state_dict_from_url(
1072
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1073
+ model.load_state_dict(state_dict)
1074
+ return model
1075
+
1076
+
1077
+ @register_model
1078
+ def convformer_m36_384(pretrained=False, **kwargs):
1079
+ model = MetaFormer(
1080
+ depths=[3, 12, 18, 3],
1081
+ dims=[96, 192, 384, 576],
1082
+ token_mixers=SepConv,
1083
+ head_fn=MlpHead,
1084
+ **kwargs)
1085
+ model.default_cfg = default_cfgs['convformer_m36_384']
1086
+ if pretrained:
1087
+ state_dict = torch.hub.load_state_dict_from_url(
1088
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1089
+ model.load_state_dict(state_dict)
1090
+ return model
1091
+
1092
+
1093
+ @register_model
1094
+ def convformer_m36_in21ft1k(pretrained=False, **kwargs):
1095
+ model = MetaFormer(
1096
+ depths=[3, 12, 18, 3],
1097
+ dims=[96, 192, 384, 576],
1098
+ token_mixers=SepConv,
1099
+ head_fn=MlpHead,
1100
+ **kwargs)
1101
+ model.default_cfg = default_cfgs['convformer_m36_in21ft1k']
1102
+ if pretrained:
1103
+ state_dict = torch.hub.load_state_dict_from_url(
1104
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1105
+ model.load_state_dict(state_dict)
1106
+ return model
1107
+
1108
+
1109
+ @register_model
1110
+ def convformer_m36_384_in21ft1k(pretrained=False, **kwargs):
1111
+ model = MetaFormer(
1112
+ depths=[3, 12, 18, 3],
1113
+ dims=[96, 192, 384, 576],
1114
+ token_mixers=SepConv,
1115
+ head_fn=MlpHead,
1116
+ **kwargs)
1117
+ model.default_cfg = default_cfgs['convformer_m36_384_in21ft1k']
1118
+ if pretrained:
1119
+ state_dict = torch.hub.load_state_dict_from_url(
1120
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1121
+ model.load_state_dict(state_dict)
1122
+ return model
1123
+
1124
+
1125
+ @register_model
1126
+ def convformer_m36_in21k(pretrained=False, **kwargs):
1127
+ model = MetaFormer(
1128
+ depths=[3, 12, 18, 3],
1129
+ dims=[96, 192, 384, 576],
1130
+ token_mixers=SepConv,
1131
+ head_fn=MlpHead,
1132
+ **kwargs)
1133
+ model.default_cfg = default_cfgs['convformer_m36_in21k']
1134
+ if pretrained:
1135
+ state_dict = torch.hub.load_state_dict_from_url(
1136
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1137
+ model.load_state_dict(state_dict)
1138
+ return model
1139
+
1140
+
1141
+ @register_model
1142
+ def convformer_b36(pretrained=False, **kwargs):
1143
+ model = MetaFormer(
1144
+ depths=[3, 12, 18, 3],
1145
+ dims=[128, 256, 512, 768],
1146
+ token_mixers=SepConv,
1147
+ head_fn=MlpHead,
1148
+ **kwargs)
1149
+ model.default_cfg = default_cfgs['convformer_b36']
1150
+ if pretrained:
1151
+ state_dict = torch.hub.load_state_dict_from_url(
1152
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1153
+ model.load_state_dict(state_dict)
1154
+ return model
1155
+
1156
+
1157
+ @register_model
1158
+ def convformer_b36_384(pretrained=False, **kwargs):
1159
+ model = MetaFormer(
1160
+ depths=[3, 12, 18, 3],
1161
+ dims=[128, 256, 512, 768],
1162
+ token_mixers=SepConv,
1163
+ head_fn=MlpHead,
1164
+ **kwargs)
1165
+ model.default_cfg = default_cfgs['convformer_b36_384']
1166
+ if pretrained:
1167
+ state_dict = torch.hub.load_state_dict_from_url(
1168
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1169
+ model.load_state_dict(state_dict)
1170
+ return model
1171
+
1172
+
1173
+ @register_model
1174
+ def convformer_b36_in21ft1k(pretrained=False, **kwargs):
1175
+ model = MetaFormer(
1176
+ depths=[3, 12, 18, 3],
1177
+ dims=[128, 256, 512, 768],
1178
+ token_mixers=SepConv,
1179
+ head_fn=MlpHead,
1180
+ **kwargs)
1181
+ model.default_cfg = default_cfgs['convformer_b36_in21ft1k']
1182
+ if pretrained:
1183
+ state_dict = torch.hub.load_state_dict_from_url(
1184
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1185
+ model.load_state_dict(state_dict)
1186
+ return model
1187
+
1188
+
1189
+ @register_model
1190
+ def convformer_b36_384_in21ft1k(pretrained=False, **kwargs):
1191
+ model = MetaFormer(
1192
+ depths=[3, 12, 18, 3],
1193
+ dims=[128, 256, 512, 768],
1194
+ token_mixers=SepConv,
1195
+ head_fn=MlpHead,
1196
+ **kwargs)
1197
+ model.default_cfg = default_cfgs['convformer_b36_384_in21ft1k']
1198
+ if pretrained:
1199
+ state_dict = torch.hub.load_state_dict_from_url(
1200
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1201
+ model.load_state_dict(state_dict)
1202
+ return model
1203
+
1204
+
1205
+ @register_model
1206
+ def convformer_b36_in21k(pretrained=False, **kwargs):
1207
+ model = MetaFormer(
1208
+ depths=[3, 12, 18, 3],
1209
+ dims=[128, 256, 512, 768],
1210
+ token_mixers=SepConv,
1211
+ head_fn=MlpHead,
1212
+ **kwargs)
1213
+ model.default_cfg = default_cfgs['convformer_b36_in21k']
1214
+ if pretrained:
1215
+ state_dict = torch.hub.load_state_dict_from_url(
1216
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1217
+ model.load_state_dict(state_dict)
1218
+ return model
1219
+
1220
+
1221
+ @register_model
1222
+ def caformer_s18(pretrained=False, **kwargs):
1223
+ model = MetaFormer(
1224
+ depths=[3, 3, 9, 3],
1225
+ dims=[64, 128, 320, 512],
1226
+ token_mixers=[SepConv, SepConv, Attention, Attention],
1227
+ head_fn=MlpHead,
1228
+ **kwargs)
1229
+ model.default_cfg = default_cfgs['caformer_s18']
1230
+ if pretrained:
1231
+ state_dict = torch.hub.load_state_dict_from_url(
1232
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1233
+ model.load_state_dict(state_dict)
1234
+ return model
1235
+
1236
+
1237
+ @register_model
1238
+ def caformer_s18_384(pretrained=False, **kwargs):
1239
+ model = MetaFormer(
1240
+ depths=[3, 3, 9, 3],
1241
+ dims=[64, 128, 320, 512],
1242
+ token_mixers=[SepConv, SepConv, Attention, Attention],
1243
+ head_fn=MlpHead,
1244
+ **kwargs)
1245
+ model.default_cfg = default_cfgs['caformer_s18_384']
1246
+ if pretrained:
1247
+ state_dict = torch.hub.load_state_dict_from_url(
1248
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1249
+ model.load_state_dict(state_dict)
1250
+ return model
1251
+
1252
+
1253
+ @register_model
1254
+ def caformer_s18_in21ft1k(pretrained=False, **kwargs):
1255
+ model = MetaFormer(
1256
+ depths=[3, 3, 9, 3],
1257
+ dims=[64, 128, 320, 512],
1258
+ token_mixers=[SepConv, SepConv, Attention, Attention],
1259
+ head_fn=MlpHead,
1260
+ **kwargs)
1261
+ model.default_cfg = default_cfgs['caformer_s18_in21ft1k']
1262
+ if pretrained:
1263
+ state_dict = torch.hub.load_state_dict_from_url(
1264
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1265
+ model.load_state_dict(state_dict)
1266
+ return model
1267
+
1268
+
1269
+ @register_model
1270
+ def caformer_s18_384_in21ft1k(pretrained=False, **kwargs):
1271
+ model = MetaFormer(
1272
+ depths=[3, 3, 9, 3],
1273
+ dims=[64, 128, 320, 512],
1274
+ token_mixers=[SepConv, SepConv, Attention, Attention],
1275
+ head_fn=MlpHead,
1276
+ **kwargs)
1277
+ model.default_cfg = default_cfgs['caformer_s18_384_in21ft1k']
1278
+ if pretrained:
1279
+ state_dict = torch.hub.load_state_dict_from_url(
1280
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1281
+ model.load_state_dict(state_dict)
1282
+ return model
1283
+
1284
+
1285
+ @register_model
1286
+ def caformer_s18_in21k(pretrained=False, **kwargs):
1287
+ model = MetaFormer(
1288
+ depths=[3, 3, 9, 3],
1289
+ dims=[64, 128, 320, 512],
1290
+ token_mixers=[SepConv, SepConv, Attention, Attention],
1291
+ head_fn=MlpHead,
1292
+ **kwargs)
1293
+ model.default_cfg = default_cfgs['caformer_s18_in21k']
1294
+ if pretrained:
1295
+ state_dict = torch.hub.load_state_dict_from_url(
1296
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1297
+ model.load_state_dict(state_dict)
1298
+ return model
1299
+
1300
+
1301
+ @register_model
1302
+ def caformer_s36(pretrained=False, **kwargs):
1303
+ model = MetaFormer(
1304
+ depths=[3, 12, 18, 3],
1305
+ dims=[64, 128, 320, 512],
1306
+ token_mixers=[SepConv, SepConv, Attention, Attention],
1307
+ head_fn=MlpHead,
1308
+ **kwargs)
1309
+ model.default_cfg = default_cfgs['caformer_s36']
1310
+ if pretrained:
1311
+ state_dict = torch.hub.load_state_dict_from_url(
1312
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1313
+ model.load_state_dict(state_dict)
1314
+ return model
1315
+
1316
+
1317
+ @register_model
1318
+ def caformer_s36_384(pretrained=False, **kwargs):
1319
+ model = MetaFormer(
1320
+ depths=[3, 12, 18, 3],
1321
+ dims=[64, 128, 320, 512],
1322
+ token_mixers=[SepConv, SepConv, Attention, Attention],
1323
+ head_fn=MlpHead,
1324
+ **kwargs)
1325
+ model.default_cfg = default_cfgs['caformer_s36_384']
1326
+ if pretrained:
1327
+ state_dict = torch.hub.load_state_dict_from_url(
1328
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1329
+ model.load_state_dict(state_dict)
1330
+ return model
1331
+
1332
+
1333
+ @register_model
1334
+ def caformer_s36_in21ft1k(pretrained=False, **kwargs):
1335
+ model = MetaFormer(
1336
+ depths=[3, 12, 18, 3],
1337
+ dims=[64, 128, 320, 512],
1338
+ token_mixers=[SepConv, SepConv, Attention, Attention],
1339
+ head_fn=MlpHead,
1340
+ **kwargs)
1341
+ model.default_cfg = default_cfgs['caformer_s36_in21ft1k']
1342
+ if pretrained:
1343
+ state_dict = torch.hub.load_state_dict_from_url(
1344
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1345
+ model.load_state_dict(state_dict)
1346
+ return model
1347
+
1348
+
1349
+ @register_model
1350
+ def caformer_s36_384_in21ft1k(pretrained=False, **kwargs):
1351
+ model = MetaFormer(
1352
+ depths=[3, 12, 18, 3],
1353
+ dims=[64, 128, 320, 512],
1354
+ token_mixers=[SepConv, SepConv, Attention, Attention],
1355
+ head_fn=MlpHead,
1356
+ **kwargs)
1357
+ model.default_cfg = default_cfgs['caformer_s36_384_in21ft1k']
1358
+ if pretrained:
1359
+ state_dict = torch.hub.load_state_dict_from_url(
1360
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1361
+ model.load_state_dict(state_dict)
1362
+ return model
1363
+
1364
+
1365
+ @register_model
1366
+ def caformer_s36_in21k(pretrained=False, **kwargs):
1367
+ model = MetaFormer(
1368
+ depths=[3, 12, 18, 3],
1369
+ dims=[64, 128, 320, 512],
1370
+ token_mixers=[SepConv, SepConv, Attention, Attention],
1371
+ head_fn=MlpHead,
1372
+ **kwargs)
1373
+ model.default_cfg = default_cfgs['caformer_s36_in21k']
1374
+ if pretrained:
1375
+ state_dict = torch.hub.load_state_dict_from_url(
1376
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1377
+ model.load_state_dict(state_dict)
1378
+ return model
1379
+
1380
+
1381
+ @register_model
1382
+ def caformer_m36(pretrained=False, **kwargs):
1383
+ model = MetaFormer(
1384
+ depths=[3, 12, 18, 3],
1385
+ dims=[96, 192, 384, 576],
1386
+ token_mixers=[SepConv, SepConv, Attention, Attention],
1387
+ head_fn=MlpHead,
1388
+ **kwargs)
1389
+ model.default_cfg = default_cfgs['caformer_m36']
1390
+ if pretrained:
1391
+ state_dict = torch.hub.load_state_dict_from_url(
1392
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1393
+ model.load_state_dict(state_dict)
1394
+ return model
1395
+
1396
+
1397
+ @register_model
1398
+ def caformer_m36_384(pretrained=False, **kwargs):
1399
+ model = MetaFormer(
1400
+ depths=[3, 12, 18, 3],
1401
+ dims=[96, 192, 384, 576],
1402
+ token_mixers=[SepConv, SepConv, Attention, Attention],
1403
+ head_fn=MlpHead,
1404
+ **kwargs)
1405
+ model.default_cfg = default_cfgs['caformer_m36_384']
1406
+ if pretrained:
1407
+ state_dict = torch.hub.load_state_dict_from_url(
1408
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1409
+ model.load_state_dict(state_dict)
1410
+ return model
1411
+
1412
+
1413
+ @register_model
1414
+ def caformer_m36_in21ft1k(pretrained=False, **kwargs):
1415
+ model = MetaFormer(
1416
+ depths=[3, 12, 18, 3],
1417
+ dims=[96, 192, 384, 576],
1418
+ token_mixers=[SepConv, SepConv, Attention, Attention],
1419
+ head_fn=MlpHead,
1420
+ **kwargs)
1421
+ model.default_cfg = default_cfgs['caformer_m36_in21ft1k']
1422
+ if pretrained:
1423
+ state_dict = torch.hub.load_state_dict_from_url(
1424
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1425
+ model.load_state_dict(state_dict)
1426
+ return model
1427
+
1428
+
1429
+ @register_model
1430
+ def caformer_m36_384_in21ft1k(pretrained=False, **kwargs):
1431
+ model = MetaFormer(
1432
+ depths=[3, 12, 18, 3],
1433
+ dims=[96, 192, 384, 576],
1434
+ token_mixers=[SepConv, SepConv, Attention, Attention],
1435
+ head_fn=MlpHead,
1436
+ **kwargs)
1437
+ model.default_cfg = default_cfgs['caformer_m36_384_in21ft1k']
1438
+ if pretrained:
1439
+ state_dict = torch.hub.load_state_dict_from_url(
1440
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1441
+ model.load_state_dict(state_dict)
1442
+ return model
1443
+
1444
+
1445
+ @register_model
1446
+ def caformer_m364_in21k(pretrained=False, **kwargs):
1447
+ model = MetaFormer(
1448
+ depths=[3, 12, 18, 3],
1449
+ dims=[96, 192, 384, 576],
1450
+ token_mixers=[SepConv, SepConv, Attention, Attention],
1451
+ head_fn=MlpHead,
1452
+ **kwargs)
1453
+ model.default_cfg = default_cfgs['caformer_m364_in21k']
1454
+ if pretrained:
1455
+ state_dict = torch.hub.load_state_dict_from_url(
1456
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1457
+ model.load_state_dict(state_dict)
1458
+ return model
1459
+
1460
+
1461
+ @register_model
1462
+ def caformer_b36(pretrained=False, **kwargs):
1463
+ model = MetaFormer(
1464
+ depths=[3, 12, 18, 3],
1465
+ dims=[128, 256, 512, 768],
1466
+ token_mixers=[SepConv, SepConv, Attention, Attention],
1467
+ head_fn=MlpHead,
1468
+ **kwargs)
1469
+ model.default_cfg = default_cfgs['caformer_b36']
1470
+ if pretrained:
1471
+ state_dict = torch.hub.load_state_dict_from_url(
1472
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1473
+ model.load_state_dict(state_dict)
1474
+ return model
1475
+
1476
+
1477
+ @register_model
1478
+ def caformer_b36_384(pretrained=False, **kwargs):
1479
+ model = MetaFormer(
1480
+ depths=[3, 12, 18, 3],
1481
+ dims=[128, 256, 512, 768],
1482
+ token_mixers=[SepConv, SepConv, Attention, Attention],
1483
+ head_fn=MlpHead,
1484
+ **kwargs)
1485
+ model.default_cfg = default_cfgs['caformer_b36_384']
1486
+ if pretrained:
1487
+ state_dict = torch.hub.load_state_dict_from_url(
1488
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1489
+ model.load_state_dict(state_dict)
1490
+ return model
1491
+
1492
+
1493
+ @register_model
1494
+ def caformer_b36_in21ft1k(pretrained=False, **kwargs):
1495
+ model = MetaFormer(
1496
+ depths=[3, 12, 18, 3],
1497
+ dims=[128, 256, 512, 768],
1498
+ token_mixers=[SepConv, SepConv, Attention, Attention],
1499
+ head_fn=MlpHead,
1500
+ **kwargs)
1501
+ model.default_cfg = default_cfgs['caformer_b36_in21ft1k']
1502
+ if pretrained:
1503
+ state_dict = torch.hub.load_state_dict_from_url(
1504
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1505
+ model.load_state_dict(state_dict)
1506
+ return model
1507
+
1508
+
1509
+ @register_model
1510
+ def caformer_b36_384_in21ft1k(pretrained=False, **kwargs):
1511
+ model = MetaFormer(
1512
+ depths=[3, 12, 18, 3],
1513
+ dims=[128, 256, 512, 768],
1514
+ token_mixers=[SepConv, SepConv, Attention, Attention],
1515
+ head_fn=MlpHead,
1516
+ **kwargs)
1517
+ model.default_cfg = default_cfgs['caformer_b36_384_in21ft1k']
1518
+ if pretrained:
1519
+ state_dict = torch.hub.load_state_dict_from_url(
1520
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1521
+ model.load_state_dict(state_dict)
1522
+ return model
1523
+
1524
+
1525
+ @register_model
1526
+ def caformer_b36_in21k(pretrained=False, **kwargs):
1527
+ model = MetaFormer(
1528
+ depths=[3, 12, 18, 3],
1529
+ dims=[128, 256, 512, 768],
1530
+ token_mixers=[SepConv, SepConv, Attention, Attention],
1531
+ head_fn=MlpHead,
1532
+ **kwargs)
1533
+ model.default_cfg = default_cfgs['caformer_b36_in21k']
1534
+ if pretrained:
1535
+ state_dict = torch.hub.load_state_dict_from_url(
1536
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
1537
+ model.load_state_dict(state_dict)
1538
+ return model
models/neuron.py ADDED
@@ -0,0 +1,1587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from typing import Callable, overload
3
+ import torch
4
+ import torch.nn as nn
5
+ from spikingjelly.clock_driven import surrogate, base, lava_exchange
6
+ from spikingjelly import configure
7
+ import math
8
+ import numpy as np
9
+ import logging
10
+ import cupy
11
+ from spikingjelly.clock_driven import neuron_kernel, cu_kernel_opt
12
+
13
+
14
+ try:
15
+ import lava.lib.dl.slayer as slayer
16
+
17
+ except BaseException as e:
18
+ logging.info(f'spikingjelly.clock_driven.neuron: {e}')
19
+ slayer = None
20
+
21
+
22
+ def check_backend(backend: str):
23
+ if backend == 'torch':
24
+ return
25
+ elif backend == 'cupy':
26
+ assert cupy is not None, 'CuPy is not installed! You can install it from "https://github.com/cupy/cupy".'
27
+ elif backend == 'lava':
28
+ assert slayer is not None, 'Lava-DL is not installed! You can install it from "https://github.com/lava-nc/lava-dl".'
29
+ else:
30
+ raise NotImplementedError(backend)
31
+
32
+
33
+ class BaseNode(base.MemoryModule):
34
+ def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
35
+ surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False):
36
+ """
37
+ * :ref:`API in English <BaseNode.__init__-en>`
38
+
39
+ .. _BaseNode.__init__-cn:
40
+
41
+ :param v_threshold: 神经元的阈值电压
42
+ :type v_threshold: float
43
+
44
+ :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
45
+ 如果设置为 ``None``,则电压会被减去 ``v_threshold``
46
+ :type v_reset: float
47
+
48
+ :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
49
+ :type surrogate_function: Callable
50
+
51
+ :param detach_reset: 是否将reset过程的计算图分离
52
+ :type detach_reset: bool
53
+
54
+ 可微分SNN神经元的基类神经元。
55
+
56
+ * :ref:`中文API <BaseNode.__init__-cn>`
57
+
58
+ .. _BaseNode.__init__-en:
59
+
60
+ :param v_threshold: threshold voltage of neurons
61
+ :type v_threshold: float
62
+
63
+ :param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to
64
+ ``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold``
65
+ :type v_reset: float
66
+
67
+ :param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation
68
+ :type surrogate_function: Callable
69
+
70
+ :param detach_reset: whether detach the computation graph of reset
71
+ :type detach_reset: bool
72
+
73
+ This class is the base class of differentiable spiking neurons.
74
+ """
75
+ assert isinstance(v_reset, float) or v_reset is None
76
+ assert isinstance(v_threshold, float)
77
+ assert isinstance(detach_reset, bool)
78
+ super().__init__()
79
+
80
+ if v_reset is None:
81
+ self.register_memory('v', 0.)
82
+ else:
83
+ self.register_memory('v', v_reset)
84
+
85
+ self.register_memory('v_threshold', v_threshold)
86
+ self.register_memory('v_reset', v_reset)
87
+
88
+ self.detach_reset = detach_reset
89
+ self.surrogate_function = surrogate_function
90
+
91
+ @abstractmethod
92
+ def neuronal_charge(self, x: torch.Tensor):
93
+ """
94
+ * :ref:`API in English <BaseNode.neuronal_charge-en>`
95
+
96
+ .. _BaseNode.neuronal_charge-cn:
97
+
98
+ 定义神经元的充电差分方程。子类必须实现这个函数。
99
+
100
+ * :ref:`中文API <BaseNode.neuronal_charge-cn>`
101
+
102
+ .. _BaseNode.neuronal_charge-en:
103
+
104
+
105
+ Define the charge difference equation. The sub-class must implement this function.
106
+ """
107
+ raise NotImplementedError
108
+
109
+ def neuronal_fire(self):
110
+ """
111
+ * :ref:`API in English <BaseNode.neuronal_fire-en>`
112
+
113
+ .. _BaseNode.neuronal_fire-cn:
114
+
115
+ 根据当前神经元的电压、阈值,计算输出脉冲。
116
+
117
+ * :ref:`中文API <BaseNode.neuronal_fire-cn>`
118
+
119
+ .. _BaseNode.neuronal_fire-en:
120
+
121
+
122
+ Calculate out spikes of neurons by their current membrane potential and threshold voltage.
123
+ """
124
+
125
+ return self.surrogate_function(self.v - self.v_threshold)
126
+
127
+ def neuronal_reset(self, spike):
128
+ """
129
+ * :ref:`API in English <BaseNode.neuronal_reset-en>`
130
+
131
+ .. _BaseNode.neuronal_reset-cn:
132
+
133
+ 根据当前神经元释放的脉冲,对膜电位进行重置。
134
+
135
+ * :ref:`中文API <BaseNode.neuronal_reset-cn>`
136
+
137
+ .. _BaseNode.neuronal_reset-en:
138
+
139
+
140
+ Reset the membrane potential according to neurons' output spikes.
141
+ """
142
+ if self.detach_reset:
143
+ spike_d = spike.detach()
144
+ else:
145
+ spike_d = spike
146
+
147
+ if self.v_reset is None:
148
+ # soft reset
149
+ self.v = self.v - spike_d * self.v_threshold
150
+
151
+ else:
152
+ # hard reset
153
+ self.v = (1. - spike_d) * self.v + spike_d * self.v_reset
154
+
155
+ def extra_repr(self):
156
+ return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}'
157
+
158
+ def forward(self, x: torch.Tensor):
159
+ """
160
+
161
+ * :ref:`API in English <BaseNode.forward-en>`
162
+
163
+ .. _BaseNode.forward-cn:
164
+
165
+ :param x: 输入到神经元的电压增量
166
+ :type x: torch.Tensor
167
+
168
+ :return: 神经元的输出脉冲
169
+ :rtype: torch.Tensor
170
+
171
+ 按照充电、放电、重置的顺序进行前向传播。
172
+
173
+ * :ref:`中文API <BaseNode.forward-cn>`
174
+
175
+ .. _BaseNode.forward-en:
176
+
177
+ :param x: increment of voltage inputted to neurons
178
+ :type x: torch.Tensor
179
+
180
+ :return: out spikes of neurons
181
+ :rtype: torch.Tensor
182
+
183
+ Forward by the order of `neuronal_charge`, `neuronal_fire`, and `neuronal_reset`.
184
+
185
+ """
186
+ self.neuronal_charge(x)
187
+ spike = self.neuronal_fire()
188
+ self.neuronal_reset(spike)
189
+ return spike
190
+
191
+
192
+ class AdaptiveBaseNode(BaseNode):
193
+ def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
194
+ v_rest: float = 0., w_rest: float = 0, tau_w: float = 2., a: float = 0., b: float = 0.,
195
+ surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False):
196
+ # b: jump amplitudes
197
+ # a: subthreshold coupling
198
+ assert isinstance(w_rest, float)
199
+ assert isinstance(v_rest, float)
200
+ assert isinstance(tau_w, float)
201
+ assert isinstance(a, float)
202
+ assert isinstance(b, float)
203
+
204
+ super.__init__(v_threshold, v_reset, surrogate_function, detach_reset)
205
+
206
+ self.register_memory('w', w_rest)
207
+
208
+ self.w_rest = w_rest
209
+ self.v_rest = v_rest
210
+ self.tau_w = tau_w
211
+ self.a = a
212
+ self.b = b
213
+
214
+ def neuronal_adaptation(self, spike):
215
+ self.w = self.w + 1. / self.tau_w * (self.a * (self.v - self.v_rest) - self.w) + self.b * spike
216
+
217
+ def extra_repr(self):
218
+ return super().extra_repr() + f', v_rest={self.v_rest}, w_rest={self.w_rest}, tau_w={self.tau_w}, a={self.a}, b={self.b}'
219
+
220
+ @overload
221
+ def forward(self, x: torch.Tensor):
222
+ self.neuronal_charge(x)
223
+ spike = self.neuronal_fire()
224
+ self.neuronal_adaptation(spike)
225
+ self.neuronal_reset(spike)
226
+ return spike
227
+
228
+
229
+ class IFNode(BaseNode):
230
+ def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
231
+ surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False,
232
+ cupy_fp32_inference=False):
233
+ """
234
+ * :ref:`API in English <IFNode.__init__-en>`
235
+
236
+ .. _IFNode.__init__-cn:
237
+
238
+ :param v_threshold: 神经元的阈值电压
239
+ :type v_threshold: float
240
+
241
+ :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
242
+ 如果设置为 ``None``,则电压会被减去 ``v_threshold``
243
+ :type v_reset: float
244
+
245
+ :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
246
+ :type surrogate_function: Callable
247
+
248
+ :param detach_reset: 是否将reset过程的计算图分离
249
+ :type detach_reset: bool
250
+
251
+ :param cupy_fp32_inference: 若为 `True`,在 `eval` 模式下,使用float32,却在GPU上运行,并且 `cupy` 已经安装,则会自动使用 `cupy` 进行加速
252
+ :type cupy_fp32_inference: bool
253
+
254
+ Integrate-and-Fire 神经元模型,可以看作理想积分器,无输入时电压保持恒定,不会像LIF神经元那样衰减。其阈下神经动力学方程为:
255
+
256
+ .. math::
257
+ V[t] = V[t-1] + X[t]
258
+
259
+ * :ref:`中文API <IFNode.__init__-cn>`
260
+
261
+ .. _IFNode.__init__-en:
262
+
263
+ :param v_threshold: threshold voltage of neurons
264
+ :type v_threshold: float
265
+
266
+ :param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to
267
+ ``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold``
268
+ :type v_reset: float
269
+
270
+ :param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation
271
+ :type surrogate_function: Callable
272
+
273
+ :param detach_reset: whether detach the computation graph of reset
274
+ :type detach_reset: bool
275
+
276
+ :param cupy_fp32_inference: If `True`, if this module is in `eval` mode, using float32, running on GPU, and `cupy` is installed, then this
277
+ module will use `cupy` to accelerate
278
+ :type cupy_fp32_inference: bool
279
+
280
+ The Integrate-and-Fire neuron, which can be seen as a ideal integrator. The voltage of the IF neuron will not decay
281
+ as that of the LIF neuron. The subthreshold neural dynamics of it is as followed:
282
+
283
+ .. math::
284
+ V[t] = V[t-1] + X[t]
285
+
286
+ """
287
+ super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
288
+
289
+ if cupy_fp32_inference:
290
+ check_backend('cupy')
291
+ self.cupy_fp32_inference = cupy_fp32_inference
292
+
293
+ def neuronal_charge(self, x: torch.Tensor):
294
+ self.v = self.v + x
295
+
296
+ def forward(self, x: torch.Tensor):
297
+ if self.cupy_fp32_inference and cupy is not None and not self.training and x.dtype == torch.float32:
298
+ # cupy is installed && eval mode && fp32
299
+ device_id = x.get_device()
300
+ if device_id < 0:
301
+ return super().forward(x)
302
+
303
+ # use cupy to accelerate
304
+ if isinstance(self.v, float):
305
+ v = torch.zeros_like(x)
306
+ if self.v != 0.:
307
+ torch.fill_(v, self.v)
308
+ self.v = v
309
+
310
+ if self.v_reset is None:
311
+ hard_reset = False
312
+ else:
313
+ hard_reset = True
314
+
315
+ code = rf'''
316
+ extern "C" __global__
317
+ void IFNode_{'hard' if hard_reset else 'soft'}_reset_inference_forward(
318
+ const float * x, const float & v_threshold, {'const float & v_reset,' if hard_reset else ''}
319
+ float * spike, float * v,
320
+ const int & numel)
321
+ '''
322
+
323
+ code += r'''
324
+ {
325
+ const int index = blockIdx.x * blockDim.x + threadIdx.x;
326
+ if (index < numel)
327
+ {
328
+ v[index] += x[index];
329
+ spike[index] = (float) (v[index] >= v_threshold);
330
+ '''
331
+
332
+ code += rf'''
333
+ {'v[index] = (1.0f - spike[index]) * v[index] + spike[index] * v_reset;' if hard_reset else 'v[index] -= spike[index] * v_threshold;'}
334
+ '''
335
+
336
+ code += r'''
337
+ }
338
+ }
339
+ '''
340
+ if hasattr(self, 'cp_kernel'):
341
+ if self.cp_kernel.code != code:
342
+ # replace codes
343
+ del self.cp_kernel
344
+ self.cp_kernel = cupy.RawKernel(code,
345
+ f"IFNode_{'hard' if hard_reset else 'soft'}_reset_inference_forward",
346
+ options=configure.cuda_compiler_options,
347
+ backend=configure.cuda_compiler_backend)
348
+ else:
349
+ self.cp_kernel = cupy.RawKernel(code,
350
+ f"IFNode_{'hard' if hard_reset else 'soft'}_reset_inference_forward",
351
+ options=configure.cuda_compiler_options,
352
+ backend=configure.cuda_compiler_backend)
353
+
354
+ with cu_kernel_opt.DeviceEnvironment(device_id):
355
+ numel = x.numel()
356
+ threads = configure.cuda_threads
357
+ blocks = cu_kernel_opt.cal_blocks(numel)
358
+ cp_numel = cupy.asarray(numel)
359
+ cp_v_threshold = cupy.asarray(self.v_threshold, dtype=np.float32)
360
+ if hard_reset:
361
+ cp_v_reset = cupy.asarray(self.v_reset, dtype=np.float32)
362
+
363
+ spike = torch.zeros_like(x)
364
+ if hard_reset:
365
+ x, cp_v_threshold, cp_v_reset, spike, self.v, cp_numel = cu_kernel_opt.get_contiguous(x,
366
+ cp_v_threshold,
367
+ cp_v_reset,
368
+ spike, self.v,
369
+ cp_numel)
370
+ kernel_args = [x, cp_v_threshold, cp_v_reset, spike, self.v, cp_numel]
371
+ else:
372
+ x, cp_v_threshold, spike, self.v, cp_numel = cu_kernel_opt.get_contiguous(x, cp_v_threshold, spike,
373
+ self.v, cp_numel)
374
+ kernel_args = [x, cp_v_threshold, spike, self.v, cp_numel]
375
+ self.cp_kernel(
376
+ (blocks,), (threads,),
377
+ cu_kernel_opt.wrap_args_to_raw_kernel(
378
+ device_id,
379
+ *kernel_args
380
+ )
381
+ )
382
+ return spike
383
+ else:
384
+ return super().forward(x)
385
+
386
+
387
+ class MultiStepIFNode(IFNode):
388
+ def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
389
+ surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, backend='torch',
390
+ lava_s_cale=1 << 6):
391
+ """
392
+ * :ref:`API in English <MultiStepIFNode.__init__-en>`
393
+
394
+ .. _MultiStepIFNode.__init__-cn:
395
+
396
+ :param v_threshold: 神经元的阈值电压
397
+ :type v_threshold: float
398
+
399
+ :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
400
+ 如果设置为 ``None``,则电压会被减去 ``v_threshold``
401
+ :type v_reset: float
402
+
403
+ :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
404
+ :type surrogate_function: Callable
405
+
406
+ :param detach_reset: 是否将reset过程的计算图分离
407
+ :type detach_reset: bool
408
+
409
+ :param backend: 使用哪种计算后端,可以为 ``'torch'`` 或 ``'cupy'``。``'cupy'`` 速度更快,但仅支持GPU。
410
+ :type backend: str
411
+
412
+ 多步版本的 :class:`spikingjelly.clock_driven.neuron.IFNode`。
413
+
414
+ .. tip::
415
+
416
+ 对于多步神经元,输入 ``x_seq.shape = [T, *]``,不仅可以使用 ``.v`` 和 ``.spike`` 获取 ``t = T - 1`` 时刻的电压和脉冲,还能够
417
+ 使用 ``.v_seq`` 和 ``.spike_seq`` 获取完整的 ``T`` 个时刻的电压和脉冲。
418
+
419
+ .. tip::
420
+
421
+ 阅读 :doc:`传播模式 <./clock_driven/10_propagation_pattern>` 以获取更多关于单步和多步传播的信息。
422
+
423
+ * :ref:`中文API <MultiStepIFNode.__init__-cn>`
424
+
425
+ .. _MultiStepIFNode.__init__-en:
426
+
427
+ :param v_threshold: threshold voltage of neurons
428
+ :type v_threshold: float
429
+
430
+ :param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to
431
+ ``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold``
432
+ :type v_reset: float
433
+
434
+ :param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation
435
+ :type surrogate_function: Callable
436
+
437
+ :param detach_reset: whether detach the computation graph of reset
438
+ :type detach_reset: bool
439
+
440
+ :param backend: use which backend, ``'torch'`` or ``'cupy'``. ``'cupy'`` is faster but only supports GPU
441
+ :type backend: str
442
+
443
+ The multi-step version of :class:`spikingjelly.clock_driven.neuron.IFNode`.
444
+
445
+ .. admonition:: Tip
446
+ :class: tip
447
+
448
+ The input for multi-step neurons are ``x_seq.shape = [T, *]``. We can get membrane potential and spike at
449
+ time-step ``t = T - 1`` by ``.v`` and ``.spike``. We can also get membrane potential and spike at all ``T``
450
+ time-steps by ``.v_seq`` and ``.spike_seq``.
451
+
452
+ .. admonition:: Tip
453
+ :class: tip
454
+
455
+ Read :doc:`Propagation Pattern <./clock_driven_en/10_propagation_pattern>` for more details about single-step
456
+ and multi-step propagation.
457
+
458
+ """
459
+ super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
460
+
461
+ self.register_memory('v_seq', None)
462
+
463
+ check_backend(backend)
464
+
465
+ self.backend = backend
466
+
467
+ self.lava_s_cale = lava_s_cale
468
+
469
+ if backend == 'lava':
470
+ self.lava_neuron = self.to_lava()
471
+ else:
472
+ self.lava_neuron = None
473
+
474
+ def forward(self, x_seq: torch.Tensor):
475
+ assert x_seq.dim() > 1
476
+ # x_seq.shape = [T, *]
477
+
478
+ if self.backend == 'torch':
479
+ spike_seq = []
480
+ self.v_seq = []
481
+ for t in range(x_seq.shape[0]):
482
+ spike_seq.append(super().forward(x_seq[t]).unsqueeze(0))
483
+ self.v_seq.append(self.v.unsqueeze(0))
484
+ spike_seq = torch.cat(spike_seq, 0)
485
+ self.v_seq = torch.cat(self.v_seq, 0)
486
+ return spike_seq
487
+
488
+ elif self.backend == 'cupy':
489
+ if isinstance(self.v, float):
490
+ v_init = self.v
491
+ self.v = torch.zeros_like(x_seq[0].data)
492
+ if v_init != 0.:
493
+ torch.fill_(self.v, v_init)
494
+
495
+ spike_seq, self.v_seq = neuron_kernel.MultiStepIFNodePTT.apply(
496
+ x_seq.flatten(1), self.v.flatten(0), self.v_threshold, self.v_reset, self.detach_reset,
497
+ self.surrogate_function.cuda_code)
498
+
499
+ spike_seq = spike_seq.reshape(x_seq.shape)
500
+ self.v_seq = self.v_seq.reshape(x_seq.shape)
501
+
502
+ self.v = self.v_seq[-1].clone()
503
+
504
+ return spike_seq
505
+
506
+ elif self.backend == 'lava':
507
+ if self.lava_neuron is None:
508
+ self.lava_neuron = self.to_lava()
509
+
510
+ spike, self.v = lava_exchange.lava_neuron_forward(self.lava_neuron, x_seq, self.v)
511
+
512
+ return spike
513
+
514
+ else:
515
+ raise NotImplementedError(self.backend)
516
+
517
+ def extra_repr(self):
518
+ return super().extra_repr() + f', backend={self.backend}'
519
+
520
+ def to_lava(self):
521
+ return lava_exchange.to_lava_neuron(self)
522
+
523
+ def reset(self):
524
+ super().reset()
525
+ if self.lava_neuron is not None:
526
+ self.lava_neuron.current_state.zero_()
527
+ self.lava_neuron.voltage_state.zero_()
528
+
529
+
530
+ class LIFNode(BaseNode):
531
+ def __init__(self, tau: float = 2., decay_input: bool = True, v_threshold: float = 1.,
532
+ v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
533
+ detach_reset: bool = False, cupy_fp32_inference=False):
534
+ """
535
+ * :ref:`API in English <LIFNode.__init__-en>`
536
+
537
+ .. _LIFNode.__init__-cn:
538
+
539
+ :param tau: 膜电位时间常数
540
+ :type tau: float
541
+
542
+ :param decay_input: 输入是否会衰减
543
+ :type decay_input: bool
544
+
545
+ :param v_threshold: 神经元的阈值电压
546
+ :type v_threshold: float
547
+
548
+ :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
549
+ 如果设置为 ``None``,则电压会被减去 ``v_threshold``
550
+ :type v_reset: float
551
+
552
+ :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
553
+ :type surrogate_function: Callable
554
+
555
+ :param detach_reset: 是否将reset过程的计算图分离
556
+ :type detach_reset: bool
557
+
558
+ :param cupy_fp32_inference: 若为 `True`,在 `eval` 模式下,使用float32,却在GPU上运行,并且 `cupy` 已经安装,则会自动使用 `cupy` 进行加速
559
+ :type cupy_fp32_inference: bool
560
+
561
+ Leaky Integrate-and-Fire 神经元模型,可以看作是带漏电的积分器。其阈下神经动力学方程为:
562
+
563
+ 若 ``decay_input == True``:
564
+
565
+ .. math::
566
+ V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}))
567
+
568
+ 若 ``decay_input == False``:
569
+
570
+ .. math::
571
+ V[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t]
572
+
573
+ .. tip::
574
+
575
+ 在 `eval` 模式下,使用float32,却在GPU上运行,并且 `cupy` 已经安装,则会自动使用 `cupy` 进行加速。
576
+
577
+ * :ref:`中文API <LIFNode.__init__-cn>`
578
+
579
+ .. _LIFNode.__init__-en:
580
+
581
+ :param tau: membrane time constant
582
+ :type tau: float
583
+
584
+ :param decay_input: whether the input will decay
585
+ :type decay_input: bool
586
+
587
+ :param v_threshold: threshold voltage of neurons
588
+ :type v_threshold: float
589
+
590
+ :param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to
591
+ ``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold``
592
+ :type v_reset: float
593
+
594
+ :param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation
595
+ :type surrogate_function: Callable
596
+
597
+ :param detach_reset: whether detach the computation graph of reset
598
+ :type detach_reset: bool
599
+
600
+ :param cupy_fp32_inference: If `True`, if this module is in `eval` mode, using float32, running on GPU, and `cupy` is installed, then this
601
+ module will use `cupy` to accelerate
602
+ :type cupy_fp32_inference: bool
603
+
604
+ The Leaky Integrate-and-Fire neuron, which can be seen as a leaky integrator.
605
+ The subthreshold neural dynamics of it is as followed:
606
+
607
+ IF ``decay_input == True``:
608
+
609
+ .. math::
610
+ V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}))
611
+
612
+ IF ``decay_input == False``:
613
+
614
+ .. math::
615
+ V[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t]
616
+
617
+ .. admonition:: Tip
618
+ :class: tip
619
+
620
+ If this module is in `eval` mode, using float32, running on GPU, and `cupy` is installed, then this
621
+ module will use `cupy` to accelerate.
622
+
623
+ """
624
+ assert isinstance(tau, float) and tau > 1.
625
+
626
+ super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
627
+ self.tau = tau
628
+ self.decay_input = decay_input
629
+
630
+ if cupy_fp32_inference:
631
+ check_backend('cupy')
632
+ self.cupy_fp32_inference = cupy_fp32_inference
633
+
634
+ def extra_repr(self):
635
+ return super().extra_repr() + f', tau={self.tau}'
636
+
637
+ def neuronal_charge(self, x: torch.Tensor):
638
+ if self.decay_input:
639
+ if self.v_reset is None or self.v_reset == 0.:
640
+ self.v = self.v + (x - self.v) / self.tau
641
+ else:
642
+ self.v = self.v + (x - (self.v - self.v_reset)) / self.tau
643
+
644
+ else:
645
+ if self.v_reset is None or self.v_reset == 0.:
646
+ self.v = self.v * (1. - 1. / self.tau) + x
647
+ else:
648
+ self.v = self.v - (self.v - self.v_reset) / self.tau + x
649
+
650
+ def forward(self, x: torch.Tensor):
651
+ if self.cupy_fp32_inference and cupy is not None and not self.training and x.dtype == torch.float32:
652
+ # cupy is installed && eval mode && fp32
653
+ device_id = x.get_device()
654
+ if device_id < 0:
655
+ return super().forward(x)
656
+
657
+ # use cupy to accelerate
658
+ if isinstance(self.v, float):
659
+ v = torch.zeros_like(x)
660
+ if self.v != 0.:
661
+ torch.fill_(v, self.v)
662
+ self.v = v
663
+
664
+ if self.v_reset is None:
665
+ hard_reset = False
666
+ else:
667
+ hard_reset = True
668
+
669
+ code = rf'''
670
+ extern "C" __global__
671
+ void LIFNode_{'hard' if hard_reset else 'soft'}_reset_decayInput_{self.decay_input}_inference_forward(
672
+ const float * x, const float & v_threshold, {'const float & v_reset,' if hard_reset else ''} const float & tau,
673
+ float * spike, float * v,
674
+ const int & numel)
675
+ '''
676
+
677
+ code += r'''
678
+ {
679
+ const int index = blockIdx.x * blockDim.x + threadIdx.x;
680
+ if (index < numel)
681
+ {
682
+
683
+ '''
684
+
685
+ if self.decay_input:
686
+ if hard_reset:
687
+ code += r'''
688
+ v[index] += (x[index] - (v[index] - v_reset)) / tau;
689
+ '''
690
+ else:
691
+ code += r'''
692
+ v[index] += (x[index] - v[index]) / tau;
693
+ '''
694
+ else:
695
+ if hard_reset:
696
+ code += r'''
697
+ v[index] = x[index] + v[index] - (v[index] - v_reset) / tau;
698
+ '''
699
+ else:
700
+ code += r'''
701
+ v[index] = x[index] + v[index] * (1.0f - 1.0f / tau);
702
+ '''
703
+
704
+ code += rf'''
705
+ spike[index] = (float) (v[index] >= v_threshold);
706
+ {'v[index] = (1.0f - spike[index]) * v[index] + spike[index] * v_reset;' if hard_reset else 'v[index] -= spike[index] * v_threshold;'}
707
+ '''
708
+
709
+ code += r'''
710
+ }
711
+ }
712
+ '''
713
+ if hasattr(self, 'cp_kernel'):
714
+ if self.cp_kernel.code != code:
715
+ # replace codes
716
+ del self.cp_kernel
717
+ self.cp_kernel = cupy.RawKernel(code,
718
+ f"LIFNode_{'hard' if hard_reset else 'soft'}_reset_decayInput_{self.decay_input}_inference_forward",
719
+ options=configure.cuda_compiler_options,
720
+ backend=configure.cuda_compiler_backend)
721
+ else:
722
+ self.cp_kernel = cupy.RawKernel(code,
723
+ f"LIFNode_{'hard' if hard_reset else 'soft'}_reset_decayInput_{self.decay_input}_inference_forward",
724
+ options=configure.cuda_compiler_options,
725
+ backend=configure.cuda_compiler_backend)
726
+
727
+ with cu_kernel_opt.DeviceEnvironment(device_id):
728
+ numel = x.numel()
729
+ threads = configure.cuda_threads
730
+ blocks = cu_kernel_opt.cal_blocks(numel)
731
+ cp_numel = cupy.asarray(numel)
732
+ cp_v_threshold = cupy.asarray(self.v_threshold, dtype=np.float32)
733
+ if hard_reset:
734
+ cp_v_reset = cupy.asarray(self.v_reset, dtype=np.float32)
735
+ cp_tau = cupy.asarray(self.tau, dtype=np.float32)
736
+ spike = torch.zeros_like(x)
737
+ if hard_reset:
738
+ x, cp_v_threshold, cp_v_reset, cp_tau, spike, self.v, cp_numel = cu_kernel_opt.get_contiguous(x,
739
+ cp_v_threshold,
740
+ cp_v_reset,
741
+ cp_tau,
742
+ spike,
743
+ self.v,
744
+ cp_numel)
745
+ kernel_args = [x, cp_v_threshold, cp_v_reset, cp_tau, spike, self.v, cp_numel]
746
+ else:
747
+ x, cp_v_threshold, cp_tau, spike, self.v, cp_numel = cu_kernel_opt.get_contiguous(x, cp_v_threshold,
748
+ cp_tau, spike,
749
+ self.v, cp_numel)
750
+ kernel_args = [x, cp_v_threshold, cp_tau, spike, self.v, cp_numel]
751
+
752
+ self.cp_kernel(
753
+ (blocks,), (threads,),
754
+ cu_kernel_opt.wrap_args_to_raw_kernel(
755
+ device_id,
756
+ *kernel_args
757
+ )
758
+ )
759
+ return spike
760
+ else:
761
+ return super().forward(x)
762
+
763
+
764
+ class MultiStepLIFNode(LIFNode):
765
+ def __init__(self, tau: float = 2., decay_input: bool = True, v_threshold: float = 1.,
766
+ v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
767
+ detach_reset: bool = False, backend='torch', lava_s_cale=1 << 6):
768
+ """
769
+ * :ref:`API in English <MultiStepLIFNode.__init__-en>`
770
+
771
+ .. _MultiStepLIFNode.__init__-cn:
772
+
773
+ :param tau: 膜电位时间常数
774
+ :type tau: float
775
+
776
+ :param decay_input: 输入是否会衰减
777
+ :type decay_input: bool
778
+
779
+ :param v_threshold: 神经元的阈值电压
780
+ :type v_threshold: float
781
+
782
+ :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
783
+ 如果设置为 ``None``,则电压会被减去 ``v_threshold``
784
+ :type v_reset: float
785
+
786
+ :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
787
+ :type surrogate_function: Callable
788
+
789
+ :param detach_reset: 是否将reset过程的计算图分离
790
+ :type detach_reset: bool
791
+
792
+ :param backend: 使用哪种计算后端,可以为 ``'torch'`` 或 ``'cupy'``。``'cupy'`` 速度更快,但仅支持GPU。
793
+ :type backend: str
794
+
795
+ 多步版本的 :class:`spikingjelly.clock_driven.neuron.LIFNode`。
796
+
797
+ .. tip::
798
+
799
+ 对于多步神经元,输入 ``x_seq.shape = [T, *]``,不仅可以使用 ``.v`` 和 ``.spike`` 获取 ``t = T - 1`` 时刻的电压和脉冲,还能够
800
+ 使用 ``.v_seq`` 和 ``.spike_seq`` 获取完整的 ``T`` 个时刻的电压和脉冲。
801
+
802
+ .. tip::
803
+
804
+ 阅读 :doc:`传播模式 <./clock_driven/10_propagation_pattern>` 以获取更多关于单步和多步传播的信息。
805
+
806
+ * :ref:`中文API <MultiStepLIFNode.__init__-cn>`
807
+
808
+ .. _MultiStepLIFNode.__init__-en:
809
+
810
+ :param tau: membrane time constant
811
+ :type tau: float
812
+
813
+ :param decay_input: whether the input will decay
814
+ :type decay_input: bool
815
+
816
+ :param v_threshold: threshold voltage of neurons
817
+ :type v_threshold: float
818
+
819
+ :param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to
820
+ ``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold``
821
+ :type v_reset: float
822
+
823
+ :param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation
824
+ :type surrogate_function: Callable
825
+
826
+ :param detach_reset: whether detach the computation graph of reset
827
+ :type detach_reset: bool
828
+
829
+ :param backend: use which backend, ``'torch'`` or ``'cupy'``. ``'cupy'`` is faster but only supports GPU
830
+ :type backend: str
831
+
832
+ The multi-step version of :class:`spikingjelly.clock_driven.neuron.LIFNode`.
833
+
834
+ .. admonition:: Tip
835
+ :class: tip
836
+
837
+ The input for multi-step neurons are ``x_seq.shape = [T, *]``. We can get membrane potential and spike at
838
+ time-step ``t = T - 1`` by ``.v`` and ``.spike``. We can also get membrane potential and spike at all ``T``
839
+ time-steps by ``.v_seq`` and ``.spike_seq``.
840
+
841
+ .. admonition:: Tip
842
+ :class: tip
843
+
844
+ Read :doc:`Propagation Pattern <./clock_driven_en/10_propagation_pattern>` for more details about single-step
845
+ and multi-step propagation.
846
+
847
+ """
848
+ super().__init__(tau, decay_input, v_threshold, v_reset, surrogate_function, detach_reset)
849
+ self.register_memory('v_seq', None)
850
+
851
+ check_backend(backend)
852
+
853
+ self.backend = backend
854
+
855
+ self.lava_s_cale = lava_s_cale
856
+
857
+ if backend == 'lava':
858
+ self.lava_neuron = self.to_lava()
859
+ else:
860
+ self.lava_neuron = None
861
+
862
+ def forward(self, x_seq: torch.Tensor):
863
+ assert x_seq.dim() > 1
864
+ # x_seq.shape = [T, *]
865
+
866
+ if self.backend == 'torch':
867
+ spike_seq = []
868
+ self.v_seq = []
869
+ for t in range(x_seq.shape[0]):
870
+ spike_seq.append(super().forward(x_seq[t]).unsqueeze(0))
871
+ self.v_seq.append(self.v.unsqueeze(0))
872
+ spike_seq = torch.cat(spike_seq, 0)
873
+ self.v_seq = torch.cat(self.v_seq, 0)
874
+ return spike_seq
875
+
876
+ elif self.backend == 'cupy':
877
+ if isinstance(self.v, float):
878
+ v_init = self.v
879
+ self.v = torch.zeros_like(x_seq[0].data)
880
+ if v_init != 0.:
881
+ torch.fill_(self.v, v_init)
882
+
883
+ spike_seq, self.v_seq = neuron_kernel.MultiStepLIFNodePTT.apply(
884
+ x_seq.flatten(1), self.v.flatten(0), self.decay_input, self.tau, self.v_threshold, self.v_reset,
885
+ self.detach_reset, self.surrogate_function.cuda_code)
886
+
887
+ spike_seq = spike_seq.reshape(x_seq.shape)
888
+ self.v_seq = self.v_seq.reshape(x_seq.shape)
889
+
890
+ self.v = self.v_seq[-1].clone()
891
+
892
+ return spike_seq
893
+
894
+ elif self.backend == 'lava':
895
+ if self.lava_neuron is None:
896
+ self.lava_neuron = self.to_lava()
897
+
898
+ spike, self.v = lava_exchange.lava_neuron_forward(self.lava_neuron, x_seq, self.v)
899
+
900
+ return spike
901
+
902
+ else:
903
+ raise NotImplementedError(self.backend)
904
+
905
+ def extra_repr(self):
906
+ return super().extra_repr() + f', backend={self.backend}'
907
+
908
+ def to_lava(self):
909
+ return lava_exchange.to_lava_neuron(self)
910
+
911
+ def reset(self):
912
+ super().reset()
913
+ if self.lava_neuron is not None:
914
+ self.lava_neuron.current_state.zero_()
915
+ self.lava_neuron.voltage_state.zero_()
916
+
917
+
918
+ class ParametricLIFNode(BaseNode):
919
+ def __init__(self, init_tau: float = 2.0, decay_input: bool = True, v_threshold: float = 1.,
920
+ v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
921
+ detach_reset: bool = False):
922
+ """
923
+ * :ref:`API in English <ParametricLIFNode.__init__-en>`
924
+
925
+ .. _ParametricLIFNode.__init__-cn:
926
+
927
+ :param init_tau: 膜电位时间常数的初始值
928
+ :type init_tau: float
929
+
930
+ :param decay_input: 输入是否会衰减
931
+ :type decay_input: bool
932
+
933
+ :param v_threshold: 神经元的阈值电压
934
+ :type v_threshold: float
935
+
936
+ :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
937
+ 如果设置为 ``None``,则电压会被减去 ``v_threshold``
938
+ :type v_reset: float
939
+
940
+ :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
941
+ :type surrogate_function: Callable
942
+
943
+ :param detach_reset: 是否将reset过程的计算图分离
944
+ :type detach_reset: bool
945
+
946
+ `Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks <https://arxiv.org/abs/2007.05785>`_
947
+ 提出的 Parametric Leaky Integrate-and-Fire (PLIF)神经元模型,可以看作是带漏电的积分器。其阈下神经动力学方程为:
948
+
949
+ 若 ``decay_input == True``:
950
+
951
+ .. math::
952
+ V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}))
953
+
954
+ 若 ``decay_input == False``:
955
+
956
+ .. math::
957
+ V[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t]
958
+
959
+ 其中 :math:`\\frac{1}{\\tau} = {\\rm Sigmoid}(w)`,:math:`w` 是可学习的参数。
960
+
961
+ * :ref:`中文API <ParametricLIFNode.__init__-cn>`
962
+
963
+ .. _ParametricLIFNode.__init__-en:
964
+
965
+ :param init_tau: the initial value of membrane time constant
966
+ :type init_tau: float
967
+
968
+ :param decay_input: whether the input will decay
969
+ :type decay_input: bool
970
+
971
+ :param v_threshold: threshold voltage of neurons
972
+ :type v_threshold: float
973
+
974
+ :param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to
975
+ ``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold``
976
+ :type v_reset: float
977
+
978
+ :param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation
979
+ :type surrogate_function: Callable
980
+
981
+ :param detach_reset: whether detach the computation graph of reset
982
+ :type detach_reset: bool
983
+
984
+ The Parametric Leaky Integrate-and-Fire (PLIF) neuron, which is proposed by `Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks <https://arxiv.org/abs/2007.05785>`_ and can be seen as a leaky integrator.
985
+ The subthreshold neural dynamics of it is as followed:
986
+
987
+ IF ``decay_input == True``:
988
+
989
+ .. math::
990
+ V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}))
991
+
992
+ IF ``decay_input == False``:
993
+
994
+ .. math::
995
+ V[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t]
996
+
997
+ where :math:`\\frac{1}{\\tau} = {\\rm Sigmoid}(w)`, :math:`w` is a learnable parameter.
998
+ """
999
+
1000
+ assert isinstance(init_tau, float) and init_tau > 1.
1001
+ super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
1002
+ self.decay_input = decay_input
1003
+ init_w = - math.log(init_tau - 1.)
1004
+ self.w = nn.Parameter(torch.as_tensor(init_w))
1005
+
1006
+ def extra_repr(self):
1007
+ with torch.no_grad():
1008
+ tau = 1. / self.w.sigmoid()
1009
+ return super().extra_repr() + f', tau={tau}'
1010
+
1011
+ def neuronal_charge(self, x: torch.Tensor):
1012
+ if self.decay_input:
1013
+ if self.v_reset is None or self.v_reset == 0.:
1014
+ self.v = self.v + (x - self.v) * self.w.sigmoid()
1015
+ else:
1016
+ self.v = self.v + (x - (self.v - self.v_reset)) * self.w.sigmoid()
1017
+ else:
1018
+ if self.v_reset is None or self.v_reset == 0.:
1019
+ self.v = self.v * (1. - self.w.sigmoid()) + x
1020
+ else:
1021
+ self.v = self.v - (self.v - self.v_reset) * self.w.sigmoid() + x
1022
+
1023
+
1024
+ class MultiStepParametricLIFNode(ParametricLIFNode):
1025
+ def __init__(self, init_tau: float = 2., decay_input: bool = True, v_threshold: float = 1.,
1026
+ v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
1027
+ detach_reset: bool = False, backend='torch'):
1028
+ """
1029
+ * :ref:`API in English <MultiStepParametricLIFNode.__init__-en>`
1030
+
1031
+ .. _MultiStepParametricLIFNode.__init__-cn:
1032
+
1033
+ :param init_tau: 膜电位时间常数的初始值
1034
+ :type init_tau: float
1035
+
1036
+ :param decay_input: 输入是否会衰减
1037
+ :type decay_input: bool
1038
+
1039
+ :param v_threshold: 神经元的阈值电压
1040
+ :type v_threshold: float
1041
+
1042
+ :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
1043
+ 如果设置为 ``None``,则电压会被减去 ``v_threshold``
1044
+ :type v_reset: float
1045
+
1046
+ :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
1047
+ :type surrogate_function: Callable
1048
+
1049
+ :param detach_reset: 是否将reset过程的计算图分离
1050
+ :type detach_reset: bool
1051
+
1052
+ 多步版本的 `Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks <https://arxiv.org/abs/2007.05785>`_
1053
+ 提出的 Parametric Leaky Integrate-and-Fire (PLIF)神经元模型,可以看作是带漏电的积分器。其阈下神经动力学方程为:
1054
+
1055
+ .. math::
1056
+ V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset})
1057
+
1058
+ 其中 :math:`\\frac{1}{\\tau} = {\\rm Sigmoid}(w)`,:math:`w` 是可学习的参数。
1059
+
1060
+ .. tip::
1061
+
1062
+ 对于多步神经元,输入 ``x_seq.shape = [T, *]``,不仅可以使用 ``.v`` 和 ``.spike`` 获取 ``t = T - 1`` 时刻的电压和脉冲,还能够
1063
+ 使用 ``.v_seq`` 和 ``.spike_seq`` 获取完整的 ``T`` 个时刻的电压和脉冲。
1064
+
1065
+ .. tip::
1066
+
1067
+ 阅读 :doc:`传播模式 <./clock_driven/10_propagation_pattern>` 以获取更多关于单步和多步传播的信息。
1068
+
1069
+ * :ref:`中文API <MultiStepParametricLIFNode.__init__-cn>`
1070
+
1071
+ .. _MultiStepParametricLIFNode.__init__-en:
1072
+
1073
+ :param init_tau: the initial value of membrane time constant
1074
+ :type init_tau: float
1075
+
1076
+ :param decay_input: whether the input will decay
1077
+ :type decay_input: bool
1078
+
1079
+ :param v_threshold: threshold voltage of neurons
1080
+ :type v_threshold: float
1081
+
1082
+ :param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to
1083
+ ``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold``
1084
+ :type v_reset: float
1085
+
1086
+ :param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation
1087
+ :type surrogate_function: Callable
1088
+
1089
+ :param detach_reset: whether detach the computation graph of reset
1090
+ :type detach_reset: bool
1091
+
1092
+ :param backend: use which backend, ``'torch'`` or ``'cupy'``. ``'cupy'`` is faster but only supports GPU
1093
+ :type backend: str
1094
+
1095
+ The multi-step Parametric Leaky Integrate-and-Fire (PLIF) neuron, which is proposed by `Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks <https://arxiv.org/abs/2007.05785>`_ and can be seen as a leaky integrator.
1096
+ The subthreshold neural dynamics of it is as followed:
1097
+
1098
+ .. math::
1099
+ V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset})
1100
+
1101
+ where :math:`\\frac{1}{\\tau} = {\\rm Sigmoid}(w)`, :math:`w` is a learnable parameter.
1102
+
1103
+ .. admonition:: Tip
1104
+ :class: tip
1105
+
1106
+ The input for multi-step neurons are ``x_seq.shape = [T, *]``. We can get membrane potential and spike at
1107
+ time-step ``t = T - 1`` by ``.v`` and ``.spike``. We can also get membrane potential and spike at all ``T``
1108
+ time-steps by ``.v_seq`` and ``.spike_seq``.
1109
+
1110
+ .. admonition:: Tip
1111
+ :class: tip
1112
+
1113
+ Read :doc:`Propagation Pattern <./clock_driven_en/10_propagation_pattern>` for more details about single-step
1114
+ and multi-step propagation.
1115
+ """
1116
+ super().__init__(init_tau, decay_input, v_threshold, v_reset, surrogate_function, detach_reset)
1117
+ self.register_memory('v_seq', None)
1118
+
1119
+ check_backend(backend)
1120
+
1121
+ self.backend = backend
1122
+
1123
+ def forward(self, x_seq: torch.Tensor):
1124
+ assert x_seq.dim() > 1
1125
+ # x_seq.shape = [T, *]
1126
+
1127
+ if self.backend == 'torch':
1128
+ spike_seq = []
1129
+ self.v_seq = []
1130
+ for t in range(x_seq.shape[0]):
1131
+ spike_seq.append(super().forward(x_seq[t]).unsqueeze(0))
1132
+ self.v_seq.append(self.v.unsqueeze(0))
1133
+ spike_seq = torch.cat(spike_seq, 0)
1134
+ self.v_seq = torch.cat(self.v_seq, 0)
1135
+ return spike_seq
1136
+
1137
+ elif self.backend == 'cupy':
1138
+ if isinstance(self.v, float):
1139
+ v_init = self.v
1140
+ self.v = torch.zeros_like(x_seq[0].data)
1141
+ if v_init != 0.:
1142
+ torch.fill_(self.v, v_init)
1143
+
1144
+ spike_seq, self.v_seq = neuron_kernel.MultiStepParametricLIFNodePTT.apply(
1145
+ x_seq.flatten(1), self.v.flatten(0), self.w.sigmoid(), self.decay_input, self.v_threshold, self.v_reset,
1146
+ self.detach_reset, self.surrogate_function.cuda_code)
1147
+
1148
+ spike_seq = spike_seq.reshape(x_seq.shape)
1149
+ self.v_seq = self.v_seq.reshape(x_seq.shape)
1150
+
1151
+ self.v = self.v_seq[-1].clone()
1152
+
1153
+ return spike_seq
1154
+ else:
1155
+ raise NotImplementedError
1156
+
1157
+ def extra_repr(self):
1158
+ return super().extra_repr() + f', backend={self.backend}'
1159
+
1160
+
1161
+ class QIFNode(BaseNode):
1162
+ def __init__(self, tau: float = 2., v_c: float = 0.8, a0: float = 1., v_threshold: float = 1., v_rest: float = 0.,
1163
+ v_reset: float = -0.1,
1164
+ surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False):
1165
+ """
1166
+ * :ref:`API in English <QIFNode.__init__-en>`
1167
+
1168
+ .. _QIFNode.__init__-cn:
1169
+
1170
+ :param tau: 膜电位时间常数
1171
+ :type tau: float
1172
+
1173
+ :param v_c: 关键电压
1174
+ :type v_c: float
1175
+
1176
+ :param a0:
1177
+ :type a0: float
1178
+
1179
+ :param v_threshold: 神经元的阈值电压
1180
+ :type v_threshold: float
1181
+
1182
+ :param v_rest: 静息电位
1183
+ :type v_rest: float
1184
+
1185
+ :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
1186
+ 如果设置为 ``None``,则电压会被减去 ``v_threshold``
1187
+ :type v_reset: float
1188
+
1189
+ :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
1190
+ :type surrogate_function: Callable
1191
+
1192
+ :param detach_reset: 是否将reset过程的计算图分离
1193
+ :type detach_reset: bool
1194
+
1195
+
1196
+ Quadratic Integrate-and-Fire 神经元模型,一种非线性积分发放神经元模型,也是指数积分发放神经元(Exponential Integrate-and-Fire)的近似版本。其阈下神经动力学方程为:
1197
+
1198
+ .. math::
1199
+ V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] + a_0 (V[t-1] - V_{rest})(V[t-1] - V_c))
1200
+
1201
+ * :ref:`中文API <QIFNode.__init__-cn>`
1202
+
1203
+ .. _QIFNode.__init__-en:
1204
+
1205
+ :param tau: membrane time constant
1206
+ :type tau: float
1207
+
1208
+ :param v_c: critical voltage
1209
+ :type v_c: float
1210
+
1211
+ :param a0:
1212
+ :type a0: float
1213
+
1214
+ :param v_threshold: threshold voltage of neurons
1215
+ :type v_threshold: float
1216
+
1217
+ :param v_rest: resting potential
1218
+ :type v_rest: float
1219
+
1220
+ :param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to
1221
+ ``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold``
1222
+ :type v_reset: float
1223
+
1224
+ :param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation
1225
+ :type surrogate_function: Callable
1226
+
1227
+ :param detach_reset: whether detach the computation graph of reset
1228
+ :type detach_reset: bool
1229
+
1230
+ The Quadratic Integrate-and-Fire neuron is a kind of nonlinear integrate-and-fire models and also an approximation of the Exponential Integrate-and-Fire model.
1231
+ The subthreshold neural dynamics of it is as followed:
1232
+
1233
+ .. math::
1234
+ V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] + a_0 (V[t-1] - V_{rest})(V[t-1] - V_c))
1235
+ """
1236
+
1237
+ assert isinstance(tau, float) and tau > 1.
1238
+ if v_reset is not None:
1239
+ assert v_threshold > v_reset
1240
+ assert v_rest >= v_reset
1241
+ assert a0 > 0
1242
+
1243
+ super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
1244
+ self.tau = tau
1245
+ self.v_c = v_c
1246
+ self.v_rest = v_rest
1247
+ self.a0 = a0
1248
+
1249
+ def extra_repr(self):
1250
+ return super().extra_repr() + f', tau={self.tau}, v_c={self.v_c}, a0={self.a0}, v_rest={self.v_rest}'
1251
+
1252
+ def neuronal_charge(self, x: torch.Tensor):
1253
+ self.v = self.v + (x + self.a0 * (self.v - self.v_rest) * (self.v - self.v_c)) / self.tau
1254
+
1255
+
1256
+ class EIFNode(BaseNode):
1257
+ def __init__(self, tau: float = 2., delta_T: float = 1., theta_rh: float = .8, v_threshold: float = 1.,
1258
+ v_rest: float = 0., v_reset: float = -0.1,
1259
+ surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False):
1260
+ """
1261
+ * :ref:`API in English <EIFNode.__init__-en>`
1262
+
1263
+ .. _EIFNode.__init__-cn:
1264
+
1265
+ :param tau: 膜电位时间常数
1266
+ :type tau: float
1267
+
1268
+ :param delta_T: 陡峭度参数
1269
+ :type delta_T: float
1270
+
1271
+ :param theta_rh: 基强度电压阈值
1272
+ :type theta_rh: float
1273
+
1274
+ :param v_threshold: 神经元的阈值电压
1275
+ :type v_threshold: float
1276
+
1277
+ :param v_rest: 静息电位
1278
+ :type v_rest: float
1279
+
1280
+ :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
1281
+ 如果设置为 ``None``,则电压会被减去 ``v_threshold``
1282
+ :type v_reset: float
1283
+
1284
+ :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
1285
+ :type surrogate_function: Callable
1286
+
1287
+ :param detach_reset: 是否将reset过程的计算图分离
1288
+ :type detach_reset: bool
1289
+
1290
+
1291
+ Exponential Integrate-and-Fire 神经元模型,一种非线性积分发放神经元模型,是由HH神经元模型(Hodgkin-Huxley model)简化后推导出的一维模型。在 :math:`\\Delta_T\\to 0` 时退化为LIF模型。其阈下神经动力学方程为:
1292
+
1293
+ .. math::
1294
+ V[t] = V[t-1] + \\frac{1}{\\tau}\\left(X[t] - (V[t-1] - V_{rest}) + \\Delta_T\\exp\\left(\\frac{V[t-1] - \\theta_{rh}}{\\Delta_T}\\right)\\right)
1295
+
1296
+ * :ref:`中文API <EIFNode.__init__-cn>`
1297
+
1298
+ .. _EIFNode.__init__-en:
1299
+
1300
+ :param tau: membrane time constant
1301
+ :type tau: float
1302
+
1303
+ :param delta_T: sharpness parameter
1304
+ :type delta_T: float
1305
+
1306
+ :param theta_rh: rheobase threshold
1307
+ :type theta_rh: float
1308
+
1309
+ :param v_threshold: threshold voltage of neurons
1310
+ :type v_threshold: float
1311
+
1312
+ :param v_rest: resting potential
1313
+ :type v_rest: float
1314
+
1315
+ :param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to
1316
+ ``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold``
1317
+ :type v_reset: float
1318
+
1319
+ :param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation
1320
+ :type surrogate_function: Callable
1321
+
1322
+ :param detach_reset: whether detach the computation graph of reset
1323
+ :type detach_reset: bool
1324
+
1325
+ The Exponential Integrate-and-Fire neuron is a kind of nonlinear integrate-and-fire models and also an one-dimensional model derived from the Hodgkin-Huxley model. It degenerates to the LIF model when :math:`\\Delta_T\\to 0`.
1326
+ The subthreshold neural dynamics of it is as followed:
1327
+
1328
+ .. math::
1329
+ V[t] = V[t-1] + \\frac{1}{\\tau}\\left(X[t] - (V[t-1] - V_{rest}) + \\Delta_T\\exp\\left(\\frac{V[t-1] - \\theta_{rh}}{\\Delta_T}\\right)\\right)
1330
+ """
1331
+
1332
+ assert isinstance(tau, float) and tau > 1.
1333
+ if v_reset is not None:
1334
+ assert v_threshold > v_reset
1335
+ assert v_rest >= v_reset
1336
+ assert delta_T > 0
1337
+
1338
+ super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
1339
+ self.tau = tau
1340
+ self.delta_T = delta_T
1341
+ self.v_rest = v_rest
1342
+ self.theta_rh = theta_rh
1343
+
1344
+ def extra_repr(self):
1345
+ return super().extra_repr() + f', tau={self.tau}, delta_T={self.delta_T}, theta_rh={self.theta_rh}'
1346
+
1347
+ def neuronal_charge(self, x: torch.Tensor):
1348
+
1349
+ with torch.no_grad():
1350
+ if not isinstance(self.v, torch.Tensor):
1351
+ self.v = torch.as_tensor(self.v, device=x.device)
1352
+
1353
+ self.v = self.v + (x + self.v_rest - self.v + self.delta_T * torch.exp(
1354
+ (self.v - self.theta_rh) / self.delta_T)) / self.tau
1355
+
1356
+
1357
+ class MultiStepEIFNode(EIFNode):
1358
+ def __init__(self, tau: float = 2., delta_T: float = 1., theta_rh: float = .8, v_threshold: float = 1.,
1359
+ v_rest: float = 0., v_reset: float = -0.1,
1360
+ surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, backend='torch'):
1361
+ """
1362
+ * :ref:`API in English <MultiStepEIFNode.__init__-en>`
1363
+
1364
+ .. _MultiStepEIFNode.__init__-cn:
1365
+
1366
+ ::param tau: 膜电位时间常数
1367
+ :type tau: float
1368
+
1369
+ :param delta_T: 陡峭度参数
1370
+ :type delta_T: float
1371
+
1372
+ :param theta_rh: 基强度电压阈值
1373
+ :type theta_rh: float
1374
+
1375
+ :param v_threshold: 神经元的阈值电压
1376
+ :type v_threshold: float
1377
+
1378
+ :param v_rest: 静息电位
1379
+ :type v_rest: float
1380
+
1381
+ :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
1382
+ 如果设置为 ``None``,则电压会被减去 ``v_threshold``
1383
+ :type v_reset: float
1384
+
1385
+ :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
1386
+ :type surrogate_function: Callable
1387
+
1388
+ :param detach_reset: 是否��reset过程的计算图分离
1389
+ :type detach_reset: bool
1390
+
1391
+ 多步版本的 :class:`spikingjelly.clock_driven.neuron.EIFNode`。
1392
+
1393
+ .. tip::
1394
+
1395
+ 对于多步神经元,输入 ``x_seq.shape = [T, *]``,不仅可以使用 ``.v`` 和 ``.spike`` 获取 ``t = T - 1`` 时刻的电压和脉冲,还能够
1396
+ 使用 ``.v_seq`` 和 ``.spike_seq`` 获取完整的 ``T`` 个时刻的电压和脉冲。
1397
+
1398
+ .. tip::
1399
+
1400
+ 阅读 :doc:`传播模式 <./clock_driven/10_propagation_pattern>` 以获取更多关于单步和多步传播的信息。
1401
+
1402
+ * :ref:`中文API <MultiStepEIFNode.__init__-cn>`
1403
+
1404
+ .. _MultiStepEIFNode.__init__-en:
1405
+
1406
+ :param tau: membrane time constant
1407
+ :type tau: float
1408
+
1409
+ :param delta_T: sharpness parameter
1410
+ :type delta_T: float
1411
+
1412
+ :param theta_rh: rheobase threshold
1413
+ :type theta_rh: float
1414
+
1415
+ :param v_threshold: threshold voltage of neurons
1416
+ :type v_threshold: float
1417
+
1418
+ :param v_rest: resting potential
1419
+ :type v_rest: float
1420
+
1421
+ :param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to
1422
+ ``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold``
1423
+ :type v_reset: float
1424
+
1425
+ :param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation
1426
+ :type surrogate_function: Callable
1427
+
1428
+ :param detach_reset: whether detach the computation graph of reset
1429
+ :type detach_reset: bool
1430
+
1431
+ :param backend: use which backend, ``'torch'`` or ``'cupy'``. ``'cupy'`` is faster but only supports GPU
1432
+ :type backend: str
1433
+
1434
+ .. admonition:: Tip
1435
+ :class: tip
1436
+
1437
+ The input for multi-step neurons are ``x_seq.shape = [T, *]``. We can get membrane potential and spike at
1438
+ time-step ``t = T - 1`` by ``.v`` and ``.spike``. We can also get membrane potential and spike at all ``T``
1439
+ time-steps by ``.v_seq`` and ``.spike_seq``.
1440
+
1441
+ .. admonition:: Tip
1442
+ :class: tip
1443
+
1444
+ Read :doc:`Propagation Pattern <./clock_driven_en/10_propagation_pattern>` for more details about single-step
1445
+ and multi-step propagation.
1446
+ """
1447
+ super().__init__(tau, delta_T, theta_rh, v_threshold, v_rest, v_reset,
1448
+ surrogate_function, detach_reset)
1449
+ self.register_memory('v_seq', None)
1450
+
1451
+ check_backend(backend)
1452
+
1453
+ self.backend = backend
1454
+
1455
+ def forward(self, x_seq: torch.Tensor):
1456
+ assert x_seq.dim() > 1
1457
+ # x_seq.shape = [T, *]
1458
+
1459
+ if self.backend == 'torch':
1460
+ spike_seq = []
1461
+ self.v_seq = []
1462
+ for t in range(x_seq.shape[0]):
1463
+ spike_seq.append(super().forward(x_seq[t]).unsqueeze(0))
1464
+ self.v_seq.append(self.v.unsqueeze(0))
1465
+ spike_seq = torch.cat(spike_seq, 0)
1466
+ self.v_seq = torch.cat(self.v_seq, 0)
1467
+ return spike_seq
1468
+
1469
+ elif self.backend == 'cupy':
1470
+ if isinstance(self.v, float):
1471
+ v_init = self.v
1472
+ self.v = torch.zeros_like(x_seq[0].data)
1473
+ if v_init != 0.:
1474
+ torch.fill_(self.v, v_init)
1475
+
1476
+ spike_seq, self.v_seq = neuron_kernel.MultiStepEIFNodePTT.apply(
1477
+ x_seq.flatten(1), self.v.flatten(0), self.tau, self.v_threshold, self.v_reset, self.v_rest,
1478
+ self.theta_rh, self.delta_T, self.detach_reset, self.surrogate_function.cuda_code)
1479
+
1480
+ spike_seq = spike_seq.reshape(x_seq.shape)
1481
+ self.v_seq = self.v_seq.reshape(x_seq.shape)
1482
+
1483
+ self.v = self.v_seq[-1].clone()
1484
+
1485
+ return spike_seq
1486
+ else:
1487
+ raise NotImplementedError
1488
+
1489
+ def extra_repr(self):
1490
+ return super().extra_repr() + f', backend={self.backend}'
1491
+
1492
+
1493
+ class GeneralNode(BaseNode):
1494
+ def __init__(self, a: float or torch.Tensor, b: float or torch.Tensor, c: float or torch.Tensor = 0.,
1495
+ v_threshold: float = 1., v_reset: float = 0.,
1496
+ surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False):
1497
+ super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
1498
+ self.a = self.register_buffer('a', torch.as_tensor(a))
1499
+ self.b = self.register_buffer('b', torch.as_tensor(b))
1500
+ self.c = self.register_buffer('c', torch.as_tensor(c))
1501
+
1502
+ def neuronal_charge(self, x: torch.Tensor):
1503
+ self.v = self.a * self.v + self.b * x + self.c
1504
+
1505
+
1506
+ class MultiStepGeneralNode(GeneralNode):
1507
+ def __init__(self, a: float, b: float, c: float, v_threshold: float = 1., v_reset: float = 0.,
1508
+ surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, backend='torch'):
1509
+
1510
+ super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
1511
+
1512
+ self.register_memory('v_seq', None)
1513
+
1514
+ check_backend(backend)
1515
+
1516
+ self.backend = backend
1517
+
1518
+ def forward(self, x_seq: torch.Tensor):
1519
+ assert x_seq.dim() > 1
1520
+ # x_seq.shape = [T, *]
1521
+
1522
+ if self.backend == 'torch':
1523
+ spike_seq = []
1524
+ self.v_seq = []
1525
+ for t in range(x_seq.shape[0]):
1526
+ spike_seq.append(super().forward(x_seq[t]).unsqueeze(0))
1527
+ self.v_seq.append(self.v.unsqueeze(0))
1528
+ spike_seq = torch.cat(spike_seq, 0)
1529
+ self.v_seq = torch.cat(self.v_seq, 0)
1530
+ return spike_seq
1531
+
1532
+ elif self.backend == 'cupy':
1533
+ if isinstance(self.v, float):
1534
+ v_init = self.v
1535
+ self.v = torch.zeros_like(x_seq[0].data)
1536
+ if v_init != 0.:
1537
+ torch.fill_(self.v, v_init)
1538
+
1539
+ raise NotImplementedError
1540
+
1541
+ spike_seq = spike_seq.reshape(x_seq.shape)
1542
+ self.v_seq = self.v_seq.reshape(x_seq.shape)
1543
+
1544
+ self.v = self.v_seq[-1].clone()
1545
+
1546
+ return spike_seq
1547
+ else:
1548
+ raise NotImplementedError
1549
+
1550
+ def extra_repr(self):
1551
+ return super().extra_repr() + f', backend={self.backend}'
1552
+
1553
+
1554
+ class LIAFNode(LIFNode):
1555
+ def __init__(self, act: Callable, threshold_related: bool, *args, **kwargs):
1556
+ """
1557
+ :param act: the activation function
1558
+ :type act: Callable
1559
+ :param threshold_related: whether the neuron uses threshold related (TR mode). If true, `y = act(h - v_th)`,
1560
+ otherwise `y = act(h)`
1561
+ :type threshold_related: bool
1562
+
1563
+ Other parameters in `*args, **kwargs` are same with :class:`LIFNode`.
1564
+
1565
+ The LIAF neuron proposed in `LIAF-Net: Leaky Integrate and Analog Fire Network for Lightweight and Efficient Spatiotemporal Information Processing <https://arxiv.org/abs/2011.06176>`_.
1566
+
1567
+ .. admonition:: Warning
1568
+ :class: warning
1569
+
1570
+ The outputs of this neuron are not binary spikes.
1571
+
1572
+ """
1573
+ super().__init__(*args, **kwargs)
1574
+ self.act = act
1575
+ self.threshold_related = threshold_related
1576
+
1577
+ def forward(self, x: torch.Tensor):
1578
+ self.neuronal_charge(x)
1579
+ if self.threshold_related:
1580
+ y = self.act(self.v - self.v_threshold)
1581
+ else:
1582
+ y = self.act(self.v)
1583
+ spike = self.neuronal_fire()
1584
+ self.neuronal_reset(spike)
1585
+ return y
1586
+
1587
+
models/q_vit/Quant.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch.nn.modules.linear import Linear
4
+ import math
5
+ from torch.nn.parameter import Parameter
6
+ from ._quan_base import _Conv2dQ, Qmodes, _LinearQ, _ActQ
7
+
8
+
9
+ __all__ = ['Conv2dQ', 'LinearQ', 'ActQ']
10
+
11
+
12
+ class FunQ(torch.autograd.Function):
13
+ @staticmethod
14
+ def forward(ctx, weight, alpha, g, Qn, Qp):
15
+ assert alpha > 0, 'alpha = {}'.format(alpha)
16
+ ctx.save_for_backward(weight, alpha)
17
+ ctx.other = g, Qn, Qp
18
+ q_w = (weight / alpha).round().clamp(Qn, Qp)
19
+ w_q = q_w * alpha
20
+ return w_q
21
+
22
+ @staticmethod
23
+ def backward(ctx, grad_weight):
24
+ weight, alpha = ctx.saved_tensors
25
+ g, Qn, Qp = ctx.other
26
+ q_w = weight / alpha
27
+ indicate_small = (q_w < Qn).float()
28
+ indicate_big = (q_w > Qp).float()
29
+ # indicate_middle = torch.ones(indicate_small.shape).to(indicate_small.device) - indicate_small - indicate_big
30
+ indicate_middle = 1.0 - indicate_small - indicate_big # Thanks to @haolibai
31
+ grad_alpha = ((indicate_small * Qn + indicate_big * Qp + indicate_middle * (
32
+ -q_w + q_w.round())) * grad_weight * g).sum().unsqueeze(dim=0)
33
+ grad_weight = indicate_middle * grad_weight
34
+ # The following operation can make sure that alpha is always greater than zero in any case and can also
35
+ # suppress the update speed of alpha. (Personal understanding)
36
+ # grad_alpha.clamp_(-alpha.item(), alpha.item()) # FYI
37
+ return grad_weight, grad_alpha, None, None, None
38
+
39
+
40
+ def grad_scale(x, scale):
41
+ y = x
42
+ y_grad = x * scale
43
+ return y.detach() - y_grad.detach() + y_grad
44
+
45
+
46
+ def round_pass(x):
47
+ y = x.round()
48
+ y_grad = x
49
+ return y.detach() - y_grad.detach() + y_grad
50
+
51
+
52
+ class Conv2dQ(_Conv2dQ):
53
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
54
+ padding=0, dilation=1, groups=1, bias=True, nbits_w=8, mode=Qmodes.kernel_wise, **kwargs):
55
+ super(Conv2dQ, self).__init__(
56
+ in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
57
+ stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias,
58
+ nbits=nbits_w, mode=mode)
59
+ self.act = ActQ(in_features=in_channels, nbits_a=nbits_w)
60
+
61
+ def forward(self, x):
62
+ if self.alpha is None:
63
+ return F.conv2d(x, self.weight, self.bias, self.stride,
64
+ self.padding, self.dilation, self.groups)
65
+ # w_reshape = self.weight.reshape([self.weight.shape[0], -1]).transpose(0, 1)
66
+ Qn = -2 ** (self.nbits - 1)
67
+ Qp = 2 ** (self.nbits - 1) - 1
68
+ if self.training and self.init_state == 0:
69
+ # self.alpha.data.copy_(self.weight.abs().max() / 2 ** (self.nbits - 1))
70
+ self.alpha.data.copy_(2 * self.weight.abs().mean() / math.sqrt(Qp))
71
+ # self.alpha.data.copy_(self.weight.abs().max() * 2)
72
+ self.init_state.fill_(1)
73
+ """
74
+ Implementation according to paper.
75
+ Feels wrong ...
76
+ When we initialize the alpha as a big number (e.g., self.weight.abs().max() * 2),
77
+ the clamp function can be skipped.
78
+ Then we get w_q = w / alpha * alpha = w, and $\frac{\partial w_q}{\partial \alpha} = 0$
79
+ As a result, I don't think the pseudo-code in the paper echoes the formula.
80
+
81
+ Please see jupyter/STE_LSQ.ipynb fo detailed comparison.
82
+ """
83
+ g = 1.0 / math.sqrt(self.weight.numel() * Qp)
84
+
85
+ # Method1: 31GB GPU memory (AlexNet w4a4 bs 2048) 17min/epoch
86
+ alpha = grad_scale(self.alpha, g)
87
+ # print(alpha.shape)
88
+ # print(self.weight.shape)
89
+ alpha = alpha.unsqueeze(1).unsqueeze(2).unsqueeze(3)
90
+ w_q = round_pass((self.weight / alpha).clamp(Qn, Qp)) * alpha
91
+
92
+ x = self.act(x)
93
+ # w = w.clamp(Qn, Qp)
94
+ # q_w = round_pass(w)
95
+ # w_q = q_w * alpha
96
+
97
+ # Method2: 25GB GPU memory (AlexNet w4a4 bs 2048) 32min/epoch
98
+ # w_q = FunLSQ.apply(self.weight, self.alpha, g, Qn, Qp)
99
+ # wq = y.transpose(0, 1).reshape(self.weight.shape).detach() + self.weight - self.weight.detach()
100
+ return F.conv2d(x, w_q, self.bias, self.stride,
101
+ self.padding, self.dilation, self.groups)
102
+
103
+
104
+ class LinearQ(_LinearQ):
105
+ def __init__(self, in_features, out_features, bias=True, nbits_w=4, **kwargs):
106
+ super(LinearQ, self).__init__(in_features=in_features,
107
+ out_features=out_features, bias=bias, nbits=nbits_w, mode=Qmodes.kernel_wise)
108
+ self.act = ActQ(in_features=in_features, nbits_a=nbits_w)
109
+
110
+ def forward(self, x):
111
+ if self.alpha is None:
112
+ return F.linear(x, self.weight, self.bias)
113
+ Qn = -2 ** (self.nbits - 1)
114
+ Qp = 2 ** (self.nbits - 1) - 1
115
+ if self.training and self.init_state == 0:
116
+ self.alpha.data.copy_(2 * self.weight.abs().mean() / math.sqrt(Qp))
117
+ # self.alpha.data.copy_(self.weight.abs().max() / 2 ** (self.nbits - 1))
118
+ self.init_state.fill_(1)
119
+ g = 1.0 / math.sqrt(self.weight.numel() * Qp)
120
+
121
+ # Method1:
122
+ alpha = grad_scale(self.alpha, g)
123
+ alpha = alpha.unsqueeze(1)
124
+ w_q = round_pass((self.weight / alpha).clamp(Qn, Qp)) * alpha
125
+
126
+ x = self.act(x)
127
+ # w = self.weight / alpha
128
+ # w = w.clamp(Qn, Qp)
129
+ # q_w = round_pass(w)
130
+ # w_q = q_w * alpha
131
+
132
+ # Method2:
133
+ # w_q = FunLSQ.apply(self.weight, self.alpha, g, Qn, Qp)
134
+ return F.linear(x, w_q, self.bias)
135
+
136
+
137
+ class ActQ(_ActQ):
138
+ def __init__(self, in_features, nbits_a=4, mode=Qmodes.kernel_wise, **kwargs):
139
+ super(ActQ, self).__init__(in_features=in_features, nbits=nbits_a, mode=mode)
140
+ # print(self.alpha.shape, self.zero_point.shape)
141
+ def forward(self, x):
142
+ if self.alpha is None:
143
+ return x
144
+
145
+ if self.training and self.init_state == 0:
146
+ # The init alpha for activation is very very important as the experimental results shows.
147
+ # Please select a init_rate for activation.
148
+ # self.alpha.data.copy_(x.max() / 2 ** (self.nbits - 1) * self.init_rate)
149
+ if x.min() < -1e-5:
150
+ self.signed.data.fill_(1)
151
+ if self.signed == 1:
152
+ Qn = -2 ** (self.nbits - 1)
153
+ Qp = 2 ** (self.nbits - 1) - 1
154
+ else:
155
+ Qn = 0
156
+ Qp = 2 ** self.nbits - 1
157
+ self.alpha.data.copy_(2 * x.abs().mean() / math.sqrt(Qp))
158
+ self.zero_point.data.copy_(self.zero_point.data * 0.9 + 0.1 * (torch.min(x.detach()) - self.alpha.data * Qn))
159
+ self.init_state.fill_(1)
160
+
161
+ if self.signed == 1:
162
+ Qn = -2 ** (self.nbits - 1)
163
+ Qp = 2 ** (self.nbits - 1) - 1
164
+ else:
165
+ Qn = 0
166
+ Qp = 2 ** self.nbits - 1
167
+
168
+ g = 1.0 / math.sqrt(x.numel() * Qp)
169
+
170
+ # Method1:
171
+ zero_point = (self.zero_point.round() - self.zero_point).detach() + self.zero_point
172
+ alpha = grad_scale(self.alpha, g)
173
+ zero_point = grad_scale(zero_point, g)
174
+ # x = round_pass((x / alpha).clamp(Qn, Qp)) * alpha
175
+ if len(x.shape)==2:
176
+ alpha = alpha.unsqueeze(0)
177
+ zero_point = zero_point.unsqueeze(0)
178
+ elif len(x.shape)==4:
179
+ alpha = alpha.unsqueeze(0).unsqueeze(2).unsqueeze(3)
180
+ zero_point = zero_point.unsqueeze(0).unsqueeze(2).unsqueeze(3)
181
+
182
+ x = round_pass((x / alpha + zero_point).clamp(Qn, Qp))
183
+ x = (x - zero_point) * alpha
184
+
185
+ return x
models/q_vit/__init__.py ADDED
File without changes
models/q_vit/__pycache__/Quant.cpython-311.pyc ADDED
Binary file (11.1 kB). View file
 
models/q_vit/__pycache__/Quant.cpython-312.pyc ADDED
Binary file (10.5 kB). View file
 
models/q_vit/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (207 Bytes). View file
 
models/q_vit/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (218 Bytes). View file
 
models/q_vit/__pycache__/_quan_base.cpython-311.pyc ADDED
Binary file (12.8 kB). View file
 
models/q_vit/__pycache__/_quan_base.cpython-312.pyc ADDED
Binary file (11.3 kB). View file
 
models/q_vit/__pycache__/quant_vision_transformer.cpython-311.pyc ADDED
Binary file (33.1 kB). View file
 
models/q_vit/__pycache__/quant_vision_transformer.cpython-312.pyc ADDED
Binary file (30.1 kB). View file
 
models/q_vit/_quan_base.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Quantized modules: the base class
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn.parameter import Parameter
7
+
8
+ import math
9
+ from enum import Enum
10
+
11
+ __all__ = ['Qmodes', '_Conv2dQ', '_LinearQ', '_ActQ',
12
+ 'truncation', 'get_sparsity_mask', 'FunStopGradient', 'round_pass', 'grad_scale']
13
+
14
+
15
+ class Qmodes(Enum):
16
+ layer_wise = 1
17
+ kernel_wise = 2
18
+
19
+
20
+ def grad_scale(x, scale):
21
+ y = x
22
+ y_grad = x * scale
23
+ return y.detach() - y_grad.detach() + y_grad
24
+
25
+
26
+ def get_sparsity_mask(param, sparsity):
27
+ bottomk, _ = torch.topk(param.abs().view(-1), int(sparsity * param.numel()), largest=False, sorted=True)
28
+ threshold = bottomk.data[-1] # This is the largest element from the group of elements that we prune away
29
+ return torch.gt(torch.abs(param), threshold).type(param.type())
30
+
31
+
32
+ def round_pass(x):
33
+ y = x.round()
34
+ y_grad = x
35
+ return y.detach() - y_grad.detach() + y_grad
36
+
37
+
38
+ class FunStopGradient(torch.autograd.Function):
39
+
40
+ @staticmethod
41
+ def forward(ctx, weight, stopGradientMask):
42
+ ctx.save_for_backward(stopGradientMask)
43
+ return weight
44
+
45
+ @staticmethod
46
+ def backward(ctx, grad_outputs):
47
+ stopGradientMask, = ctx.saved_tensors
48
+ grad_inputs = grad_outputs * stopGradientMask
49
+ return grad_inputs, None
50
+
51
+
52
+ def log_shift(value_fp):
53
+ value_shift = 2 ** (torch.log2(value_fp).ceil())
54
+ return value_shift
55
+
56
+
57
+ def clamp(input, min, max, inplace=False):
58
+ if inplace:
59
+ input.clamp_(min, max)
60
+ return input
61
+ return torch.clamp(input, min, max)
62
+
63
+
64
+ def get_quantized_range(num_bits, signed=True):
65
+ if signed:
66
+ n = 2 ** (num_bits - 1)
67
+ return -n, n - 1
68
+ return 0, 2 ** num_bits - 1
69
+
70
+
71
+ def linear_quantize(input, scale_factor, inplace=False):
72
+ if inplace:
73
+ input.mul_(scale_factor).round_()
74
+ return input
75
+ return torch.round(scale_factor * input)
76
+
77
+
78
+ def linear_quantize_clamp(input, scale_factor, clamp_min, clamp_max, inplace=False):
79
+ output = linear_quantize(input, scale_factor, inplace)
80
+ return clamp(output, clamp_min, clamp_max, inplace)
81
+
82
+
83
+ def linear_dequantize(input, scale_factor, inplace=False):
84
+ if inplace:
85
+ input.div_(scale_factor)
86
+ return input
87
+ return input / scale_factor
88
+
89
+
90
+ def truncation(fp_data, nbits=8):
91
+ il = torch.log2(torch.max(fp_data.max(), fp_data.min().abs())) + 1
92
+ il = math.ceil(il - 1e-5)
93
+ qcode = nbits - il
94
+ scale_factor = 2 ** qcode
95
+ clamp_min, clamp_max = get_quantized_range(nbits, signed=True)
96
+ q_data = linear_quantize_clamp(fp_data, scale_factor, clamp_min, clamp_max)
97
+ q_data = linear_dequantize(q_data, scale_factor)
98
+ return q_data, qcode
99
+
100
+
101
+ def get_default_kwargs_q(kwargs_q, layer_type):
102
+ default = {
103
+ 'nbits': 4
104
+ }
105
+ if isinstance(layer_type, _Conv2dQ):
106
+ default.update({
107
+ 'mode': Qmodes.layer_wise})
108
+ elif isinstance(layer_type, _LinearQ):
109
+ pass
110
+ elif isinstance(layer_type, _ActQ):
111
+ pass
112
+ # default.update({
113
+ # 'signed': 'Auto'})
114
+ else:
115
+ assert NotImplementedError
116
+ return
117
+ for k, v in default.items():
118
+ if k not in kwargs_q:
119
+ kwargs_q[k] = v
120
+ return kwargs_q
121
+
122
+
123
+ class _Conv2dQ(nn.Conv2d):
124
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
125
+ padding=0, dilation=1, groups=1, bias=True, **kwargs_q):
126
+ super(_Conv2dQ, self).__init__(in_channels, out_channels, kernel_size, stride=stride,
127
+ padding=padding, dilation=dilation, groups=groups, bias=bias)
128
+ self.kwargs_q = get_default_kwargs_q(kwargs_q, layer_type=self)
129
+ self.nbits = kwargs_q['nbits']
130
+ if self.nbits < 0:
131
+ self.register_parameter('alpha', None)
132
+ return
133
+ self.q_mode = kwargs_q['mode']
134
+ if self.q_mode == Qmodes.kernel_wise:
135
+ self.alpha = Parameter(torch.Tensor(out_channels))
136
+ else: # layer-wise quantization
137
+ self.alpha = Parameter(torch.Tensor(1))
138
+ self.register_buffer('init_state', torch.zeros(1))
139
+
140
+ def add_param(self, param_k, param_v):
141
+ self.kwargs_q[param_k] = param_v
142
+
143
+ def set_bit(self, nbits):
144
+ self.kwargs_q['nbits'] = nbits
145
+
146
+ def extra_repr(self):
147
+ s_prefix = super(_Conv2dQ, self).extra_repr()
148
+ if self.alpha is None:
149
+ return '{}, fake'.format(s_prefix)
150
+ return '{}, {}'.format(s_prefix, self.kwargs_q)
151
+
152
+
153
+ class _LinearQ(nn.Linear):
154
+ def __init__(self, in_features, out_features, bias=True, **kwargs_q):
155
+ super(_LinearQ, self).__init__(in_features=in_features, out_features=out_features, bias=bias)
156
+ self.kwargs_q = get_default_kwargs_q(kwargs_q, layer_type=self)
157
+ self.nbits = kwargs_q['nbits']
158
+ if self.nbits < 0:
159
+ self.register_parameter('alpha', None)
160
+ return
161
+ self.q_mode = kwargs_q['mode']
162
+ self.alpha = Parameter(torch.Tensor(1))
163
+ if self.q_mode == Qmodes.kernel_wise:
164
+ self.alpha = Parameter(torch.Tensor(out_features))
165
+ self.register_buffer('init_state', torch.zeros(1))
166
+
167
+ def add_param(self, param_k, param_v):
168
+ self.kwargs_q[param_k] = param_v
169
+
170
+ def extra_repr(self):
171
+ s_prefix = super(_LinearQ, self).extra_repr()
172
+ if self.alpha is None:
173
+ return '{}, fake'.format(s_prefix)
174
+ return '{}, {}'.format(s_prefix, self.kwargs_q)
175
+
176
+
177
+ class _ActQ(nn.Module):
178
+ def __init__(self, in_features, **kwargs_q):
179
+ super(_ActQ, self).__init__()
180
+ self.kwargs_q = get_default_kwargs_q(kwargs_q, layer_type=self)
181
+ self.nbits = kwargs_q['nbits']
182
+ if self.nbits < 0:
183
+ self.register_parameter('alpha', None)
184
+ self.register_parameter('zero_point', None)
185
+ return
186
+ # self.signed = kwargs_q['signed']
187
+ self.q_mode = kwargs_q['mode']
188
+ self.alpha = Parameter(torch.Tensor(1))
189
+ self.zero_point = Parameter(torch.Tensor([0]))
190
+ if self.q_mode == Qmodes.kernel_wise:
191
+ self.alpha = Parameter(torch.Tensor(in_features))
192
+ self.zero_point = Parameter(torch.Tensor(in_features))
193
+ torch.nn.init.zeros_(self.zero_point)
194
+ # self.zero_point = Parameter(torch.Tensor([0]))
195
+ self.register_buffer('init_state', torch.zeros(1))
196
+ self.register_buffer('signed', torch.zeros(1))
197
+
198
+ def add_param(self, param_k, param_v):
199
+ self.kwargs_q[param_k] = param_v
200
+
201
+ def set_bit(self, nbits):
202
+ self.kwargs_q['nbits'] = nbits
203
+
204
+ def extra_repr(self):
205
+ # s_prefix = super(_ActQ, self).extra_repr()
206
+ if self.alpha is None:
207
+ return 'fake'
208
+ return '{}'.format(self.kwargs_q)
models/q_vit/quant_vision_transformer.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import logging
3
+ from functools import partial
4
+ from collections import OrderedDict
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
11
+ from timm.models.helpers import load_pretrained
12
+ from timm.models.layers import Mlp
13
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
14
+ from timm.models.resnet import resnet26d, resnet50d
15
+ from timm.models.registry import register_model
16
+
17
+ import numpy as np
18
+ from .Quant import *
19
+ from ._quan_base import *
20
+
21
+
22
+ _logger = logging.getLogger(__name__)
23
+
24
+
25
+ def _cfg(url='', **kwargs):
26
+ return {
27
+ 'url': url,
28
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
29
+ 'crop_pct': .9, 'interpolation': 'bicubic',
30
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
31
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
32
+ **kwargs
33
+ }
34
+
35
+
36
+ default_cfgs = {
37
+ # patch models (my experiments)
38
+ 'vit_small_patch16_224': _cfg(
39
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
40
+ ),
41
+
42
+ # patch models (weights ported from official Google JAX impl)
43
+ 'vit_base_patch16_224': _cfg(
44
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
45
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
46
+ ),
47
+ 'vit_base_patch32_224': _cfg(
48
+ url='', # no official model weights for this combo, only for in21k
49
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
50
+ 'vit_base_patch16_384': _cfg(
51
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
52
+ input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
53
+ 'vit_base_patch32_384': _cfg(
54
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth',
55
+ input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
56
+ 'vit_large_patch16_224': _cfg(
57
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
58
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
59
+ 'vit_large_patch32_224': _cfg(
60
+ url='', # no official model weights for this combo, only for in21k
61
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
62
+ 'vit_large_patch16_384': _cfg(
63
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
64
+ input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
65
+ 'vit_large_patch32_384': _cfg(
66
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
67
+ input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
68
+
69
+ # patch models, imagenet21k (weights ported from official Google JAX impl)
70
+ 'vit_base_patch16_224_in21k': _cfg(
71
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth',
72
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
73
+ 'vit_base_patch32_224_in21k': _cfg(
74
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth',
75
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
76
+ 'vit_large_patch16_224_in21k': _cfg(
77
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth',
78
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
79
+ 'vit_large_patch32_224_in21k': _cfg(
80
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
81
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
82
+ 'vit_huge_patch14_224_in21k': _cfg(
83
+ url='', # FIXME I have weights for this but > 2GB limit for github release binaries
84
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
85
+
86
+ # hybrid models (weights ported from official Google JAX impl)
87
+ 'vit_base_resnet50_224_in21k': _cfg(
88
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
89
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9, first_conv='patch_embed.backbone.stem.conv'),
90
+ 'vit_base_resnet50_384': _cfg(
91
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
92
+ input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'),
93
+
94
+ # hybrid models (my experiments)
95
+ 'vit_small_resnet26d_224': _cfg(),
96
+ 'vit_small_resnet50d_s3_224': _cfg(),
97
+ 'vit_base_resnet26d_224': _cfg(),
98
+ 'vit_base_resnet50d_224': _cfg(),
99
+
100
+ # deit models (FB weights)
101
+ 'vit_deit_tiny_patch16_224': _cfg(
102
+ url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
103
+ 'vit_deit_small_patch16_224': _cfg(
104
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
105
+ 'vit_deit_base_patch16_224': _cfg(
106
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',),
107
+ 'vit_deit_base_patch16_384': _cfg(
108
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
109
+ input_size=(3, 384, 384), crop_pct=1.0),
110
+ 'vit_deit_tiny_distilled_patch16_224': _cfg(
111
+ url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth'),
112
+ 'vit_deit_small_distilled_patch16_224': _cfg(
113
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth'),
114
+ 'vit_deit_base_distilled_patch16_224': _cfg(
115
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', ),
116
+ 'vit_deit_base_distilled_patch16_384': _cfg(
117
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
118
+ input_size=(3, 384, 384), crop_pct=1.0),
119
+ }
120
+
121
+ class Q_Mlp(nn.Module):
122
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
123
+ """
124
+ def __init__(self, nbits, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
125
+ super().__init__()
126
+ out_features = out_features or in_features
127
+ hidden_features = hidden_features or in_features
128
+ drop_probs = to_2tuple(drop)
129
+
130
+ self.fc1 = LinearQ(in_features, hidden_features, nbits_w=nbits, mode=Qmodes.kernel_wise)
131
+ self.act = act_layer()
132
+ self.drop1 = nn.Dropout(drop_probs[0])
133
+ self.fc2 = LinearQ(hidden_features, out_features, nbits_w=nbits, mode=Qmodes.kernel_wise)
134
+ self.drop2 = nn.Dropout(drop_probs[1])
135
+
136
+ def forward(self, x):
137
+ x = self.fc1(x)
138
+ # print(torch.max(x), torch.min(x))
139
+ x = self.act(x)
140
+
141
+ x = torch.clip(x, -10., 10.)
142
+ # print(torch.clip(x, -10., 10.))
143
+ x = self.drop1(x)
144
+ x = self.fc2(x)
145
+ x = self.drop2(x)
146
+ return x
147
+
148
+
149
+ class Q_Attention(nn.Module):
150
+
151
+ def __init__(self, nbits, dim, num_heads=8, quantize_attn=True, qkv_bias=False, attn_drop=0., proj_drop=0.):
152
+ super().__init__()
153
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
154
+ self.num_heads = num_heads
155
+ head_dim = dim // num_heads
156
+ self.scale = head_dim ** -0.5
157
+ self.quantize_attn = quantize_attn
158
+
159
+ self.norm_q = nn.LayerNorm(head_dim)
160
+ self.norm_k = nn.LayerNorm(head_dim)
161
+
162
+
163
+ if self.quantize_attn:
164
+
165
+ self.qkv = LinearQ(dim, dim * 3, bias=qkv_bias, nbits_w=nbits, mode=Qmodes.kernel_wise)
166
+
167
+ self.attn_drop = nn.Dropout(attn_drop)
168
+
169
+ self.proj = LinearQ(dim, dim, nbits_w=nbits, mode=Qmodes.kernel_wise)
170
+ self.q_act = ActQ(nbits_a=nbits, in_features=self.num_heads)
171
+ self.k_act = ActQ(nbits_a=nbits, in_features=self.num_heads)
172
+ self.v_act = ActQ(nbits_a=nbits, in_features=self.num_heads)
173
+ self.attn_act = ActQ(nbits_a=nbits, in_features=self.num_heads)
174
+ else:
175
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
176
+ self.attn_drop = nn.Dropout(attn_drop)
177
+ self.proj = nn.Linear(dim, dim)
178
+ self.q_act = ActQ(nbits_a=nbits, in_features=self.num_heads)
179
+ self.k_act = ActQ(nbits_a=nbits, in_features=self.num_heads)
180
+ self.v_act = ActQ(nbits_a=nbits, in_features=self.num_heads)
181
+ self.attn_act = ActQ(nbits_a=nbits, in_features=self.num_heads)
182
+
183
+ self.proj_drop = nn.Dropout(proj_drop)
184
+
185
+ def forward(self, x):
186
+ B, N, C = x.shape
187
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
188
+
189
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
190
+ q = self.norm_q(q)
191
+ k = self.norm_k(k)
192
+
193
+ q = self.q_act(q)
194
+ k = self.k_act(k)
195
+ v = self.v_act(v)
196
+
197
+ attn = (q @ k.transpose(-2, -1)) * self.scale
198
+ attn = attn.softmax(dim=-1)
199
+ attn = self.attn_drop(attn)
200
+ attn = self.attn_act(attn)
201
+
202
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
203
+
204
+ x = self.proj(x)
205
+ x = self.proj_drop(x)
206
+ return x
207
+
208
+
209
+ class Q_Block(nn.Module):
210
+
211
+ def __init__(self, nbits, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
212
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
213
+ super().__init__()
214
+ self.norm1 = norm_layer(dim)
215
+ self.attn = Q_Attention(nbits, dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
216
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
217
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
218
+ self.norm2 = norm_layer(dim)
219
+ mlp_hidden_dim = int(dim * mlp_ratio)
220
+ self.mlp = Q_Mlp(nbits=nbits, in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
221
+
222
+ def forward(self, x):
223
+ x = x + self.drop_path(self.attn(self.norm1(x)))
224
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
225
+ return x
226
+
227
+ class Q_PatchEmbed(nn.Module):
228
+ """ Image to Patch Embedding
229
+ """
230
+ def __init__(self, nbits=4, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
231
+ super().__init__()
232
+ img_size = to_2tuple(img_size)
233
+ patch_size = to_2tuple(patch_size)
234
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
235
+ self.img_size = img_size
236
+ self.patch_size = patch_size
237
+ self.num_patches = num_patches
238
+
239
+ self.proj = Conv2dQ(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
240
+ # nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
241
+
242
+ def forward(self, x):
243
+ B, C, H, W = x.shape
244
+ # FIXME look at relaxing size constraints
245
+ assert H == self.img_size[0] and W == self.img_size[1], \
246
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
247
+ x = self.proj(x).flatten(2).transpose(1, 2)
248
+ return x
249
+
250
+ class lowbit_VisionTransformer(nn.Module):
251
+ """ Vision Transformer
252
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
253
+ - https://arxiv.org/abs/2010.11929
254
+ Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
255
+ - https://arxiv.org/abs/2012.12877
256
+ """
257
+
258
+ def __init__(self, nbits, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
259
+ num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=True,
260
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=Q_PatchEmbed, norm_layer=None,
261
+ act_layer=None, weight_init=''):
262
+ """
263
+ Args:
264
+ nbits: nbits
265
+ img_size (int, tuple): input image size
266
+ patch_size (int, tuple): patch size
267
+ in_chans (int): number of input channels
268
+ num_classes (int): number of classes for classification head
269
+ embed_dim (int): embedding dimension
270
+ depth (int): depth of transformer
271
+ num_heads (int): number of attention heads
272
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
273
+ qkv_bias (bool): enable bias for qkv if True
274
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
275
+ distilled (bool): model includes a distillation token and head as in DeiT models
276
+ drop_rate (float): dropout rate
277
+ attn_drop_rate (float): attention dropout rate
278
+ drop_path_rate (float): stochastic depth rate
279
+ embed_layer (nn.Module): patch embedding layer
280
+ norm_layer: (nn.Module): normalization layer
281
+ weight_init: (str): weight init scheme
282
+ """
283
+ super().__init__()
284
+ self.num_classes = num_classes
285
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
286
+ self.num_tokens = 2 if distilled else 1
287
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
288
+ act_layer = act_layer or nn.GELU
289
+
290
+ self.patch_embed = embed_layer(
291
+ nbits=nbits, img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
292
+ num_patches = self.patch_embed.num_patches
293
+
294
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
295
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
296
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
297
+ self.pos_drop = nn.Dropout(p=drop_rate)
298
+
299
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
300
+ self.blocks = nn.Sequential(*[
301
+ Q_Block(
302
+ nbits=nbits, dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
303
+ attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
304
+ for i in range(depth)])
305
+ self.norm = norm_layer(embed_dim)
306
+
307
+ # Representation layer
308
+ if representation_size and not distilled:
309
+ self.num_features = representation_size
310
+ self.pre_logits = nn.Sequential(OrderedDict([
311
+ ('fc', nn.Linear(embed_dim, representation_size)),
312
+ ('act', nn.Tanh())
313
+ ]))
314
+ else:
315
+ self.pre_logits = nn.Identity()
316
+
317
+ # Classifier head(s)
318
+ self.head = LinearQ(self.num_features, num_classes, nbits_w=8) if num_classes > 0 else nn.Identity()
319
+ # nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
320
+ self.head_dist = None
321
+ if distilled:
322
+ self.head_dist = LinearQ(self.embed_dim, self.num_classes, nbits_w=8) if num_classes > 0 else nn.Identity()
323
+ # self.head = LinearQ(self.embed_dim, self.num_classes, nbits_w=8) if num_classes > 0 else nn.Identity()
324
+ # nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
325
+
326
+ self.init_weights(weight_init)
327
+
328
+ def init_weights(self, mode=''):
329
+ assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
330
+ head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
331
+ trunc_normal_(self.pos_embed, std=.02)
332
+ if self.dist_token is not None:
333
+ trunc_normal_(self.dist_token, std=.02)
334
+ if mode.startswith('jax'):
335
+ # leave cls token as zeros to match jax impl
336
+ named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self)
337
+ else:
338
+ trunc_normal_(self.cls_token, std=.02)
339
+ self.apply(_init_vit_weights)
340
+
341
+ def _init_weights(self, m):
342
+ # this fn left here for compat with downstream users
343
+ _init_vit_weights(m)
344
+
345
+ @torch.jit.ignore()
346
+ def load_pretrained(self, checkpoint_path, prefix=''):
347
+ _load_weights(self, checkpoint_path, prefix)
348
+
349
+ @torch.jit.ignore
350
+ def no_weight_decay(self):
351
+ return {'pos_embed', 'cls_token', 'dist_token'}
352
+
353
+ def get_classifier(self):
354
+ if self.dist_token is None:
355
+ return self.head
356
+ else:
357
+ return self.head, self.head_dist
358
+
359
+ def reset_classifier(self, num_classes, global_pool=''):
360
+ self.num_classes = num_classes
361
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
362
+ if self.num_tokens == 2:
363
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
364
+
365
+ def forward_features(self, x):
366
+ x = self.patch_embed(x)
367
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
368
+ if self.dist_token is None:
369
+ x = torch.cat((cls_token, x), dim=1)
370
+ else:
371
+ x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
372
+ x = self.pos_drop(x + self.pos_embed)
373
+ x = self.blocks(x)
374
+ x = self.norm(x)
375
+ if self.dist_token is None:
376
+ return self.pre_logits(x[:, 0])
377
+ else:
378
+ return x[:, 0], x[:, 1]
379
+
380
+ def forward(self, x):
381
+ x = self.forward_features(x)
382
+ if self.head_dist is not None:
383
+ x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
384
+ if self.training and not torch.jit.is_scripting():
385
+ # during inference, return the average of both classifier predictions
386
+ return x, x_dist
387
+ else:
388
+ return (x + x_dist) / 2
389
+ else:
390
+ x = self.head(x)
391
+ return x
392
+
393
+ def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
394
+ """ ViT weight initialization
395
+ * When called without n, head_bias, jax_impl args it will behave exactly the same
396
+ as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
397
+ * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
398
+ """
399
+ if isinstance(module, nn.Linear):
400
+ if name.startswith('head'):
401
+ nn.init.zeros_(module.weight)
402
+ nn.init.constant_(module.bias, head_bias)
403
+ elif name.startswith('pre_logits'):
404
+ lecun_normal_(module.weight)
405
+ nn.init.zeros_(module.bias)
406
+ else:
407
+ if jax_impl:
408
+ nn.init.xavier_uniform_(module.weight)
409
+ if module.bias is not None:
410
+ if 'mlp' in name:
411
+ nn.init.normal_(module.bias, std=1e-6)
412
+ else:
413
+ nn.init.zeros_(module.bias)
414
+ else:
415
+ trunc_normal_(module.weight, std=.02)
416
+ if module.bias is not None:
417
+ nn.init.zeros_(module.bias)
418
+ elif jax_impl and isinstance(module, nn.Conv2d):
419
+ # NOTE conv was left to pytorch default in my original init
420
+ lecun_normal_(module.weight)
421
+ if module.bias is not None:
422
+ nn.init.zeros_(module.bias)
423
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
424
+ nn.init.zeros_(module.bias)
425
+ nn.init.ones_(module.weight)
426
+
427
+ def resize_pos_embed(posemb, posemb_new):
428
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
429
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
430
+ _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
431
+ ntok_new = posemb_new.shape[1]
432
+ if True:
433
+ posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
434
+ ntok_new -= 1
435
+ else:
436
+ posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
437
+ gs_old = int(math.sqrt(len(posemb_grid)))
438
+ gs_new = int(math.sqrt(ntok_new))
439
+ _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new)
440
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
441
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear')
442
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1)
443
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
444
+ return posemb
445
+
446
+
447
+ def checkpoint_filter_fn(state_dict, model):
448
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
449
+ out_dict = {}
450
+ if 'model' in state_dict:
451
+ # For deit models
452
+ state_dict = state_dict['model']
453
+ for k, v in state_dict.items():
454
+ if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
455
+ # For old models that I trained prior to conv based patchification
456
+ O, I, H, W = model.patch_embed.proj.weight.shape
457
+ v = v.reshape(O, -1, H, W)
458
+ elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
459
+ # To resize pos embedding when using model at different size from pretrained weights
460
+ v = resize_pos_embed(v, model.pos_embed)
461
+ out_dict[k] = v
462
+ return out_dict
463
+
464
+
465
+ def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs):
466
+ default_cfg = default_cfgs[variant]
467
+ default_num_classes = default_cfg['num_classes']
468
+ default_img_size = default_cfg['input_size'][-1]
469
+
470
+ num_classes = kwargs.pop('num_classes', default_num_classes)
471
+ img_size = kwargs.pop('img_size', default_img_size)
472
+ repr_size = kwargs.pop('representation_size', None)
473
+ if repr_size is not None and num_classes != default_num_classes:
474
+ # Remove representation layer if fine-tuning. This may not always be the desired action,
475
+ # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
476
+ _logger.warning("Removing representation layer for fine-tuning.")
477
+ repr_size = None
478
+
479
+ model_cls = DistilledVisionTransformer if distilled else VisionTransformer
480
+ model = model_cls(img_size=img_size, num_classes=num_classes, representation_size=repr_size, **kwargs)
481
+ model.default_cfg = default_cfg
482
+
483
+ if pretrained:
484
+ load_pretrained(
485
+ model, num_classes=num_classes, in_chans=kwargs.get('in_chans', 3),
486
+ filter_fn=partial(checkpoint_filter_fn, model=model))
487
+ return model
488
+
489
+
490
+ @register_model
491
+ def fourbits_deit_small_patch16_224(pretrained=False, **kwargs):
492
+ model = lowbit_VisionTransformer(
493
+ nbits=4, patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
494
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
495
+ model.default_cfg = _cfg()
496
+ if pretrained:
497
+ torch.hub.load_state_dict_from_url(
498
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
499
+ map_location="cpu", check_hash=True
500
+ )
501
+ return model
502
+
503
+ @register_model
504
+ def threebits_deit_small_patch16_224(pretrained=False, **kwargs):
505
+ model = lowbit_VisionTransformer(
506
+ nbits=3, patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
507
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
508
+ model.default_cfg = _cfg()
509
+ if pretrained:
510
+ torch.hub.load_state_dict_from_url(
511
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
512
+ map_location="cpu", check_hash=True
513
+ )
514
+ return model
515
+
516
+ @register_model
517
+ def twobits_deit_small_patch16_224(pretrained=False, **kwargs):
518
+ model = lowbit_VisionTransformer(
519
+ nbits=2, patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
520
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
521
+ model.default_cfg = _cfg()
522
+ if pretrained:
523
+ torch.hub.load_state_dict_from_url(
524
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
525
+ map_location="cpu", check_hash=True
526
+ )
527
+ return model
models/qk_model_v1_1003.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from spikingjelly.clock_driven.neuron import MultiStepParametricLIFNode, MultiStepLIFNode
4
+ from timm.models.layers import to_2tuple, trunc_normal_, DropPath
5
+ from timm.models.registry import register_model
6
+ from timm.models.vision_transformer import _cfg
7
+ from functools import partial
8
+ from timm.models import create_model
9
+
10
+ __all__ = ['QKFormer']
11
+
12
+ class MLP(nn.Module):
13
+ def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
14
+ super().__init__()
15
+ out_features = out_features or in_features
16
+ hidden_features = hidden_features or in_features
17
+ self.mlp1_conv = nn.Conv2d(in_features, hidden_features, kernel_size=1, stride=1)
18
+ self.mlp1_bn = nn.BatchNorm2d(hidden_features)
19
+ self.mlp1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
20
+
21
+ self.mlp2_conv = nn.Conv2d(hidden_features, out_features, kernel_size=1, stride=1)
22
+ self.mlp2_bn = nn.BatchNorm2d(out_features)
23
+ self.mlp2_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
24
+
25
+ self.c_hidden = hidden_features
26
+ self.c_output = out_features
27
+
28
+ def forward(self, x):
29
+ T, B, C, H, W = x.shape
30
+
31
+ x = self.mlp1_conv(x.flatten(0, 1))
32
+ x = self.mlp1_bn(x).reshape(T, B, self.c_hidden, H, W)
33
+ x = self.mlp1_lif(x)
34
+
35
+ x = self.mlp2_conv(x.flatten(0, 1))
36
+ x = self.mlp2_bn(x).reshape(T, B, C, H, W)
37
+ x = self.mlp2_lif(x)
38
+ return x
39
+
40
+ class Token_QK_Attention(nn.Module):
41
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
42
+ super().__init__()
43
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
44
+
45
+ self.dim = dim
46
+ self.num_heads = num_heads
47
+
48
+ self.q_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
49
+ self.q_bn = nn.BatchNorm1d(dim)
50
+ self.q_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
51
+
52
+ self.k_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
53
+ self.k_bn = nn.BatchNorm1d(dim)
54
+ self.k_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
55
+
56
+ self.attn_lif = MultiStepLIFNode(tau=2.0, v_threshold=0.5, detach_reset=True, backend='torch')
57
+
58
+ self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1)
59
+ self.proj_bn = nn.BatchNorm1d(dim)
60
+ self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
61
+
62
+ def forward(self, x):
63
+ T, B, C, H, W = x.shape
64
+
65
+ x = x.flatten(3)
66
+ T, B, C, N = x.shape
67
+ x_for_qkv = x.flatten(0, 1)
68
+
69
+ q_conv_out = self.q_conv(x_for_qkv)
70
+ q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, N)
71
+ q_conv_out = self.q_lif(q_conv_out)
72
+ q = q_conv_out.unsqueeze(2).reshape(T, B, self.num_heads, C // self.num_heads, N)
73
+
74
+ k_conv_out = self.k_conv(x_for_qkv)
75
+ k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, N)
76
+ k_conv_out = self.k_lif(k_conv_out)
77
+ k = k_conv_out.unsqueeze(2).reshape(T, B, self.num_heads, C // self.num_heads, N)
78
+
79
+ q = torch.sum(q, dim=3, keepdim=True)
80
+ attn = self.attn_lif(q)
81
+ x = torch.mul(attn, k)
82
+
83
+ x = x.flatten(2, 3)
84
+ x = self.proj_bn(self.proj_conv(x.flatten(0, 1))).reshape(T, B, C, H, W)
85
+ # print(f"proj_conv out shape: {x.shape}")
86
+ x = self.proj_lif(x)
87
+ return x
88
+
89
+ class Spiking_Self_Attention(nn.Module):
90
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
91
+ super().__init__()
92
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
93
+ self.dim = dim
94
+ self.num_heads = num_heads
95
+ head_dim = dim // num_heads
96
+ self.scale = 0.125
97
+ self.q_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
98
+ self.q_bn = nn.BatchNorm1d(dim)
99
+ self.q_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
100
+
101
+ self.k_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
102
+ self.k_bn = nn.BatchNorm1d(dim)
103
+ self.k_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
104
+
105
+ self.v_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
106
+ self.v_bn = nn.BatchNorm1d(dim)
107
+ self.v_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
108
+ self.attn_lif = MultiStepLIFNode(tau=2.0, v_threshold=0.5, detach_reset=True, backend='torch')
109
+
110
+ self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1)
111
+ self.proj_bn = nn.BatchNorm1d(dim)
112
+ self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
113
+
114
+ self.qkv_mp = nn.MaxPool1d(4)
115
+
116
+ def forward(self, x):
117
+ T, B, C, H, W = x.shape
118
+
119
+ x = x.flatten(3)
120
+ T, B, C, N = x.shape
121
+ x_for_qkv = x.flatten(0, 1)
122
+
123
+ q_conv_out = self.q_conv(x_for_qkv)
124
+ q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, N).contiguous()
125
+ q_conv_out = self.q_lif(q_conv_out)
126
+ q = q_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2,
127
+ 4).contiguous()
128
+
129
+ k_conv_out = self.k_conv(x_for_qkv)
130
+ k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, N).contiguous()
131
+ k_conv_out = self.k_lif(k_conv_out)
132
+ k = k_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2,
133
+ 4).contiguous()
134
+
135
+ v_conv_out = self.v_conv(x_for_qkv)
136
+ v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, N).contiguous()
137
+ v_conv_out = self.v_lif(v_conv_out)
138
+ v = v_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2,
139
+ 4).contiguous()
140
+
141
+ x = k.transpose(-2, -1) @ v
142
+ x = (q @ x) * self.scale
143
+
144
+ x = x.transpose(3, 4).reshape(T, B, C, N).contiguous()
145
+ x = self.attn_lif(x)
146
+ x = x.flatten(0, 1)
147
+ x = self.proj_lif(self.proj_bn(self.proj_conv(x))).reshape(T, B, C, H, W)
148
+ return x
149
+
150
+ class TokenSpikingTransformer(nn.Module):
151
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
152
+ drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):
153
+ super().__init__()
154
+ self.tssa = Token_QK_Attention(dim, num_heads)
155
+ mlp_hidden_dim = int(dim * mlp_ratio)
156
+ self.mlp = MLP(in_features= dim, hidden_features=mlp_hidden_dim, drop=drop)
157
+
158
+ def forward(self, x):
159
+
160
+ x = x + self.tssa(x)
161
+ x = x + self.mlp(x)
162
+
163
+ return x
164
+
165
+ class SpikingTransformer(nn.Module):
166
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
167
+ drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):
168
+ super().__init__()
169
+ self.ssa = Spiking_Self_Attention(dim, num_heads)
170
+ mlp_hidden_dim = int(dim * mlp_ratio)
171
+ self.mlp = MLP(in_features= dim, hidden_features=mlp_hidden_dim, drop=drop)
172
+
173
+ def forward(self, x):
174
+
175
+ x = x + self.ssa(x)
176
+ x = x + self.mlp(x)
177
+
178
+ return x
179
+
180
+ class PatchEmbedInit(nn.Module):
181
+ def __init__(self, img_size_h=128, img_size_w=128, patch_size=4, in_channels=2, embed_dims=256):
182
+ super().__init__()
183
+ self.image_size = [img_size_h, img_size_w]
184
+ patch_size = to_2tuple(patch_size)
185
+ self.patch_size = patch_size
186
+ self.C = in_channels
187
+ self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1]
188
+ self.num_patches = self.H * self.W
189
+
190
+ self.proj_conv = nn.Conv2d(in_channels, embed_dims // 8, kernel_size=3, stride=1, padding=1, bias=False)
191
+ self.proj_bn = nn.BatchNorm2d(embed_dims // 8)
192
+ self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
193
+
194
+ self.proj1_conv = nn.Conv2d(embed_dims // 8, embed_dims // 4, kernel_size=3, stride=1, padding=1, bias=False)
195
+ self.proj1_bn = nn.BatchNorm2d(embed_dims // 4)
196
+ self.maxpool1 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
197
+ self.proj1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
198
+
199
+ self.proj2_conv = nn.Conv2d(embed_dims//4, embed_dims // 2, kernel_size=3, stride=1, padding=1, bias=False)
200
+ self.proj2_bn = nn.BatchNorm2d(embed_dims // 2)
201
+ self.maxpool2 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
202
+ self.proj2_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
203
+
204
+ self.proj3_conv = nn.Conv2d(embed_dims // 2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
205
+ self.proj3_bn = nn.BatchNorm2d(embed_dims)
206
+ self.maxpool3 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
207
+ self.proj3_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
208
+
209
+ self.proj_res_conv = nn.Conv2d(embed_dims // 4, embed_dims, kernel_size=1, stride=4, padding=0, bias=False)
210
+ self.proj_res_bn = nn.BatchNorm2d(embed_dims)
211
+ self.proj_res_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
212
+
213
+
214
+ def forward(self, x):
215
+ T, B, C, H, W = x.shape
216
+ # Downsampling + Res
217
+ # x_feat = x.flatten(0, 1)
218
+ x = self.proj_conv(x.flatten(0, 1))
219
+ x = self.proj_bn(x).reshape(T, B, -1, H, W)
220
+ x = self.proj_lif(x).flatten(0, 1).contiguous()
221
+
222
+ x = self.proj1_conv(x)
223
+ x = self.proj1_bn(x)
224
+ x = self.maxpool1(x)
225
+ _, _, H1, W1 = x.shape
226
+ x = x.reshape(T, B, -1, H1, W1).contiguous()
227
+ x = self.proj1_lif(x).flatten(0, 1).contiguous()
228
+
229
+ x_feat = x
230
+ x = self.proj2_conv(x)
231
+ x = self.proj2_bn(x)
232
+ x = self.maxpool2(x)
233
+ _, _, H2, W2 = x.shape
234
+ x = x.reshape(T, B, -1, H2, W2).contiguous()
235
+ x = self.proj2_lif(x).flatten(0, 1).contiguous()
236
+
237
+ x = self.proj3_conv(x)
238
+ x = self.proj3_bn(x)
239
+ x = self.maxpool3(x)
240
+ _, _, H3, W3 = x.shape
241
+ x = x.reshape(T, B, -1, H3, W3).contiguous()
242
+ x = self.proj3_lif(x)
243
+
244
+ x_feat = self.proj_res_conv(x_feat)
245
+ x_feat = self.proj_res_bn(x_feat)
246
+ _, _, Hres, Wres = x_feat.shape
247
+ x_feat = x_feat.reshape(T, B, -1, Hres, Wres).contiguous()
248
+ x_feat = self.proj_res_lif(x_feat)
249
+ x = x + x_feat # shortcut
250
+
251
+ return x
252
+
253
+ class PatchEmbeddingStage(nn.Module):
254
+ def __init__(self, img_size_h=128, img_size_w=128, patch_size=4, in_channels=2, embed_dims=256):
255
+ super().__init__()
256
+ self.image_size = [img_size_h, img_size_w]
257
+ patch_size = to_2tuple(patch_size)
258
+ self.patch_size = patch_size
259
+ self.C = in_channels
260
+ self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1]
261
+ self.num_patches = self.H * self.W
262
+
263
+ self.proj_conv = nn.Conv2d(embed_dims//2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
264
+ self.proj_bn = nn.BatchNorm2d(embed_dims)
265
+ self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
266
+
267
+ self.proj4_conv = nn.Conv2d(embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
268
+ self.proj4_bn = nn.BatchNorm2d(embed_dims)
269
+ self.proj4_maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
270
+ self.proj4_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
271
+
272
+ self.proj_res_conv = nn.Conv2d(embed_dims//2, embed_dims, kernel_size=1, stride=2, padding=0, bias=False)
273
+ self.proj_res_bn = nn.BatchNorm2d(embed_dims)
274
+ self.proj_res_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
275
+
276
+ def forward(self, x):
277
+ T, B, C, H, W = x.shape
278
+ # Downsampling + Res
279
+
280
+ x = x.flatten(0, 1).contiguous()
281
+ x_feat = x
282
+
283
+ x = self.proj_conv(x)
284
+ x = self.proj_bn(x).reshape(T, B, -1, H, W).contiguous()
285
+ x = self.proj_lif(x).flatten(0, 1).contiguous()
286
+
287
+ x = self.proj4_conv(x)
288
+ x = self.proj4_bn(x)
289
+ x = self.proj4_maxpool(x)
290
+ _, _, H4, W4 = x.shape
291
+ x = x.reshape(T, B, -1, H4, W4).contiguous()
292
+ x = self.proj4_lif(x)
293
+
294
+ x_feat = self.proj_res_conv(x_feat)
295
+ x_feat = self.proj_res_bn(x_feat)
296
+ _, _, Hres, Wres = x_feat.shape
297
+ x_feat = x_feat.reshape(T, B, -1, Hres, Wres).contiguous()
298
+ x_feat = self.proj_res_lif(x_feat)
299
+
300
+ x = x + x_feat # shortcut
301
+
302
+ return x
303
+
304
+
305
+ class vit_snn(nn.Module):
306
+ def __init__(self,
307
+ img_size_h=128, img_size_w=128, patch_size=16, in_channels=2, num_classes=11,
308
+ embed_dims=[64, 128, 256], num_heads=[1, 2, 4], mlp_ratios=[4, 4, 4], qkv_bias=False, qk_scale=None,
309
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
310
+ depths=[6, 8, 6], sr_ratios=[8, 4, 2], T=4, pretrained_cfg=None, in_chans = 3, no_weight_decay = None
311
+ ):
312
+ super().__init__()
313
+ self.num_classes = num_classes
314
+ self.depths = depths
315
+ self.T = T
316
+ num_heads = [16, 16, 16]
317
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)] # stochastic depth decay rule
318
+
319
+ #
320
+ patch_embed1 = PatchEmbedInit(img_size_h=img_size_h,
321
+ img_size_w=img_size_w,
322
+ patch_size=patch_size,
323
+ in_channels=in_channels,
324
+ embed_dims=embed_dims // 2)
325
+
326
+ stage1 = nn.ModuleList([TokenSpikingTransformer(
327
+ dim=embed_dims // 2, num_heads=num_heads[0], mlp_ratio=mlp_ratios, qkv_bias=qkv_bias,
328
+ qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j],
329
+ norm_layer=norm_layer, sr_ratio=sr_ratios)
330
+ for j in range(1)])
331
+
332
+
333
+ patch_embed2 = PatchEmbeddingStage(img_size_h=img_size_h,
334
+ img_size_w=img_size_w,
335
+ patch_size=patch_size,
336
+ in_channels=in_channels,
337
+ embed_dims=embed_dims)
338
+
339
+
340
+ stage2 = nn.ModuleList([SpikingTransformer(
341
+ dim=embed_dims, num_heads=num_heads[1], mlp_ratio=mlp_ratios, qkv_bias=qkv_bias,
342
+ qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j],
343
+ norm_layer=norm_layer, sr_ratio=sr_ratios)
344
+ for j in range(1)])
345
+
346
+
347
+ setattr(self, f"patch_embed1", patch_embed1)
348
+ setattr(self, f"stage1", stage1)
349
+ setattr(self, f"patch_embed2", patch_embed2)
350
+ setattr(self, f"stage2", stage2)
351
+
352
+
353
+ # classification head
354
+ self.head = nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity()
355
+ self.apply(self._init_weights)
356
+
357
+ @torch.jit.ignore
358
+ def no_weight_decay(self):
359
+ return {'pose_embed'}
360
+
361
+ @torch.jit.ignore
362
+ def _get_pos_embed(self, pos_embed, patch_embed, H, W):
363
+ return None
364
+
365
+ def _init_weights(self, m):
366
+ if isinstance(m, nn.Linear):
367
+ trunc_normal_(m.weight, std=.02)
368
+ if isinstance(m, nn.Linear) and m.bias is not None:
369
+ nn.init.constant_(m.bias, 0)
370
+ elif isinstance(m, nn.LayerNorm):
371
+ nn.init.constant_(m.bias, 0)
372
+ nn.init.constant_(m.weight, 1.0)
373
+
374
+ def forward_features(self, x):
375
+ stage1 = getattr(self, f"stage1")
376
+ patch_embed1 = getattr(self, f"patch_embed1")
377
+ stage2 = getattr(self, f"stage2")
378
+ patch_embed2 = getattr(self, f"patch_embed2")
379
+
380
+ x = patch_embed1(x)
381
+ for blk in stage1:
382
+ x = blk(x)
383
+
384
+ x = patch_embed2(x)
385
+ for blk in stage2:
386
+ x = blk(x)
387
+
388
+ return x.flatten(3).mean(3)
389
+
390
+ def forward(self, x):
391
+ x = x.permute(1, 0, 2, 3, 4) # [T, N, 2, *, *]
392
+ x = self.forward_features(x)
393
+ x = self.head(x.mean(0))
394
+ return x
395
+
396
+
397
+ @register_model
398
+ def QKFormer_1003(pretrained=False, **kwargs):
399
+ model = vit_snn(
400
+ patch_size=16, embed_dims=256, num_heads=16, mlp_ratios=1,
401
+ in_channels=2, num_classes=101, qkv_bias=False,
402
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=4, sr_ratios=1,
403
+ **kwargs
404
+ )
405
+ model.default_cfg = _cfg()
406
+ return model
407
+
408
+
409
+ from timm.models import create_model
410
+
411
+ if __name__ == '__main__':
412
+ x = torch.randn(1, 1, 2, 128, 128).cuda()
413
+ model = create_model(
414
+ 'QKFormer_1003',
415
+ pretrained=False,
416
+ drop_rate=0,
417
+ drop_path_rate=0.1,
418
+ drop_block_rate=None,
419
+ ).cuda()
420
+ model.eval()
421
+
422
+ from torchinfo import summary
423
+ summary(model, input_size=(1, 1, 2, 128, 128))
424
+ y = model(x)
425
+ print(y.shape)
426
+ print('Test Good!')
models/qk_model_with_delay/__init__.py ADDED
File without changes
models/qk_model_with_delay/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (221 Bytes). View file
 
models/qk_model_with_delay/__pycache__/delay_synaptic_func_inter.cpython-311.pyc ADDED
Binary file (11.3 kB). View file
 
models/qk_model_with_delay/__pycache__/delay_synaptic_inter_model.cpython-311.pyc ADDED
Binary file (30.3 kB). View file
 
models/qk_model_with_delay/delay_synaptic_func_inter.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+
6
+ def set_sigma_for_DCLS(model, s):
7
+ for name, module in model.named_modules():
8
+ if module.__class__.__name__ == 'DelayConv':
9
+ if hasattr(module, 'sigma'):
10
+ module.sigma = s
11
+ print('Set sigma to ',s)
12
+
13
+ class DropoutNd(nn.Module):
14
+ def __init__(self, p: float = 0.5, tie=True, transposed=True):
15
+ """
16
+ tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d)
17
+ """
18
+ super().__init__()
19
+ if p < 0 or p >= 1:
20
+ raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p))
21
+ self.p = p
22
+ self.tie = tie
23
+ self.transposed = transposed
24
+ self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p)
25
+
26
+ def forward(self, X):
27
+ """X: (batch, dim, lengths...)."""
28
+ if self.training:
29
+ if not self.transposed: X = rearrange(X, 'b ... d -> b d ...')
30
+ # binomial = torch.distributions.binomial.Binomial(probs=1-self.p) # This is incredibly slow because of CPU -> GPU copying
31
+ mask_shape = X.shape[:2] + (1,) * (X.ndim - 2) if self.tie else X.shape
32
+ # mask = self.binomial.sample(mask_shape)
33
+ mask = torch.rand(*mask_shape, device=X.device) < 1. - self.p
34
+ X = X * mask * (1.0 / (1 - self.p))
35
+ if not self.transposed: X = rearrange(X, 'b d ... -> b ... d')
36
+ return X
37
+ return X
38
+
39
+ class DelayConv(nn.Module):
40
+ def __init__(
41
+ self,
42
+ in_c,
43
+ k,
44
+ dropout=0.0,
45
+ n_delay=1,
46
+ dilation=1,
47
+ kernel_type='triangle_r_temp'
48
+ ):
49
+ super().__init__()
50
+ self.C = in_c # 输入和输出通道数
51
+ self.win_len = k
52
+ self.dilation = dilation
53
+ self.n_delay = n_delay
54
+ self.kernel_type = kernel_type
55
+
56
+ self.t = torch.arange(self.win_len).float().unsqueeze(0) # [1, k]
57
+ self.sigma = self.win_len // 2
58
+
59
+ self.delay_kernel = None
60
+ self.bump = None
61
+
62
+ # ========== 修改:d 形状 -> [C_out, C_in, n_delay] ==========
63
+ d = torch.rand(self.C, self.C, self.n_delay)
64
+ with torch.no_grad():
65
+ for co in range(self.C):
66
+ for ci in range(self.C):
67
+ d[co, ci, :] = torch.randperm(self.win_len - 2)[:self.n_delay] + 1
68
+ self.register("d", d, lr=1e-2)
69
+
70
+ # 初始化权重: [C_out, C_in, k]
71
+ weight = torch.ones([self.C, self.C, k])
72
+ with torch.no_grad():
73
+ for co in range(self.C): # output channel
74
+ for ci in range(self.C): # input channel
75
+ for i in range(k - 2, -1, -1):
76
+ weight[co, ci, i] = weight[co, ci, i + 1] / 2
77
+
78
+ self.weight = nn.Parameter(weight)
79
+
80
+ self.dropout = nn.Dropout(dropout / 5) if dropout > 0.0 else nn.Identity()
81
+
82
+ def register(self, name, tensor, lr=None):
83
+ """注册可训练或固定参数"""
84
+ if lr == 0.0:
85
+ self.register_buffer(name, tensor)
86
+ else:
87
+ self.register_parameter(name, nn.Parameter(tensor))
88
+ optim = {"weight_decay": 0}
89
+ if lr is not None:
90
+ optim["lr"] = lr
91
+ setattr(getattr(self, name), "_optim", optim)
92
+
93
+ def update_kernel(self, device):
94
+ """
95
+ 输出 delay kernel: shape [C_out, C_in, k]
96
+ """
97
+ t = self.t.to(device).view(1, 1, 1, -1) # [1,1,1,k]
98
+ d = self.d.to(device) # [C_out, C_in, n_delay]
99
+
100
+ # ---------- 计算 bump ----------
101
+ if self.kernel_type == 'gauss':
102
+ bump = torch.exp(-0.5 * ((t - self.win_len + d.unsqueeze(-1) + 1) / self.sigma) ** 2)
103
+ bump = (bump - 1e-3).relu() + 1e-3
104
+ bump = bump / (bump.sum(dim=-1, keepdim=True) + 1e-7)
105
+
106
+ elif self.kernel_type == 'triangle':
107
+ bump = torch.relu(1 - torch.abs((t - self.win_len + d.unsqueeze(-1) + 1) / self.sigma))
108
+ bump = bump / (bump.sum(dim=-1, keepdim=True).detach() + 1e-7)
109
+
110
+ elif self.kernel_type == 'triangle_r':
111
+ d_int = (d.round() - d).detach() + d
112
+ bump = torch.relu(1 - torch.abs((t - self.win_len + d_int.unsqueeze(-1) + 1) / self.sigma))
113
+ bump = bump / (bump.sum(dim=-1, keepdim=True).detach() + 1e-7)
114
+
115
+ elif self.kernel_type == 'triangle_r_temp':
116
+ scale = min(1.0, 1.0 / self.sigma)
117
+ d_int = (d.round() - d).detach() * scale + d
118
+ bump = torch.relu(1 - torch.abs((t - self.win_len + d_int.unsqueeze(-1) + 1) / self.sigma))
119
+ bump = bump / (bump.sum(dim=-1, keepdim=True).detach() + 1e-7) # [C_out, C_in, n_delay, k]
120
+ # ------ 在eval模式硬化bump ------
121
+ if not self.training:
122
+ max_idx = bump.argmax(dim=-1, keepdim=True) # 找最大值索引
123
+ hard_mask = torch.zeros_like(bump)
124
+ hard_mask.scatter_(-1, max_idx, 1.0)
125
+ bump = bump * hard_mask
126
+ # --------------------------------
127
+ else:
128
+ raise ValueError(f"Unknown kernel_type: {self.kernel_type}")
129
+
130
+ # bump: [C_out, C_in, n_delay, k]
131
+ self.bump = bump.detach().clone().to(device)
132
+
133
+ # ---------- 沿 n_delay 维度求和: [C_out, C_in, k] ----------
134
+ bump_sum = bump.sum(dim=2)
135
+
136
+ # ---------- 生成最终卷积核 ----------
137
+ # weight: [C_out, C_in, k]
138
+ self.delay_kernel = (self.weight * bump_sum).to(device) # [C_out, C_in, k]
139
+
140
+ def forward(self, x):
141
+ """
142
+ x: (T, B, N, C)
143
+ return: (T*B, C, N)
144
+ """
145
+ # 调整维度
146
+ x = x.permute(0, 1, 3, 2).contiguous() # (T, B, N, C)
147
+ T, B, N, C = x.shape
148
+ assert C == self.C, f"Input channel mismatch: {C} vs {self.C}"
149
+ x = x.permute(1, 2, 3, 0).contiguous() # (B, N, C, T)
150
+
151
+ # 合并 B*N 作为 batch
152
+ x_reshaped = x.view(B * N, C, T) # (B*N, C, T)
153
+ device = x.device
154
+
155
+ # 更新 kernel
156
+ self.update_kernel(device) # -> [C_out, C_in, k]
157
+ kernel = self.delay_kernel
158
+
159
+ # padding
160
+ pad_left = (self.win_len - 1) * self.dilation
161
+ x_padded = F.pad(x_reshaped, (pad_left, 0)) # (B*N, C, T+pad)
162
+
163
+ # 全通道卷积: groups=1 (跨通道交互)
164
+ y = F.conv1d(x_padded, kernel, stride=1, dilation=self.dilation, groups=1) # (B*N, C, T)
165
+
166
+ # 还原到原始形状
167
+ y = y.view(B, N, C, T).permute(3, 0, 2, 1).contiguous().view(-1, C, N) # (T*B, C, N)
168
+
169
+ return self.dropout(y)
models/qk_model_with_delay/delay_synaptic_inter_model.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from spikingjelly.clock_driven.neuron import MultiStepParametricLIFNode, MultiStepLIFNode
4
+ from timm.models.layers import to_2tuple, trunc_normal_, DropPath
5
+ from timm.models.registry import register_model
6
+ from timm.models.vision_transformer import _cfg
7
+ from functools import partial
8
+ from timm.models import create_model
9
+ from .delay_synaptic_func_inter import DelayConv
10
+
11
+ __all__ = ['delay_QKFormer']
12
+
13
+ class MLP(nn.Module):
14
+ def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
15
+ super().__init__()
16
+ out_features = out_features or in_features
17
+ hidden_features = hidden_features or in_features
18
+ self.mlp1_conv = nn.Conv2d(in_features, hidden_features, kernel_size=1, stride=1)
19
+ self.mlp1_bn = nn.BatchNorm2d(hidden_features)
20
+ self.mlp1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
21
+
22
+ self.mlp2_conv = nn.Conv2d(hidden_features, out_features, kernel_size=1, stride=1)
23
+ self.mlp2_bn = nn.BatchNorm2d(out_features)
24
+ self.mlp2_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
25
+
26
+ self.c_hidden = hidden_features
27
+ self.c_output = out_features
28
+
29
+ def forward(self, x):
30
+ T, B, C, H, W = x.shape
31
+
32
+ x = self.mlp1_conv(x.flatten(0, 1))
33
+ x = self.mlp1_bn(x).reshape(T, B, self.c_hidden, H, W)
34
+ x = self.mlp1_lif(x)
35
+
36
+ x = self.mlp2_conv(x.flatten(0, 1))
37
+ x = self.mlp2_bn(x).reshape(T, B, C, H, W)
38
+ x = self.mlp2_lif(x)
39
+ return x
40
+
41
+ class Token_QK_Attention(nn.Module):
42
+ def __init__(self,
43
+ dim,
44
+ num_heads=8,
45
+ qkv_bias=False,
46
+ qk_scale=None,
47
+ attn_drop=0.,
48
+ proj_drop=0.,
49
+ sr_ratio=1,
50
+ k=16):
51
+ super().__init__()
52
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
53
+
54
+ self.dim = dim
55
+ self.num_heads = num_heads
56
+
57
+ self.q_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
58
+ self.q_bn = nn.BatchNorm1d(dim)
59
+ self.q_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
60
+
61
+ # self.k_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
62
+ self.k_proj_delay = DelayConv(in_c=self.dim, k=k)
63
+ self.k_bn = nn.BatchNorm1d(dim)
64
+ self.k_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
65
+
66
+ self.attn_lif = MultiStepLIFNode(tau=2.0, v_threshold=0.5, detach_reset=True, backend='cupy')
67
+
68
+ self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1)
69
+ self.proj_bn = nn.BatchNorm1d(dim)
70
+ self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
71
+
72
+ def forward(self, x):
73
+ T, B, C, H, W = x.shape
74
+
75
+ x = x.flatten(3)
76
+ T, B, C, N = x.shape
77
+ x_for_qkv = x.flatten(0, 1)
78
+
79
+ q_conv_out = self.q_conv(x_for_qkv)
80
+ q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, N)
81
+ q_conv_out = self.q_lif(q_conv_out)
82
+ q = q_conv_out.unsqueeze(2).reshape(T, B, self.num_heads, C // self.num_heads, N)
83
+
84
+ # k_conv_out = self.k_conv(x_for_qkv)
85
+ k_conv_out = self.k_proj_delay(x_for_qkv.reshape(T,B,C,N))
86
+ k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, N)
87
+ k_conv_out = self.k_lif(k_conv_out)
88
+ k = k_conv_out.unsqueeze(2).reshape(T, B, self.num_heads, C // self.num_heads, N)
89
+
90
+ q = torch.sum(q, dim=3, keepdim=True)
91
+ attn = self.attn_lif(q)
92
+ x = torch.mul(attn, k)
93
+
94
+ x = x.flatten(2, 3)
95
+ x = self.proj_bn(self.proj_conv(x.flatten(0, 1))).reshape(T, B, C, H, W)
96
+ x = self.proj_lif(x)
97
+ return x
98
+
99
+ class Spiking_Self_Attention(nn.Module):
100
+ def __init__(self,
101
+ dim,
102
+ num_heads=8,
103
+ qkv_bias=False,
104
+ qk_scale=None,
105
+ attn_drop=0.,
106
+ proj_drop=0.,
107
+ sr_ratio=1,
108
+ k=16):
109
+ super().__init__()
110
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
111
+ self.dim = dim
112
+ self.num_heads = num_heads
113
+ head_dim = dim // num_heads
114
+ self.scale = 0.125
115
+ self.q_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
116
+ self.q_bn = nn.BatchNorm1d(dim)
117
+ self.q_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
118
+
119
+ # self.k_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
120
+ self.k_proj_delay = DelayConv(in_c=self.dim, k=k)
121
+ self.k_bn = nn.BatchNorm1d(dim)
122
+ self.k_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
123
+
124
+ # self.v_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
125
+ self.v_proj_delay = DelayConv(in_c=self.dim, k=k)
126
+ self.v_bn = nn.BatchNorm1d(dim)
127
+ self.v_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
128
+ self.attn_lif = MultiStepLIFNode(tau=2.0, v_threshold=0.5, detach_reset=True, backend='cupy')
129
+
130
+ self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1)
131
+ self.proj_bn = nn.BatchNorm1d(dim)
132
+ self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
133
+
134
+ self.qkv_mp = nn.MaxPool1d(4)
135
+
136
+ def forward(self, x):
137
+ T, B, C, H, W = x.shape
138
+
139
+ x = x.flatten(3)
140
+ T, B, C, N = x.shape
141
+ x_for_qkv = x.flatten(0, 1)
142
+
143
+ q_conv_out = self.q_conv(x_for_qkv)
144
+ q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, N).contiguous()
145
+ q_conv_out = self.q_lif(q_conv_out)
146
+ q = q_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2,
147
+ 4).contiguous()
148
+
149
+ k_conv_out = self.k_proj_delay(x_for_qkv.reshape(T,B,C,N))
150
+ k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, N).contiguous()
151
+ k_conv_out = self.k_lif(k_conv_out)
152
+ k = k_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2,
153
+ 4).contiguous()
154
+
155
+ v_conv_out = self.v_proj_delay(x_for_qkv.reshape(T,B,C,N))
156
+ v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, N).contiguous()
157
+ v_conv_out = self.v_lif(v_conv_out)
158
+ v = v_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2,
159
+ 4).contiguous()
160
+
161
+ x = k.transpose(-2, -1) @ v
162
+ x = (q @ x) * self.scale
163
+
164
+ x = x.transpose(3, 4).reshape(T, B, C, N).contiguous()
165
+ x = self.attn_lif(x)
166
+ x = x.flatten(0, 1)
167
+ x = self.proj_lif(self.proj_bn(self.proj_conv(x))).reshape(T, B, C, H, W)
168
+ return x
169
+
170
+ class TokenSpikingTransformer(nn.Module):
171
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
172
+ drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):
173
+ super().__init__()
174
+ self.tssa = Token_QK_Attention(dim, num_heads)
175
+ mlp_hidden_dim = int(dim * mlp_ratio)
176
+ self.mlp = MLP(in_features= dim, hidden_features=mlp_hidden_dim, drop=drop)
177
+
178
+ def forward(self, x):
179
+
180
+ x = x + self.tssa(x)
181
+ x = x + self.mlp(x)
182
+
183
+ return x
184
+
185
+ class SpikingTransformer(nn.Module):
186
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
187
+ drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):
188
+ super().__init__()
189
+ self.ssa = Spiking_Self_Attention(dim, num_heads)
190
+ mlp_hidden_dim = int(dim * mlp_ratio)
191
+ self.mlp = MLP(in_features= dim, hidden_features=mlp_hidden_dim, drop=drop)
192
+
193
+ def forward(self, x):
194
+
195
+ x = x + self.ssa(x)
196
+ x = x + self.mlp(x)
197
+
198
+ return x
199
+
200
+ class PatchEmbedInit(nn.Module):
201
+ def __init__(self, img_size_h=128, img_size_w=128, patch_size=4, in_channels=2, embed_dims=256):
202
+ super().__init__()
203
+ self.image_size = [img_size_h, img_size_w]
204
+ patch_size = to_2tuple(patch_size)
205
+ self.patch_size = patch_size
206
+ self.C = in_channels
207
+ self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1]
208
+ self.num_patches = self.H * self.W
209
+
210
+ self.proj_conv = nn.Conv2d(in_channels, embed_dims // 8, kernel_size=3, stride=1, padding=1, bias=False)
211
+ self.proj_bn = nn.BatchNorm2d(embed_dims // 8)
212
+ self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
213
+
214
+ self.proj1_conv = nn.Conv2d(embed_dims // 8, embed_dims // 4, kernel_size=3, stride=1, padding=1, bias=False)
215
+ self.proj1_bn = nn.BatchNorm2d(embed_dims // 4)
216
+ self.maxpool1 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
217
+ self.proj1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
218
+
219
+ self.proj2_conv = nn.Conv2d(embed_dims//4, embed_dims // 2, kernel_size=3, stride=1, padding=1, bias=False)
220
+ self.proj2_bn = nn.BatchNorm2d(embed_dims // 2)
221
+ self.maxpool2 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
222
+ self.proj2_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
223
+
224
+ self.proj3_conv = nn.Conv2d(embed_dims // 2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
225
+ self.proj3_bn = nn.BatchNorm2d(embed_dims)
226
+ self.maxpool3 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
227
+ self.proj3_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
228
+
229
+ self.proj_res_conv = nn.Conv2d(embed_dims // 4, embed_dims, kernel_size=1, stride=4, padding=0, bias=False)
230
+ self.proj_res_bn = nn.BatchNorm2d(embed_dims)
231
+ self.proj_res_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
232
+
233
+
234
+ def forward(self, x):
235
+ T, B, C, H, W = x.shape
236
+ # Downsampling + Res
237
+ # x_feat = x.flatten(0, 1)
238
+ x = self.proj_conv(x.flatten(0, 1))
239
+ x = self.proj_bn(x).reshape(T, B, -1, H, W)
240
+ x = self.proj_lif(x).flatten(0, 1).contiguous()
241
+
242
+ x = self.proj1_conv(x)
243
+ x = self.proj1_bn(x)
244
+ x = self.maxpool1(x)
245
+ _, _, H1, W1 = x.shape
246
+ x = x.reshape(T, B, -1, H1, W1).contiguous()
247
+ x = self.proj1_lif(x).flatten(0, 1).contiguous()
248
+
249
+ x_feat = x
250
+ x = self.proj2_conv(x)
251
+ x = self.proj2_bn(x)
252
+ x = self.maxpool2(x)
253
+ _, _, H2, W2 = x.shape
254
+ x = x.reshape(T, B, -1, H2, W2).contiguous()
255
+ x = self.proj2_lif(x).flatten(0, 1).contiguous()
256
+
257
+ x = self.proj3_conv(x)
258
+ x = self.proj3_bn(x)
259
+ x = self.maxpool3(x)
260
+ _, _, H3, W3 = x.shape
261
+ x = x.reshape(T, B, -1, H3, W3).contiguous()
262
+ x = self.proj3_lif(x)
263
+
264
+ x_feat = self.proj_res_conv(x_feat)
265
+ x_feat = self.proj_res_bn(x_feat)
266
+ _, _, Hres, Wres = x_feat.shape
267
+ x_feat = x_feat.reshape(T, B, -1, Hres, Wres).contiguous()
268
+ x_feat = self.proj_res_lif(x_feat)
269
+ x = x + x_feat # shortcut
270
+
271
+ return x
272
+
273
+ class PatchEmbeddingStage(nn.Module):
274
+ def __init__(self, img_size_h=128, img_size_w=128, patch_size=4, in_channels=2, embed_dims=256):
275
+ super().__init__()
276
+ self.image_size = [img_size_h, img_size_w]
277
+ patch_size = to_2tuple(patch_size)
278
+ self.patch_size = patch_size
279
+ self.C = in_channels
280
+ self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1]
281
+ self.num_patches = self.H * self.W
282
+
283
+ self.proj_conv = nn.Conv2d(embed_dims//2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
284
+ self.proj_bn = nn.BatchNorm2d(embed_dims)
285
+ self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
286
+
287
+ self.proj4_conv = nn.Conv2d(embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
288
+ self.proj4_bn = nn.BatchNorm2d(embed_dims)
289
+ self.proj4_maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
290
+ self.proj4_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
291
+
292
+ self.proj_res_conv = nn.Conv2d(embed_dims//2, embed_dims, kernel_size=1, stride=2, padding=0, bias=False)
293
+ self.proj_res_bn = nn.BatchNorm2d(embed_dims)
294
+ self.proj_res_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
295
+
296
+ def forward(self, x):
297
+ T, B, C, H, W = x.shape
298
+ # Downsampling + Res
299
+
300
+ x = x.flatten(0, 1).contiguous()
301
+ x_feat = x
302
+
303
+ x = self.proj_conv(x)
304
+ x = self.proj_bn(x).reshape(T, B, -1, H, W).contiguous()
305
+ x = self.proj_lif(x).flatten(0, 1).contiguous()
306
+
307
+ x = self.proj4_conv(x)
308
+ x = self.proj4_bn(x)
309
+ x = self.proj4_maxpool(x)
310
+ _, _, H4, W4 = x.shape
311
+ x = x.reshape(T, B, -1, H4, W4).contiguous()
312
+ x = self.proj4_lif(x)
313
+
314
+ x_feat = self.proj_res_conv(x_feat)
315
+ x_feat = self.proj_res_bn(x_feat)
316
+ _, _, Hres, Wres = x_feat.shape
317
+ x_feat = x_feat.reshape(T, B, -1, Hres, Wres).contiguous()
318
+ x_feat = self.proj_res_lif(x_feat)
319
+
320
+ x = x + x_feat # shortcut
321
+
322
+ return x
323
+
324
+
325
+ class vit_snn(nn.Module):
326
+ def __init__(self,
327
+ img_size_h=128, img_size_w=128, patch_size=16, in_channels=2, num_classes=11,
328
+ embed_dims=[64, 128, 256], num_heads=[1, 2, 4], mlp_ratios=[4, 4, 4], qkv_bias=False, qk_scale=None,
329
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
330
+ depths=[6, 8, 6], sr_ratios=[8, 4, 2], T=4, pretrained_cfg=None, in_chans = 3, no_weight_decay = None
331
+ ):
332
+ super().__init__()
333
+ self.num_classes = num_classes
334
+ self.depths = depths
335
+ self.T = T
336
+ num_heads = [16, 16, 16]
337
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)] # stochastic depth decay rule
338
+
339
+ #
340
+ patch_embed1 = PatchEmbedInit(img_size_h=img_size_h,
341
+ img_size_w=img_size_w,
342
+ patch_size=patch_size,
343
+ in_channels=in_channels,
344
+ embed_dims=embed_dims // 2)
345
+
346
+ stage1 = nn.ModuleList([TokenSpikingTransformer(
347
+ dim=embed_dims // 2, num_heads=num_heads[0], mlp_ratio=mlp_ratios, qkv_bias=qkv_bias,
348
+ qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j],
349
+ norm_layer=norm_layer, sr_ratio=sr_ratios)
350
+ for j in range(1)])
351
+
352
+
353
+ patch_embed2 = PatchEmbeddingStage(img_size_h=img_size_h,
354
+ img_size_w=img_size_w,
355
+ patch_size=patch_size,
356
+ in_channels=in_channels,
357
+ embed_dims=embed_dims)
358
+
359
+
360
+ stage2 = nn.ModuleList([SpikingTransformer(
361
+ dim=embed_dims, num_heads=num_heads[1], mlp_ratio=mlp_ratios, qkv_bias=qkv_bias,
362
+ qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j],
363
+ norm_layer=norm_layer, sr_ratio=sr_ratios)
364
+ for j in range(1)])
365
+
366
+
367
+ setattr(self, f"patch_embed1", patch_embed1)
368
+ setattr(self, f"stage1", stage1)
369
+ setattr(self, f"patch_embed2", patch_embed2)
370
+ setattr(self, f"stage2", stage2)
371
+
372
+
373
+ # classification head
374
+ self.head = nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity()
375
+ self.apply(self._init_weights)
376
+
377
+ @torch.jit.ignore
378
+ def no_weight_decay(self):
379
+ return {'pose_embed'}
380
+
381
+ @torch.jit.ignore
382
+ def _get_pos_embed(self, pos_embed, patch_embed, H, W):
383
+ return None
384
+
385
+ def _init_weights(self, m):
386
+ if isinstance(m, nn.Linear):
387
+ trunc_normal_(m.weight, std=.02)
388
+ if isinstance(m, nn.Linear) and m.bias is not None:
389
+ nn.init.constant_(m.bias, 0)
390
+ elif isinstance(m, nn.LayerNorm):
391
+ nn.init.constant_(m.bias, 0)
392
+ nn.init.constant_(m.weight, 1.0)
393
+
394
+ def forward_features(self, x):
395
+ stage1 = getattr(self, f"stage1")
396
+ patch_embed1 = getattr(self, f"patch_embed1")
397
+ stage2 = getattr(self, f"stage2")
398
+ patch_embed2 = getattr(self, f"patch_embed2")
399
+
400
+ x = patch_embed1(x)
401
+ for blk in stage1:
402
+ x = blk(x)
403
+
404
+ x = patch_embed2(x)
405
+ for blk in stage2:
406
+ x = blk(x)
407
+
408
+ return x.flatten(3).mean(3)
409
+
410
+ def forward(self, x):
411
+ x = x.permute(1, 0, 2, 3, 4) # [T, N, 2, *, *]
412
+ # print("torch.unique", torch.unique(x))
413
+ # print("torch.count_nonzero", torch.count_nonzero(x))
414
+ # print("numel()", x.numel())
415
+ x = self.forward_features(x)
416
+ x = self.head(x.mean(0))
417
+ return x
418
+
419
+
420
+ @register_model
421
+ def delay_QKFormer(pretrained=False, **kwargs):
422
+ model = vit_snn(
423
+ patch_size=16, embed_dims=256, num_heads=16, mlp_ratios=4,
424
+ in_channels=2, num_classes=101, qkv_bias=False,
425
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=4, sr_ratios=1,
426
+ **kwargs
427
+ )
428
+ model.default_cfg = _cfg()
429
+ return model
430
+
431
+
432
+ from timm.models import create_model
433
+
434
+ if __name__ == '__main__':
435
+ x = torch.randn(1, 1, 2, 128, 128).cuda()
436
+ model = create_model(
437
+ 'delay_QKFormer',
438
+ pretrained=False,
439
+ drop_rate=0,
440
+ drop_path_rate=0.1,
441
+ drop_block_rate=None,
442
+ ).cuda()
443
+ model.eval()
444
+
445
+ from torchinfo import summary
446
+ summary(model, input_size=(1, 1, 2, 128, 128))
447
+ # y = model(x)
448
+ # print(y.shape)
449
+ # print('Test Good!')
450
+
451
+
452
+
453
+
454
+
455
+
456
+
457
+
458
+
459
+