CurHarsh commited on
Commit
0ff5906
·
1 Parent(s): 3781a48

Update models/context_cluster.py

Browse files
Files changed (1) hide show
  1. models/context_cluster.py +847 -0
models/context_cluster.py CHANGED
@@ -0,0 +1,847 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ContextCluster implementation
3
+ # --------------------------------------------------------
4
+ # Context Cluster -- Image as Set of Points, ICLR'23 Oral
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Written by Xu Ma (ma.xu1@northeastern.com)
7
+ # --------------------------------------------------------
8
+ """
9
+ import os
10
+ import copy
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
15
+ from timm.models.layers import DropPath, trunc_normal_
16
+ from timm.models.registry import register_model
17
+ from timm.models.layers.helpers import to_2tuple
18
+ from einops import rearrange
19
+ import torch.nn.functional as F
20
+
21
+ try:
22
+ from mmseg.models.builder import BACKBONES as seg_BACKBONES
23
+ from mmseg.utils import get_root_logger
24
+ from mmcv.runner import _load_checkpoint
25
+
26
+ has_mmseg = True
27
+ except ImportError:
28
+ print("If for semantic segmentation, please install mmsegmentation first")
29
+ has_mmseg = False
30
+
31
+ try:
32
+ from mmdet.models.builder import BACKBONES as det_BACKBONES
33
+ from mmdet.utils import get_root_logger
34
+ from mmcv.runner import _load_checkpoint
35
+
36
+ has_mmdet = True
37
+ except ImportError:
38
+ print("If for detection, please install mmdetection first")
39
+ has_mmdet = False
40
+
41
+
42
+ def _cfg(url='', **kwargs):
43
+ return {
44
+ 'url': url,
45
+ 'num_classes': 1000, 'input_size': (3, 224, 224),
46
+ 'crop_pct': .95, 'interpolation': 'bicubic',
47
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
48
+ 'classifier': 'head',
49
+ **kwargs
50
+ }
51
+
52
+
53
+ default_cfgs = {
54
+ 'model_small': _cfg(crop_pct=0.9),
55
+ 'model_medium': _cfg(crop_pct=0.95),
56
+ }
57
+
58
+
59
+ class PointRecuder(nn.Module):
60
+ """
61
+ Point Reducer is implemented by a layer of conv since it is mathmatically equal.
62
+ Input: tensor in shape [B, in_chans, H, W]
63
+ Output: tensor in shape [B, embed_dim, H/stride, W/stride]
64
+ """
65
+
66
+ def __init__(self, patch_size=16, stride=16, padding=0,
67
+ in_chans=3, embed_dim=768, norm_layer=None):
68
+ super().__init__()
69
+ patch_size = to_2tuple(patch_size)
70
+ stride = to_2tuple(stride)
71
+ padding = to_2tuple(padding)
72
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size,
73
+ stride=stride, padding=padding)
74
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
75
+
76
+ def forward(self, x):
77
+ x = self.proj(x)
78
+ x = self.norm(x)
79
+ return x
80
+
81
+
82
+ class GroupNorm(nn.GroupNorm):
83
+ """
84
+ Group Normalization with 1 group.
85
+ Input: tensor in shape [B, C, H, W]
86
+ """
87
+
88
+ def __init__(self, num_channels, **kwargs):
89
+ super().__init__(1, num_channels, **kwargs)
90
+
91
+
92
+ def pairwise_cos_sim(x1: torch.Tensor, x2: torch.Tensor):
93
+ """
94
+ return pair-wise similarity matrix between two tensors
95
+ :param x1: [B,...,M,D]
96
+ :param x2: [B,...,N,D]
97
+ :return: similarity matrix [B,...,M,N]
98
+ """
99
+ x1 = F.normalize(x1, dim=-1)
100
+ x2 = F.normalize(x2, dim=-1)
101
+
102
+ sim = torch.matmul(x1, x2.transpose(-2, -1))
103
+ return sim
104
+
105
+
106
+ class Cluster(nn.Module):
107
+ def __init__(self, dim, out_dim, proposal_w=2, proposal_h=2, fold_w=2, fold_h=2, heads=4, head_dim=24,
108
+ return_center=False):
109
+ """
110
+
111
+ :param dim: channel nubmer
112
+ :param out_dim: channel nubmer
113
+ :param proposal_w: the sqrt(proposals) value, we can also set a different value
114
+ :param proposal_h: the sqrt(proposals) value, we can also set a different value
115
+ :param fold_w: the sqrt(number of regions) value, we can also set a different value
116
+ :param fold_h: the sqrt(number of regions) value, we can also set a different value
117
+ :param heads: heads number in context cluster
118
+ :param head_dim: dimension of each head in context cluster
119
+ :param return_center: if just return centers instead of dispatching back (deprecated).
120
+ """
121
+ super().__init__()
122
+ self.heads = heads
123
+ self.head_dim = head_dim
124
+ self.f = nn.Conv2d(dim, heads * head_dim, kernel_size=1) # for similarity
125
+ self.proj = nn.Conv2d(heads * head_dim, out_dim, kernel_size=1) # for projecting channel number
126
+ self.v = nn.Conv2d(dim, heads * head_dim, kernel_size=1) # for value
127
+ self.sim_alpha = nn.Parameter(torch.ones(1))
128
+ self.sim_beta = nn.Parameter(torch.zeros(1))
129
+ self.centers_proposal = nn.AdaptiveAvgPool2d((proposal_w, proposal_h))
130
+ self.fold_w = fold_w
131
+ self.fold_h = fold_h
132
+ self.return_center = return_center
133
+
134
+ def forward(self, x): # [b,c,w,h]
135
+ value = self.v(x)
136
+ x = self.f(x)
137
+ x = rearrange(x, "b (e c) w h -> (b e) c w h", e=self.heads)
138
+ value = rearrange(value, "b (e c) w h -> (b e) c w h", e=self.heads)
139
+ if self.fold_w > 1 and self.fold_h > 1:
140
+ # split the big feature maps to small local regions to reduce computations.
141
+ b0, c0, w0, h0 = x.shape
142
+ assert w0 % self.fold_w == 0 and h0 % self.fold_h == 0, \
143
+ f"Ensure the feature map size ({w0}*{h0}) can be divided by fold {self.fold_w}*{self.fold_h}"
144
+ x = rearrange(x, "b c (f1 w) (f2 h) -> (b f1 f2) c w h", f1=self.fold_w,
145
+ f2=self.fold_h) # [bs*blocks,c,ks[0],ks[1]]
146
+ value = rearrange(value, "b c (f1 w) (f2 h) -> (b f1 f2) c w h", f1=self.fold_w, f2=self.fold_h)
147
+ b, c, w, h = x.shape
148
+ centers = self.centers_proposal(x) # [b,c,C_W,C_H], we set M = C_W*C_H and N = w*h
149
+ value_centers = rearrange(self.centers_proposal(value), 'b c w h -> b (w h) c') # [b,C_W,C_H,c]
150
+ b, c, ww, hh = centers.shape
151
+ sim = torch.sigmoid(
152
+ self.sim_beta +
153
+ self.sim_alpha * pairwise_cos_sim(
154
+ centers.reshape(b, c, -1).permute(0, 2, 1),
155
+ x.reshape(b, c, -1).permute(0, 2, 1)
156
+ )
157
+ ) # [B,M,N]
158
+ # we use mask to sololy assign each point to one center
159
+ sim_max, sim_max_idx = sim.max(dim=1, keepdim=True)
160
+ mask = torch.zeros_like(sim) # binary #[B,M,N]
161
+ mask.scatter_(1, sim_max_idx, 1.)
162
+ sim = sim * mask
163
+ value2 = rearrange(value, 'b c w h -> b (w h) c') # [B,N,D]
164
+ # aggregate step, out shape [B,M,D]
165
+ ###
166
+ # Update Comment: Mar/26/2022
167
+ # a small bug: mask.sum should be sim.sum according to Eq. (1), mask can be considered as a hard version of sim in out implementation.
168
+ # We will update all checkpoints and the bug once all models are re-trained.
169
+ ###
170
+ out = ((value2.unsqueeze(dim=1) * sim.unsqueeze(dim=-1)).sum(dim=2) + value_centers) / (
171
+ mask.sum(dim=-1, keepdim=True) + 1.0) # [B,M,D]
172
+
173
+ if self.return_center:
174
+ out = rearrange(out, "b (w h) c -> b c w h", w=ww)
175
+ else:
176
+ # dispatch step, return to each point in a cluster
177
+ out = (out.unsqueeze(dim=2) * sim.unsqueeze(dim=-1)).sum(dim=1) # [B,N,D]
178
+ out = rearrange(out, "b (w h) c -> b c w h", w=w)
179
+
180
+ if self.fold_w > 1 and self.fold_h > 1:
181
+ # recover the splited regions back to big feature maps if use the region partition.
182
+ out = rearrange(out, "(b f1 f2) c w h -> b c (f1 w) (f2 h)", f1=self.fold_w, f2=self.fold_h)
183
+ out = rearrange(out, "(b e) c w h -> b (e c) w h", e=self.heads)
184
+ out = self.proj(out)
185
+ return out
186
+
187
+
188
+ class Mlp(nn.Module):
189
+ """
190
+ Implementation of MLP with nn.Linear (would be slightly faster in both training and inference).
191
+ Input: tensor with shape [B, C, H, W]
192
+ """
193
+
194
+ def __init__(self, in_features, hidden_features=None,
195
+ out_features=None, act_layer=nn.GELU, drop=0.):
196
+ super().__init__()
197
+ out_features = out_features or in_features
198
+ hidden_features = hidden_features or in_features
199
+ self.fc1 = nn.Linear(in_features, hidden_features)
200
+ self.act = act_layer()
201
+ self.fc2 = nn.Linear(hidden_features, out_features)
202
+ self.drop = nn.Dropout(drop)
203
+ self.apply(self._init_weights)
204
+
205
+ def _init_weights(self, m):
206
+ if isinstance(m, nn.Linear):
207
+ trunc_normal_(m.weight, std=.02)
208
+ if m.bias is not None:
209
+ nn.init.constant_(m.bias, 0)
210
+
211
+ def forward(self, x):
212
+ x = self.fc1(x.permute(0, 2, 3, 1))
213
+ x = self.act(x)
214
+ x = self.drop(x)
215
+ x = self.fc2(x).permute(0, 3, 1, 2)
216
+ x = self.drop(x)
217
+ return x
218
+
219
+
220
+ class ClusterBlock(nn.Module):
221
+ """
222
+ Implementation of one block.
223
+ --dim: embedding dim
224
+ --mlp_ratio: mlp expansion ratio
225
+ --act_layer: activation
226
+ --norm_layer: normalization
227
+ --drop: dropout rate
228
+ --drop path: Stochastic Depth,
229
+ refer to https://arxiv.org/abs/1603.09382
230
+ --use_layer_scale, --layer_scale_init_value: LayerScale,
231
+ refer to https://arxiv.org/abs/2103.17239
232
+ """
233
+
234
+ def __init__(self, dim, mlp_ratio=4.,
235
+ act_layer=nn.GELU, norm_layer=GroupNorm,
236
+ drop=0., drop_path=0.,
237
+ use_layer_scale=True, layer_scale_init_value=1e-5,
238
+ # for context-cluster
239
+ proposal_w=2, proposal_h=2, fold_w=2, fold_h=2, heads=4, head_dim=24, return_center=False):
240
+
241
+ super().__init__()
242
+
243
+ self.norm1 = norm_layer(dim)
244
+ # dim, out_dim, proposal_w=2,proposal_h=2, fold_w=2, fold_h=2, heads=4, head_dim=24, return_center=False
245
+ self.token_mixer = Cluster(dim=dim, out_dim=dim, proposal_w=proposal_w, proposal_h=proposal_h,
246
+ fold_w=fold_w, fold_h=fold_h, heads=heads, head_dim=head_dim, return_center=False)
247
+ self.norm2 = norm_layer(dim)
248
+ mlp_hidden_dim = int(dim * mlp_ratio)
249
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
250
+ act_layer=act_layer, drop=drop)
251
+
252
+ # The following two techniques are useful to train deep ContextClusters.
253
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
254
+ self.use_layer_scale = use_layer_scale
255
+ if use_layer_scale:
256
+ self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
257
+ self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
258
+
259
+ def forward(self, x):
260
+ if self.use_layer_scale:
261
+ x = x + self.drop_path(
262
+ self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
263
+ * self.token_mixer(self.norm1(x)))
264
+ x = x + self.drop_path(
265
+ self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
266
+ * self.mlp(self.norm2(x)))
267
+ else:
268
+ x = x + self.drop_path(self.token_mixer(self.norm1(x)))
269
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
270
+ return x
271
+
272
+
273
+ def basic_blocks(dim, index, layers,
274
+ mlp_ratio=4.,
275
+ act_layer=nn.GELU, norm_layer=GroupNorm,
276
+ drop_rate=.0, drop_path_rate=0.,
277
+ use_layer_scale=True, layer_scale_init_value=1e-5,
278
+ # for context-cluster
279
+ proposal_w=2, proposal_h=2, fold_w=2, fold_h=2, heads=4, head_dim=24, return_center=False):
280
+ blocks = []
281
+ for block_idx in range(layers[index]):
282
+ block_dpr = drop_path_rate * ( block_idx + sum(layers[:index])) / (sum(layers) - 1)
283
+ blocks.append(ClusterBlock(
284
+ dim, mlp_ratio=mlp_ratio,
285
+ act_layer=act_layer, norm_layer=norm_layer,
286
+ drop=drop_rate, drop_path=block_dpr,
287
+ use_layer_scale=use_layer_scale,
288
+ layer_scale_init_value=layer_scale_init_value,
289
+ proposal_w=proposal_w, proposal_h=proposal_h, fold_w=fold_w, fold_h=fold_h,
290
+ heads=heads, head_dim=head_dim, return_center=False
291
+ ))
292
+ blocks = nn.Sequential(*blocks)
293
+
294
+ return blocks
295
+
296
+
297
+ class ContextCluster(nn.Module):
298
+ """
299
+ ContextCluster, the main class of our model
300
+ --layers: [x,x,x,x], number of blocks for the 4 stages
301
+ --embed_dims, --mlp_ratios, the embedding dims, mlp ratios
302
+ --downsamples: flags to apply downsampling or not
303
+ --norm_layer, --act_layer: define the types of normalization and activation
304
+ --num_classes: number of classes for the image classification
305
+ --in_patch_size, --in_stride, --in_pad: specify the patch embedding
306
+ for the input image
307
+ --down_patch_size --down_stride --down_pad:
308
+ specify the downsample (patch embed.)
309
+ --fork_feat: whether output features of the 4 stages, for dense prediction
310
+ --init_cfg, --pretrained:
311
+ for mmdetection and mmsegmentation to load pretrained weights
312
+ """
313
+
314
+ def __init__(self, layers, embed_dims=None,
315
+ mlp_ratios=None, downsamples=None,
316
+ norm_layer=nn.BatchNorm2d, act_layer=nn.GELU,
317
+ num_classes=1000,
318
+ in_patch_size=4, in_stride=4, in_pad=0,
319
+ down_patch_size=2, down_stride=2, down_pad=0,
320
+ drop_rate=0., drop_path_rate=0.,
321
+ use_layer_scale=True, layer_scale_init_value=1e-5,
322
+ fork_feat=False,
323
+ init_cfg=None,
324
+ pretrained=None,
325
+ # the parameters for context-cluster
326
+ proposal_w=[2, 2, 2, 2], proposal_h=[2, 2, 2, 2], fold_w=[8, 4, 2, 1], fold_h=[8, 4, 2, 1],
327
+ heads=[2, 4, 6, 8], head_dim=[16, 16, 32, 32],
328
+ **kwargs):
329
+
330
+ super().__init__()
331
+
332
+ if not fork_feat:
333
+ self.num_classes = num_classes
334
+ self.fork_feat = fork_feat
335
+
336
+ self.patch_embed = PointRecuder(
337
+ patch_size=in_patch_size, stride=in_stride, padding=in_pad,
338
+ in_chans=5, embed_dim=embed_dims[0])
339
+
340
+ # set the main block in network
341
+ network = []
342
+ for i in range(len(layers)):
343
+ stage = basic_blocks(embed_dims[i], i, layers,
344
+ mlp_ratio=mlp_ratios[i],
345
+ act_layer=act_layer, norm_layer=norm_layer,
346
+ drop_rate=drop_rate,
347
+ drop_path_rate=drop_path_rate,
348
+ use_layer_scale=use_layer_scale,
349
+ layer_scale_init_value=layer_scale_init_value,
350
+ proposal_w=proposal_w[i], proposal_h=proposal_h[i],
351
+ fold_w=fold_w[i], fold_h=fold_h[i], heads=heads[i], head_dim=head_dim[i],
352
+ return_center=False
353
+ )
354
+ network.append(stage)
355
+ if i >= len(layers) - 1:
356
+ break
357
+ if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
358
+ # downsampling between two stages
359
+ network.append(
360
+ PointRecuder(
361
+ patch_size=down_patch_size, stride=down_stride,
362
+ padding=down_pad,
363
+ in_chans=embed_dims[i], embed_dim=embed_dims[i + 1]
364
+ )
365
+ )
366
+
367
+ self.network = nn.ModuleList(network)
368
+
369
+ if self.fork_feat:
370
+ # add a norm layer for each output
371
+ self.out_indices = [0, 2, 4, 6]
372
+ for i_emb, i_layer in enumerate(self.out_indices):
373
+ if i_emb == 0 and os.environ.get('FORK_LAST3', None):
374
+ # TODO: more elegant way
375
+ """For RetinaNet, `start_level=1`. The first norm layer will not used.
376
+ cmd: `FORK_LAST3=1 python -m torch.distributed.launch ...`
377
+ """
378
+ layer = nn.Identity()
379
+ else:
380
+ layer = norm_layer(embed_dims[i_emb])
381
+ layer_name = f'norm{i_layer}'
382
+ self.add_module(layer_name, layer)
383
+ else:
384
+ # Classifier head
385
+ self.norm = norm_layer(embed_dims[-1])
386
+ self.head = nn.Linear(
387
+ embed_dims[-1], num_classes) if num_classes > 0 \
388
+ else nn.Identity()
389
+
390
+ self.apply(self.cls_init_weights)
391
+
392
+ self.init_cfg = copy.deepcopy(init_cfg)
393
+ # load pre-trained model
394
+ if self.fork_feat and (
395
+ self.init_cfg is not None or pretrained is not None):
396
+ self.init_weights()
397
+
398
+ # init for classification
399
+ def cls_init_weights(self, m):
400
+ if isinstance(m, nn.Linear):
401
+ trunc_normal_(m.weight, std=.02)
402
+ if isinstance(m, nn.Linear) and m.bias is not None:
403
+ nn.init.constant_(m.bias, 0)
404
+
405
+ # init for mmdetection or mmsegmentation by loading
406
+ # imagenet pre-trained weights
407
+ def init_weights(self, pretrained=None):
408
+ logger = get_root_logger()
409
+ if self.init_cfg is None and pretrained is None:
410
+ logger.warn(f'No pre-trained weights for '
411
+ f'{self.__class__.__name__}, '
412
+ f'training start from scratch')
413
+ pass
414
+ else:
415
+ assert 'checkpoint' in self.init_cfg, f'Only support ' \
416
+ f'specify `Pretrained` in ' \
417
+ f'`init_cfg` in ' \
418
+ f'{self.__class__.__name__} '
419
+ if self.init_cfg is not None:
420
+ ckpt_path = self.init_cfg['checkpoint']
421
+ elif pretrained is not None:
422
+ ckpt_path = pretrained
423
+
424
+ ckpt = _load_checkpoint(
425
+ ckpt_path, logger=logger, map_location='cpu')
426
+ if 'state_dict' in ckpt:
427
+ _state_dict = ckpt['state_dict']
428
+ elif 'model' in ckpt:
429
+ _state_dict = ckpt['model']
430
+ else:
431
+ _state_dict = ckpt
432
+
433
+ state_dict = _state_dict
434
+ missing_keys, unexpected_keys = \
435
+ self.load_state_dict(state_dict, False)
436
+
437
+ # show for debug
438
+ # print('missing_keys: ', missing_keys)
439
+ # print('unexpected_keys: ', unexpected_keys)
440
+
441
+ def get_classifier(self):
442
+ return self.head
443
+
444
+ def reset_classifier(self, num_classes):
445
+ self.num_classes = num_classes
446
+ self.head = nn.Linear(
447
+ self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
448
+
449
+ def forward_embeddings(self, x):
450
+ _, c, img_w, img_h = x.shape
451
+ # print(f"det img size is {img_w} * {img_h}")
452
+ # register positional information buffer.
453
+ range_w = torch.arange(0, img_w, step=1) / (img_w - 1.0)
454
+ range_h = torch.arange(0, img_h, step=1) / (img_h - 1.0)
455
+ fea_pos = torch.stack(torch.meshgrid(range_w, range_h, indexing='ij'), dim=-1).float()
456
+ fea_pos = fea_pos.to(x.device)
457
+ fea_pos = fea_pos - 0.5
458
+ pos = fea_pos.permute(2, 0, 1).unsqueeze(dim=0).expand(x.shape[0], -1, -1, -1)
459
+ x = self.patch_embed(torch.cat([x, pos], dim=1))
460
+ return x
461
+
462
+ def forward_tokens(self, x):
463
+ outs = []
464
+ for idx, block in enumerate(self.network):
465
+ x = block(x)
466
+ if self.fork_feat and idx in self.out_indices:
467
+ norm_layer = getattr(self, f'norm{idx}')
468
+ x_out = norm_layer(x)
469
+ outs.append(x_out)
470
+ if self.fork_feat:
471
+ # output the features of four stages for dense prediction
472
+ return outs
473
+ # output only the features of last layer for image classification
474
+ return x
475
+
476
+ def forward(self, x):
477
+ # input embedding
478
+ x = self.forward_embeddings(x)
479
+ # through backbone
480
+ x = self.forward_tokens(x)
481
+ if self.fork_feat:
482
+ # otuput features of four stages for dense prediction
483
+ return x
484
+ x = self.norm(x)
485
+ cls_out = self.head(x.mean([-2, -1]))
486
+ # for image classification
487
+ return cls_out
488
+
489
+
490
+ @register_model
491
+ def coc_tiny(pretrained=False, **kwargs):
492
+ layers = [3, 4, 5, 2]
493
+ norm_layer = GroupNorm
494
+ embed_dims = [32, 64, 196, 320]
495
+ mlp_ratios = [8, 8, 4, 4]
496
+ downsamples = [True, True, True, True]
497
+ proposal_w = [2, 2, 2, 2]
498
+ proposal_h = [2, 2, 2, 2]
499
+ fold_w = [8, 4, 2, 1]
500
+ fold_h = [8, 4, 2, 1]
501
+ heads = [4, 4, 8, 8]
502
+ head_dim = [24, 24, 24, 24]
503
+ down_patch_size = 3
504
+ down_pad = 1
505
+ model = ContextCluster(
506
+ layers, embed_dims=embed_dims, norm_layer=norm_layer,
507
+ mlp_ratios=mlp_ratios, downsamples=downsamples,
508
+ down_patch_size=down_patch_size, down_pad=down_pad,
509
+ proposal_w=proposal_w, proposal_h=proposal_h, fold_w=fold_w, fold_h=fold_h,
510
+ heads=heads, head_dim=head_dim,
511
+ **kwargs)
512
+ model.default_cfg = default_cfgs['model_small']
513
+ return model
514
+
515
+
516
+ @register_model
517
+ def coc_tiny2(pretrained=False, **kwargs):
518
+ layers = [3, 4, 5, 2]
519
+ norm_layer = GroupNorm
520
+ embed_dims = [32, 64, 196, 320]
521
+ mlp_ratios = [8, 8, 4, 4]
522
+ downsamples = [True, True, True, True]
523
+ proposal_w = [4, 2, 7, 4]
524
+ proposal_h = [4, 2, 7, 4]
525
+ fold_w = [7, 7, 1, 1]
526
+ fold_h = [7, 7, 1, 1]
527
+ heads = [4, 4, 8, 8]
528
+ head_dim = [24, 24, 24, 24]
529
+ down_patch_size = 3
530
+ down_pad = 1
531
+ model = ContextCluster(
532
+ layers, embed_dims=embed_dims, norm_layer=norm_layer,
533
+ mlp_ratios=mlp_ratios, downsamples=downsamples,
534
+ down_patch_size=down_patch_size, down_pad=down_pad,
535
+ proposal_w=proposal_w, proposal_h=proposal_h, fold_w=fold_w, fold_h=fold_h,
536
+ heads=heads, head_dim=head_dim,
537
+ **kwargs)
538
+ model.default_cfg = default_cfgs['model_small']
539
+ return model
540
+
541
+
542
+ @register_model
543
+ def coc_small(pretrained=False, **kwargs):
544
+ layers = [2, 2, 6, 2]
545
+ norm_layer = GroupNorm
546
+ embed_dims = [64, 128, 320, 512]
547
+ mlp_ratios = [8, 8, 4, 4]
548
+ downsamples = [True, True, True, True]
549
+ proposal_w = [2, 2, 2, 2]
550
+ proposal_h = [2, 2, 2, 2]
551
+ fold_w = [8, 4, 2, 1]
552
+ fold_h = [8, 4, 2, 1]
553
+ heads = [4, 4, 8, 8]
554
+ head_dim = [32, 32, 32, 32]
555
+ down_patch_size = 3
556
+ down_pad = 1
557
+ model = ContextCluster(
558
+ layers, embed_dims=embed_dims, norm_layer=norm_layer,
559
+ mlp_ratios=mlp_ratios, downsamples=downsamples,
560
+ down_patch_size=down_patch_size, down_pad=down_pad,
561
+ proposal_w=proposal_w, proposal_h=proposal_h, fold_w=fold_w, fold_h=fold_h,
562
+ heads=heads, head_dim=head_dim,
563
+ **kwargs)
564
+ model.default_cfg = default_cfgs['model_small']
565
+ return model
566
+
567
+
568
+ @register_model
569
+ def coc_medium(pretrained=False, **kwargs):
570
+ layers = [4, 4, 12, 4]
571
+ norm_layer = GroupNorm
572
+ embed_dims = [64, 128, 320, 512]
573
+ mlp_ratios = [8, 8, 4, 4]
574
+ downsamples = [True, True, True, True]
575
+ proposal_w = [2, 2, 2, 2]
576
+ proposal_h = [2, 2, 2, 2]
577
+ fold_w = [8, 4, 2, 1]
578
+ fold_h = [8, 4, 2, 1]
579
+ heads = [6, 6, 12, 12]
580
+ head_dim = [32, 32, 32, 32]
581
+ down_patch_size = 3
582
+ down_pad = 1
583
+ model = ContextCluster(
584
+ layers, embed_dims=embed_dims, norm_layer=norm_layer,
585
+ mlp_ratios=mlp_ratios, downsamples=downsamples,
586
+ down_patch_size=down_patch_size, down_pad=down_pad,
587
+ proposal_w=proposal_w, proposal_h=proposal_h, fold_w=fold_w, fold_h=fold_h,
588
+ heads=heads, head_dim=head_dim,
589
+ **kwargs)
590
+ model.default_cfg = default_cfgs['model_small']
591
+ return model
592
+
593
+
594
+ @register_model
595
+ def coc_base_dim64(pretrained=False, **kwargs):
596
+ layers = [6, 6, 24, 6]
597
+ norm_layer = GroupNorm
598
+ embed_dims = [64, 128, 320, 512]
599
+ mlp_ratios = [8, 8, 4, 4]
600
+ downsamples = [True, True, True, True]
601
+ proposal_w = [2, 2, 2, 2]
602
+ proposal_h = [2, 2, 2, 2]
603
+ fold_w = [8, 4, 2, 1]
604
+ fold_h = [8, 4, 2, 1]
605
+ heads = [8, 8, 16, 16]
606
+ head_dim = [32, 32, 32, 32]
607
+ down_patch_size = 3
608
+ down_pad = 1
609
+ model = ContextCluster(
610
+ layers, embed_dims=embed_dims, norm_layer=norm_layer,
611
+ mlp_ratios=mlp_ratios, downsamples=downsamples,
612
+ down_patch_size=down_patch_size, down_pad=down_pad,
613
+ proposal_w=proposal_w, proposal_h=proposal_h, fold_w=fold_w, fold_h=fold_h,
614
+ heads=heads, head_dim=head_dim,
615
+ **kwargs)
616
+ model.default_cfg = default_cfgs['model_small']
617
+ return model
618
+
619
+
620
+ @register_model
621
+ def coc_base_dim96(pretrained=False, **kwargs):
622
+ layers = [4, 4, 12, 4]
623
+ norm_layer = GroupNorm
624
+ embed_dims = [96, 192, 384, 768]
625
+ mlp_ratios = [8, 8, 4, 4]
626
+ downsamples = [True, True, True, True]
627
+ proposal_w = [2, 2, 2, 2]
628
+ proposal_h = [2, 2, 2, 2]
629
+ fold_w = [8, 4, 2, 1]
630
+ fold_h = [8, 4, 2, 1]
631
+ heads = [8, 8, 16, 16]
632
+ head_dim = [32, 32, 32, 32]
633
+ down_patch_size = 3
634
+ down_pad = 1
635
+ model = ContextCluster(
636
+ layers, embed_dims=embed_dims, norm_layer=norm_layer,
637
+ mlp_ratios=mlp_ratios, downsamples=downsamples,
638
+ down_patch_size=down_patch_size, down_pad=down_pad,
639
+ proposal_w=proposal_w, proposal_h=proposal_h, fold_w=fold_w, fold_h=fold_h,
640
+ heads=heads, head_dim=head_dim,
641
+ **kwargs)
642
+ model.default_cfg = default_cfgs['model_small']
643
+ return model
644
+
645
+
646
+ """
647
+ Updated: add plain models (without region partition) for tiny, small, and base , etc.
648
+ Re-trained with new implementation (PWconv->MLP for faster training and inference), achieve slightly better performance.
649
+ """
650
+ @register_model
651
+ def coc_tiny_plain(pretrained=False, **kwargs):
652
+ # sharing same parameters as coc_tiny, without region partition.
653
+ layers = [3, 4, 5, 2]
654
+ norm_layer = GroupNorm
655
+ embed_dims = [32, 64, 196, 320]
656
+ mlp_ratios = [8, 8, 4, 4]
657
+ downsamples = [True, True, True, True]
658
+ proposal_w = [4, 4, 2, 2]
659
+ proposal_h = [4, 4, 2, 2]
660
+ fold_w = [1, 1, 1, 1]
661
+ fold_h = [1, 1, 1, 1]
662
+ heads = [4, 4, 8, 8]
663
+ head_dim = [24, 24, 24, 24]
664
+ down_patch_size = 3
665
+ down_pad = 1
666
+ model = ContextCluster(
667
+ layers, embed_dims=embed_dims, norm_layer=norm_layer,
668
+ mlp_ratios=mlp_ratios, downsamples=downsamples,
669
+ down_patch_size=down_patch_size, down_pad=down_pad,
670
+ proposal_w=proposal_w, proposal_h=proposal_h, fold_w=fold_w, fold_h=fold_h,
671
+ heads=heads, head_dim=head_dim,
672
+ **kwargs)
673
+ model.default_cfg = default_cfgs['model_small']
674
+ return model
675
+
676
+
677
+ if has_mmdet:
678
+ @seg_BACKBONES.register_module()
679
+ @det_BACKBONES.register_module()
680
+ class context_cluster_small_feat2(ContextCluster):
681
+ def __init__(self, **kwargs):
682
+ layers = [2, 2, 6, 2]
683
+ norm_layer=GroupNorm
684
+ embed_dims = [64, 128, 320, 512]
685
+ mlp_ratios = [8, 8, 4, 4]
686
+ downsamples = [True, True, True, True]
687
+ proposal_w=[2,2,2,2]
688
+ proposal_h=[2,2,2,2]
689
+ fold_w=[8,4,2,1]
690
+ fold_h=[8,4,2,1]
691
+ heads=[4,4,8,8]
692
+ head_dim=[32,32,32,32]
693
+ down_patch_size=3
694
+ down_pad = 1
695
+ super().__init__(
696
+ layers, embed_dims=embed_dims, norm_layer=norm_layer,
697
+ mlp_ratios=mlp_ratios, downsamples=downsamples,
698
+ down_patch_size = down_patch_size, down_pad=down_pad,
699
+ proposal_w=proposal_w, proposal_h=proposal_h, fold_w=fold_w, fold_h=fold_h,
700
+ heads=heads, head_dim=head_dim,
701
+ fork_feat=True,
702
+ **kwargs)
703
+
704
+
705
+ @seg_BACKBONES.register_module()
706
+ @det_BACKBONES.register_module()
707
+ class context_cluster_small_feat5(ContextCluster):
708
+ def __init__(self, **kwargs):
709
+ layers = [2, 2, 6, 2]
710
+ norm_layer=GroupNorm
711
+ embed_dims = [64, 128, 320, 512]
712
+ mlp_ratios = [8, 8, 4, 4]
713
+ downsamples = [True, True, True, True]
714
+ proposal_w=[5,5,5,5]
715
+ proposal_h=[5,5,5,5]
716
+ fold_w=[8,4,2,1]
717
+ fold_h=[8,4,2,1]
718
+ heads=[4,4,8,8]
719
+ head_dim=[32,32,32,32]
720
+ down_patch_size=3
721
+ down_pad = 1
722
+ super().__init__(
723
+ layers, embed_dims=embed_dims, norm_layer=norm_layer,
724
+ mlp_ratios=mlp_ratios, downsamples=downsamples,
725
+ down_patch_size = down_patch_size, down_pad=down_pad,
726
+ proposal_w=proposal_w, proposal_h=proposal_h, fold_w=fold_w, fold_h=fold_h,
727
+ heads=heads, head_dim=head_dim,
728
+ fork_feat=True,
729
+ **kwargs)
730
+
731
+
732
+ @seg_BACKBONES.register_module()
733
+ @det_BACKBONES.register_module()
734
+ class context_cluster_small_feat7(ContextCluster):
735
+ def __init__(self, **kwargs):
736
+ layers = [2, 2, 6, 2]
737
+ norm_layer=GroupNorm
738
+ embed_dims = [64, 128, 320, 512]
739
+ mlp_ratios = [8, 8, 4, 4]
740
+ downsamples = [True, True, True, True]
741
+ proposal_w=[7,7,7,7]
742
+ proposal_h=[7,7,7,7]
743
+ fold_w=[8,4,2,1]
744
+ fold_h=[8,4,2,1]
745
+ heads=[4,4,8,8]
746
+ head_dim=[32,32,32,32]
747
+ down_patch_size=3
748
+ down_pad = 1
749
+ super().__init__(
750
+ layers, embed_dims=embed_dims, norm_layer=norm_layer,
751
+ mlp_ratios=mlp_ratios, downsamples=downsamples,
752
+ down_patch_size = down_patch_size, down_pad=down_pad,
753
+ proposal_w=proposal_w, proposal_h=proposal_h, fold_w=fold_w, fold_h=fold_h,
754
+ heads=heads, head_dim=head_dim,
755
+ fork_feat=True,
756
+ **kwargs)
757
+
758
+
759
+ @seg_BACKBONES.register_module()
760
+ @det_BACKBONES.register_module()
761
+ class context_cluster_medium_feat2(ContextCluster):
762
+ def __init__(self, **kwargs):
763
+ layers = [4, 4, 12, 4]
764
+ norm_layer=GroupNorm
765
+ embed_dims = [64, 128, 320, 512]
766
+ mlp_ratios = [8, 8, 4, 4]
767
+ downsamples = [True, True, True, True]
768
+ proposal_w=[2,2,2,2]
769
+ proposal_h=[2,2,2,2]
770
+ fold_w=[8,4,2,1]
771
+ fold_h=[8,4,2,1]
772
+ heads=[6,6,12,12]
773
+ head_dim=[32,32,32,32]
774
+ down_patch_size=3
775
+ down_pad = 1
776
+ super().__init__(
777
+ layers, embed_dims=embed_dims, norm_layer=norm_layer,
778
+ mlp_ratios=mlp_ratios, downsamples=downsamples,
779
+ down_patch_size = down_patch_size, down_pad=down_pad,
780
+ proposal_w=proposal_w, proposal_h=proposal_h, fold_w=fold_w, fold_h=fold_h,
781
+ heads=heads, head_dim=head_dim,
782
+ fork_feat=True,
783
+ **kwargs)
784
+
785
+
786
+ @seg_BACKBONES.register_module()
787
+ @det_BACKBONES.register_module()
788
+ class context_cluster_medium_feat5(ContextCluster):
789
+ def __init__(self, **kwargs):
790
+ layers = [4, 4, 12, 4]
791
+ norm_layer=GroupNorm
792
+ embed_dims = [64, 128, 320, 512]
793
+ mlp_ratios = [8, 8, 4, 4]
794
+ downsamples = [True, True, True, True]
795
+ proposal_w=[5, 5, 5, 5]
796
+ proposal_h=[5, 5, 5, 5]
797
+ fold_w=[8,4,2,1]
798
+ fold_h=[8,4,2,1]
799
+ heads=[6,6,12,12]
800
+ head_dim=[32,32,32,32]
801
+ down_patch_size=3
802
+ down_pad = 1
803
+ super().__init__(
804
+ layers, embed_dims=embed_dims, norm_layer=norm_layer,
805
+ mlp_ratios=mlp_ratios, downsamples=downsamples,
806
+ down_patch_size = down_patch_size, down_pad=down_pad,
807
+ proposal_w=proposal_w, proposal_h=proposal_h, fold_w=fold_w, fold_h=fold_h,
808
+ heads=heads, head_dim=head_dim,
809
+ fork_feat=True,
810
+ **kwargs)
811
+
812
+
813
+ @seg_BACKBONES.register_module()
814
+ @det_BACKBONES.register_module()
815
+ class context_cluster_medium_feat7(ContextCluster):
816
+ def __init__(self, **kwargs):
817
+ layers = [4, 4, 12, 4]
818
+ norm_layer=GroupNorm
819
+ embed_dims = [64, 128, 320, 512]
820
+ mlp_ratios = [8, 8, 4, 4]
821
+ downsamples = [True, True, True, True]
822
+ proposal_w=[7,7,7,7]
823
+ proposal_h=[7,7,7,7]
824
+ fold_w=[8,4,2,1]
825
+ fold_h=[8,4,2,1]
826
+ heads=[6,6,12,12]
827
+ head_dim=[32,32,32,32]
828
+ down_patch_size=3
829
+ down_pad = 1
830
+ super().__init__(
831
+ layers, embed_dims=embed_dims, norm_layer=norm_layer,
832
+ mlp_ratios=mlp_ratios, downsamples=downsamples,
833
+ down_patch_size = down_patch_size, down_pad=down_pad,
834
+ proposal_w=proposal_w, proposal_h=proposal_h, fold_w=fold_w, fold_h=fold_h,
835
+ heads=heads, head_dim=head_dim,
836
+ fork_feat=True,
837
+ **kwargs)
838
+
839
+
840
+ if __name__ == '__main__':
841
+ input = torch.rand(2, 3, 224, 224)
842
+ model = coc_base_dim64()
843
+ out = model(input)
844
+ print(model)
845
+ print(out.shape)
846
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
847
+ print("number of params: {:.2f}M".format(n_parameters/1024**2))