stefanosgikas commited on
Commit
514055c
·
verified ·
1 Parent(s): d2d4110

Upload painformer.py

Browse files
Files changed (1) hide show
  1. architecture/painformer.py +664 -0
architecture/painformer.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from functools import partial
5
+
6
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7
+ from timm.models.registry import register_model
8
+ from timm.models.vision_transformer import _cfg
9
+ import math
10
+ import numpy as np
11
+ from pytorch_wavelets import DWTForward, DWTInverse # (or import DWT, IDWT)
12
+
13
+
14
+ class SpectralGatingNetwork(nn.Module):
15
+ def __init__(self, dim):
16
+ super().__init__()
17
+ # this weights are valid for h=14 and w=8
18
+ if dim == 64: #96 for large model, 64 for small and base model
19
+ self.h = 56 #H
20
+ self.w = 29 #(W/2)+1
21
+ self.complex_weight = nn.Parameter(torch.randn(self.h, self.w, dim, 2, dtype=torch.float32) * 0.02)
22
+ if dim ==128:
23
+ self.h = 28 #H
24
+ self.w = 15 #(W/2)+1, this is due to rfft2
25
+ self.complex_weight = nn.Parameter(torch.randn(self.h, self.w, dim, 2, dtype=torch.float32) * 0.02)
26
+ if dim == 96: #96 for large model, 64 for small and base model
27
+ self.h = 56 #H
28
+ self.w = 29 #(W/2)+1
29
+ self.complex_weight = nn.Parameter(torch.randn(self.h, self.w, dim, 2, dtype=torch.float32) * 0.02)
30
+ if dim ==192:
31
+ self.h = 28 #H
32
+ self.w = 15 #(W/2)+1, this is due to rfft2
33
+ self.complex_weight = nn.Parameter(torch.randn(self.h, self.w, dim, 2, dtype=torch.float32) * 0.02)
34
+
35
+ def forward(self, x, H, W):
36
+ # print('wno',x.shape) #CIFAR100 image :[128, 196, 384]
37
+ B, N, C = x.shape
38
+ # print('wno B, N, C',B, N, C) #CIFAR100 image : 128 196 384
39
+ x = x.view(B, H, W, C)
40
+ # B, H, W, C=x.shape
41
+ x = x.to(torch.float32)
42
+ # print(x.dtype)
43
+ # Add above for this error, RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same
44
+ x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
45
+ # print('wno',x.shape)
46
+ weight = torch.view_as_complex(self.complex_weight)
47
+ # print('weight',weight.shape)
48
+ x = x * weight
49
+ x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
50
+ # print('wno',x.shape)
51
+ x = x.reshape(B, N, C)# permute is not same as reshape or view
52
+ return x
53
+ #return x, weight
54
+
55
+
56
+ def rand_bbox(size, lam, scale=1):
57
+ W = size[1] // scale
58
+ H = size[2] // scale
59
+ cut_rat = np.sqrt(1. - lam)
60
+ cut_w = np.int(W * cut_rat)
61
+ cut_h = np.int(H * cut_rat)
62
+
63
+ # uniform
64
+ cx = np.random.randint(W)
65
+ cy = np.random.randint(H)
66
+
67
+ bbx1 = np.clip(cx - cut_w // 2, 0, W)
68
+ bby1 = np.clip(cy - cut_h // 2, 0, H)
69
+ bbx2 = np.clip(cx + cut_w // 2, 0, W)
70
+ bby2 = np.clip(cy + cut_h // 2, 0, H)
71
+
72
+ return bbx1, bby1, bbx2, bby2
73
+
74
+ class ClassAttention(nn.Module):
75
+ def __init__(self, dim, num_heads):
76
+ super().__init__()
77
+ self.num_heads = num_heads
78
+ head_dim = dim // num_heads
79
+ self.head_dim = head_dim
80
+ self.scale = head_dim**-0.5
81
+ self.kv = nn.Linear(dim, dim * 2)
82
+ self.q = nn.Linear(dim, dim)
83
+ self.proj = nn.Linear(dim, dim)
84
+ self.apply(self._init_weights)
85
+
86
+ def _init_weights(self, m):
87
+ if isinstance(m, nn.Linear):
88
+ trunc_normal_(m.weight, std=.02)
89
+ if isinstance(m, nn.Linear) and m.bias is not None:
90
+ nn.init.constant_(m.bias, 0)
91
+ elif isinstance(m, nn.LayerNorm):
92
+ nn.init.constant_(m.bias, 0)
93
+ nn.init.constant_(m.weight, 1.0)
94
+ elif isinstance(m, nn.Conv2d):
95
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
96
+ fan_out //= m.groups
97
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
98
+ if m.bias is not None:
99
+ m.bias.data.zero_()
100
+
101
+ def forward(self, x):
102
+ B, N, C = x.shape
103
+ kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
104
+ k, v = kv[0], kv[1]
105
+ q = self.q(x[:, :1, :]).reshape(B, self.num_heads, 1, self.head_dim)
106
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
107
+ attn = attn.softmax(dim=-1)
108
+ cls_embed = (attn @ v).transpose(1, 2).reshape(B, 1, self.head_dim * self.num_heads)
109
+ cls_embed = self.proj(cls_embed)
110
+ return cls_embed
111
+
112
+ class FFN(nn.Module):
113
+ def __init__(self, in_features, hidden_features):
114
+ super().__init__()
115
+ self.fc1 = nn.Linear(in_features, hidden_features)
116
+ self.act = nn.GELU()
117
+ self.fc2 = nn.Linear(hidden_features, in_features)
118
+ self.apply(self._init_weights)
119
+
120
+ def _init_weights(self, m):
121
+ if isinstance(m, nn.Linear):
122
+ trunc_normal_(m.weight, std=.02)
123
+ if isinstance(m, nn.Linear) and m.bias is not None:
124
+ nn.init.constant_(m.bias, 0)
125
+ elif isinstance(m, nn.LayerNorm):
126
+ nn.init.constant_(m.bias, 0)
127
+ nn.init.constant_(m.weight, 1.0)
128
+ elif isinstance(m, nn.Conv2d):
129
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
130
+ fan_out //= m.groups
131
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
132
+ if m.bias is not None:
133
+ m.bias.data.zero_()
134
+
135
+ def forward(self, x):
136
+ x = self.fc1(x)
137
+ x = self.act(x)
138
+ x = self.fc2(x)
139
+ return x
140
+
141
+ class ClassBlock(nn.Module):
142
+ def __init__(self, dim, num_heads, mlp_ratio, norm_layer=nn.LayerNorm):
143
+ super().__init__()
144
+ self.norm1 = norm_layer(dim)
145
+ self.norm2 = norm_layer(dim)
146
+ self.attn = ClassAttention(dim, num_heads)
147
+ self.mlp = FFN(dim, int(dim * mlp_ratio))
148
+ self.apply(self._init_weights)
149
+
150
+ def _init_weights(self, m):
151
+ if isinstance(m, nn.Linear):
152
+ trunc_normal_(m.weight, std=.02)
153
+ if isinstance(m, nn.Linear) and m.bias is not None:
154
+ nn.init.constant_(m.bias, 0)
155
+ elif isinstance(m, nn.LayerNorm):
156
+ nn.init.constant_(m.bias, 0)
157
+ nn.init.constant_(m.weight, 1.0)
158
+ elif isinstance(m, nn.Conv2d):
159
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
160
+ fan_out //= m.groups
161
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
162
+ if m.bias is not None:
163
+ m.bias.data.zero_()
164
+
165
+ def forward(self, x):
166
+ cls_embed = x[:, :1]
167
+ cls_embed = cls_embed + self.attn(self.norm1(x))
168
+ cls_embed = cls_embed + self.mlp(self.norm2(cls_embed))
169
+ return torch.cat([cls_embed, x[:, 1:]], dim=1)
170
+
171
+ class PVT2FFN(nn.Module):
172
+ def __init__(self, in_features, hidden_features):
173
+ super().__init__()
174
+ self.fc1 = nn.Linear(in_features, hidden_features)
175
+ self.dwconv = DWConv(hidden_features)
176
+ self.act = nn.GELU()
177
+ self.fc2 = nn.Linear(hidden_features, in_features)
178
+ self.apply(self._init_weights)
179
+
180
+ def _init_weights(self, m):
181
+ if isinstance(m, nn.Linear):
182
+ trunc_normal_(m.weight, std=.02)
183
+ if isinstance(m, nn.Linear) and m.bias is not None:
184
+ nn.init.constant_(m.bias, 0)
185
+ elif isinstance(m, nn.LayerNorm):
186
+ nn.init.constant_(m.bias, 0)
187
+ nn.init.constant_(m.weight, 1.0)
188
+ elif isinstance(m, nn.Conv2d):
189
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
190
+ fan_out //= m.groups
191
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
192
+ if m.bias is not None:
193
+ m.bias.data.zero_()
194
+
195
+ def forward(self, x, H, W):
196
+ x = self.fc1(x)
197
+ x = self.dwconv(x, H, W)
198
+ x = self.act(x)
199
+ x = self.fc2(x)
200
+ return x
201
+
202
+ class Attention(nn.Module):
203
+ def __init__(self, dim, num_heads):
204
+ super().__init__()
205
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
206
+
207
+ self.dim = dim
208
+ self.num_heads = num_heads
209
+ head_dim = dim // num_heads
210
+ self.scale = head_dim ** -0.5
211
+
212
+ self.q = nn.Linear(dim, dim)
213
+ self.kv = nn.Linear(dim, dim * 2)
214
+ self.proj = nn.Linear(dim, dim)
215
+ self.apply(self._init_weights)
216
+
217
+ def _init_weights(self, m):
218
+ if isinstance(m, nn.Linear):
219
+ trunc_normal_(m.weight, std=.02)
220
+ if isinstance(m, nn.Linear) and m.bias is not None:
221
+ nn.init.constant_(m.bias, 0)
222
+ elif isinstance(m, nn.LayerNorm):
223
+ nn.init.constant_(m.bias, 0)
224
+ nn.init.constant_(m.weight, 1.0)
225
+ elif isinstance(m, nn.Conv2d):
226
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
227
+ fan_out //= m.groups
228
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
229
+ if m.bias is not None:
230
+ m.bias.data.zero_()
231
+
232
+ def forward(self, x, H, W):
233
+ B, N, C = x.shape
234
+ q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
235
+ kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
236
+ k, v = kv[0], kv[1]
237
+ attn = (q @ k.transpose(-2, -1)) * self.scale
238
+ attn = attn.softmax(dim=-1)
239
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
240
+ x = self.proj(x)
241
+ #return x
242
+ return x, attn
243
+
244
+ class Block(nn.Module):
245
+ def __init__(self,
246
+ dim,
247
+ num_heads,
248
+ mlp_ratio,
249
+ drop_path=0.,
250
+ norm_layer=nn.LayerNorm,
251
+ sr_ratio=1,
252
+ block_type = 'wave'
253
+ ):
254
+ super().__init__()
255
+ self.norm1 = norm_layer(dim)
256
+ self.norm2 = norm_layer(dim)
257
+
258
+ if block_type == 'std_att':
259
+ self.attn = Attention(dim, num_heads)
260
+ else:
261
+ self.attn = SpectralGatingNetwork(dim)
262
+ self.mlp = PVT2FFN(in_features=dim, hidden_features=int(dim * mlp_ratio))
263
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
264
+ self.apply(self._init_weights)
265
+
266
+ def _init_weights(self, m):
267
+ if isinstance(m, nn.Linear):
268
+ trunc_normal_(m.weight, std=.02)
269
+ if isinstance(m, nn.Linear) and m.bias is not None:
270
+ nn.init.constant_(m.bias, 0)
271
+ elif isinstance(m, nn.LayerNorm):
272
+ nn.init.constant_(m.bias, 0)
273
+ nn.init.constant_(m.weight, 1.0)
274
+ elif isinstance(m, nn.Conv2d):
275
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
276
+ fan_out //= m.groups
277
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
278
+ if m.bias is not None:
279
+ m.bias.data.zero_()
280
+
281
+ # def forward(self, x, H, W): ## !!!!!!!!!!!!!!!!
282
+ # x = x + self.drop_path(self.attn(self.norm1(x), H, W))
283
+ # x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
284
+ # return x
285
+
286
+
287
+ def forward(self, x, H, W):
288
+ attn_output, attn_weights = self.attn(self.norm1(x), H, W) if isinstance(self.attn, Attention) else (self.attn(self.norm1(x), H, W), None)
289
+ x = x + self.drop_path(attn_output)
290
+ x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
291
+
292
+ # Optionally return attention weights for visualization or analysis
293
+ return (x, attn_weights) if attn_weights is not None else x
294
+
295
+
296
+ class DownSamples(nn.Module):
297
+ def __init__(self, in_channels, out_channels):
298
+ super().__init__()
299
+ self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
300
+ self.norm = nn.LayerNorm(out_channels)
301
+ self.apply(self._init_weights)
302
+
303
+ def _init_weights(self, m):
304
+ if isinstance(m, nn.Linear):
305
+ trunc_normal_(m.weight, std=.02)
306
+ if isinstance(m, nn.Linear) and m.bias is not None:
307
+ nn.init.constant_(m.bias, 0)
308
+ elif isinstance(m, nn.LayerNorm):
309
+ nn.init.constant_(m.bias, 0)
310
+ nn.init.constant_(m.weight, 1.0)
311
+ elif isinstance(m, nn.Conv2d):
312
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
313
+ fan_out //= m.groups
314
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
315
+ if m.bias is not None:
316
+ m.bias.data.zero_()
317
+
318
+ def forward(self, x):
319
+ x = self.proj(x)
320
+ _, _, H, W = x.shape
321
+ x = x.flatten(2).transpose(1, 2)
322
+ x = self.norm(x)
323
+ return x, H, W
324
+
325
+ class Stem(nn.Module):
326
+ def __init__(self, in_channels, stem_hidden_dim, out_channels):
327
+ super().__init__()
328
+ hidden_dim = stem_hidden_dim
329
+ self.conv = nn.Sequential(
330
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=7, stride=2,
331
+ padding=3, bias=False), # 112x112
332
+ nn.BatchNorm2d(hidden_dim),
333
+ nn.ReLU(inplace=True),
334
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1,
335
+ padding=1, bias=False), # 112x112
336
+ nn.BatchNorm2d(hidden_dim),
337
+ nn.ReLU(inplace=True),
338
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1,
339
+ padding=1, bias=False), # 112x112
340
+ nn.BatchNorm2d(hidden_dim),
341
+ nn.ReLU(inplace=True),
342
+ )
343
+ self.proj = nn.Conv2d(hidden_dim,
344
+ out_channels,
345
+ kernel_size=3,
346
+ stride=2,
347
+ padding=1)
348
+ self.norm = nn.LayerNorm(out_channels)
349
+
350
+ self.apply(self._init_weights)
351
+
352
+ def _init_weights(self, m):
353
+ if isinstance(m, nn.Linear):
354
+ trunc_normal_(m.weight, std=.02)
355
+ if isinstance(m, nn.Linear) and m.bias is not None:
356
+ nn.init.constant_(m.bias, 0)
357
+ elif isinstance(m, nn.LayerNorm):
358
+ nn.init.constant_(m.bias, 0)
359
+ nn.init.constant_(m.weight, 1.0)
360
+ elif isinstance(m, nn.Conv2d):
361
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
362
+ fan_out //= m.groups
363
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
364
+ if m.bias is not None:
365
+ m.bias.data.zero_()
366
+
367
+ def forward(self, x):
368
+ x = self.conv(x)
369
+ x = self.proj(x)
370
+ _, _, H, W = x.shape
371
+ x = x.flatten(2).transpose(1, 2)
372
+ x = self.norm(x)
373
+ return x, H, W
374
+
375
+ class SpectFormer(nn.Module):
376
+ def __init__(self,
377
+ in_chans=3,
378
+ num_classes=1000,
379
+ stem_hidden_dim = 32,
380
+ embed_dims=[64, 128, 320, 448],
381
+ num_heads=[2, 4, 10, 14],
382
+ mlp_ratios=[8, 8, 4, 4],
383
+ drop_path_rate=0.,
384
+ norm_layer=nn.LayerNorm,
385
+ depths=[3, 4, 6, 3],
386
+ sr_ratios=[4, 2, 1, 1],
387
+ num_stages=4,
388
+ token_label=False,
389
+ **kwargs
390
+ ):
391
+ super().__init__()
392
+ self.num_classes = num_classes
393
+ self.depths = depths
394
+ self.num_stages = num_stages
395
+
396
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
397
+ cur = 0
398
+
399
+ for i in range(num_stages):
400
+ if i == 0:
401
+ patch_embed = Stem(in_chans, stem_hidden_dim, embed_dims[i])
402
+ else:
403
+ patch_embed = DownSamples(embed_dims[i - 1], embed_dims[i])
404
+
405
+ block = nn.ModuleList([Block(
406
+ dim = embed_dims[i],
407
+ num_heads = num_heads[i],
408
+ mlp_ratio = mlp_ratios[i],
409
+ drop_path=dpr[cur + j],
410
+ norm_layer=norm_layer,
411
+ sr_ratio = sr_ratios[i],
412
+ block_type='wave' if i < 2 else 'std_att')
413
+ for j in range(depths[i])])
414
+
415
+ norm = norm_layer(embed_dims[i])
416
+ cur += depths[i]
417
+
418
+ setattr(self, f"patch_embed{i + 1}", patch_embed)
419
+ setattr(self, f"block{i + 1}", block)
420
+ setattr(self, f"norm{i + 1}", norm)
421
+
422
+ post_layers = ['ca']
423
+ self.post_network = nn.ModuleList([
424
+ ClassBlock(
425
+ dim = embed_dims[-1],
426
+ num_heads = num_heads[-1],
427
+ mlp_ratio = mlp_ratios[-1],
428
+ norm_layer=norm_layer)
429
+ for _ in range(len(post_layers))
430
+ ])
431
+
432
+ # classification head
433
+ self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
434
+ ##################################### token_label #####################################
435
+ self.return_dense = token_label
436
+ self.mix_token = token_label
437
+ self.beta = 1.0
438
+ self.pooling_scale = 8
439
+ if self.return_dense:
440
+ self.aux_head = nn.Linear(
441
+ embed_dims[-1],
442
+ num_classes) if num_classes > 0 else nn.Identity()
443
+ ##################################### token_label #####################################
444
+
445
+ self.apply(self._init_weights)
446
+
447
+ def _init_weights(self, m):
448
+ if isinstance(m, nn.Linear):
449
+ trunc_normal_(m.weight, std=.02)
450
+ if isinstance(m, nn.Linear) and m.bias is not None:
451
+ nn.init.constant_(m.bias, 0)
452
+ elif isinstance(m, nn.LayerNorm):
453
+ nn.init.constant_(m.bias, 0)
454
+ nn.init.constant_(m.weight, 1.0)
455
+ elif isinstance(m, nn.Conv2d):
456
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
457
+ fan_out //= m.groups
458
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
459
+ if m.bias is not None:
460
+ m.bias.data.zero_()
461
+
462
+ def forward_cls(self, x):
463
+ B, N, C = x.shape
464
+ cls_tokens = x.mean(dim=1, keepdim=True)
465
+ x = torch.cat((cls_tokens, x), dim=1)
466
+ for block in self.post_network:
467
+ x = block(x)
468
+ return x
469
+
470
+ ########## Normal block without Attention Maps ##########
471
+ # def forward_features(self, x):
472
+ # B = x.shape[0]
473
+ # for i in range(self.num_stages):
474
+ # patch_embed = getattr(self, f"patch_embed{i + 1}")
475
+ # block = getattr(self, f"block{i + 1}")
476
+ # x, H, W = patch_embed(x)
477
+ # for blk in block:
478
+ # x = blk(x, H, W)
479
+ # tokens = x
480
+
481
+ # if i != self.num_stages - 1:
482
+ # norm = getattr(self, f"norm{i + 1}")
483
+ # x = norm(x)
484
+ # x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
485
+
486
+ # x = self.forward_cls(x)[:, 0]
487
+ # norm = getattr(self, f"norm{self.num_stages}")
488
+ # x = norm(x)
489
+ # return x, tokens
490
+
491
+ ########### You can create Attention Maps with this block ##########
492
+ def forward_features(self, x):
493
+ B = x.shape[0]
494
+ attention_maps = [] # Collect attention maps if available
495
+ tokens = None # Initialize tokens to ensure scope coverage
496
+
497
+ for i in range(self.num_stages):
498
+ patch_embed = getattr(self, f"patch_embed{i + 1}")
499
+ block = getattr(self, f"block{i + 1}")
500
+ x, H, W = patch_embed(x)
501
+
502
+ for blk in block:
503
+ outputs = blk(x, H, W)
504
+ if isinstance(outputs, tuple):
505
+ x, attn_weights = outputs
506
+ attention_maps.append(attn_weights) # Store attention maps
507
+ else:
508
+ x = outputs
509
+
510
+ tokens = x # Update tokens with the latest block output
511
+
512
+ if i != self.num_stages - 1:
513
+ norm = getattr(self, f"norm{i + 1}")
514
+ x = norm(x)
515
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
516
+
517
+ x = self.forward_cls(x)[:, 0] # Further processing for classification token
518
+ norm = getattr(self, f"norm{self.num_stages}")
519
+ x = norm(x)
520
+ return x, tokens, attention_maps
521
+
522
+
523
+ ########## Normal block without Attention Maps ##########
524
+ # def forward(self, x):
525
+ # if not self.return_dense:
526
+ # x, tokens = self.forward_features(x)
527
+ # x = self.head(x)
528
+ # return x, tokens
529
+ # else:
530
+ # x, H, W = self.forward_embeddings(x)
531
+ # # mix token, see token labeling for details.
532
+ # if self.mix_token and self.training:
533
+ # lam = np.random.beta(self.beta, self.beta)
534
+ # patch_h, patch_w = x.shape[1] // self.pooling_scale, x.shape[
535
+ # 2] // self.pooling_scale
536
+ # bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam, scale=self.pooling_scale)
537
+ # temp_x = x.clone()
538
+ # sbbx1,sbby1,sbbx2,sbby2=self.pooling_scale*bbx1,self.pooling_scale*bby1,\
539
+ # self.pooling_scale*bbx2,self.pooling_scale*bby2
540
+ # temp_x[:, sbbx1:sbbx2, sbby1:sbby2, :] = x.flip(0)[:, sbbx1:sbbx2, sbby1:sbby2, :]
541
+ # x = temp_x
542
+ # else:
543
+ # bbx1, bby1, bbx2, bby2 = 0, 0, 0, 0
544
+
545
+ # x = self.forward_tokens(x, H, W)
546
+ # x_cls = self.head(x[:, 0])
547
+ # x_aux = self.aux_head(
548
+ # x[:, 1:]
549
+ # ) # generate classes in all feature tokens, see token labeling
550
+
551
+ # if not self.training:
552
+ # return x_cls + 0.5 * x_aux.max(1)[0]
553
+
554
+ # if self.mix_token and self.training: # reverse "mix token", see token labeling for details.
555
+ # x_aux = x_aux.reshape(x_aux.shape[0], patch_h, patch_w, x_aux.shape[-1])
556
+
557
+ # temp_x = x_aux.clone()
558
+ # temp_x[:, bbx1:bbx2, bby1:bby2, :] = x_aux.flip(0)[:, bbx1:bbx2, bby1:bby2, :]
559
+ # x_aux = temp_x
560
+
561
+ # x_aux = x_aux.reshape(x_aux.shape[0], patch_h * patch_w, x_aux.shape[-1])
562
+
563
+ # return x_cls, x_aux, (bbx1, bby1, bbx2, bby2)
564
+
565
+
566
+
567
+
568
+ ########### You can create Attention Maps with this block ##########
569
+ def forward(self, x):
570
+ attention_maps = [] # Initialize to collect attention maps from all blocks
571
+
572
+ if not self.return_dense:
573
+ # Retrieve main output, tokens, and attention maps
574
+ x, tokens, new_attention_maps = self.forward_features(x)
575
+ attention_maps.extend(new_attention_maps) # Collect new attention maps
576
+ x = self.head(x)
577
+ return x, tokens, attention_maps
578
+ else:
579
+ # For dense token labeling and feature manipulation
580
+ x, H, W = self.forward_embeddings(x)
581
+ x, new_attention_maps = self.forward_tokens(x, H, W) # Adjusted to return attention maps
582
+ attention_maps.extend(new_attention_maps) # Collect new attention maps
583
+
584
+ if self.mix_token and self.training:
585
+ lam = np.random.beta(self.beta, self.beta)
586
+ patch_h, patch_w = x.shape[1] // self.pooling_scale, x.shape[2] // self.pooling_scale
587
+ bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam, scale=self.pooling_scale)
588
+ sbbx1, sbby1, sbbx2, sbby2 = self.pooling_scale * bbx1, self.pooling_scale * bby1, self.pooling_scale * bbx2, self.pooling_scale * bby2
589
+ temp_x = x.clone()
590
+ temp_x[:, sbbx1:sbbx2, sbby1:sbby2, :] = x.flip(0)[:, sbbx1:sbbx2, sbby1:sbby2, :]
591
+ x = temp_x
592
+ else:
593
+ bbx1, bby1, bbx2, bby2 = 0, 0, 0, 0 # Default to zero if no mixing
594
+
595
+ x_cls = self.head(x[:, 0])
596
+ x_aux = self.aux_head(x[:, 1:]) # Class prediction for all feature tokens
597
+
598
+ if not self.training:
599
+ return x_cls + 0.5 * x_aux.max(1)[0], attention_maps
600
+
601
+ return x_cls, x_aux, (bbx1, bby1, bbx2, bby2), attention_maps
602
+
603
+
604
+
605
+
606
+
607
+
608
+
609
+ def forward_tokens(self, x, H, W):
610
+ B = x.shape[0]
611
+ x = x.view(B, -1, x.size(-1))
612
+
613
+ for i in range(self.num_stages):
614
+ if i != 0:
615
+ patch_embed = getattr(self, f"patch_embed{i + 1}")
616
+ x, H, W = patch_embed(x)
617
+
618
+ block = getattr(self, f"block{i + 1}")
619
+ for blk in block:
620
+ x = blk(x, H, W)
621
+
622
+ if i != self.num_stages - 1:
623
+ norm = getattr(self, f"norm{i + 1}")
624
+ x = norm(x)
625
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
626
+
627
+ x = self.forward_cls(x)
628
+ norm = getattr(self, f"norm{self.num_stages}")
629
+ x = norm(x)
630
+ return x
631
+
632
+ def forward_embeddings(self, x):
633
+ patch_embed = getattr(self, f"patch_embed{0 + 1}")
634
+ x, H, W = patch_embed(x)
635
+ x = x.view(x.size(0), H, W, -1)
636
+ return x, H, W
637
+
638
+
639
+ class DWConv(nn.Module):
640
+ def __init__(self, dim=768):
641
+ super(DWConv, self).__init__()
642
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
643
+
644
+ def forward(self, x, H, W):
645
+ B, N, C = x.shape
646
+ x = x.transpose(1, 2).view(B, C, H, W)
647
+ x = self.dwconv(x)
648
+ x = x.flatten(2).transpose(1, 2)
649
+ return x
650
+
651
+
652
+ @register_model
653
+ def painformer(pretrained=False, **kwargs):
654
+ model = SpectFormer(
655
+ stem_hidden_dim = 64,
656
+ embed_dims = [64, 128, 320, 160],
657
+ num_heads = [2, 4, 10, 16],
658
+ mlp_ratios = [8, 8, 4, 4],
659
+ norm_layer = partial(nn.LayerNorm, eps=1e-6),
660
+ depths = [3, 4, 12, 3],
661
+ sr_ratios = [4, 2, 1, 1],
662
+ **kwargs)
663
+ model.default_cfg = _cfg()
664
+ return model