sosigikiller commited on
Commit
5df9707
·
1 Parent(s): 760e62e

initial push

Browse files
example.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from model.ResNet18 import ResNet18
3
+ from model.CSAT import CSAT
4
+ from model.CSATv2 import CSATv2
5
+ from torch import nn
6
+
7
+ img_size = 224
8
+ path = r'./weight/CSAT_ImageNet.pth.tar' #or CSAT_RCKD.pth.tar <- for pathological image analysis
9
+ model = CSAT(img_size=img_size)
10
+ state = torch.load(path, map_location='cpu')
11
+ model.load_state_dict(state)
12
+ data = torch.zeros((1, 3, img_size, img_size)) #b, c, h, w = 1, 3, 224, 224
13
+ model.head = nn.Identity()
14
+ output = model(data)#b, c = 1, 176
15
+ print(output.shape)
16
+
17
+ path = r'./weight/ResNet18_RCKD.pth.tar'
18
+ model = ResNet18()
19
+ state = torch.load(path, map_location='cpu')
20
+ model.load_state_dict(state)
21
+ data = torch.zeros((1, 3, img_size, img_size)) #b, c, h, w = 1, 3, 224, 224
22
+ model.fc = nn.Identity()
23
+ output = model(data)#b, c = 1, 512
24
+ print(output.shape)
25
+
26
+ path = r'./weight/CSAT_v2_ImageNet.pth.tar'
27
+ model = CSATv2(img_size=img_size)
28
+ state = torch.load(path, map_location='cpu')
29
+ model.load_state_dict(state['state_dict'])
30
+ data = torch.zeros((1, 3, img_size, img_size)) #b, c, h, w = 1, 3, 224, 224
31
+ model.fc = nn.Identity()
32
+ output = model(data)#b, c = 1, 512
33
+ print(output.shape)
model/CSAT.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from einops.layers.torch import Rearrange
4
+ from torch.nn.functional import softmax, sigmoid
5
+
6
+ class Block(nn.Module):
7
+ """ ConvNeXtV2 Block.
8
+
9
+ Args:
10
+ dim (int): Number of input channels.
11
+ drop_path (float): Stochastic depth rate. Default: 0.0
12
+ """
13
+
14
+ def __init__(self, dim, drop_path=0., img_size=None):
15
+ super().__init__()
16
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
17
+ self.norm = LayerNorm(dim, eps=1e-6)
18
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
19
+ self.act = nn.GELU()
20
+ self.grn = GRN(4 * dim)
21
+ self.pwconv2 = nn.Linear(4 * dim, dim)
22
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
23
+ self.attention = Spatial_Attention()
24
+ def forward(self, x):
25
+ input = x
26
+ x = self.dwconv(x)
27
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
28
+ x = self.norm(x)
29
+ x = self.pwconv1(x)
30
+ x = self.act(x)
31
+ x = self.grn(x)
32
+ x = self.pwconv2(x)
33
+
34
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
35
+ attention = self.attention(x)
36
+ x = x * nn.UpsamplingBilinear2d(x.shape[2:])(attention)
37
+ x = input + self.drop_path(x)
38
+ return x
39
+
40
+ class Spatial_Attention(nn.Module):
41
+ def __init__(self):
42
+ super().__init__()
43
+ self.avgpool = nn.AdaptiveAvgPool2d((7,7))
44
+ self.conv = nn.Conv2d(2,1, kernel_size=7, padding=3)
45
+ self.attention = TransformerBlock(1, 1, heads=1, dim_head=1, img_size=[7,7])
46
+
47
+ def forward(self, x):
48
+ x_avg = x.mean([1]).unsqueeze(1)
49
+ x_max = x.max(dim=1).values.unsqueeze(1)
50
+ # x = torch.concat([x_avg,x_max],dim=1)
51
+ x = torch.cat([x_avg, x_max], dim=1)
52
+ x = self.avgpool(x)
53
+ x = self.conv(x)
54
+ x = self.attention(x)
55
+ return x
56
+
57
+ class TransformerBlock(nn.Module):
58
+ def __init__(self, inp, oup, heads=8, dim_head=32, img_size=None, downsample=False, dropout=0.):
59
+ super().__init__()
60
+ hidden_dim = int(inp * 4)
61
+
62
+ self.downsample = downsample
63
+ self.ih, self.iw = img_size
64
+
65
+ if self.downsample:
66
+ self.pool1 = nn.MaxPool2d(3, 2, 1)
67
+ self.pool2 = nn.MaxPool2d(3, 2, 1)
68
+ self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)
69
+
70
+ self.attn = Attention(inp, oup, heads, dim_head, dropout)
71
+ self.ff = FeedForward(oup, hidden_dim, dropout)
72
+
73
+ self.attn = nn.Sequential(
74
+ Rearrange('b c ih iw -> b (ih iw) c'),
75
+ PreNorm(inp, self.attn, nn.LayerNorm),
76
+ Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
77
+ )
78
+
79
+ self.ff = nn.Sequential(
80
+ Rearrange('b c ih iw -> b (ih iw) c'),
81
+ PreNorm(oup, self.ff, nn.LayerNorm),
82
+ Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
83
+ )
84
+
85
+ def forward(self, x):
86
+ if self.downsample:
87
+ x = self.proj(self.pool1(x)) + self.attn(self.pool2(x))
88
+ else:
89
+ x = x + self.attn(x)
90
+ x = x + self.ff(x)
91
+ return x
92
+
93
+
94
+ class CSAT(nn.Module):
95
+ def __init__(self,
96
+ img_size=384,
97
+ num_classes=1000,
98
+ drop_path_rate=0,
99
+ head_init_scale=1,
100
+ weight = None
101
+ ):
102
+ super().__init__()
103
+ dims = [32, 48, 96, 176]
104
+ channel_order = "channels_first"
105
+ depths = [2, 2, 6, 4]
106
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
107
+
108
+ self.stem = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=dims[0], kernel_size=4, stride=4),
109
+ LayerNorm(normalized_shape=dims[0], data_format=channel_order))
110
+
111
+ self.stages1 = nn.Sequential(
112
+ Block(dim=dims[0], drop_path=dp_rates[0], img_size=[int(img_size / 4), int(img_size / 4)]),
113
+ Block(dim=dims[0], drop_path=dp_rates[1], img_size=[int(img_size / 4), int(img_size / 4)]),
114
+ LayerNorm(dims[0], eps=1e-6, data_format=channel_order),
115
+ nn.Conv2d(dims[0], dims[0 + 1], kernel_size=2, stride=2),
116
+ )
117
+
118
+ self.stages2 = nn.Sequential(
119
+ Block(dim=dims[1], drop_path=dp_rates[0], img_size=[int(img_size / 8), int(img_size / 8)]),
120
+ Block(dim=dims[1], drop_path=dp_rates[1], img_size=[int(img_size / 8), int(img_size / 8)]),
121
+ LayerNorm(dims[1], eps=1e-6, data_format=channel_order),
122
+ nn.Conv2d(dims[1], dims[1 + 1], kernel_size=2, stride=2),
123
+ )
124
+
125
+ self.stages3 = nn.Sequential(
126
+ Block(dim=dims[2], drop_path=dp_rates[0], img_size=[int(img_size / 16), int(img_size / 16)]),
127
+ Block(dim=dims[2], drop_path=dp_rates[1], img_size=[int(img_size / 16), int(img_size / 16)]),
128
+ Block(dim=dims[2], drop_path=dp_rates[2], img_size=[int(img_size / 16), int(img_size / 16)]),
129
+ Block(dim=dims[2], drop_path=dp_rates[3], img_size=[int(img_size / 16), int(img_size / 16)]),
130
+ Block(dim=dims[2], drop_path=dp_rates[4], img_size=[int(img_size / 16), int(img_size / 16)]),
131
+ Block(dim=dims[2], drop_path=dp_rates[5], img_size=[int(img_size / 16), int(img_size / 16)]),
132
+ TransformerBlock(inp=dims[2], oup=dims[2], img_size=[int(img_size / 16), int(img_size / 16)]),
133
+ TransformerBlock(inp=dims[2], oup=dims[2], img_size=[int(img_size / 16), int(img_size / 16)]),
134
+ LayerNorm(dims[2], eps=1e-6, data_format=channel_order),
135
+ nn.Conv2d(dims[2], dims[2 + 1], kernel_size=2, stride=2),
136
+ )
137
+
138
+ self.stages4 = nn.Sequential(
139
+ Block(dim=dims[3], drop_path=dp_rates[0], img_size=[int(img_size / 32), int(img_size / 32)]),
140
+ Block(dim=dims[3], drop_path=dp_rates[1], img_size=[int(img_size / 32), int(img_size / 32)]),
141
+ Block(dim=dims[3], drop_path=dp_rates[2], img_size=[int(img_size / 32), int(img_size / 32)]),
142
+ Block(dim=dims[3], drop_path=dp_rates[3], img_size=[int(img_size / 32), int(img_size / 32)]),
143
+ TransformerBlock(inp=dims[3], oup=dims[3], img_size=[int(img_size / 32), int(img_size / 32)]),
144
+ TransformerBlock(inp=dims[3], oup=dims[3], img_size=[int(img_size / 32), int(img_size / 32)]),
145
+ )
146
+
147
+ self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
148
+ self.head = nn.Linear(dims[-1], num_classes)
149
+
150
+ self.apply(self._init_weights)
151
+ self.head.weight.data.mul_(head_init_scale)
152
+ self.head.bias.data.mul_(head_init_scale)
153
+
154
+ if weight != None:
155
+ self.load_checkpoint(checkpoint=weight)
156
+ self.freeze_weight()
157
+
158
+ def _init_weights(self, m):
159
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
160
+ trunc_normal_(m.weight, std=.02)
161
+ try:
162
+ nn.init.constant_(m.bias, 0)
163
+ except: # transformer layers
164
+ pass
165
+ # print("transformer layer can't initialize")
166
+
167
+ def freeze_weight(self):
168
+ for name, param in self.named_parameters():
169
+ if param.requires_grad and 'pos_embed' in name:
170
+ param.requires_grad = False
171
+
172
+ def load_checkpoint(self, checkpoint=None):
173
+ state = torch.load(checkpoint, map_location='cpu')
174
+ if 'state_dict' in state:
175
+ state_dict = state['state_dict']
176
+ elif 'model' in state:
177
+ state_dict = state['model']
178
+ for key in list(state_dict.keys()):
179
+ state_dict[key.replace('module.', '')] = state_dict.pop(key)
180
+ elif 'q_state_dict' in state:
181
+ state_dict = state['q_state_dict']
182
+
183
+ for key in list(state_dict.keys()):
184
+ state_dict[key.replace('backbone.', '')] = state_dict.pop(key)
185
+
186
+ model_dict = self.state_dict()
187
+ weights = {k: v for k, v in state_dict.items() if k in model_dict}
188
+
189
+ model_dict.update(weights)
190
+ del model_dict['head.weight']
191
+ del model_dict['head.bias']
192
+ self.load_state_dict(model_dict, strict=False)
193
+
194
+ def forward(self, x):
195
+ outputs = self.encoder(x)
196
+ # x, low_level, mid_level, high_level = self.seg_encoder(x)
197
+ return outputs
198
+
199
+ def encoder(self, x):
200
+ x = self.stem(x)
201
+ for _, layer in enumerate(self.stages1):
202
+ if _ == len(self.stages1) - 1:
203
+ x1 = x
204
+ x = layer(x)
205
+
206
+ for _, layer in enumerate(self.stages2):
207
+ if _ == len(self.stages2) - 1:
208
+ x2 = x
209
+ x = layer(x)
210
+
211
+ for _, layer in enumerate(self.stages3):
212
+ if _ == len(self.stages3) - 1:
213
+ x3 = x
214
+ x = layer(x)
215
+
216
+ x = self.stages4(x)
217
+ x = self.norm(x.mean([-2, -1]))
218
+ x = self.head(x)
219
+ return x
220
+
221
+ def seg_encoder(self, x):
222
+ org_img = x
223
+ x = self.stem(x)
224
+ for _, layer in enumerate(self.stages1):
225
+ if _ == len(self.stages1) - 2:
226
+ low_level = x
227
+ x = layer(x)
228
+
229
+ x = self.stages2(x)
230
+
231
+ for _, layer in enumerate(self.stages3):
232
+ if _ == len(self.stages3) - 2:
233
+ mid_level = x
234
+ x = layer(x)
235
+
236
+ for _, layer in enumerate(self.stages4):
237
+ x = layer(x)
238
+ high_level = x
239
+
240
+ return org_img, low_level, mid_level, high_level
241
+
242
+ import torch
243
+ import torch.nn as nn
244
+ import torch.nn.functional as F
245
+ from einops import rearrange
246
+ import math
247
+ import warnings
248
+
249
+ class LayerNorm(nn.Module):
250
+ """ LayerNorm that supports two data formats: channels_last (default) or channels_first.
251
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
252
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
253
+ with shape (batch_size, channels, height, width).
254
+ """
255
+
256
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
257
+ super().__init__()
258
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
259
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
260
+ self.eps = eps
261
+ self.data_format = data_format
262
+ if self.data_format not in ["channels_last", "channels_first"]:
263
+ raise NotImplementedError
264
+ self.normalized_shape = (normalized_shape,)
265
+
266
+ def forward(self, x):
267
+ if self.data_format == "channels_last":
268
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
269
+ elif self.data_format == "channels_first":
270
+ u = x.mean(1, keepdim=True)
271
+ s = (x - u).pow(2).mean(1, keepdim=True)
272
+ x = (x - u) / torch.sqrt(s + self.eps)
273
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
274
+ return x
275
+
276
+
277
+ class GRN(nn.Module):
278
+ """ GRN (Global Response Normalization) layer
279
+ """
280
+
281
+ def __init__(self, dim):
282
+ super().__init__()
283
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
284
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
285
+
286
+ def forward(self, x):
287
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
288
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
289
+ return self.gamma * (x * Nx) + self.beta + x
290
+
291
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
292
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
293
+
294
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
295
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
296
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
297
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
298
+ 'survival rate' as the argument.
299
+
300
+ """
301
+ if drop_prob == 0. or not training:
302
+ return x
303
+ keep_prob = 1 - drop_prob
304
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
305
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
306
+ random_tensor.floor_() # binarize
307
+ output = x.div(keep_prob) * random_tensor
308
+ return output
309
+
310
+
311
+ class DropPath(nn.Module):
312
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
313
+ """
314
+ def __init__(self, drop_prob=None):
315
+ super(DropPath, self).__init__()
316
+ self.drop_prob = drop_prob
317
+
318
+ def forward(self, x):
319
+ return drop_path(x, self.drop_prob, self.training)
320
+
321
+ class FeedForward(nn.Module):
322
+ def __init__(self, dim, hidden_dim, dropout=0.):
323
+ super().__init__()
324
+ self.net = nn.Sequential(
325
+ nn.Linear(dim, hidden_dim),
326
+ nn.GELU(),
327
+ nn.Dropout(dropout),
328
+ nn.Linear(hidden_dim, dim),
329
+ nn.Dropout(dropout)
330
+ )
331
+
332
+ def forward(self, x):
333
+ return self.net(x)
334
+
335
+ class PreNorm(nn.Module):
336
+ def __init__(self, dim, fn, norm):
337
+ super().__init__()
338
+ self.norm = norm(dim)
339
+ self.fn = fn
340
+
341
+ def forward(self, x, **kwargs):
342
+ return self.fn(self.norm(x), **kwargs)
343
+
344
+ class Attention(nn.Module):
345
+ def __init__(self, inp, oup, heads=8, dim_head=32, dropout=0.):
346
+ super().__init__()
347
+ inner_dim = dim_head * heads
348
+ project_out = not (heads == 1 and dim_head == inp)
349
+
350
+ # self.ih, self.iw = image_size
351
+ self.heads = heads
352
+ self.scale = dim_head ** -0.5
353
+
354
+ self.attend = nn.Softmax(dim=-1)
355
+ self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False)
356
+
357
+ self.to_out = nn.Sequential(
358
+ nn.Linear(inner_dim, oup),
359
+ nn.Dropout(dropout)
360
+ ) if project_out else nn.Identity()
361
+ self.pos_embed = PosCNN(in_chans=inp)
362
+
363
+ def forward(self, x):
364
+ x = self.pos_embed(x)
365
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
366
+ q, k, v = map(lambda t: rearrange(
367
+ t, 'b n (h d) -> b h n d', h=self.heads), qkv)
368
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
369
+ attn = self.attend(dots)
370
+ out = torch.matmul(attn, v)
371
+ out = rearrange(out, 'b h n d -> b n (h d)')
372
+ out = self.to_out(out)
373
+ return out
374
+
375
+ # PEG from https://arxiv.org/abs/2102.10882
376
+ class PosCNN(nn.Module):
377
+ def __init__(self, in_chans):
378
+ super(PosCNN, self).__init__()
379
+ self.proj = nn.Conv2d(in_chans, in_chans, kernel_size=3, stride = 1, padding=1, bias=True, groups=in_chans)
380
+
381
+ def forward(self, x):
382
+ B, N, C = x.shape
383
+ feat_token = x
384
+ H, W = int(N**0.5), int(N**0.5)
385
+ cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)
386
+ x = self.proj(cnn_feat) + cnn_feat
387
+ x = x.flatten(2).transpose(1, 2)
388
+ return x
389
+
390
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
391
+ # type: (Tensor, float, float, float, float) -> Tensor
392
+ r"""Fills the input Tensor with values drawn from a truncated
393
+ normal distribution. The values are effectively drawn from the
394
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
395
+ with values outside :math:`[a, b]` redrawn until they are within
396
+ the bounds. The method used for generating the random values works
397
+ best when :math:`a \leq \text{mean} \leq b`.
398
+ Args:
399
+ tensor: an n-dimensional `torch.Tensor`
400
+ mean: the mean of the normal distribution
401
+ std: the standard deviation of the normal distribution
402
+ a: the minimum cutoff value
403
+ b: the maximum cutoff value
404
+ Examples:
405
+ >>> w = torch.empty(3, 5)
406
+ >>> nn.init.trunc_normal_(w)
407
+ """
408
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
409
+
410
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
411
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
412
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
413
+ def norm_cdf(x):
414
+ # Computes standard normal cumulative distribution function
415
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
416
+
417
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
418
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
419
+ "The distribution of values may be incorrect.",
420
+ stacklevel=2)
421
+
422
+ with torch.no_grad():
423
+ # Values are generated by using a truncated uniform distribution and
424
+ # then using the inverse CDF for the normal distribution.
425
+ # Get upper and lower cdf values
426
+ l = norm_cdf((a - mean) / std)
427
+ u = norm_cdf((b - mean) / std)
428
+
429
+ # Uniformly fill tensor with values from [l, u], then translate to
430
+ # [2l-1, 2u-1].
431
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
432
+
433
+ # Use inverse cdf transform for normal distribution to get truncated
434
+ # standard normal
435
+ tensor.erfinv_()
436
+
437
+ # Transform to proper mean, std
438
+ tensor.mul_(std * math.sqrt(2.))
439
+ tensor.add_(mean)
440
+
441
+ # Clamp to ensure it's in the proper range
442
+ tensor.clamp_(min=a, max=b)
443
+ return tensor
444
+
445
+ class DoubleConv(nn.Module):
446
+ """(convolution => [BN] => ReLU) * 2"""
447
+
448
+ def __init__(self, in_channels, out_channels, mid_channels=None):
449
+ super().__init__()
450
+ if not mid_channels:
451
+ mid_channels = out_channels
452
+ self.double_conv = nn.Sequential(
453
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
454
+ nn.BatchNorm2d(mid_channels),
455
+ nn.ReLU(inplace=True),
456
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
457
+ nn.BatchNorm2d(out_channels),
458
+ nn.ReLU(inplace=True)
459
+ )
460
+
461
+ def forward(self, x):
462
+ return self.double_conv(x)
463
+
464
+ class Up(nn.Module):
465
+ """Upscaling then double conv"""
466
+
467
+ def __init__(self, in_channels, out_channels, bilinear=True):
468
+ super().__init__()
469
+
470
+ # if bilinear, use the normal convolutions to reduce the number of channels
471
+ if bilinear:
472
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
473
+ self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
474
+ else:
475
+ self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
476
+ self.conv = DoubleConv(in_channels, out_channels)
477
+
478
+ def forward(self, x1, x2):
479
+ x1 = self.up(x1)
480
+ # input is CHW
481
+ diffY = x2.size()[2] - x1.size()[2]
482
+ diffX = x2.size()[3] - x1.size()[3]
483
+
484
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
485
+ diffY // 2, diffY - diffY // 2])
486
+ # if you have padding issues, see
487
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
488
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
489
+ x = torch.cat([x2, x1], dim=1)
490
+ return self.conv(x)
model/CSATv2.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ from torch import nn
4
+ from einops.layers.torch import Rearrange
5
+ from .DCT import Learnable_DCT2D #Learnable for H&E slide
6
+ # from .DCT import Static_DCT2D #Static for Imagenet
7
+
8
+ class Block(nn.Module):
9
+ """ ConvNeXtV2 Block.
10
+
11
+ Args:
12
+ dim (int): Number of input channels.
13
+ drop_path (float): Stochastic depth rate. Default: 0.0
14
+ """
15
+
16
+ def __init__(self, dim, drop_path=0.):
17
+ super().__init__()
18
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
19
+ self.norm = LayerNorm(dim, eps=1e-6)
20
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
21
+ self.act = nn.GELU()
22
+ self.grn = GRN(4 * dim)
23
+ self.pwconv2 = nn.Linear(4 * dim, dim)
24
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
25
+ self.attention = Spatial_Attention()
26
+ def forward(self, x):
27
+ input = x
28
+ x = self.dwconv(x)
29
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
30
+ x = self.norm(x)
31
+ x = self.pwconv1(x)
32
+ x = self.act(x)
33
+ x = self.grn(x)
34
+ x = self.pwconv2(x)
35
+
36
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
37
+ attention = self.attention(x)
38
+ x = x * nn.UpsamplingBilinear2d(x.shape[2:])(attention)
39
+ x = input + self.drop_path(x)
40
+ return x
41
+
42
+ class Spatial_Attention(nn.Module):
43
+ def __init__(self):
44
+ super().__init__()
45
+ self.avgpool = nn.AdaptiveAvgPool2d((7,7))
46
+ self.conv = nn.Conv2d(2,1, kernel_size=7, padding=3)
47
+ self.attention = TransformerBlock(1, 1, heads=1, dim_head=1, img_size=[7,7])
48
+
49
+ def forward(self, x):
50
+ x_avg = x.mean([1]).unsqueeze(1)
51
+ x_max = x.max(dim=1).values.unsqueeze(1)
52
+ # x = torch.concat([x_avg,x_max],dim=1)
53
+ x = torch.cat([x_avg, x_max], dim=1)
54
+ x = self.avgpool(x)
55
+ x = self.conv(x)
56
+ x = self.attention(x)
57
+ return x
58
+
59
+ class TransformerBlock(nn.Module):
60
+ def __init__(self, inp, oup, heads=8, dim_head=32, img_size=None, downsample=False, dropout=0.):
61
+ super().__init__()
62
+ hidden_dim = int(inp * 4)
63
+
64
+ self.downsample = downsample
65
+ self.ih, self.iw = img_size
66
+
67
+ if self.downsample:
68
+ self.pool1 = nn.MaxPool2d(3, 2, 1)
69
+ self.pool2 = nn.MaxPool2d(3, 2, 1)
70
+ self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)
71
+
72
+ self.attn = Attention(inp, oup, heads, dim_head, dropout)
73
+ self.ff = FeedForward(oup, hidden_dim, dropout)
74
+
75
+ self.attn = nn.Sequential(
76
+ Rearrange('b c ih iw -> b (ih iw) c'),
77
+ PreNorm(inp, self.attn, nn.LayerNorm),
78
+ Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
79
+ )
80
+
81
+ self.ff = nn.Sequential(
82
+ Rearrange('b c ih iw -> b (ih iw) c'),
83
+ PreNorm(oup, self.ff, nn.LayerNorm),
84
+ Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
85
+ )
86
+
87
+ def forward(self, x):
88
+ if self.downsample:
89
+ x = self.proj(self.pool1(x)) + self.attn(self.pool2(x))
90
+ else:
91
+ x = x + self.attn(x)
92
+ x = x + self.ff(x)
93
+ return x
94
+
95
+
96
+ class CSATv2(nn.Module):
97
+ def __init__(self, img_size=None, num_classes=1000, drop_path_rate=0, head_init_scale=1):
98
+ super().__init__()
99
+ dims = [32, 72, 168, 386]
100
+ channel_order = "channels_first"
101
+ depths = [2, 2, 6, 4]
102
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
103
+
104
+ # self.stem = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=dims[0], kernel_size=4, stride=4),
105
+ # LayerNorm(normalized_shape=dims[0], data_format=channel_order))
106
+
107
+ self.stages1 = nn.Sequential(
108
+ Block(dim=dims[0], drop_path=dp_rates[0]),
109
+ Block(dim=dims[0], drop_path=dp_rates[1]),
110
+ LayerNorm(dims[0], eps=1e-6, data_format=channel_order),
111
+ nn.Conv2d(dims[0], dims[0 + 1], kernel_size=2, stride=2),
112
+ )
113
+
114
+ self.stages2 = nn.Sequential(
115
+ Block(dim=dims[1], drop_path=dp_rates[0]),
116
+ Block(dim=dims[1], drop_path=dp_rates[1]),
117
+ LayerNorm(dims[1], eps=1e-6, data_format=channel_order),
118
+ nn.Conv2d(dims[1], dims[1 + 1], kernel_size=2, stride=2),
119
+ )
120
+
121
+ self.stages3 = nn.Sequential(
122
+ Block(dim=dims[2], drop_path=dp_rates[0]),
123
+ Block(dim=dims[2], drop_path=dp_rates[1]),
124
+ Block(dim=dims[2], drop_path=dp_rates[2]),
125
+ Block(dim=dims[2], drop_path=dp_rates[3]),
126
+ Block(dim=dims[2], drop_path=dp_rates[4]),
127
+ Block(dim=dims[2], drop_path=dp_rates[5]),
128
+ TransformerBlock(inp=dims[2], oup=dims[2], img_size=[int(img_size / 32), int(img_size / 32)]),
129
+ TransformerBlock(inp=dims[2], oup=dims[2], img_size=[int(img_size / 32), int(img_size / 32)]),
130
+ LayerNorm(dims[2], eps=1e-6, data_format=channel_order),
131
+ nn.Conv2d(dims[2], dims[2 + 1], kernel_size=2, stride=2),
132
+ )
133
+
134
+ self.stages4 = nn.Sequential(
135
+ Block(dim=dims[3], drop_path=dp_rates[0]),
136
+ Block(dim=dims[3], drop_path=dp_rates[1]),
137
+ Block(dim=dims[3], drop_path=dp_rates[2]),
138
+ Block(dim=dims[3], drop_path=dp_rates[3]),
139
+ TransformerBlock(inp=dims[3], oup=dims[3], img_size=[int(img_size / 64), int(img_size / 64)]),
140
+ TransformerBlock(inp=dims[3], oup=dims[3], img_size=[int(img_size / 64), int(img_size / 64)]),
141
+ )
142
+
143
+ self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
144
+ self.head = nn.Linear(dims[-1], num_classes)
145
+
146
+ self.apply(self._init_weights)
147
+ self.head.weight.data.mul_(head_init_scale)
148
+ self.head.bias.data.mul_(head_init_scale)
149
+ self.dct = Learnable_DCT2D(8)
150
+ # self.dct = Static_DCT2D(8)
151
+
152
+ def load_checkpoint(self, checkpoint):
153
+ state = torch.load(checkpoint, map_location='cpu')
154
+ try:
155
+ state_dict = state['state_dict']
156
+ except:
157
+ state_dict = state['model']
158
+ for key in list(state_dict.keys()):
159
+ state_dict[key.replace('module.backbone.', '').replace('resnet.', '')] = state_dict.pop(key)
160
+
161
+ model_dict = self.state_dict()
162
+ weights = {k: v for k, v in state_dict.items() if k in model_dict}
163
+
164
+ model_dict.update(weights)
165
+ del model_dict['head.bias']
166
+ del model_dict['head.weight']
167
+ self.load_state_dict(model_dict, strict=False)
168
+
169
+ def preprocess(self, x):
170
+ x = cv2.cvtColor(x, cv2.COLOR_BGR2YCR_CB)
171
+ return x
172
+
173
+ def _init_weights(self, m):
174
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
175
+ trunc_normal_(m.weight, std=.02)
176
+ try:
177
+ nn.init.constant_(m.bias, 0)
178
+ except: # transformer layers
179
+ pass
180
+ # print("transformer layer can't initialize")
181
+
182
+
183
+ def forward(self, x):
184
+ # x = self.preprocess(x)
185
+ x = self.dct(x)#b, c, h, w -> b, c, *, h, w
186
+ x = self.stages1(x)
187
+ x = self.stages2(x)
188
+ x = self.stages3(x)
189
+ x = self.stages4(x)
190
+ x = self.norm(x.mean([-2, -1]))
191
+ x = self.head(x)
192
+ return x
193
+
194
+ import torch
195
+ import torch.nn as nn
196
+ import torch.nn.functional as F
197
+ from einops import rearrange
198
+ import math
199
+ import warnings
200
+
201
+ class LayerNorm(nn.Module):
202
+ """ LayerNorm that supports two data formats: channels_last (default) or channels_first.
203
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
204
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
205
+ with shape (batch_size, channels, height, width).
206
+ """
207
+
208
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
209
+ super().__init__()
210
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
211
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
212
+ self.eps = eps
213
+ self.data_format = data_format
214
+ if self.data_format not in ["channels_last", "channels_first"]:
215
+ raise NotImplementedError
216
+ self.normalized_shape = (normalized_shape,)
217
+
218
+ def forward(self, x):
219
+ if self.data_format == "channels_last":
220
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
221
+ elif self.data_format == "channels_first":
222
+ u = x.mean(1, keepdim=True)
223
+ s = (x - u).pow(2).mean(1, keepdim=True)
224
+ x = (x - u) / torch.sqrt(s + self.eps)
225
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
226
+ return x
227
+
228
+
229
+ class GRN(nn.Module):
230
+ """ GRN (Global Response Normalization) layer
231
+ """
232
+
233
+ def __init__(self, dim):
234
+ super().__init__()
235
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
236
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
237
+
238
+ def forward(self, x):
239
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
240
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
241
+ return self.gamma * (x * Nx) + self.beta + x
242
+
243
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
244
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
245
+
246
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
247
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
248
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
249
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
250
+ 'survival rate' as the argument.
251
+
252
+ """
253
+ if drop_prob == 0. or not training:
254
+ return x
255
+ keep_prob = 1 - drop_prob
256
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
257
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
258
+ random_tensor.floor_() # binarize
259
+ output = x.div(keep_prob) * random_tensor
260
+ return output
261
+
262
+
263
+ class DropPath(nn.Module):
264
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
265
+ """
266
+ def __init__(self, drop_prob=None):
267
+ super(DropPath, self).__init__()
268
+ self.drop_prob = drop_prob
269
+
270
+ def forward(self, x):
271
+ return drop_path(x, self.drop_prob, self.training)
272
+
273
+ class FeedForward(nn.Module):
274
+ def __init__(self, dim, hidden_dim, dropout=0.):
275
+ super().__init__()
276
+ self.net = nn.Sequential(
277
+ nn.Linear(dim, hidden_dim),
278
+ nn.GELU(),
279
+ nn.Dropout(dropout),
280
+ nn.Linear(hidden_dim, dim),
281
+ nn.Dropout(dropout)
282
+ )
283
+
284
+ def forward(self, x):
285
+ return self.net(x)
286
+
287
+ class PreNorm(nn.Module):
288
+ def __init__(self, dim, fn, norm):
289
+ super().__init__()
290
+ self.norm = norm(dim)
291
+ self.fn = fn
292
+
293
+ def forward(self, x, **kwargs):
294
+ return self.fn(self.norm(x), **kwargs)
295
+
296
+ class Attention(nn.Module):
297
+ def __init__(self, inp, oup, heads=8, dim_head=32, dropout=0.):
298
+ super().__init__()
299
+ inner_dim = dim_head * heads
300
+ project_out = not (heads == 1 and dim_head == inp)
301
+
302
+ # self.ih, self.iw = image_size
303
+ self.heads = heads
304
+ self.scale = dim_head ** -0.5
305
+
306
+ self.attend = nn.Softmax(dim=-1)
307
+ self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False)
308
+
309
+ self.to_out = nn.Sequential(
310
+ nn.Linear(inner_dim, oup),
311
+ nn.Dropout(dropout)
312
+ ) if project_out else nn.Identity()
313
+ self.pos_embed = PosCNN(in_chans=inp)
314
+
315
+ def forward(self, x):
316
+ x = self.pos_embed(x)
317
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
318
+ q, k, v = map(lambda t: rearrange(
319
+ t, 'b n (h d) -> b h n d', h=self.heads), qkv)
320
+
321
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
322
+ attn = self.attend(dots)
323
+ out = torch.matmul(attn, v)
324
+ out = rearrange(out, 'b h n d -> b n (h d)')
325
+ out = self.to_out(out)
326
+ return out
327
+
328
+ # PEG from https://arxiv.org/abs/2102.10882
329
+ class PosCNN(nn.Module):
330
+ def __init__(self, in_chans):
331
+ super(PosCNN, self).__init__()
332
+ self.proj = nn.Conv2d(in_chans, in_chans, kernel_size=3, stride = 1, padding=1, bias=True, groups=in_chans)
333
+
334
+ def forward(self, x):
335
+ B, N, C = x.shape
336
+ feat_token = x
337
+ H, W = int(N**0.5), int(N**0.5)
338
+ cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)
339
+ x = self.proj(cnn_feat) + cnn_feat
340
+ x = x.flatten(2).transpose(1, 2)
341
+ return x
342
+
343
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
344
+ # type: (Tensor, float, float, float, float) -> Tensor
345
+ r"""Fills the input Tensor with values drawn from a truncated
346
+ normal distribution. The values are effectively drawn from the
347
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
348
+ with values outside :math:`[a, b]` redrawn until they are within
349
+ the bounds. The method used for generating the random values works
350
+ best when :math:`a \leq \text{mean} \leq b`.
351
+ Args:
352
+ tensor: an n-dimensional `torch.Tensor`
353
+ mean: the mean of the normal distribution
354
+ std: the standard deviation of the normal distribution
355
+ a: the minimum cutoff value
356
+ b: the maximum cutoff value
357
+ Examples:
358
+ >>> w = torch.empty(3, 5)
359
+ >>> nn.init.trunc_normal_(w)
360
+ """
361
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
362
+
363
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
364
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
365
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
366
+ def norm_cdf(x):
367
+ # Computes standard normal cumulative distribution function
368
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
369
+
370
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
371
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
372
+ "The distribution of values may be incorrect.",
373
+ stacklevel=2)
374
+
375
+ with torch.no_grad():
376
+ # Values are generated by using a truncated uniform distribution and
377
+ # then using the inverse CDF for the normal distribution.
378
+ # Get upper and lower cdf values
379
+ l = norm_cdf((a - mean) / std)
380
+ u = norm_cdf((b - mean) / std)
381
+
382
+ # Uniformly fill tensor with values from [l, u], then translate to
383
+ # [2l-1, 2u-1].
384
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
385
+
386
+ # Use inverse cdf transform for normal distribution to get truncated
387
+ # standard normal
388
+ tensor.erfinv_()
389
+
390
+ # Transform to proper mean, std
391
+ tensor.mul_(std * math.sqrt(2.))
392
+ tensor.add_(mean)
393
+
394
+ # Clamp to ensure it's in the proper range
395
+ tensor.clamp_(min=a, max=b)
396
+ return tensor
model/DCT.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import math
7
+
8
+ __all__ = ['DCT2D']
9
+
10
+
11
+ # Helper Functions
12
+ mean=[[932.42657,-0.00260,0.33415,-0.02840,0.00003,-0.02792,-0.00183,0.00006,0.00032,0.03402,-0.00571,0.00020,0.00006,-0.00038,-0.00558,-0.00116,-0.00000,-0.00047,-0.00008,-0.00030,0.00942,0.00161,-0.00009,-0.00006,-0.00014,-0.00035,0.00001,-0.00220,0.00033,-0.00002,-0.00003,-0.00020,0.00007,-0.00000,0.00005,0.00293,-0.00004,0.00006,0.00019,0.00004,0.00006,-0.00015,-0.00002,0.00007,0.00010,-0.00004,0.00008,0.00000,0.00008,-0.00001,0.00015,0.00002,0.00007,0.00003,0.00004,-0.00001,0.00004,-0.00000,0.00002,-0.00000,-0.00008,-0.00000,-0.00003,0.00003],
13
+ [962.34735,-0.00428,0.09835,0.00152,-0.00009,0.00312,-0.00141,-0.00001,-0.00013,0.01050,0.00065,0.00006,-0.00000,0.00003,0.00264,0.00000,0.00001,0.00007,-0.00006,0.00003,0.00341,0.00163,0.00004,0.00003,-0.00001,0.00008,-0.00000,0.00090,0.00018,-0.00006,-0.00001,0.00007,-0.00003,-0.00001,0.00006,0.00084,-0.00000,-0.00001,0.00000,0.00004,-0.00001,-0.00002,0.00000,0.00001,0.00002,0.00001,0.00004,0.00011,0.00000,-0.00003,0.00011,-0.00002,0.00001,0.00001,0.00001,0.00001,-0.00007,-0.00003,0.00001,0.00000,0.00001,0.00002,0.00001,0.00000],
14
+ [1053.16101,-0.00213,-0.09207,0.00186,0.00013,0.00034,-0.00119,0.00002,0.00011,-0.00984,0.00046,-0.00007,-0.00001,-0.00005,0.00180,0.00042,0.00002,-0.00010,0.00004,0.00003,-0.00301,0.00125,-0.00002,-0.00003,-0.00001,-0.00001,-0.00001,0.00056,0.00021,0.00001,-0.00001,0.00002,-0.00001,-0.00001,0.00005,-0.00070,-0.00002,-0.00002,0.00005,-0.00004,-0.00000,0.00002,-0.00002,0.00001,0.00000,-0.00003,0.00004,0.00007,0.00001,0.00000,0.00013,-0.00000,0.00000,0.00002,-0.00000,-0.00001,-0.00004,-0.00003,0.00000,0.00001,-0.00001,0.00001,-0.00000,0.00000]]
15
+
16
+ var=[[270372.37500,6287.10645,5974.94043,1653.10889,1463.91748,1832.58997,755.92468,692.41528,648.57184,641.46881,285.79288,301.62100,380.43405,349.84027,374.15891,190.30960,190.76746,221.64578,200.82646,145.87979,126.92046,62.14622,67.75562,102.42001,129.74922,130.04631,103.12189,97.76417,53.17402,54.81048,73.48712,81.04342,69.35100,49.06024,33.96053,37.03279,20.48858,24.94830,33.90822,44.54912,47.56363,40.03160,30.43313,22.63899,26.53739,26.57114,21.84404,17.41557,15.18253,10.69678,11.24111,12.97229,15.08971,15.31646,8.90409,7.44213,6.66096,6.97719,4.17834,3.83882,4.51073,2.36646,2.41363,1.48266],
17
+ [18839.21094,321.70932,300.15259,77.47830,76.02293,89.04748,33.99642,34.74807,32.12333,28.19588,12.04675,14.26871,18.45779,16.59588,15.67892,7.37718,8.56312,10.28946,9.41013,6.69090,5.16453,2.55186,3.03073,4.66765,5.85418,5.74644,4.33702,3.66948,1.95107,2.26034,3.06380,3.50705,3.06359,2.19284,1.54454,1.57860,0.97078,1.13941,1.48653,1.89996,1.95544,1.64950,1.24754,0.93677,1.09267,1.09516,0.94163,0.78966,0.72489,0.50841,0.50909,0.55664,0.63111,0.64125,0.38847,0.33378,0.30918,0.33463,0.20875,0.19298,0.21903,0.13380,0.13444,0.09554],
18
+ [17127.39844,292.81421,271.45209,66.64056,63.60253,76.35437,28.06587,27.84831,25.96656,23.60370,9.99173,11.34992,14.46955,12.92553,12.69353,5.91537,6.60187,7.90891,7.32825,5.32785,4.29660,2.13459,2.44135,3.66021,4.50335,4.38959,3.34888,2.97181,1.60633,1.77010,2.35118,2.69018,2.38189,1.74596,1.26014,1.31684,0.79327,0.92046,1.17670,1.47609,1.50914,1.28725,0.99898,0.74832,0.85736,0.85800,0.74663,0.63508,0.58748,0.41098,0.41121,0.44663,0.50277,0.51519,0.31729,0.27336,0.25399,0.27241,0.17353,0.16255,0.18440,0.11602,0.11511,0.08450]]
19
+ #torch.tensor(var)
20
+
21
+ def _zigzag_permutation(rows: int, cols: int) -> List[int]:
22
+ idx_matrix = np.arange(0, rows * cols, 1).reshape(rows, cols).tolist()
23
+ dia = [[] for _ in range(rows + cols - 1)]
24
+ zigzag = []
25
+ for i in range(rows):
26
+ for j in range(cols):
27
+ s = i + j
28
+ if s % 2 == 0:
29
+ dia[s].insert(0, idx_matrix[i][j])
30
+ else:
31
+ dia[s].append(idx_matrix[i][j])
32
+ for d in dia:
33
+ zigzag.extend(d)
34
+ return zigzag
35
+
36
+
37
+ # Kernels
38
+
39
+
40
+ def _dct_kernel_type_2(
41
+ kernel_size: int,
42
+ orthonormal: bool,
43
+ device=None,
44
+ dtype=None,
45
+ ) -> torch.Tensor:
46
+ factory_kwargs = dict(device=device, dtype=dtype)
47
+ x = torch.eye(kernel_size, **factory_kwargs)
48
+ v = x.clone().contiguous().view(-1, kernel_size)
49
+ v = torch.cat([v, v.flip([1])], dim=-1)
50
+ v = torch.fft.fft(v, dim=-1)[:, :kernel_size]
51
+ try:
52
+ k = torch.tensor(-1j, **factory_kwargs) * torch.pi * torch.arange(kernel_size, **factory_kwargs)[None, :]
53
+ except:
54
+ k = torch.tensor(-1j, **factory_kwargs) * math.pi * torch.arange(kernel_size, **factory_kwargs)[None, :]
55
+ k = torch.exp(k / (kernel_size * 2))
56
+ v = v * k
57
+ v = v.real
58
+ if orthonormal:
59
+ v[:, 0] = v[:, 0] * torch.sqrt(torch.tensor(1 / (kernel_size * 4), **factory_kwargs))
60
+ v[:, 1:] = v[:, 1:] * torch.sqrt(torch.tensor(1 / (kernel_size * 2), **factory_kwargs))
61
+ v = v.contiguous().view(*x.shape)
62
+ return v
63
+
64
+
65
+ def _dct_kernel_type_3(
66
+ kernel_size: int,
67
+ orthonormal: bool,
68
+ device=None,
69
+ dtype=None,
70
+ ) -> torch.Tensor:
71
+ return torch.linalg.inv(_dct_kernel_type_2(kernel_size, orthonormal, device, dtype))
72
+
73
+
74
+ # Modules
75
+
76
+
77
+ class _DCT1D(nn.Module):
78
+
79
+ def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True,
80
+ device=None, dtype=None) -> None:
81
+ factory_kwargs = dict(device=device, dtype=dtype)
82
+ super(_DCT1D, self).__init__()
83
+ kernel = {'2': _dct_kernel_type_2, '3': _dct_kernel_type_3}
84
+ self.weights = nn.Parameter(kernel[f'{kernel_type}'](kernel_size, orthonormal, **factory_kwargs).T, False)
85
+ self.register_parameter('bias', None)
86
+
87
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
88
+ return nn.functional.linear(x, self.weights, self.bias)
89
+
90
+
91
+ class _DCT2D(nn.Module):
92
+
93
+ def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True,
94
+ device=None, dtype=None) -> None:
95
+ factory_kwargs = dict(device=device, dtype=dtype)
96
+ super(_DCT2D, self).__init__()
97
+ self.transform = _DCT1D(kernel_size, kernel_type, orthonormal, **factory_kwargs)
98
+
99
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
100
+ # [..., H, W] @ DCT_Kernel.T -> [..., W, H] @ DCT_Kernel.T -> [..., H, W]
101
+ return self.transform(self.transform(x).transpose(-1, -2)).transpose(-1, -2)
102
+
103
+
104
+ # Discrete Cosine Transforms (DCT)
105
+
106
+
107
+ class Learnable_DCT2D(nn.Module):
108
+ r"""Computes the two-dimensional block-wise discrete cosine transform of :attr:`input`.
109
+
110
+ Args:
111
+ kernel_size (int): Size of the coefficient kernel
112
+ kernel_type (int): Type of the DCT (see Notes). Default: 2
113
+ orthonormal (bool): A boolean makes the corresponding matrix of coefficients orthonormal. Default: True
114
+
115
+ """
116
+
117
+ def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True,
118
+ device=None, dtype=None) -> None:
119
+ factory_kwargs = dict(device=device, dtype=dtype)
120
+ super(Learnable_DCT2D, self).__init__()
121
+ self.k = kernel_size
122
+ self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size))
123
+ self.transform = _DCT2D(kernel_size, kernel_type, orthonormal, **factory_kwargs)
124
+ self.permutation = _zigzag_permutation(kernel_size, kernel_size)
125
+ self.Y_Conv = nn.Conv2d(kernel_size**2, 24, kernel_size=1, padding=0)
126
+ self.Cb_Conv = nn.Conv2d(kernel_size**2, 4, kernel_size=1, padding=0)
127
+ self.Cr_Conv = nn.Conv2d(kernel_size**2, 4, kernel_size=1, padding=0)
128
+ self.mean = torch.tensor(mean, requires_grad=False)
129
+ self.var = torch.tensor(var, requires_grad=False)
130
+ self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406], requires_grad=False)
131
+ self.imagenet_var = torch.tensor([0.229, 0.224, 0.225], requires_grad=False)
132
+ def denormalize(self, x):
133
+ x = x.multiply(self.imagenet_var.to(x.device)).add_(self.imagenet_mean.to(x.device)) * 255 # denormalize
134
+ return x
135
+
136
+ def rgb2ycbcr(self, x):
137
+ y = (x[:,:,:,0] *0.299) + (x[:,:,:,1]* 0.587) + (x[:,:,:,2] * 0.114) #rgb2ycbcr
138
+ cb = 0.564 * (x[:,:,:,2] - y) + 128
139
+ cr = 0.713 * (x[:,:,:,0] - y) + 128
140
+ x = torch.stack([y, cb, cr],dim=-1)
141
+ return x
142
+
143
+ def frequncy_normalize(self, x):
144
+ x[:, 0, ].sub_(self.mean.to(x.device)[0]).div_((self.var.to(x.device)[0]**0.5+1e-8))
145
+ x[:, 1, ].sub_(self.mean.to(x.device)[1]).div_((self.var.to(x.device)[1]**0.5+1e-8))
146
+ x[:, 2, ].sub_(self.mean.to(x.device)[2]).div_((self.var.to(x.device)[2]**0.5+1e-8))
147
+ return x
148
+
149
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
150
+ b, c, h, w = x.shape #b, c, h, w
151
+ x = x.permute(0, 2, 3, 1)#b, h, w, c
152
+ x = self.denormalize(x)#b, h, w,c
153
+ x = self.rgb2ycbcr(x)#b, h, w, c
154
+ x = x.permute(0, 3, 1, 2)#b, c, h, w
155
+ x = self.unfold(x).transpose(-1, -2)#b,c, h, w -> b, c*block, blocks
156
+ x = x.reshape(b, h // self.k, w // self.k, c, self.k, self.k)
157
+ x = self.transform(x)
158
+ x = x.reshape(-1, c, self.k * self.k)
159
+ x = x[:, :, self.permutation]
160
+ x = self.frequncy_normalize(x)
161
+ x = x.reshape(b, h // self.k, w // self.k, c, -1)#? b, block -> b, h, w, c, block
162
+ x = x.permute(0, 3, 4, 1, 2).contiguous() # b, c, block, h, w
163
+ x_Y = self.Y_Conv(x[:, 0, ])
164
+ x_Cb = self.Cb_Conv(x[:, 1, ])
165
+ x_Cr = self.Cr_Conv(x[:, 2, ])
166
+ x = torch.cat([x_Y, x_Cb, x_Cr], axis=1)
167
+ return x
168
+
169
+ class Static_DCT2D(nn.Module):
170
+ r"""Computes the two-dimensional block-wise discrete cosine transform of :attr:`input`.
171
+
172
+ Args:
173
+ kernel_size (int): Size of the coefficient kernel
174
+ kernel_type (int): Type of the DCT (see Notes). Default: 2
175
+ orthonormal (bool): A boolean makes the corresponding matrix of coefficients orthonormal. Default: True
176
+
177
+ """
178
+
179
+ def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True,
180
+ device=None, dtype=None) -> None:
181
+ factory_kwargs = dict(device=device, dtype=dtype)
182
+ super(Static_DCT2D, self).__init__()
183
+ self.k = kernel_size
184
+ self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size))
185
+ self.transform = _DCT2D(kernel_size, kernel_type, orthonormal, **factory_kwargs)
186
+ self.permutation = _zigzag_permutation(kernel_size, kernel_size)
187
+ self.mean = torch.tensor(mean, requires_grad=False)
188
+ self.var = torch.tensor(var, requires_grad=False)
189
+ self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406], requires_grad=False)
190
+ self.imagenet_var = torch.tensor([0.229, 0.224, 0.225], requires_grad=False)
191
+
192
+ def denormalize(self, x):
193
+ x = x.multiply(self.imagenet_var.to(x.device)).add_(self.imagenet_mean.to(x.device)) * 255 # denormalize
194
+ return x
195
+
196
+ def rgb2ycbcr(self, x):
197
+ y = (x[:,:,:,0] *0.299) + (x[:,:,:,1]* 0.587) + (x[:,:,:,2] * 0.114) #rgb2ycbcr
198
+ cb = 0.564 * (x[:,:,:,2] - y) + 128
199
+ cr = 0.713 * (x[:,:,:,0] - y) + 128
200
+ x = torch.stack([y, cb, cr],dim=-1)
201
+ return x
202
+
203
+ def frequncy_normalize(self, x):
204
+ x[:, 0, ].sub_(self.mean.to(x.device)[0]).div_((self.var.to(x.device)[0]**0.5+1e-8))
205
+ x[:, 1, ].sub_(self.mean.to(x.device)[1]).div_((self.var.to(x.device)[1]**0.5+1e-8))
206
+ x[:, 2, ].sub_(self.mean.to(x.device)[2]).div_((self.var.to(x.device)[2]**0.5+1e-8))
207
+ return x
208
+
209
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
210
+ b, c, h, w = x.shape #b, c, h, w
211
+ x = x.permute(0, 2, 3, 1)#b, h, w, c
212
+ x = self.denormalize(x)#b, h, w,c
213
+ x = self.rgb2ycbcr(x)#b, h, w, c
214
+ x = x.permute(0, 3, 1, 2)#b, c, h, w
215
+ x = self.unfold(x).transpose(-1, -2)#b,c, h, w -> b, c*block, blocks
216
+ x = x.reshape(b, h // self.k, w // self.k, c, self.k, self.k)
217
+ x = self.transform(x)
218
+ x = x.reshape(-1, c, self.k * self.k)
219
+ x = x[:, :, self.permutation]
220
+ x = self.frequncy_normalize(x)
221
+ x = x.reshape(b, h // self.k, w // self.k, c, -1)#? b, block -> b, h, w, c, block
222
+ x = x.permute(0, 3, 4, 1, 2).contiguous() # b, c, block, h, w
223
+ x_Y = self.Y_Conv(x[:, 0, ])
224
+ x_Cb = self.Cb_Conv(x[:, 1, ])
225
+ x_Cr = self.Cr_Conv(x[:, 2, ])
226
+ x = torch.cat([x_Y, x_Cb, x_Cr], axis=1)
227
+ return x
228
+
229
+ class DCT2D(nn.Module):
230
+ r"""Computes the two-dimensional block-wise discrete cosine transform of :attr:`input`.
231
+
232
+ Args:
233
+ kernel_size (int): Size of the coefficient kernel
234
+ kernel_type (int): Type of the DCT (see Notes). Default: 2
235
+ orthonormal (bool): A boolean makes the corresponding matrix of coefficients orthonormal. Default: True
236
+
237
+ """
238
+
239
+ def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True,
240
+ device=None, dtype=None) -> None:
241
+ factory_kwargs = dict(device=device, dtype=dtype)
242
+ super(DCT2D, self).__init__()
243
+ self.k = kernel_size
244
+ self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size))
245
+ self.transform = _DCT2D(kernel_size, kernel_type, orthonormal, **factory_kwargs)
246
+ self.permutation = _zigzag_permutation(kernel_size, kernel_size)
247
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
248
+ b, c, h, w = x.shape
249
+ x = self.unfold(x).transpose(-1, -2)#b,c, h, w -> b, c*block, blocks
250
+ x = x.reshape(b, h // self.k, w // self.k, c, self.k, self.k)
251
+ x = self.transform(x)
252
+ x = x.reshape(-1, c, self.k * self.k)
253
+ x = x[:, :, self.permutation]
254
+ x = x.reshape(b*(h // self.k)*(w // self.k), c, -1)#? b, block -> b, h, w, c, block
255
+ #torch.max(x[:,0,],axis=0).values.detach().cpu().numpy()
256
+
257
+ mean_list = torch.zeros([3,64])
258
+ var_list = torch.zeros([3, 64])
259
+ mean_list[0] = torch.mean(x[:, 0, ],axis=0)
260
+ mean_list[1] = torch.mean(x[:, 1, ], axis=0)
261
+ mean_list[2] = torch.mean(x[:, 2, ], axis=0)
262
+ var_list[0] = torch.var(x[:, 0, ],axis=0)
263
+ var_list[1] = torch.var(x[:, 1, ], axis=0)
264
+ var_list[2] = torch.var(x[:, 2, ], axis=0)
265
+ return mean_list, var_list
model/ResNet18.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+
3
+ class ResNet18(torchvision.models.ResNet):
4
+ def __init__(self, num_classes=1000, weight=None):
5
+ super(ResNet18, self).__init__(block=torchvision.models.resnet.BasicBlock, layers=[2, 2, 2, 2], num_classes=num_classes)
6
+ self.zero_init_residual = True
7
+
8
+ def forward(self, x):
9
+ return self._forward_impl(x)
model/__pycache__/CSAT.cpython-38.pyc ADDED
Binary file (17.3 kB). View file
 
model/__pycache__/CSATv2.cpython-38.pyc ADDED
Binary file (14.5 kB). View file
 
model/__pycache__/DCT.cpython-38.pyc ADDED
Binary file (12.7 kB). View file
 
model/__pycache__/ResNet18.cpython-38.pyc ADDED
Binary file (804 Bytes). View file
 
model/__pycache__/SPTCNN.cpython-38.pyc ADDED
Binary file (17.3 kB). View file
 
weight/CSAT_ImageNet.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5ee778e954aa1a4f60a4a2f1376dbaa917bb42f9835d2170f2fd266f485d21c
3
+ size 12417421
weight/CSAT_RCKD.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efc69c12e2e11aec9487b06baf76cef4fd523bbdc34444b78240e19bb45337ce
3
+ size 12417421
weight/CSAT_v2_ImageNet.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c34eb9183e3b4c89c0197ea870197b001313af420cd31f3f5304ed0e73a76e7
3
+ size 44578564
weight/ResNet18_RCKD.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dfc3ca681cd4be87c0d4e4ed78669dd9b94e081eaab543d9e475666b7c1fee3b
3
+ size 46836189