heyoujue commited on
Commit
322161a
·
1 Parent(s): 103bef1

add submission code

Browse files
README.md CHANGED
@@ -1,7 +1,34 @@
1
- # Cross-Domain Few-Shot Segmentation
2
 
3
- The work will be released as soon as review period is over
4
 
5
- --> Release Date: 26.02.2024
 
 
 
 
 
 
6
 
7
- Stay tuned :star2:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapt Before Comparision - A New Perspective on Cross-Domain Few-Shot Segmentation
2
 
3
+ Code for the Reproducing the Paper
4
 
5
+ ## Preparing Data
6
+ Because we follow PATNet and RtD, please refer to their work for prepration of the following datasets:
7
+ - Deepglobe (PAT)
8
+ - ISIC (PAT)
9
+ - Chest X-Ray (Lung) (PAT)
10
+ - FSS-1000 (PAT)
11
+ - SUIM (RtD)
12
 
13
+ You do not need to get all datasets. Just prepare the one you want to test our method with.
14
+
15
+ ## Python package prerequisites
16
+ 1. torch
17
+ 2. torchvision
18
+ 3. cv2
19
+ 4. numpy
20
+ 5. for others, follow the console output
21
+
22
+ ## Run it
23
+ Call
24
+ `python main.py --benchmark {} --datapath {} --nshot {}`
25
+
26
+ for example
27
+ `python main.py --benchmark deepglobe --datapath ./datasets/deepglobe/ --nshot 1`
28
+
29
+ Available `benchmark` strings: deepglobe,isic,lung,fss,suim
30
+ Easiest to prepare should be Lung or FSS.
31
+
32
+ Default is quick-infer mode.
33
+ To change this, set `config.featext.fit_every_episode = True` in the main file.
34
+ You can change all other parameters likewise, check the available parameters in runner.makeConfig.
core/backbone.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ from operator import add
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torchvision.models import resnet
7
+
8
+ class Backbone(nn.Module):
9
+
10
+ def __init__(self, typestr):
11
+ super(Backbone, self).__init__()
12
+
13
+ self.backbone = typestr
14
+
15
+ # feature extractor initialization
16
+ if typestr == 'resnet50':
17
+ self.feature_extractor = resnet.resnet50(weights=resnet.ResNet50_Weights.DEFAULT)
18
+ self.feat_channels = [256, 512, 1024, 2048]
19
+ self.nlayers = [3, 4, 6, 3]
20
+ self.feat_ids = list(range(0, 17))
21
+ else:
22
+ raise Exception('Unavailable backbone: %s' % typestr)
23
+ self.feature_extractor.eval()
24
+
25
+ # define model
26
+ self.lids = reduce(add, [[i + 1] * x for i, x in enumerate(self.nlayers)])
27
+ self.stack_ids = torch.tensor(self.lids).bincount()[-4:].cumsum(dim=0)
28
+
29
+ self.cross_entropy_loss = nn.CrossEntropyLoss()
30
+
31
+ def extract_feats(self, img):
32
+ r""" Extract input image features """
33
+ feats = []
34
+ bottleneck_ids = reduce(add, list(map(lambda x: list(range(x)), self.nlayers)))
35
+ # Layer 0
36
+ feat = self.feature_extractor.conv1.forward(img)
37
+ feat = self.feature_extractor.bn1.forward(feat)
38
+ feat = self.feature_extractor.relu.forward(feat)
39
+ feat = self.feature_extractor.maxpool.forward(feat)
40
+
41
+ # Layer 1-4
42
+ for hid, (bid, lid) in enumerate(zip(bottleneck_ids, self.lids)):
43
+ res = feat
44
+ feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].conv1.forward(feat)
45
+ feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].bn1.forward(feat)
46
+ feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
47
+ feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].conv2.forward(feat)
48
+ feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].bn2.forward(feat)
49
+ feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
50
+ feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].conv3.forward(feat)
51
+ feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].bn3.forward(feat)
52
+
53
+ if bid == 0:
54
+ res = self.feature_extractor.__getattr__('layer%d' % lid)[bid].downsample.forward(res)
55
+
56
+ feat += res
57
+
58
+ if hid + 1 in self.feat_ids:
59
+ feats.append(feat.clone())
60
+
61
+ feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
62
+
63
+ return feats
core/contrastivehead.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ import torch
3
+ import torch.nn as nn
4
+ from utils import segutils
5
+ import core.denseaffinity as dautils
6
+
7
+ identity_mapping = lambda x, *args, **kwargs: x
8
+
9
+
10
+ class ContrastiveConfig:
11
+ def __init__(self, config=None):
12
+ # Define the internal dictionary with default settings.
13
+ if config is None:
14
+ self._data = {
15
+ 'aug': {
16
+ 'n_transformed_imgs': 2,
17
+ 'blurkernelsize': [1], # chooses one of this kernel sizes
18
+ 'maxjitter': 0.0,
19
+ 'maxangle': 0, # rotation
20
+ # 'translate': (0,0), # BE CAREFUL WITH TRANSLATE - if you apply it on the feature volume that has smaller spatial dims correspondences break
21
+ 'maxscale': 1.0, # 1.0 = No scaling
22
+ 'maxshear': 20,
23
+ 'randomhflip': False,
24
+ 'apply_affine': True,
25
+ 'debug': False
26
+ },
27
+ 'model': {
28
+ 'out_channels': 64,
29
+ 'kernel_size': 1,
30
+ 'prepend_relu': False,
31
+ 'append_normalize': False,
32
+ 'debug': False
33
+ },
34
+ 'fitting': {
35
+ 'lr': 1e-2,
36
+ 'optimizer': torch.optim.SGD,
37
+ 'num_epochs': 25,
38
+ 'nce': {
39
+ 'temperature': 0.5,
40
+ 'debug': False
41
+ },
42
+ 'normalize_after_fwd_pass': True,
43
+ 'q_nceloss': True,
44
+ 's_nceloss': True,
45
+ 'protoloss': False,
46
+ 'keepvarloss': True,
47
+ 'symmetricloss': False,
48
+ 'selfattentionloss': False,
49
+ 'o_t_contr_proto_loss': True,
50
+ 'debug': False
51
+ },
52
+ 'featext': {
53
+ 'l0': 3, # the first resnet bottleneck id to consider (0,1,2,3,4,5...15)
54
+ 'fit_every_episode': False
55
+ }
56
+ }
57
+ else:
58
+ self._data = config
59
+
60
+ def __getattr__(self, key):
61
+ # Try to get '_data' without causing a recursive call to __getattr__
62
+ _data = super().__getattribute__('_data') if '_data' in self.__dict__ else None
63
+
64
+ if _data is not None and key in _data:
65
+ if isinstance(_data[key], dict):
66
+ return ContrastiveConfig(_data[key])
67
+ return _data[key]
68
+
69
+ # If we're here, it means the key was not found in the data,
70
+ # so we let Python raise the appropriate AttributeError.
71
+ raise AttributeError(f"No setting named {key}")
72
+
73
+ def __setattr__(self, key, value):
74
+ # Prevent overwriting of the '_data' attribute by normal means
75
+ if key == '_data':
76
+ super().__setattr__(key, value)
77
+ else:
78
+ # Try to get '_data' without causing a recursive call to __getattr__
79
+ _data = super().__getattribute__('_data') if '_data' in self.__dict__ else None
80
+
81
+ if _data is not None:
82
+ _data[key] = value
83
+ else:
84
+ # This situation should not normally occur, handle appropriately (e.g., log an error, raise exception)
85
+ raise AttributeError("Unexpected")
86
+
87
+ # Optional: Representation for better debugging.
88
+ def __repr__(self):
89
+ return str(self._data)
90
+
91
+
92
+ def dense_info_nce_loss(original_features, transformed_features, config_nce):
93
+ B, C, H, W = transformed_features.shape
94
+ o_features = original_features.expand(B, C, H, W).permute(0, 2, 3, 1).view(B, H * W, C)
95
+ t_features = transformed_features.permute(0, 2, 3, 1).view(B, H * W, C)
96
+
97
+ # Calculate dot product between original and transformed feature vectors for positive pairs
98
+ positive_logits = torch.einsum('bik,bik->bi', o_features, t_features) / config_nce.temperature
99
+
100
+ # Calculate dot product between original features and all other transformed features for negative pairs
101
+ all_logits = torch.einsum('bik,bjk->bij', o_features, t_features) / config_nce.temperature
102
+
103
+ if config_nce.debug: print('pos/neg:', positive_logits.mean().detach(), all_logits.mean().detach())
104
+
105
+ # Using the log-sum-exp trick
106
+ max_logits = torch.max(all_logits, dim=-1, keepdim=True).values
107
+ log_sum_exp = max_logits + torch.log(torch.sum(torch.exp(all_logits - max_logits), dim=-1, keepdim=True))
108
+
109
+ # Compute InfoNCE loss
110
+ loss = - (positive_logits - log_sum_exp.squeeze())
111
+ return loss.mean() # [B=k*aug] or [B=k] -> scalar
112
+
113
+
114
+ def ssim(a, b):
115
+ return torch.nn.CosineSimilarity()(a, b)
116
+
117
+ def augwise_proto(feat_vol, mask, k, aug):
118
+ k, aug, c, h, w = k, aug, *feat_vol.shape[-3:]
119
+ feature_vectors_augwise = torch.cat(feat_vol.view(k, aug, c, h * w).unbind(0), dim=-1)
120
+ mask_augwise = torch.cat(segutils.downsample_mask(mask, h, w).view(k, aug, h * w).unbind(0), dim=-1)
121
+ assert feature_vectors_augwise.shape == (aug, c, k * h * w) and mask_augwise.shape == (
122
+ aug, k * h * w), "of transformed"
123
+
124
+ fg_proto, bg_proto = segutils.fg_bg_proto(feature_vectors_augwise, mask_augwise)
125
+ assert fg_proto.shape == bg_proto.shape == (aug, c)
126
+
127
+ return fg_proto, bg_proto
128
+
129
+
130
+ def calc_q_pred_coarse_nodetach(qft, sft, s_mask, l0=3):
131
+ bsz, c, hq, wq = qft.shape
132
+ hs, ws = sft.shape[-2:]
133
+
134
+ sft_row = torch.cat(sft.unbind(1), -1) # bsz,k,c,h,w -> bsz,c,h,w*k
135
+ smasks_downsampled = [segutils.downsample_mask(m, hs, ws) for m in s_mask.unbind(1)]
136
+ smask_row = torch.cat(smasks_downsampled, -1)
137
+
138
+ damat = dautils.buildDenseAffinityMat(qft, sft_row)
139
+ filtered = dautils.filterDenseAffinityMap(damat, smask_row)
140
+ q_pred_coarse = filtered.view(bsz, hq, wq)
141
+ return q_pred_coarse
142
+
143
+
144
+ # input k*aug,c,h,w
145
+ def self_attention_loss(f_base, f_transformed, mask_base, mask_transformed, k, aug):
146
+ c, h, w = f_base.shape[-3:]
147
+ pseudoquery = torch.cat(f_base.view(k, aug, c, h, w).unbind(0), -1) # shape aug,c,h,w*k
148
+ pseudoquerymask = torch.cat(mask_base.view(k, aug, h, w).unbind(0), -1) # shape aug,h,w*k
149
+ pseudosupport = f_transformed.view(k, aug, c, h, w).transpose(0, 1) # shape bsz,k,c,h,w
150
+ pseudosupportmask = mask_transformed.view(k, aug, h, w).transpose(0, 1) # shape bsz,k,h,w
151
+ # display(segutils.tensor_table(q=pseudoquery, s=pseudosupport, m=pseudosupportmask))
152
+ pred_map = calc_q_pred_coarse_nodetach(pseudoquery, pseudosupport, pseudosupportmask, l0=0)
153
+
154
+ loss = torch.nn.BCELoss()(pred_map.float(), pseudoquerymask.float())
155
+ return loss.mean()
156
+
157
+
158
+ # features of base, transformed: [b,c,h,w]
159
+ # if base features are aligned with transformed features, pass both same
160
+ def ctrstive_prototype_loss(base, transformed, mask_base, mask_transformed, k, aug):
161
+ assert transformed.shape == base.shape, ".."
162
+ b, c, h, w = base.shape
163
+ assert b == k * aug, 'provide correct k and aug such that dim0=k*aug'
164
+ assert mask_base.shape == mask_transformed.shape == (b, h, w), ".."
165
+ fg_proto_o, bg_proto_o = augwise_proto(base, mask_base, k, aug)
166
+ fg_proto_t, bg_proto_t = augwise_proto(transformed, mask_transformed, k, aug)
167
+ # i: fg, b: bg
168
+ # p_b_i, p_b_j = segutils.fg_bg_proto(base.view(b,c,h*w), mask_base.view(b,h*w))
169
+ # p_t_i, p_t_j = segutils.fg_bg_proto(transformed.view(b,c,h*w), mask_transformed.view(b,h*w))
170
+ enumer = torch.exp(
171
+ ssim(fg_proto_o, fg_proto_t)) # 5vs5 (augvsaug), but in 5-shot: 25vs25, no, you want also augvsaug
172
+ denom = torch.exp(ssim(fg_proto_o, fg_proto_t)) + torch.exp(ssim(fg_proto_o, bg_proto_t))
173
+ assert enumer.shape == denom.shape == torch.Size([aug]), 'you want to calculate one prototype for each augmentation'
174
+ loss = -torch.log(enumer / denom) # [bsz]
175
+ return loss.mean()
176
+
177
+
178
+ def opposite_proto_sim_in_aug(transformed_features, mapped_s_masks, k, aug):
179
+ fg_proto_t, bg_proto_t = augwise_proto(transformed_features, mapped_s_masks, k, aug)
180
+ fg_bg_sim_t = ssim(fg_proto_t, bg_proto_t)
181
+ return fg_bg_sim_t.mean()
182
+
183
+
184
+ def proto_align_val_measure(original_features, transformed_features, mapped_s_masks, k, aug):
185
+ fg_proto_o, _ = augwise_proto(original_features, mapped_s_masks, k, aug)
186
+ fg_proto_t, _ = augwise_proto(transformed_features, mapped_s_masks, k, aug)
187
+ fg_proto_sim = ssim(fg_proto_o, fg_proto_t)
188
+ return fg_proto_sim.mean()
189
+
190
+
191
+ def atest():
192
+ k, aug, c, h, w = 2, 5, 8, 20, 20
193
+ f_base = torch.rand(k * aug, c, h, w).float()
194
+ f_base.requires_grad = True
195
+ f_transformed = torch.rand(k * aug, c, h, w).float()
196
+ mask_base = torch.randint(0, 2, (k * aug, h, w)).float()
197
+ mask_transformed = torch.randint(0, 2, (k * aug, h, w)).float()
198
+
199
+ return self_attention_loss(f_base, f_transformed, mask_base, mask_transformed, k, aug)
200
+
201
+ def keep_var_loss(original_features, transformed_features):
202
+ meandiff = original_features.mean((-2, -1)) - transformed_features.mean((-2, -1))
203
+ vardiff = original_features.var((-2, -1)) - transformed_features.var((-2, -1))
204
+ keepvarloss = torch.abs(meandiff).mean() + torch.abs(
205
+ vardiff).mean() # [k*aug,c] -> [scalar] or [aug,c] -> [scalar]
206
+ return keepvarloss
207
+
208
+ class ContrastiveFeatureTransformer(nn.Module):
209
+ def __init__(self, in_channels, config_model):
210
+ super(ContrastiveFeatureTransformer, self).__init__()
211
+
212
+ out_channels, kernel_size = config_model.out_channels, config_model.kernel_size
213
+ # Add a convolutional layer and a batch normalization layer for learning
214
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2)
215
+ self.bn = nn.BatchNorm2d(out_channels)
216
+ self.linear = nn.Conv2d(out_channels, out_channels, 1)
217
+
218
+ self.prepend_relu = config_model.prepend_relu
219
+ self.append_normalize = config_model.append_normalize
220
+ self.debug = config_model.debug
221
+
222
+ def forward(self, x):
223
+ if self.prepend_relu:
224
+ x = nn.ReLU()(x)
225
+ x = self.conv(x)
226
+ x = self.bn(x)
227
+ x = nn.ReLU()(x)
228
+ x = self.linear(x)
229
+ if self.append_normalize:
230
+ x = F.normalize(x, p=2, dim=1)
231
+ return x
232
+
233
+ # fits the model for one semantic class, therefore does not work with batches
234
+ # mapped_qfeat_vol, aug_qfeat_vols: [aug,c,h,w]
235
+ # mapped_sfeat_vol, aug_sfeat_vols: [k*aug,c,h,w]
236
+ # augmented_smasks: [k*aug,h,w]
237
+ def fit(self, mapped_qfeat_vol, aug_qfeat_vols, mapped_sfeat_vol, aug_sfeat_vols, augmented_smasks, config_fit):
238
+ f_norm = F.normalize if config_fit.normalize_after_fwd_pass else identity_mapping
239
+ optimizer = config_fit.optimizer(self.parameters(), lr=config_fit.lr)
240
+ for epoch in range(config_fit.num_epochs):
241
+ # Pass original and transformed image batches through the model
242
+
243
+ # Q
244
+ original_features = f_norm(self(mapped_qfeat_vol), p=2, dim=1) # fwd pass non-augmented
245
+ transformed_features = f_norm(self(aug_qfeat_vols), p=2, dim=1) # fwd pass augmented
246
+
247
+ qloss = dense_info_nce_loss(original_features, transformed_features,
248
+ config_fit.nce) if config_fit.q_nceloss else 0
249
+ if config_fit.keepvarloss: # 1. idea: Let query and support have the same feature distribution (mean/var per channel)
250
+ qloss += keep_var_loss(original_features, transformed_features)
251
+ # S
252
+ original_features = f_norm(self(mapped_sfeat_vol), p=2, dim=1) # fwd pass non-augmented
253
+ transformed_features = f_norm(self(aug_sfeat_vols), p=2, dim=1) # fwd pass augmented
254
+
255
+ sloss = dense_info_nce_loss(original_features, transformed_features,
256
+ config_fit.nce) if config_fit.s_nceloss else 0
257
+ if config_fit.keepvarloss:
258
+ sloss += keep_var_loss(original_features, transformed_features)
259
+
260
+ # 2. class-aware loss: opposite classes should get opposite features
261
+ # for prototype calculation, we want only one prototype per class
262
+ # so we average over features of entire k
263
+ # but calculate prototype for each augmentation individually [k*aug,c,h,w]->[aug,c,k*h*w]->[aug,c]
264
+ kaug, c, h, w = transformed_features.shape
265
+ aug = aug_qfeat_vols.shape[0]
266
+ k = kaug // aug
267
+ if config_fit.protoloss:
268
+ assert not config_fit.o_t_contr_proto_loss, 'only one of the proto losses should be used'
269
+ opposite_proto_sim = opposite_proto_sim_in_aug(transformed_features, augmented_smasks, k, aug)
270
+ if config_fit.debug and (epoch == config_fit.num_epochs - 1 or epoch == 0): print(
271
+ 'proto-sim intER-class transf<->transf', opposite_proto_sim.item())
272
+ proto_loss = opposite_proto_sim
273
+ elif config_fit.selfattentionloss:
274
+ proto_loss = self_attention_loss(original_features, transformed_features, augmented_smasks,
275
+ augmented_smasks, k, aug)
276
+ if config_fit.debug and (epoch == config_fit.num_epochs - 1 or epoch == 0): print(
277
+ 'self-att non-transf<->transformed bce', proto_loss.item())
278
+ elif config_fit.o_t_contr_proto_loss:
279
+ o_t_contr_proto_loss = ctrstive_prototype_loss(original_features, transformed_features,
280
+ augmented_smasks, augmented_smasks, k, aug)
281
+ if config_fit.debug and (epoch == config_fit.num_epochs - 1 or epoch == 0): print(
282
+ 'proto-contr non-transf<->transformed', o_t_contr_proto_loss.item())
283
+ proto_loss = o_t_contr_proto_loss
284
+ else:
285
+ proto_loss = 0
286
+
287
+ if config_fit.debug and (epoch == config_fit.num_epochs - 1 or epoch == 0):
288
+ proto_align_val = proto_align_val_measure(original_features, transformed_features, augmented_smasks, k,
289
+ aug)
290
+ print('proto-sim intRA-class non-transf<->transformed (for validation)', proto_align_val.item())
291
+
292
+ # 3. do not let only one image fit well - regularization
293
+ q_s_loss_diff = torch.abs(qloss - sloss) if config_fit.symmetricloss else 0
294
+
295
+ # Aggregate loss
296
+ loss = qloss + sloss + q_s_loss_diff + proto_loss
297
+ assert loss.isfinite().all(), f"invalid contrastive loss:{loss}"
298
+
299
+ # Backpropagation and optimization
300
+ if config_fit.debug and (epoch == config_fit.num_epochs - 1 or epoch == 0):
301
+ def gradient_magnitude(loss_term):
302
+ optimizer.zero_grad()
303
+ loss_term.backward(retain_graph=True)
304
+ magn = torch.abs(self.conv.weight.grad.mean()) + torch.abs(self.linear.weight.grad.mean())
305
+ return magn
306
+
307
+ q_loss_grad_magnitude = gradient_magnitude(qloss)
308
+ s_loss_grad_magnitude = gradient_magnitude(sloss)
309
+ proto_loss_grad_magnitude = gradient_magnitude(proto_loss)
310
+ q_s_loss_diff_grad_magnitude = gradient_magnitude(q_s_loss_diff)
311
+ display(segutils.tensor_table(q_loss_grad_magnitude=q_loss_grad_magnitude,
312
+ s_loss_grad_magnitude=s_loss_grad_magnitude,
313
+ proto_loss_grad_magnitude=proto_loss_grad_magnitude,
314
+ q_s_loss_diff_grad_magnitude=q_s_loss_diff_grad_magnitude))
315
+
316
+ optimizer.zero_grad()
317
+ loss.backward()
318
+ optimizer.step()
319
+
320
+ if config_fit.debug and epoch % 10 == 0: print('loss', loss.detach())
321
+
322
+
323
+ import numpy as np
324
+ import torch.nn.functional as F
325
+ from torchvision.transforms.functional import affine
326
+ from torchvision.transforms import GaussianBlur, ColorJitter
327
+
328
+
329
+ class AffineProxy:
330
+ def __init__(self, angle, translate, scale, shear):
331
+ self.affine_params = {
332
+ 'angle': angle,
333
+ 'translate': translate,
334
+ 'scale': scale,
335
+ 'shear': shear
336
+ }
337
+
338
+ def apply(self, img):
339
+ return affine(img, angle=self.affine_params['angle'], translate=self.affine_params['translate'],
340
+ scale=self.affine_params['scale'], shear=self.affine_params['shear'])
341
+
342
+
343
+ # def affine_proxy(angle, translate, scale, shear):
344
+ # def inner(img):
345
+ # return affine(img, angle=angle, translate=translate, scale=scale, shear=shear)
346
+
347
+ # return inner
348
+
349
+ class Augmen:
350
+ def __init__(self, config_aug):
351
+ self.config = config_aug
352
+ self.blurs, self.jitters, self.affines = self.setup_augmentations()
353
+
354
+ def copy_construct(self, blurs, jitters, affines, config_aug):
355
+ self.config = config_aug
356
+ self.blurs, self.jitters, self.affines = blurs, jitters, affines
357
+
358
+ def setup_augmentations(self):
359
+ blurkernelsize = self.config.blurkernelsize
360
+ maxjitter = self.config.maxjitter
361
+
362
+ maxangle = self.config.maxangle
363
+ translate = (0, 0)
364
+ maxscale = self.config.maxscale
365
+ maxshear = self.config.maxshear
366
+
367
+ blurs = []
368
+ jitters = []
369
+ affine_trans = []
370
+ for i in range(self.config.n_transformed_imgs):
371
+ # Randomize kernel size for GaussianBlur
372
+ kernel_size = np.random.choice(torch.tensor(blurkernelsize), (1,)).item()
373
+ blur = GaussianBlur(kernel_size)
374
+ blurs.append(blur)
375
+
376
+ # Randomize values for ColorJitter
377
+ brightness_val = torch.rand(1).item() * maxjitter # up to <maxjitter> change
378
+ contrast_val = torch.rand(1).item() * maxjitter
379
+ saturation_val = torch.rand(1).item() * maxjitter
380
+ jitter = ColorJitter(brightness=brightness_val, contrast=contrast_val, saturation=saturation_val)
381
+ jitters.append(jitter)
382
+
383
+ # Random values for each iteration
384
+ angle = torch.randint(-maxangle, maxangle + 1, (1,)).item()
385
+ shear = [torch.randint(-maxshear, maxshear + 1, (1,)).item() for _ in range(2)]
386
+ scale = torch.rand(1).item() * (1 - maxscale) + maxscale
387
+ affine_trans.append(AffineProxy(angle=angle, translate=translate, scale=scale, shear=shear))
388
+
389
+ return (blurs, jitters, affine_trans) # tuple of lists
390
+
391
+ def augment(self, original_image, orignal_mask):
392
+ transformed_imgs = []
393
+ transformed_masks = []
394
+ for blur, jitter, affine_trans in zip(self.blurs, self.jitters, self.affines):
395
+ # Apply non-geometric transformations
396
+ t_img = blur(original_image)
397
+ t_img = jitter(t_img)
398
+ t_mask = orignal_mask.clone()
399
+
400
+ if self.config.apply_affine:
401
+ t_img = affine_trans.apply(t_img)
402
+ t_mask = affine_trans.apply(t_mask)
403
+
404
+ transformed_imgs.append(t_img)
405
+ transformed_masks.append(t_mask)
406
+ return torch.stack(transformed_imgs, dim=1), torch.stack(transformed_masks, dim=1)
407
+
408
+ # [bsz,ch,h,w] -> [bsz,aug,ch,h,w], where aug is the number of augmentated images
409
+ def applyAffines(self, feat_vol):
410
+ return torch.stack([trans.apply(feat_vol) for trans in self.affines], dim=1)
411
+
412
+
413
+ class CTrBuilder:
414
+ # call init 1st, pass all config parameters (initatiate a ContrastiveConfig class in your code)
415
+ def __init__(self, config, augmentator=None):
416
+ if augmentator is None:
417
+ augmentator = Augmen(config.aug)
418
+ self.augmentator = augmentator
419
+
420
+ self.augimgs = self.AugImgStack(augmentator)
421
+
422
+ self.hasfit = False
423
+ self.config = config
424
+
425
+ class AugImgStack():
426
+ def __init__(self, augmentator):
427
+ self.augmentator = augmentator
428
+ self.q, self.s, self.s_mask = None, None, None
429
+
430
+ def init(self, s_img):
431
+ # c is color channels here, not feature channels
432
+ bsz, k, aug, c, h, w = *s_img.shape[:2], self.augmentator.config.n_transformed_imgs, *s_img.shape[-3:]
433
+ self.q = torch.empty(bsz, aug, c, h, w).to(s_img.device)
434
+ self.s = torch.empty(bsz, k, aug, c, h, w).to(s_img.device)
435
+ self.s_mask = torch.empty(bsz, k, aug, h, w).to(s_img.device)
436
+
437
+ def show(self):
438
+ bsz_, k_, aug_ = self.s.shape[:3]
439
+ for b in range(bsz_):
440
+ display('aug x queries', segutils.pilImageRow(*[segutils.norm(img) for img in self.q[b]]))
441
+ for k in range(k_):
442
+ print('k=', k, ' aug x (s, smask):')
443
+ display(segutils.pilImageRow(*[segutils.norm(img) for img in self.s[b, k]]))
444
+ display(segutils.pilImageRow(*self.s_mask[b, k]))
445
+
446
+ def showAugmented(self):
447
+ self.augimgs.show()
448
+
449
+ # 2nd call makeAugmented
450
+ def makeAugmented(self, q_img, s_img, s_mask):
451
+ # 2. Augmentation
452
+ # 2.1 Apply transformations to images
453
+ self.augimgs.init(s_img)
454
+ self.augimgs.q, _ = self.augmentator.augment(q_img, s_mask)
455
+
456
+ for k in range(s_img.shape[1]):
457
+ s_aug_imgs, s_aug_masks = self.augmentator.augment(s_img[:, k], s_mask[:, k])
458
+ self.augimgs.s[:, k] = s_aug_imgs
459
+ self.augimgs.s_mask[:, k] = s_aug_masks
460
+ if self.config.aug.debug: self.augimgs.show()
461
+
462
+ # 3rd call build_and_fit
463
+ def build_and_fit(self, q_feat, s_feat, q_feataug, s_feataug, s_maskaug=None):
464
+ if s_maskaug is None: s_maskaug = self.augimgs.s_mask
465
+ self.ctrs = self.buildContrastiveTransformers(q_feat, s_feat, q_feataug, s_feataug, s_maskaug)
466
+ self.hasfit = True
467
+
468
+ def buildContrastiveTransformers(self, qfeat_alllayers, sfeat_alllayers, query_feats_aug, support_feats_aug,
469
+ supp_aug_mask, s_mask=None):
470
+ contrastive_transformers = []
471
+ l0 = self.config.featext.l0
472
+ # [bsz,k,aug,h,w] -> [k*aug,h,w]
473
+ s_aug_mask = supp_aug_mask.view(-1, *supp_aug_mask.shape[-2:])
474
+ # iterate over feature layers
475
+ for (qfeat, sfeat, qfeataug, sfeataug) in zip(qfeat_alllayers[l0:], sfeat_alllayers[l0:], query_feats_aug[l0:],
476
+ support_feats_aug[l0:]):
477
+ bsz, k, aug, ch, h, w = sfeataug.shape
478
+ # we fit it for exactly one class, so use no batches
479
+ assert bsz == 1, "bsz should be 1"
480
+ assert supp_aug_mask.shape[1] == sfeat.shape[
481
+ 1] == k, f'augmented support shot-dimension mismatch:{s_aug_mask.shape[1]=},{sfeat.shape[1]=},(bsz,k,aug,ch,h,w)={bsz, k, aug, ch, h, w}'
482
+ assert supp_aug_mask.shape[2] == qfeataug.shape[1] == aug, 'augmented shot-dimension mismatch'
483
+ # [bsz,c,h,w] -> [1,c,h,w]
484
+ qfeat = qfeat.view(-1, *qfeat.shape[-3:])
485
+ # [bsz,k,c,h,w] -> [k,c,h,w]
486
+ sfeat = sfeat.view(-1, *sfeat.shape[-3:])
487
+ # [bsz,aug,c,h,w] -> [aug,c,h,w]
488
+ qfeataug = qfeataug.view(-1, *qfeataug.shape[-3:])
489
+ # [bsz,k,aug,c,h,w] -> [k*aug,c,h,w]
490
+ sfeataug = sfeataug.view(-1, *qfeataug.shape[-3:])
491
+
492
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
493
+ contrastive_head = ContrastiveFeatureTransformer(in_channels=ch, config_model=self.config.model).to(device)
494
+
495
+ # 3. Feature volumes from untransformed image need to be geometrically mapped to allow for dense matching
496
+ mapped_qfeat = self.augmentator.applyAffines(qfeat)
497
+ assert mapped_qfeat.shape[1] == aug, "should be 1,aug,c,h,w"
498
+ mapped_qfeat = mapped_qfeat.view(-1, *qfeat.shape[-3:]) # ->[aug,c,h,w]
499
+ mapped_sfeat = self.augmentator.applyAffines(sfeat)
500
+ assert mapped_sfeat.shape[1] == aug and mapped_sfeat.shape[0] == k, "should be k,aug,c,h,w"
501
+ mapped_sfeat = mapped_sfeat.view(-1, *sfeat.shape[-3:]) # ->[k*aug,c,h,w]
502
+
503
+ contrastive_head.fit(mapped_qfeat, qfeataug, mapped_sfeat, sfeataug,
504
+ segutils.downsample_mask(s_aug_mask, h, w), self.config.fitting)
505
+
506
+ contrastive_transformers.append(contrastive_head)
507
+ # show how support image and its augmentations would produce a affinity map
508
+ if s_mask != None:
509
+ display(segutils.to_pil(segutils.norm(dautils.filterDenseAffinityMap(
510
+ dautils.buildDenseAffinityMat(contrastive_head(sfeat), contrastive_head(sfeataug[:1])),
511
+ segutils.downsample_mask(s_mask, h, w)).view(1, h, w))))
512
+ display(segutils.to_pil(segutils.norm(dautils.filterDenseAffinityMap(
513
+ dautils.buildDenseAffinityMat(contrastive_head(qfeat), contrastive_head(sfeat)),
514
+ segutils.downsample_mask(s_mask, h, w)).view(1, h, w))))
515
+ return contrastive_transformers
516
+
517
+ # You have fitted the contrastive transformers, now apply the transform and then pass to the downstream DCAMA
518
+ # you just need to append the empty layers you exluded ([:3]), they're also skipped in dcama
519
+ # Obtain the result of the contrastive head, which will be the new query and support feat representation
520
+ def getTaskAdaptedFeats(self, layerwise_feats):
521
+ if (self.ctrs == None): print("error: call buildContrastiveTransformers() first")
522
+ task_adapted_feats = []
523
+
524
+ for idx in range(len(layerwise_feats)):
525
+ if idx < self.config.featext.l0:
526
+ task_adapted_feats.append(None)
527
+ else:
528
+ input_shape = layerwise_feats[idx].shape
529
+ idxth_feat = layerwise_feats[idx].view(-1, *input_shape[-3:])
530
+ forward_pass_res = self.ctrs[idx - self.config.featext.l0](idxth_feat)
531
+ target_shape = *input_shape[:-3], *forward_pass_res.shape[
532
+ -3:] # borrow channel dim from result, but bsz,k dims from input
533
+ task_adapted_feats.append(forward_pass_res.view(target_shape))
534
+
535
+ return task_adapted_feats
536
+
537
+
538
+ class FeatureMaker:
539
+ def __init__(self, feat_extraction_method, class_ids, config=ContrastiveConfig()):
540
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
541
+ self.featextractor = feat_extraction_method
542
+ self.c_trs = {ctr: CTrBuilder(config) for ctr in class_ids}
543
+ self.norm_bb_feats = False
544
+
545
+ def extract_bb_feats(self, img):
546
+ with torch.no_grad():
547
+ return self.featextractor(img)
548
+
549
+ def create_and_fit(self, c_tr, q_img, s_img, s_mask, q_feat, s_feat):
550
+ print('doing contrastive')
551
+ c_tr.makeAugmented(q_img, s_img, s_mask)
552
+
553
+ bsz, k, c, h, w = s_img.shape
554
+ aug = c_tr.augmentator.config.n_transformed_imgs
555
+ # [bsz,aug,c,h,w]->[bsz*aug,c,h,w] squeeze for forward pass
556
+ q_feataug = self.extract_bb_feats(c_tr.augimgs.q.view(-1, c, h, w)) # returns layer-list
557
+ # then restore
558
+ q_feataug = [l.view(bsz, aug, *l.shape[1:]) for l in q_feataug]
559
+ # [bsz,k,aug,c,h,w]->[bsz*k*aug,c,h,w]->[bsz,k,aug,c,h,w]
560
+ s_feataug = self.extract_bb_feats(c_tr.augimgs.s.view(-1, c, h, w))
561
+ s_feataug = [l.view(bsz, k, aug, *l.shape[1:]) for l in s_feataug]
562
+
563
+ c_tr.build_and_fit(q_feat, s_feat, q_feataug, s_feataug)
564
+
565
+ def taskAdapt(self, q_img, s_img, s_mask, class_id):
566
+ ch_norm = lambda t: t / torch.linalg.norm(t, dim=1)
567
+ q_feat = self.extract_bb_feats(q_img)
568
+ bsz, k, c, h, w = s_img.shape
569
+ s_feat = self.extract_bb_feats(s_img.view(-1, c, h, w))
570
+ if self.norm_bb_feats:
571
+ q_feat = [ch_norm(l) for l in q_feat]
572
+ s_feat = [ch_norm(l) for l in q_feat]
573
+ s_feat = [l.view(bsz, k, *l.shape[1:]) for l in s_feat]
574
+
575
+ c_tr = self.c_trs[class_id] # select the relevant ctr for this class
576
+
577
+ if c_tr.hasfit is False or c_tr.config.featext.fit_every_episode: # create and fit a contrastive transformer if not existing yet
578
+ self.create_and_fit(c_tr, q_img, s_img, s_mask, q_feat, s_feat)
579
+
580
+ q_feat_t, s_feat_t = c_tr.getTaskAdaptedFeats(q_feat), c_tr.getTaskAdaptedFeats(
581
+ s_feat) # tocheck: do they require_grad here?
582
+ return q_feat_t, s_feat_t
core/denseaffinity.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+ from utils import segutils
5
+
6
+
7
+ def buildHyperCol(feat_pyram):
8
+ # concatenate along channel dim
9
+ # upsample spatial size to largest feat vol space available
10
+ target_size = feat_pyram[0].shape[-2:]
11
+ upsampled = []
12
+ for layer in feat_pyram:
13
+ # if idx < self.stack_ids[0]: continue
14
+ upsampled.append(F.interpolate(layer, size=target_size, mode='bilinear', align_corners=False))
15
+ return torch.cat(upsampled, dim=1)
16
+
17
+
18
+ # accepts both:
19
+ # s_feat_vol: [bsz,k,c,h,w]->[bsz,c,h,w*k]
20
+ # s_mask: [bsz,k,h,w]->[bsz,h,w*k]
21
+ def paste_supports_together(supports):
22
+ return torch.cat(supports.unbind(dim=1), dim=-1)
23
+
24
+
25
+ # Attention regular:
26
+ # 1. Dot product
27
+ # 2. Divide by square root of key length (#nchannels)
28
+ # 3. Softmax
29
+ # 4. Multiply with V (mask)
30
+
31
+ def buildDenseAffinityMat(qfeat_volume, sfeat_volume, softmax_arg2=True): # bsz,C,H,W
32
+ qfeat_volume, sfeat_volume = qfeat_volume.permute(0, 2, 3, 1), sfeat_volume.permute(0, 2, 3, 1)
33
+ bsz, H, Wq, C = qfeat_volume.shape
34
+ Ws = sfeat_volume.shape[2]
35
+ # [px,C][C,px]=[px,px]
36
+ dense_affinity_mat = torch.matmul(qfeat_volume.view(bsz, H * Wq, C),
37
+ sfeat_volume.view(bsz, H * Ws, C).transpose(1, 2))
38
+ if softmax_arg2 is False: return dense_affinity_mat
39
+ dense_affinity_mat_softmax = (dense_affinity_mat / math.sqrt(C)).softmax(
40
+ dim=-1) # each query pixel's affinities sum up to 1 over support pxls
41
+ return dense_affinity_mat_softmax
42
+
43
+
44
+ # filter with support mask following DAM
45
+ def filterDenseAffinityMap(dense_affinity_mat, downsampled_smask):
46
+ # for each query pixel, aggregate all correlations where the support mask ==1
47
+ # [px,px][px,1]=[px,1]
48
+ bsz, HWq, HWs = dense_affinity_mat.shape
49
+ # let mean(V)=1 -> sum(V)=len(V) -> d_mask / mean(d_mask)
50
+ # downsampled_smask_norm = downsampled_smask / downsampled_smask.mean()
51
+ q_coarse = torch.matmul(dense_affinity_mat, downsampled_smask.view(bsz, HWs, 1))
52
+ return q_coarse.view(bsz, HWq)
53
+
54
+
55
+ def upsample(volume, h, w):
56
+ return F.interpolate(volume, size=(h, w), mode='bilinear', align_corners=False)
57
+
58
+ class DAMatComparison:
59
+
60
+ def algo_mean(self, q_pred_coarses_t, s_mask=None):
61
+ return q_pred_coarses_t.mean(1)
62
+
63
+ def calc_q_pred_coarses(self, q_feat_t, s_feat_t, s_mask, l0=3):
64
+ q_pred_coarses = []
65
+ h0, w0 = q_feat_t[l0].shape[-2:]
66
+ for (qft, sft) in zip(q_feat_t[l0:], s_feat_t[l0:]):
67
+ qft, sft = qft.detach(), sft.detach()
68
+ bsz, c, hq, wq = qft.shape
69
+ hs, ws = sft.shape[-2:]
70
+
71
+ sft_row = torch.cat(sft.unbind(1), -1) # bsz,k,c,h,w -> bsz,c,h,w*k
72
+ smasks_downsampled = [segutils.downsample_mask(m, hs, ws) for m in s_mask.unbind(1)]
73
+ smask_row = torch.cat(smasks_downsampled, -1)
74
+
75
+ damat = buildDenseAffinityMat(qft, sft_row)
76
+ filtered = filterDenseAffinityMap(damat, smask_row)
77
+ q_pred_coarse = upsample(filtered.view(bsz, 1, hq, wq), h0, w0).squeeze(1)
78
+ q_pred_coarses.append(q_pred_coarse)
79
+ return torch.stack(q_pred_coarses, dim=1)
80
+
81
+ def forward(self, q_feat_t, s_feat_t, s_mask, upsample=True, debug=False):
82
+ q_pred_coarses_t = self.calc_q_pred_coarses(q_feat_t, s_feat_t, s_mask)
83
+
84
+ if debug: display(segutils.pilImageRow(*q_pred_coarses_t.unbind(1), q_pred_coarses_t.mean(1)))
85
+
86
+ # select the algorithm
87
+ postprocessing_algorithm = self.algo_mean
88
+ # do the postprocessing
89
+ logit_mask = postprocessing_algorithm(q_pred_coarses_t, s_mask)
90
+ if upsample: # if query and support have different shape, then you must do upsampling yourself afterwards
91
+ logit_mask = segutils.downsample_mask(logit_mask, *s_mask.shape[-2:])
92
+
93
+ return logit_mask
core/runner.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data.dataset import FSSDataset
2
+ from core.backbone import Backbone
3
+ from eval.logger import Logger, AverageMeter
4
+ from eval.evaluation import Evaluator
5
+ from utils import commonutils as utils
6
+ import utils.segutils as segutils
7
+ import core.contrastivehead as ctrutils
8
+ import core.denseaffinity as dautils
9
+ import torch
10
+
11
+ class args:
12
+ backbone = 'resnet50'
13
+ logpath = '/kaggle/working/logs'
14
+ nworker = 0
15
+ bsz = 1
16
+ benchmark='' #e.g. deepglobe,isic,etc.
17
+ datapath='' #path to the selected dataset
18
+ fold = 0
19
+ nshot = 1
20
+
21
+ class SingleSampleEval:
22
+ def __init__(self, batch, feat_maker, debug=False):
23
+ self.damat_comp = dautils.DAMatComparison()
24
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
+ self.batch = batch
26
+ self.feat_maker = feat_maker
27
+ self.debug = debug
28
+ self.thresh_method = 'pred_mean'
29
+
30
+ def taskAdapt(self, detach=True):
31
+ b = self.batch
32
+ if self.device.type == 'cuda': b = utils.to_cuda(b)
33
+ self.q_img, self.s_img, self.s_mask, self.class_id = b['query_img'], b['support_imgs'], b['support_masks'], b[
34
+ 'class_id'].item()
35
+ self.task_adapted = self.feat_maker.taskAdapt(self.q_img, self.s_img, self.s_mask, self.class_id)
36
+
37
+ def compare_feats(self):
38
+ if self.task_adapted is None:
39
+ print("error, do task adaption first")
40
+ return None
41
+ self.logit_mask = self.damat_comp.forward(self.task_adapted[0], self.task_adapted[1], self.s_mask)
42
+ return self.logit_mask
43
+
44
+ def threshold(self, method=None):
45
+ if self.logit_mask is None:
46
+ print("error, calculate logit mask first (do forward pass)")
47
+ if method is None:
48
+ method = self.thresh_method
49
+ self.thresh = calcthresh(self.logit_mask, self.s_mask, method)
50
+ self.pred_mask = (self.logit_mask > self.thresh).float()
51
+ return self.thresh, self.pred_mask
52
+
53
+ def apply_crf(self):
54
+ return apply_crf(self.q_img, self.logit_mask, thresh_fn(self.thresh_method))
55
+
56
+ # this method calls above components sequentially
57
+ def forward(self):
58
+ self.taskAdapt()
59
+
60
+ self.logit_mask = self.compare_feats()
61
+
62
+ self.thresh, self.pred_mask = self.threshold()
63
+
64
+ return self.logit_mask, self.pred_mask
65
+
66
+ def calc_metrics(self):
67
+ # assert torch.logical_or(self.logit_mask<0, self.logit_mask>1).sum()==0, display(tensor_table(logit_mask=self.logit_mask))
68
+ self.area_inter, self.area_union = Evaluator.classify_prediction(self.pred_mask, self.batch)
69
+ self.fgratio_pred = self.pred_mask.float().mean()
70
+ self.fgratio_gt = self.batch['query_mask'].float().mean()
71
+ return self.area_inter[1] / self.area_union[1] # fg-iou
72
+
73
+ def plots(self):
74
+ display(pilImageRow(norm(self.logit_mask[0]), (self.logit_mask[0] > self.thresh).float(), self.pred_mask,
75
+ self.batch['query_mask'][:1], norm(self.q_img[0]), norm(self.s_img[0, 0])))
76
+ display(segutils.tensor_table(probs=self.logit_mask))
77
+
78
+ print('s_mask.mean, pred_mask.mean, thresh:', self.s_mask.mean().item(), self.logit_mask.mean().item(),
79
+ self.thresh.item())
80
+
81
+ class AverageMeterWrapper:
82
+ def __init__(self, dataloader, device='cpu', initlogger=True):
83
+ if initlogger: Logger.initialize(args, training=False)
84
+ self.average_meter = AverageMeter(dataloader.dataset, device)
85
+ self.device=device
86
+ self.dataloader = dataloader
87
+ self.write_batch_idx = 50
88
+ def update(self, sseval):
89
+ self.average_meter.update(sseval.area_inter, sseval.area_union, torch.tensor(sseval.class_id).to(self.device), loss=None)
90
+ def update_manual(self, area_inter, area_union, class_id):
91
+ if isinstance(class_id, int): class_id = torch.tensor(class_id).to(self.device)
92
+ self.average_meter.update(area_inter, area_union, class_id, loss=None)
93
+ def write(self, i):
94
+ self.average_meter.write_process(i, len(self.dataloader), 0, self.write_batch_idx)
95
+
96
+ def makeDataloader():
97
+
98
+ FSSDataset.initialize(img_size=400, datapath=args.datapath)
99
+ dataloader = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'test', args.nshot)
100
+ return dataloader
101
+
102
+
103
+ def makeConfig():
104
+ config = ctrutils.ContrastiveConfig()
105
+ config.fitting.protoloss = False
106
+ config.fitting.o_t_contr_proto_loss = True
107
+ config.fitting.selfattentionloss = False
108
+ config.fitting.keepvarloss = True
109
+ config.fitting.symmetricloss = False
110
+ config.fitting.q_nceloss = True
111
+ config.fitting.s_nceloss = True
112
+ config.fitting.num_epochs = 25
113
+ config.fitting.lr = 1e-2
114
+ config.fitting.debug = False
115
+ config.model.out_channels = 64
116
+ config.featext.fit_every_episode = False
117
+ config.aug.blurkernelsize = [1]
118
+ config.aug.n_transformed_imgs = 2
119
+ config.aug.maxjitter = 0.0
120
+ config.aug.maxangle = 0
121
+ config.aug.maxscale = 1
122
+ config.aug.maxshear = 20
123
+ config.aug.apply_affine = True
124
+ config.aug.debug = False
125
+ return config
126
+
127
+
128
+ def makeFeatureMaker(dataset, config, device='cpu', randseed=2, feat_extr_method=None):
129
+ utils.fix_randseed(randseed)
130
+ if feat_extr_method is None:
131
+ feat_extr_method = Backbone(args.backbone).to(device).extract_feats
132
+ feat_maker = ctrutils.FeatureMaker(feat_extr_method, dataset.class_ids, config)
133
+ utils.fix_randseed(randseed)
134
+ feat_maker.norm_bb_feats = False
135
+ return feat_maker
136
+ def apply_crf(rgb_img, fg_pred, thresh_fn,iterations=5): #5 on deployment, 1 on support-aug test for speedup
137
+ crf = segutils.CRF(gaussian_stdxy=(1,1), gaussian_compat=2,
138
+ bilateral_stdxy=(35,35), bilateral_compat=1, stdrgb=(13,13,13))
139
+ q = crf.iterrefine(iterations, rgb_img, fg_pred, thresh_fn)
140
+ return q.argmax(1)
141
+
142
+ def calcthresh(fused_pred, s_masks, method='otsus'):
143
+ if method=='iterotsus':
144
+ thresh = segutils.iterative_otsus(fused_pred,s_masks,maxiters=5)[0]
145
+ return thresh
146
+ elif method=='1iterotsus':
147
+ thresh = segutils.iterative_otsus(fused_pred,s_masks,maxiters=1)[0]
148
+ return thresh
149
+ elif method=='otsus':
150
+ thresh = segutils.otsus(fused_pred)[0]
151
+ return thresh
152
+ # elif method=='via_triclass':
153
+ # thresh = segutils.otsus(fused_pred, mode='via_triclass')[0]
154
+ elif method=='pred_mean':
155
+ otsu_thresh = segutils.otsus(fused_pred)[0]
156
+ thresh = torch.max(otsu_thresh, fused_pred.mean())
157
+ # elif method=='3kmeans':
158
+ # k3 = segutils.KMeans(fused_pred.float().view(1,-1), k=3)
159
+ # thresh = k3.compute_thresholds()[0][-1]
160
+ return thresh
161
+
162
+ def thresh_fn(method):
163
+ def inner(fused_pred, s_masks=None):
164
+ return calcthresh(fused_pred, s_masks, method)
165
+ return inner
data/coco.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" COCO-20i few-shot semantic segmentation dataset """
2
+ import os
3
+ import pickle
4
+
5
+ from torch.utils.data import Dataset
6
+ import torch.nn.functional as F
7
+ import torch
8
+ import PIL.Image as Image
9
+ import numpy as np
10
+
11
+
12
+ class DatasetCOCO(Dataset):
13
+ def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize=False):
14
+ self.split = 'val' if split in ['val', 'test'] else 'trn'
15
+ self.fold = fold
16
+ self.nfolds = 4
17
+ self.nclass = 80
18
+ self.benchmark = 'coco'
19
+ self.shot = shot
20
+ self.split_coco = split if split == 'val2014' else 'train2014'
21
+ self.base_path = os.path.join(datapath, 'COCO2014')
22
+ self.transform = transform
23
+ self.use_original_imgsize = use_original_imgsize
24
+
25
+ self.class_ids = self.build_class_ids()
26
+ self.img_metadata_classwise = self.build_img_metadata_classwise()
27
+ self.img_metadata = self.build_img_metadata()
28
+
29
+ def __len__(self):
30
+ return len(self.img_metadata) if self.split == 'trn' else 1000
31
+
32
+ def __getitem__(self, idx):
33
+ # ignores idx during training & testing and perform uniform sampling over object classes to form an episode
34
+ # (due to the large size of the COCO dataset)
35
+ query_img, query_mask, support_imgs, support_masks, query_name, support_names, class_sample, org_qry_imsize = self.load_frame()
36
+
37
+ query_img = self.transform(query_img)
38
+ query_mask = query_mask.float()
39
+ if not self.use_original_imgsize:
40
+ query_mask = F.interpolate(query_mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze()
41
+
42
+ support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs])
43
+ for midx, smask in enumerate(support_masks):
44
+ support_masks[midx] = F.interpolate(smask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze()
45
+ support_masks = torch.stack(support_masks)
46
+
47
+ batch = {'query_img': query_img,
48
+ 'query_mask': query_mask,
49
+ 'query_name': query_name,
50
+
51
+ 'org_query_imsize': org_qry_imsize,
52
+
53
+ 'support_imgs': support_imgs,
54
+ 'support_masks': support_masks,
55
+ 'support_names': support_names,
56
+ 'class_id': torch.tensor(class_sample)}
57
+
58
+ return batch
59
+
60
+ def build_class_ids(self):
61
+ nclass_trn = self.nclass // self.nfolds
62
+ class_ids_val = [self.fold + self.nfolds * v for v in range(nclass_trn)]
63
+ class_ids_trn = [x for x in range(self.nclass) if x not in class_ids_val]
64
+ class_ids = class_ids_trn if self.split == 'trn' else class_ids_val
65
+
66
+ return class_ids
67
+
68
+ def build_img_metadata_classwise(self):
69
+ with open('./data/splits/coco/%s/fold%d.pkl' % (self.split, self.fold), 'rb') as f:
70
+ img_metadata_classwise = pickle.load(f)
71
+ return img_metadata_classwise
72
+
73
+ def build_img_metadata(self):
74
+ img_metadata = []
75
+ for k in self.img_metadata_classwise.keys():
76
+ img_metadata += self.img_metadata_classwise[k]
77
+ return sorted(list(set(img_metadata)))
78
+
79
+ def read_mask(self, name):
80
+ mask_path = os.path.join(self.base_path, 'annotations', name)
81
+ mask = torch.tensor(np.array(Image.open(mask_path[:mask_path.index('.jpg')] + '.png')))
82
+ return mask
83
+
84
+ def load_frame(self):
85
+ class_sample = np.random.choice(self.class_ids, 1, replace=False)[0]
86
+ query_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
87
+ query_img = Image.open(os.path.join(self.base_path, query_name)).convert('RGB')
88
+ query_mask = self.read_mask(query_name)
89
+
90
+ org_qry_imsize = query_img.size
91
+
92
+ query_mask[query_mask != class_sample + 1] = 0
93
+ query_mask[query_mask == class_sample + 1] = 1
94
+
95
+ support_names = []
96
+ while True: # keep sampling support set if query == support
97
+ support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
98
+ if query_name != support_name: support_names.append(support_name)
99
+ if len(support_names) == self.shot: break
100
+
101
+ support_imgs = []
102
+ support_masks = []
103
+ for support_name in support_names:
104
+ support_imgs.append(Image.open(os.path.join(self.base_path, support_name)).convert('RGB'))
105
+ support_mask = self.read_mask(support_name)
106
+ support_mask[support_mask != class_sample + 1] = 0
107
+ support_mask[support_mask == class_sample + 1] = 1
108
+ support_masks.append(support_mask)
109
+
110
+ return query_img, query_mask, support_imgs, support_masks, query_name, support_names, class_sample, org_qry_imsize
111
+
data/dataset.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Dataloader builder for few-shot semantic segmentation dataset """
2
+ from torch.utils.data.distributed import DistributedSampler as Sampler
3
+ from torch.utils.data import DataLoader
4
+ from torchvision import transforms
5
+
6
+ from data.pascal import DatasetPASCAL
7
+ from data.coco import DatasetCOCO
8
+ from data.fss import DatasetFSS
9
+ from data.deepglobe import DatasetDeepglobe
10
+ from data.isic import DatasetISIC
11
+ from data.lung import DatasetLung
12
+ from data.fss import DatasetFSS
13
+ from data.suim import DatasetSUIM
14
+
15
+
16
+ class FSSDataset:
17
+
18
+ @classmethod
19
+ def initialize(cls, img_size, datapath):
20
+
21
+ cls.datasets = {
22
+ 'pascal': DatasetPASCAL,
23
+ 'coco': DatasetCOCO,
24
+ 'fss': DatasetFSS,
25
+ 'deepglobe': DatasetDeepglobe,
26
+ 'isic': DatasetISIC,
27
+ 'lung': DatasetLung,
28
+ 'suim': DatasetSUIM
29
+ }
30
+
31
+ cls.img_mean = [0.485, 0.456, 0.406]
32
+ cls.img_std = [0.229, 0.224, 0.225]
33
+ cls.datapath = datapath
34
+
35
+ cls.transform = transforms.Compose([transforms.Resize(size=(img_size, img_size)),
36
+ transforms.ToTensor(),
37
+ transforms.Normalize(cls.img_mean, cls.img_std)])
38
+
39
+ @classmethod
40
+ def build_dataloader(cls, benchmark, bsz, nworker, fold, split, shot=1):
41
+ nworker = nworker if split == 'trn' else 0
42
+
43
+ dataset = cls.datasets[benchmark](cls.datapath, fold=fold,
44
+ transform=cls.transform,
45
+ split=split, shot=shot)
46
+ # Force randomness during training for diverse episode combinations
47
+ # Freeze randomness during testing for reproducibility
48
+ #train_sampler = Sampler(dataset) if split == 'trn' else None
49
+ dataloader = DataLoader(dataset, batch_size=bsz, shuffle=split=='trn', num_workers=nworker,
50
+ pin_memory=True)
51
+
52
+ return dataloader
data/deepglobe.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" FSS-1000 few-shot semantic segmentation dataset """
2
+ import os
3
+ import glob
4
+
5
+ from torch.utils.data import Dataset
6
+ import torch.nn.functional as F
7
+ import torch
8
+ import PIL.Image as Image
9
+ import numpy as np
10
+
11
+
12
+ class DatasetDeepglobe(Dataset):
13
+ def __init__(self, datapath, fold, transform, split, shot, num_val=600):
14
+ self.split = split
15
+ self.benchmark = 'deepglobe'
16
+ self.shot = shot
17
+ self.num_val = num_val
18
+
19
+ self.base_path = os.path.join(datapath)
20
+ self.to_annpath = lambda p: p.replace('jpg', 'png').replace('origin', 'groundtruth')
21
+
22
+ self.categories = ['1','2','3','4','5','6']
23
+
24
+ self.class_ids = range(0, 6)
25
+ self.img_metadata_classwise, self.num_images = self.build_img_metadata_classwise()
26
+
27
+ self.transform = transform
28
+
29
+ def __len__(self):
30
+ # if it is the target domain, then also test on entire dataset
31
+ return self.num_images if self.split !='val' else self.num_val
32
+
33
+ def __getitem__(self, idx):
34
+ query_name, support_names, class_sample = self.sample_episode(idx)
35
+ query_img, query_mask, support_imgs, support_masks = self.load_frame(query_name, support_names)
36
+
37
+ query_img = self.transform(query_img)
38
+ query_mask = F.interpolate(query_mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze()
39
+
40
+ support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs])
41
+
42
+ support_masks_tmp = []
43
+ for smask in support_masks:
44
+ smask = F.interpolate(smask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze()
45
+ support_masks_tmp.append(smask)
46
+ support_masks = torch.stack(support_masks_tmp)
47
+
48
+ batch = {'query_img': query_img,
49
+ 'query_mask': query_mask,
50
+ 'support_set': (support_imgs, support_masks),
51
+ 'support_classes': torch.tensor([class_sample]), # adapt to Nway
52
+
53
+ 'query_name': query_name, # REMOVE
54
+ 'support_imgs': support_imgs, # REMOVE
55
+ 'support_masks': support_masks, # REMOVE
56
+ 'support_names': support_names, # REMOVE
57
+ 'class_id': torch.tensor(class_sample)} # REMOVE
58
+
59
+ return batch
60
+
61
+
62
+ def load_frame(self, query_name, support_names):
63
+ query_img = Image.open(query_name).convert('RGB')
64
+ support_imgs = [Image.open(name).convert('RGB') for name in support_names]
65
+
66
+ query_id = query_name.split('/')[-1].split('.')[0]
67
+ ann_path = os.path.join(self.base_path, query_name.split('/')[-4], 'test', 'groundtruth')
68
+ query_name = os.path.join(ann_path, query_id) + '.png'
69
+ support_ids = [name.split('/')[-1].split('.')[0] for name in support_names]
70
+ support_names = [os.path.join(ann_path, sid) + '.png' for name, sid in zip(support_names, support_ids)]
71
+
72
+ query_mask = self.read_mask(query_name)
73
+ support_masks = [self.read_mask(name) for name in support_names]
74
+
75
+ return query_img, query_mask, support_imgs, support_masks
76
+
77
+ def read_mask(self, img_name):
78
+ mask = torch.tensor(np.array(Image.open(img_name).convert('L')))
79
+ mask[mask < 128] = 0
80
+ mask[mask >= 128] = 1
81
+ return mask
82
+
83
+ def sample_episode(self, idx):
84
+ class_id = idx % len(self.class_ids)
85
+ class_sample = self.categories[class_id]
86
+
87
+ query_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
88
+ support_names = []
89
+ while True: # keep sampling support set if query == support
90
+ support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
91
+ if query_name != support_name: support_names.append(support_name)
92
+ if len(support_names) == self.shot: break
93
+
94
+ return query_name, support_names, class_id
95
+
96
+
97
+ # def build_img_metadata(self):
98
+ # img_metadata = []
99
+ # for cat in self.categories:
100
+ # os.path.join(self.base_path, cat)
101
+ # img_paths = sorted([path for path in glob.glob('%s/*' % os.path.join(self.base_path, cat, 'test', 'origin'))])
102
+ # for img_path in img_paths:
103
+ # if os.path.basename(img_path).split('.')[1] == 'jpg':
104
+ # img_metadata.append(img_path)
105
+ # return img_metadata
106
+
107
+ def build_img_metadata_classwise(self):
108
+ num_images=0
109
+ img_metadata_classwise = {}
110
+ for cat in self.categories:
111
+ img_metadata_classwise[cat] = []
112
+
113
+ for cat in self.categories:
114
+ img_paths = sorted([path for path in glob.glob('%s/*' % os.path.join(self.base_path, cat, 'test', 'origin'))])
115
+ for img_path in img_paths:
116
+ if os.path.basename(img_path).split('.')[1] == 'jpg':
117
+ img_metadata_classwise[cat] += [img_path]
118
+ num_images += 1
119
+ return img_metadata_classwise, num_images
data/fss.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" FSS-1000 few-shot semantic segmentation dataset """
2
+ import os
3
+ import glob
4
+
5
+ from torch.utils.data import Dataset
6
+ import torch.nn.functional as F
7
+ import torch
8
+ import PIL.Image as Image
9
+ import numpy as np
10
+
11
+
12
+ class DatasetFSS(Dataset):
13
+ def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize=False):
14
+ self.split = split
15
+ self.benchmark = 'fss'
16
+ self.shot = shot
17
+
18
+ self.base_path = os.path.join(datapath, 'FSS-1000')
19
+
20
+ # Given predefined test split, load randomly generated training/val splits:
21
+ # (reference regarding trn/val/test splits: https://github.com/HKUSTCV/FSS-1000/issues/7))
22
+ with open('./data/splits/fss/%s.txt' % split, 'r') as f:
23
+ self.categories = f.read().split('\n')[:-1]
24
+ self.categories = sorted(self.categories)
25
+
26
+ self.class_ids = self.build_class_ids()
27
+ self.img_metadata = self.build_img_metadata()
28
+
29
+ self.transform = transform
30
+
31
+ def __len__(self):
32
+ return len(self.img_metadata)
33
+
34
+ def __getitem__(self, idx):
35
+ query_name, support_names, class_sample = self.sample_episode(idx)
36
+ query_img, query_mask, support_imgs, support_masks = self.load_frame(query_name, support_names)
37
+
38
+ query_img = self.transform(query_img)
39
+ query_mask = F.interpolate(query_mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze()
40
+
41
+ support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs])
42
+
43
+ support_masks_tmp = []
44
+ for smask in support_masks:
45
+ smask = F.interpolate(smask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze()
46
+ support_masks_tmp.append(smask)
47
+ support_masks = torch.stack(support_masks_tmp)
48
+
49
+ batch = {'query_img': query_img,
50
+ 'query_mask': query_mask,
51
+ 'query_name': query_name,
52
+
53
+ 'support_imgs': support_imgs,
54
+ 'support_masks': support_masks,
55
+ 'support_names': support_names,
56
+
57
+ 'class_id': torch.tensor(class_sample)}
58
+
59
+ return batch
60
+
61
+ def load_frame(self, query_name, support_names):
62
+ query_img = Image.open(query_name).convert('RGB')
63
+ support_imgs = [Image.open(name).convert('RGB') for name in support_names]
64
+
65
+ query_id = query_name.split('/')[-1].split('.')[0]
66
+ query_name = os.path.join(os.path.dirname(query_name), query_id) + '.png'
67
+ support_ids = [name.split('/')[-1].split('.')[0] for name in support_names]
68
+ support_names = [os.path.join(os.path.dirname(name), sid) + '.png' for name, sid in zip(support_names, support_ids)]
69
+
70
+ query_mask = self.read_mask(query_name)
71
+ support_masks = [self.read_mask(name) for name in support_names]
72
+
73
+ return query_img, query_mask, support_imgs, support_masks
74
+
75
+ def read_mask(self, img_name):
76
+ mask = torch.tensor(np.array(Image.open(img_name).convert('L')))
77
+ mask[mask < 128] = 0
78
+ mask[mask >= 128] = 1
79
+ return mask
80
+
81
+ def sample_episode(self, idx):
82
+ query_name = self.img_metadata[idx]
83
+ class_sample = self.categories.index(query_name.split('/')[-2])
84
+ if self.split == 'val':
85
+ class_sample += 520
86
+ elif self.split == 'test':
87
+ class_sample += 760
88
+
89
+ support_names = []
90
+ while True: # keep sampling support set if query == support
91
+ support_name = np.random.choice(range(1, 11), 1, replace=False)[0]
92
+ support_name = os.path.join(os.path.dirname(query_name), str(support_name)) + '.jpg'
93
+ if query_name != support_name: support_names.append(support_name)
94
+ if len(support_names) == self.shot: break
95
+
96
+ return query_name, support_names, class_sample
97
+
98
+ def build_class_ids(self):
99
+ if self.split == 'trn':
100
+ class_ids = range(0, 520)
101
+ elif self.split == 'val':
102
+ class_ids = range(520, 760)
103
+ elif self.split == 'test':
104
+ class_ids = range(760, 1000)
105
+ return class_ids
106
+
107
+ def build_img_metadata(self):
108
+ img_metadata = []
109
+ for cat in self.categories:
110
+ img_paths = sorted([path for path in glob.glob('%s/*' % os.path.join(self.base_path, cat))])
111
+ for img_path in img_paths:
112
+ if os.path.basename(img_path).split('.')[1] == 'jpg':
113
+ img_metadata.append(img_path)
114
+ return img_metadata
data/isic.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" ISIC few-shot semantic segmentation dataset """
2
+ import os
3
+ import glob
4
+
5
+ from torch.utils.data import Dataset
6
+ import torch.nn.functional as F
7
+ import torch
8
+ import PIL.Image as Image
9
+ import numpy as np
10
+
11
+
12
+ class DatasetISIC(Dataset):
13
+ def __init__(self, datapath, fold, transform, split, shot, num_val=600):
14
+ self.split = split
15
+ self.benchmark = 'isic'
16
+ self.shot = shot
17
+ self.num_val = num_val
18
+
19
+ self.base_path = os.path.join(datapath)
20
+ self.categories = ['1', '2', '3']
21
+
22
+ self.class_ids = range(0, 3)
23
+ self.img_metadata_classwise,self.num_images = self.build_img_metadata_classwise()
24
+
25
+ self.transform = transform
26
+
27
+ def __len__(self):
28
+ return self.num_images if self.split != 'val' else self.num_val
29
+
30
+ def __getitem__(self, idx):
31
+ query_name, support_names, class_sample = self.sample_episode(idx)
32
+ query_img, query_mask, support_imgs, support_masks = self.load_frame(query_name, support_names)
33
+
34
+ query_img = self.transform(query_img)
35
+ query_mask = F.interpolate(query_mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze()
36
+
37
+ support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs])
38
+
39
+ support_masks_tmp = []
40
+ for smask in support_masks:
41
+ smask = F.interpolate(smask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze()
42
+ support_masks_tmp.append(smask)
43
+ support_masks = torch.stack(support_masks_tmp)
44
+
45
+ batch = {'query_img': query_img,
46
+ 'query_mask': query_mask,
47
+ 'query_name': query_name,
48
+
49
+ 'support_imgs': support_imgs,
50
+ 'support_masks': support_masks,
51
+ 'support_names': support_names,
52
+
53
+ 'class_id': torch.tensor(class_sample)}
54
+
55
+ return batch
56
+
57
+ def load_frame(self, query_name, support_names):
58
+ query_img = Image.open(query_name).convert('RGB')
59
+ support_imgs = [Image.open(name).convert('RGB') for name in support_names]
60
+
61
+ query_id = query_name.split('/')[-1].split('.')[0]
62
+ ann_path = os.path.join(self.base_path, 'ISIC2018_Task1_Training_GroundTruth')
63
+ query_name = os.path.join(ann_path, query_id) + '_segmentation.png'
64
+ support_ids = [name.split('/')[-1].split('.')[0] for name in support_names]
65
+ support_names = [os.path.join(ann_path, sid) + '_segmentation.png' for name, sid in zip(support_names, support_ids)]
66
+
67
+ query_mask = self.read_mask(query_name)
68
+ support_masks = [self.read_mask(name) for name in support_names]
69
+
70
+ return query_img, query_mask, support_imgs, support_masks
71
+
72
+ def read_mask(self, img_name):
73
+ mask = torch.tensor(np.array(Image.open(img_name).convert('L')))
74
+ mask[mask < 128] = 0
75
+ mask[mask >= 128] = 1
76
+ return mask
77
+
78
+ def sample_episode(self, idx):
79
+ class_id = idx % len(self.class_ids)
80
+ class_sample = self.categories[class_id]
81
+
82
+ query_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
83
+ support_names = []
84
+ while True: # keep sampling support set if query == support
85
+ support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
86
+ if query_name != support_name: support_names.append(support_name)
87
+ if len(support_names) == self.shot: break
88
+
89
+ return query_name, support_names, class_id
90
+
91
+ def build_img_metadata(self):
92
+ img_metadata = []
93
+ for cat in self.categories:
94
+ os.path.join(self.base_path, cat)
95
+ img_paths = sorted([path for path in glob.glob('%s/*' % os.path.join(self.base_path, 'ISIC2018_Task1-2_Training_Input', cat))])
96
+ for img_path in img_paths:
97
+ if os.path.basename(img_path).split('.')[1] == 'jpg':
98
+ img_metadata.append(img_path)
99
+ return img_metadata
100
+
101
+ def build_img_metadata_classwise(self):
102
+ num_images=0
103
+ img_metadata_classwise = {}
104
+ for cat in self.categories:
105
+ img_metadata_classwise[cat] = []
106
+
107
+ for cat in self.categories:
108
+ img_paths = sorted([path for path in glob.glob('%s/*' % os.path.join(self.base_path, 'ISIC2018_Task1-2_Training_Input', cat))])
109
+ for img_path in img_paths:
110
+ if os.path.basename(img_path).split('.')[1] == 'jpg':
111
+ img_metadata_classwise[cat] += [img_path]
112
+ num_images += 1
113
+ return img_metadata_classwise, num_images
data/lung.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Chest X-ray few-shot semantic segmentation dataset """
2
+ import os
3
+ import glob
4
+
5
+ from torch.utils.data import Dataset
6
+ import torch.nn.functional as F
7
+ import torch
8
+ import PIL.Image as Image
9
+ import numpy as np
10
+
11
+
12
+ class DatasetLung(Dataset):
13
+ def __init__(self, datapath, fold, transform, split, shot=1, num_val=600):
14
+ self.benchmark = 'lung'
15
+ self.shot = shot
16
+ self.split = split
17
+ self.num_val = num_val
18
+
19
+ self.base_path = os.path.join(datapath)
20
+ self.img_path = os.path.join(self.base_path, 'CXR_png')
21
+ self.ann_path = os.path.join(self.base_path, 'masks')
22
+
23
+ self.categories = ['1']
24
+
25
+ self.class_ids = range(0, 1)
26
+ self.img_metadata_classwise, self.num_images = self.build_img_metadata_classwise()
27
+
28
+ self.transform = transform
29
+
30
+ def __len__(self):
31
+ return self.num_images if self.split != 'val' else self.num_val
32
+
33
+ def __getitem__(self, idx):
34
+ query_name, support_names, class_sample = self.sample_episode(idx)
35
+ query_img, query_mask, support_imgs, support_masks = self.load_frame(query_name, support_names)
36
+
37
+ query_img = self.transform(query_img)
38
+ query_mask = F.interpolate(query_mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze()
39
+
40
+ support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs])
41
+
42
+ support_masks_tmp = []
43
+ for smask in support_masks:
44
+ smask = F.interpolate(smask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze()
45
+ support_masks_tmp.append(smask)
46
+ support_masks = torch.stack(support_masks_tmp)
47
+
48
+ batch = {'query_img': query_img,
49
+ 'query_mask': query_mask,
50
+ 'query_name': query_name,
51
+ 'support_imgs': support_imgs,
52
+ 'support_masks': support_masks,
53
+ 'class_id': torch.tensor(class_sample),
54
+ 'support_names': support_names,
55
+
56
+ 'support_set': [support_imgs, support_masks],
57
+ 'support_classes': torch.tensor([class_sample])
58
+ }
59
+
60
+ return batch
61
+
62
+ def load_frame(self, query_name, support_names):
63
+ query_mask = self.read_mask(query_name)
64
+ support_masks = [self.read_mask(name) for name in support_names]
65
+
66
+ query_id = query_name[:-9] + '.png'
67
+ query_img = Image.open(os.path.join(self.img_path, os.path.basename(query_id))).convert('RGB')
68
+
69
+ support_ids = [os.path.basename(name)[:-9] + '.png' for name in support_names]
70
+ support_names = [os.path.join(self.img_path, sid) for sid in support_ids]
71
+ support_imgs = [Image.open(name).convert('RGB') for name in support_names]
72
+
73
+ return query_img, query_mask, support_imgs, support_masks
74
+
75
+ def read_mask(self, img_name):
76
+ mask = torch.tensor(np.array(Image.open(img_name).convert('L')))
77
+ mask[mask < 128] = 0
78
+ mask[mask >= 128] = 1
79
+ return mask
80
+
81
+ def sample_episode(self, idx):
82
+ class_id = idx % len(self.class_ids)
83
+ class_sample = self.categories[class_id]
84
+
85
+ query_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
86
+ support_names = []
87
+ while True: # keep sampling support set if query == support
88
+ support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
89
+ if query_name != support_name: support_names.append(support_name)
90
+ if len(support_names) == self.shot: break
91
+
92
+ return query_name, support_names, class_id
93
+
94
+ def build_img_metadata(self):
95
+ img_metadata = []
96
+ for cat in self.categories:
97
+ os.path.join(self.base_path, cat)
98
+ img_paths = sorted([path for path in glob.glob('%s/*' % os.path.join(self.img_path, cat))])
99
+ for img_path in img_paths:
100
+ if os.path.basename(img_path).split('.')[1] == 'png':
101
+ img_metadata.append(img_path)
102
+ return img_metadata
103
+
104
+ def build_img_metadata_classwise(self):
105
+ num_images=0
106
+ img_metadata_classwise = {}
107
+ for cat in self.categories:
108
+ img_metadata_classwise[cat] = []
109
+
110
+ for cat in self.categories:
111
+ img_paths = sorted([path for path in glob.glob('%s/*' % self.ann_path)])
112
+ for img_path in img_paths:
113
+ if os.path.basename(img_path).split('.')[1] == 'png':
114
+ img_metadata_classwise[cat] += [img_path]
115
+ num_images+=1
116
+ return img_metadata_classwise, num_images
data/pascal.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" PASCAL-5i few-shot semantic segmentation dataset """
2
+ import os
3
+
4
+ from torch.utils.data import Dataset
5
+ import torch.nn.functional as F
6
+ import torch
7
+ import PIL.Image as Image
8
+ import numpy as np
9
+
10
+
11
+ class DatasetPASCAL(Dataset):
12
+ def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize=False):
13
+ self.split = 'val' if split in ['val', 'test'] else 'trn'
14
+ self.fold = fold
15
+ self.nfolds = 4
16
+ self.nclass = 20
17
+ self.benchmark = 'pascal'
18
+ self.shot = shot
19
+ self.use_original_imgsize = use_original_imgsize
20
+
21
+ self.img_path = os.path.join(datapath, 'VOC2012/JPEGImages/')
22
+ self.ann_path = os.path.join(datapath, 'VOC2012/SegmentationClassAug/')
23
+ self.transform = transform
24
+
25
+ self.class_ids = self.build_class_ids()
26
+ self.img_metadata = self.build_img_metadata()
27
+ self.img_metadata_classwise = self.build_img_metadata_classwise()
28
+
29
+ def __len__(self):
30
+ return len(self.img_metadata) if self.split == 'trn' else 1000
31
+
32
+ def __getitem__(self, idx):
33
+ idx %= len(self.img_metadata) # for testing, as n_images < 1000
34
+ query_name, support_names, class_sample = self.sample_episode(idx)
35
+ query_img, query_cmask, support_imgs, support_cmasks, org_qry_imsize = self.load_frame(query_name, support_names)
36
+
37
+ query_img = self.transform(query_img)
38
+ if not self.use_original_imgsize:
39
+ query_cmask = F.interpolate(query_cmask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze()
40
+ query_mask, query_ignore_idx = self.extract_ignore_idx(query_cmask.float(), class_sample)
41
+
42
+ support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs])
43
+
44
+ support_masks = []
45
+ support_ignore_idxs = []
46
+ for scmask in support_cmasks:
47
+ scmask = F.interpolate(scmask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze()
48
+ support_mask, support_ignore_idx = self.extract_ignore_idx(scmask, class_sample)
49
+ support_masks.append(support_mask)
50
+ support_ignore_idxs.append(support_ignore_idx)
51
+ support_masks = torch.stack(support_masks)
52
+ support_ignore_idxs = torch.stack(support_ignore_idxs)
53
+
54
+ batch = {'query_img': query_img,
55
+ 'query_mask': query_mask,
56
+ 'query_name': query_name,
57
+ 'query_ignore_idx': query_ignore_idx,
58
+
59
+ 'org_query_imsize': org_qry_imsize,
60
+
61
+ 'support_imgs': support_imgs,
62
+ 'support_masks': support_masks,
63
+ 'support_names': support_names,
64
+ 'support_ignore_idxs': support_ignore_idxs,
65
+
66
+ 'class_id': torch.tensor(class_sample)}
67
+
68
+ return batch
69
+
70
+ def extract_ignore_idx(self, mask, class_id):
71
+ boundary = (mask / 255).floor()
72
+ mask[mask != class_id + 1] = 0
73
+ mask[mask == class_id + 1] = 1
74
+
75
+ return mask, boundary
76
+
77
+ def load_frame(self, query_name, support_names):
78
+ query_img = self.read_img(query_name)
79
+ query_mask = self.read_mask(query_name)
80
+ support_imgs = [self.read_img(name) for name in support_names]
81
+ support_masks = [self.read_mask(name) for name in support_names]
82
+
83
+ org_qry_imsize = query_img.size
84
+
85
+ return query_img, query_mask, support_imgs, support_masks, org_qry_imsize
86
+
87
+ def read_mask(self, img_name):
88
+ r"""Return segmentation mask in PIL Image"""
89
+ mask = torch.tensor(np.array(Image.open(os.path.join(self.ann_path, img_name) + '.png')))
90
+ return mask
91
+
92
+ def read_img(self, img_name):
93
+ r"""Return RGB image in PIL Image"""
94
+ return Image.open(os.path.join(self.img_path, img_name) + '.jpg')
95
+
96
+ def sample_episode(self, idx):
97
+ query_name, class_sample = self.img_metadata[idx]
98
+
99
+ support_names = []
100
+ while True: # keep sampling support set if query == support
101
+ support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
102
+ if query_name != support_name: support_names.append(support_name)
103
+ if len(support_names) == self.shot: break
104
+
105
+ return query_name, support_names, class_sample
106
+
107
+ def build_class_ids(self):
108
+ nclass_trn = self.nclass // self.nfolds
109
+ class_ids_val = [self.fold * nclass_trn + i for i in range(nclass_trn)]
110
+ class_ids_trn = [x for x in range(self.nclass) if x not in class_ids_val]
111
+
112
+ if self.split == 'trn':
113
+ return class_ids_trn
114
+ else:
115
+ return class_ids_val
116
+
117
+ def build_img_metadata(self):
118
+
119
+ def read_metadata(split, fold_id):
120
+ fold_n_metadata = os.path.join('data/splits/pascal/%s/fold%d.txt' % (split, fold_id))
121
+ with open(fold_n_metadata, 'r') as f:
122
+ fold_n_metadata = f.read().split('\n')[:-1]
123
+ fold_n_metadata = [[data.split('__')[0], int(data.split('__')[1]) - 1] for data in fold_n_metadata]
124
+ return fold_n_metadata
125
+
126
+ img_metadata = []
127
+ if self.split == 'trn': # For training, read image-metadata of "the other" folds
128
+ for fold_id in range(self.nfolds):
129
+ if fold_id == self.fold: # Skip validation fold
130
+ continue
131
+ img_metadata += read_metadata(self.split, fold_id)
132
+ elif self.split == 'val': # For validation, read image-metadata of "current" fold
133
+ img_metadata = read_metadata(self.split, self.fold)
134
+ else:
135
+ raise Exception('Undefined split %s: ' % self.split)
136
+
137
+ print('Total (%s) images are : %d' % (self.split, len(img_metadata)))
138
+
139
+ return img_metadata
140
+
141
+ def build_img_metadata_classwise(self):
142
+ img_metadata_classwise = {}
143
+ for class_id in range(self.nclass):
144
+ img_metadata_classwise[class_id] = []
145
+
146
+ for img_name, img_class in self.img_metadata:
147
+ img_metadata_classwise[img_class] += [img_name]
148
+ return img_metadata_classwise
data/splits/fss/test.txt ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bus
2
+ hotel_slipper
3
+ burj_al
4
+ reflex_camera
5
+ abe's_flyingfish
6
+ oiltank_car
7
+ doormat
8
+ fish_eagle
9
+ barber_shaver
10
+ motorbike
11
+ feather_clothes
12
+ wandering_albatross
13
+ rice_cooker
14
+ delta_wing
15
+ fish
16
+ nintendo_switch
17
+ bustard
18
+ diver
19
+ minicooper
20
+ cathedrale_paris
21
+ big_ben
22
+ combination_lock
23
+ villa_savoye
24
+ american_alligator
25
+ gym_ball
26
+ andean_condor
27
+ leggings
28
+ pyramid_cube
29
+ jet_aircraft
30
+ meatloaf
31
+ reel
32
+ swan
33
+ osprey
34
+ crt_screen
35
+ microscope
36
+ rubber_eraser
37
+ arrow
38
+ monkey
39
+ mitten
40
+ spiderman
41
+ parthenon
42
+ bat
43
+ chess_king
44
+ sulphur_butterfly
45
+ quail_egg
46
+ oriole
47
+ iron_man
48
+ wooden_boat
49
+ anise
50
+ steering_wheel
51
+ groenendael
52
+ dwarf_beans
53
+ pteropus
54
+ chalk_brush
55
+ bloodhound
56
+ moon
57
+ english_foxhound
58
+ boxing_gloves
59
+ peregine_falcon
60
+ pyraminx
61
+ cicada
62
+ screw
63
+ shower_curtain
64
+ tredmill
65
+ bulb
66
+ bell_pepper
67
+ lemur_catta
68
+ doughnut
69
+ twin_tower
70
+ astronaut
71
+ nintendo_3ds
72
+ fennel_bulb
73
+ indri
74
+ captain_america_shield
75
+ kunai
76
+ broom
77
+ iphone
78
+ earphone1
79
+ flying_squirrel
80
+ onion
81
+ vinyl
82
+ sydney_opera_house
83
+ oyster
84
+ harmonica
85
+ egg
86
+ breast_pump
87
+ guitar
88
+ potato_chips
89
+ tunnel
90
+ cuckoo
91
+ rubick_cube
92
+ plastic_bag
93
+ phonograph
94
+ net_surface_shoes
95
+ goldfinch
96
+ ipad
97
+ mite_predator
98
+ coffee_mug
99
+ golden_plover
100
+ f1_racing
101
+ lapwing
102
+ nintendo_gba
103
+ pizza
104
+ rally_car
105
+ drilling_platform
106
+ cd
107
+ fly
108
+ magpie_bird
109
+ leaf_fan
110
+ little_blue_heron
111
+ carriage
112
+ moist_proof_pad
113
+ flying_snakes
114
+ dart_target
115
+ warehouse_tray
116
+ nintendo_wiiu
117
+ chiffon_cake
118
+ bath_ball
119
+ manatee
120
+ cloud
121
+ marimba
122
+ eagle
123
+ ruler
124
+ soymilk_machine
125
+ sled
126
+ seagull
127
+ glider_flyingfish
128
+ doublebus
129
+ transport_helicopter
130
+ window_screen
131
+ truss_bridge
132
+ wasp
133
+ snowman
134
+ poached_egg
135
+ strawberry
136
+ spinach
137
+ earphone2
138
+ downy_pitch
139
+ taj_mahal
140
+ rocking_chair
141
+ cablestayed_bridge
142
+ sealion
143
+ banana_boat
144
+ pheasant
145
+ stone_lion
146
+ electronic_stove
147
+ fox
148
+ iguana
149
+ rugby_ball
150
+ hang_glider
151
+ water_buffalo
152
+ lotus
153
+ paper_plane
154
+ missile
155
+ flamingo
156
+ american_chamelon
157
+ kart
158
+ chinese_knot
159
+ cabbage_butterfly
160
+ key
161
+ church
162
+ tiltrotor
163
+ helicopter
164
+ french_fries
165
+ water_heater
166
+ snow_leopard
167
+ goblet
168
+ fan
169
+ snowplow
170
+ leafhopper
171
+ pspgo
172
+ black_bear
173
+ quail
174
+ condor
175
+ chandelier
176
+ hair_razor
177
+ white_wolf
178
+ toaster
179
+ pidan
180
+ pyramid
181
+ chicken_leg
182
+ letter_opener
183
+ apple_icon
184
+ porcupine
185
+ chicken
186
+ stingray
187
+ warplane
188
+ windmill
189
+ bamboo_slip
190
+ wig
191
+ flying_geckos
192
+ stonechat
193
+ haddock
194
+ australian_terrier
195
+ hover_board
196
+ siamang
197
+ canton_tower
198
+ santa_sledge
199
+ arch_bridge
200
+ curlew
201
+ sushi
202
+ beet_root
203
+ accordion
204
+ leaf_egg
205
+ stealth_aircraft
206
+ stork
207
+ bucket
208
+ hawk
209
+ chess_queen
210
+ ocarina
211
+ knife
212
+ whippet
213
+ cantilever_bridge
214
+ may_bug
215
+ wagtail
216
+ leather_shoes
217
+ wheelchair
218
+ shumai
219
+ speedboat
220
+ vacuum_cup
221
+ chess_knight
222
+ pumpkin_pie
223
+ wooden_spoon
224
+ bamboo_dragonfly
225
+ ganeva_chair
226
+ soap
227
+ clearwing_flyingfish
228
+ pencil_sharpener1
229
+ cricket
230
+ photocopier
231
+ nintendo_sp
232
+ samarra_mosque
233
+ clam
234
+ charge_battery
235
+ flying_frog
236
+ ferrari911
237
+ polo_shirt
238
+ echidna
239
+ coin
240
+ tower_pisa
data/splits/fss/trn.txt ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fountain
2
+ taxi
3
+ assult_rifle
4
+ radio
5
+ comb
6
+ box_turtle
7
+ igloo
8
+ head_cabbage
9
+ cottontail
10
+ coho
11
+ ashtray
12
+ joystick
13
+ sleeping_bag
14
+ jackfruit
15
+ trailer_truck
16
+ shower_cap
17
+ ibex
18
+ kinguin
19
+ squirrel
20
+ ac_wall
21
+ sidewinder
22
+ remote_control
23
+ marshmallow
24
+ bolotie
25
+ polar_bear
26
+ rock_beauty
27
+ tokyo_tower
28
+ wafer
29
+ red_bayberry
30
+ electronic_toothbrush
31
+ hartebeest
32
+ cassette
33
+ oil_filter
34
+ bomb
35
+ walnut
36
+ toilet_tissue
37
+ memory_stick
38
+ wild_boar
39
+ cableways
40
+ chihuahua
41
+ envelope
42
+ bison
43
+ poker
44
+ pubg_lvl3helmet
45
+ indian_cobra
46
+ staffordshire
47
+ park_bench
48
+ wombat
49
+ black_grouse
50
+ submarine
51
+ washer
52
+ agama
53
+ coyote
54
+ feeder
55
+ sarong
56
+ buckingham_palace
57
+ frog
58
+ steam_locomotive
59
+ acorn
60
+ german_pointer
61
+ obelisk
62
+ polecat
63
+ black_swan
64
+ butterfly
65
+ mountain_tent
66
+ gorilla
67
+ sloth_bear
68
+ aubergine
69
+ stinkhorn
70
+ stole
71
+ owl
72
+ mooli
73
+ pool_table
74
+ collar
75
+ lhasa_apso
76
+ ambulance
77
+ spade
78
+ pufferfish
79
+ paint_brush
80
+ lark
81
+ golf_ball
82
+ hock
83
+ fork
84
+ drake
85
+ bee_house
86
+ mooncake
87
+ wok
88
+ cocacola
89
+ water_bike
90
+ ladder
91
+ psp
92
+ bassoon
93
+ bear
94
+ border_terrier
95
+ petri_dish
96
+ pill_bottle
97
+ aircraft_carrier
98
+ panther
99
+ canoe
100
+ baseball_player
101
+ turtle
102
+ espresso
103
+ throne
104
+ cornet
105
+ coucal
106
+ eletrical_switch
107
+ bra
108
+ snail
109
+ backpack
110
+ jacamar
111
+ scroll_brush
112
+ gliding_lizard
113
+ raft
114
+ pinwheel
115
+ grasshopper
116
+ green_mamba
117
+ eft_newt
118
+ computer_mouse
119
+ vine_snake
120
+ recreational_vehicle
121
+ llama
122
+ meerkat
123
+ chainsaw
124
+ ferret
125
+ garbage_can
126
+ kangaroo
127
+ litchi
128
+ carbonara
129
+ housefinch
130
+ modem
131
+ tebby_cat
132
+ thatch
133
+ face_powder
134
+ tomb
135
+ apple
136
+ ladybug
137
+ killer_whale
138
+ rocket
139
+ airship
140
+ surfboard
141
+ lesser_panda
142
+ jordan_logo
143
+ banana
144
+ nail_scissor
145
+ swab
146
+ perfume
147
+ punching_bag
148
+ victor_icon
149
+ waffle_iron
150
+ trimaran
151
+ garlic
152
+ flute
153
+ langur
154
+ starfish
155
+ parallel_bars
156
+ dandie_dinmont
157
+ cosmetic_brush
158
+ screwdriver
159
+ brick_card
160
+ balance_weight
161
+ hornet
162
+ carton
163
+ toothpaste
164
+ bracelet
165
+ egg_tart
166
+ pencil_sharpener2
167
+ swimming_glasses
168
+ howler_monkey
169
+ camel
170
+ dragonfly
171
+ lionfish
172
+ convertible
173
+ mule
174
+ usb
175
+ conch
176
+ papaya
177
+ garbage_truck
178
+ dingo
179
+ radiator
180
+ solar_dish
181
+ streetcar
182
+ trilobite
183
+ bouzouki
184
+ ringlet_butterfly
185
+ space_shuttle
186
+ waffle
187
+ american_staffordshire
188
+ violin
189
+ flowerpot
190
+ forklift
191
+ manx
192
+ sundial
193
+ snowmobile
194
+ chickadee_bird
195
+ ruffed_grouse
196
+ brick_tea
197
+ paddle
198
+ stove
199
+ carousel
200
+ spatula
201
+ beaker
202
+ gas_pump
203
+ lawn_mower
204
+ speaker
205
+ tank
206
+ tresher
207
+ kappa_logo
208
+ hare
209
+ tennis_racket
210
+ shopping_cart
211
+ thimble
212
+ tractor
213
+ anemone_fish
214
+ trolleybus
215
+ steak
216
+ capuchin
217
+ red_breasted_merganser
218
+ golden_retriever
219
+ light_tube
220
+ flatworm
221
+ melon_seed
222
+ digital_watch
223
+ jacko_lantern
224
+ brown_bear
225
+ cairn
226
+ mushroom
227
+ chalk
228
+ skull
229
+ stapler
230
+ potato
231
+ telescope
232
+ proboscis
233
+ microphone
234
+ torii
235
+ baseball_bat
236
+ dhole
237
+ excavator
238
+ fig
239
+ snake
240
+ bradypod
241
+ pepitas
242
+ prairie_chicken
243
+ scorpion
244
+ shotgun
245
+ bottle_cap
246
+ file_cabinet
247
+ grey_whale
248
+ one-armed_bandit
249
+ banded_gecko
250
+ flying_disc
251
+ croissant
252
+ toothbrush
253
+ miniskirt
254
+ pokermon_ball
255
+ gazelle
256
+ grey_fox
257
+ esport_chair
258
+ necklace
259
+ ptarmigan
260
+ watermelon
261
+ besom
262
+ pomelo
263
+ radio_telescope
264
+ studio_couch
265
+ black_stork
266
+ vestment
267
+ koala
268
+ brambling
269
+ muscle_car
270
+ window_shade
271
+ space_heater
272
+ sunglasses
273
+ motor_scooter
274
+ ladyfinger
275
+ pencil_box
276
+ titi_monkey
277
+ chicken_wings
278
+ mount_fuji
279
+ giant_panda
280
+ dart
281
+ fire_engine
282
+ running_shoe
283
+ dumbbell
284
+ donkey
285
+ loafer
286
+ hard_disk
287
+ globe
288
+ lifeboat
289
+ medical_kit
290
+ brain_coral
291
+ paper_towel
292
+ dugong
293
+ seatbelt
294
+ skunk
295
+ military_vest
296
+ cocktail_shaker
297
+ zucchini
298
+ quad_drone
299
+ ocicat
300
+ shih-tzu
301
+ teapot
302
+ tile_roof
303
+ cheese_burger
304
+ handshower
305
+ red_wolf
306
+ stop_sign
307
+ mouse
308
+ battery
309
+ adidas_logo2
310
+ earplug
311
+ hummingbird
312
+ brush_pen
313
+ pistachio
314
+ hamster
315
+ air_strip
316
+ indian_elephant
317
+ otter
318
+ cucumber
319
+ scabbard
320
+ hawthorn
321
+ bullet_train
322
+ leopard
323
+ whale
324
+ cream
325
+ chinese_date
326
+ jellyfish
327
+ lobster
328
+ skua
329
+ single_log
330
+ chicory
331
+ bagel
332
+ beacon
333
+ pingpong_racket
334
+ spoon
335
+ yurt
336
+ wallaby
337
+ egret
338
+ christmas_stocking
339
+ mcdonald_uncle
340
+ wrench
341
+ spark_plug
342
+ triceratops
343
+ wall_clock
344
+ jinrikisha
345
+ pickup
346
+ rhinoceros
347
+ swimming_trunk
348
+ band-aid
349
+ spotted_salamander
350
+ leeks
351
+ marmot
352
+ warthog
353
+ cello
354
+ stool
355
+ chest
356
+ toilet_plunger
357
+ wardrobe
358
+ cannon
359
+ adidas_logo1
360
+ drumstick
361
+ lady_slipper
362
+ puma_logo
363
+ great_wall
364
+ white_shark
365
+ witch_hat
366
+ vending_machine
367
+ wreck
368
+ chopsticks
369
+ garfish
370
+ african_elephant
371
+ children_slide
372
+ hornbill
373
+ zebra
374
+ boa_constrictor
375
+ armour
376
+ pineapple
377
+ angora
378
+ brick
379
+ car_wheel
380
+ wallet
381
+ boston_bull
382
+ hyena
383
+ lynx
384
+ crash_helmet
385
+ terrapin_turtle
386
+ persian_cat
387
+ shift_gear
388
+ cactus_ball
389
+ fur_coat
390
+ plate
391
+ pen
392
+ okra
393
+ mario
394
+ airedale
395
+ cowboy_hat
396
+ celery
397
+ macaque
398
+ candle
399
+ goose
400
+ raccoon
401
+ brasscica
402
+ almond
403
+ maotai_bottle
404
+ soccer_ball
405
+ sports_car
406
+ tobacco_pipe
407
+ water_polo
408
+ eggnog
409
+ hook
410
+ ostrich
411
+ patas
412
+ table_lamp
413
+ teddy
414
+ mongoose
415
+ spoonbill
416
+ redheart
417
+ crane
418
+ dinosaur
419
+ kitchen_knife
420
+ seal
421
+ baboon
422
+ golfcart
423
+ roller_coaster
424
+ avocado
425
+ birdhouse
426
+ yorkshire_terrier
427
+ saluki
428
+ basketball
429
+ buckler
430
+ harvester
431
+ afghan_hound
432
+ beam_bridge
433
+ guinea_pig
434
+ lorikeet
435
+ shakuhachi
436
+ motarboard
437
+ statue_liberty
438
+ police_car
439
+ sulphur_crested
440
+ gourd
441
+ sombrero
442
+ mailbox
443
+ adhensive_tape
444
+ night_snake
445
+ bushtit
446
+ mouthpiece
447
+ beaver
448
+ bathtub
449
+ printer
450
+ cumquat
451
+ orange
452
+ cleaver
453
+ quill_pen
454
+ panpipe
455
+ diamond
456
+ gypsy_moth
457
+ cauliflower
458
+ lampshade
459
+ cougar
460
+ traffic_light
461
+ briefcase
462
+ ballpoint
463
+ african_grey
464
+ kremlin
465
+ barometer
466
+ peacock
467
+ paper_crane
468
+ sunscreen
469
+ tofu
470
+ bedlington_terrier
471
+ snowball
472
+ carrot
473
+ tiger
474
+ mink
475
+ cristo_redentor
476
+ ladle
477
+ keyboard
478
+ maraca
479
+ monitor
480
+ water_snake
481
+ can_opener
482
+ mud_turtle
483
+ bald_eagle
484
+ carp
485
+ cn_tower
486
+ egyptian_cat
487
+ hen_of_the_woods
488
+ measuring_cup
489
+ roller_skate
490
+ kite
491
+ sandwich_cookies
492
+ sandwich
493
+ persimmon
494
+ chess_bishop
495
+ coffin
496
+ ruddy_turnstone
497
+ prayer_rug
498
+ rain_barrel
499
+ neck_brace
500
+ nematode
501
+ rosehip
502
+ dutch_oven
503
+ goldfish
504
+ blossom_card
505
+ dough
506
+ trench_coat
507
+ sponge
508
+ stupa
509
+ wash_basin
510
+ electric_fan
511
+ spring_scroll
512
+ potted_plant
513
+ sparrow
514
+ car_mirror
515
+ gecko
516
+ diaper
517
+ leatherback_turtle
518
+ strainer
519
+ guacamole
520
+ microwave
data/splits/fss/val.txt ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ handcuff
2
+ mortar
3
+ matchstick
4
+ wine_bottle
5
+ dowitcher
6
+ triumphal_arch
7
+ gyromitra
8
+ hatchet
9
+ airliner
10
+ broccoli
11
+ olive
12
+ pubg_lvl3backpack
13
+ calculator
14
+ toucan
15
+ shovel
16
+ sewing_machine
17
+ icecream
18
+ woodpecker
19
+ pig
20
+ relay_stick
21
+ mcdonald_sign
22
+ cpu
23
+ peanut
24
+ pumpkin
25
+ sturgeon
26
+ hammer
27
+ hami_melon
28
+ squirrel_monkey
29
+ shuriken
30
+ power_drill
31
+ pingpong_ball
32
+ crocodile
33
+ carambola
34
+ monarch_butterfly
35
+ drum
36
+ water_tower
37
+ panda
38
+ toilet_brush
39
+ pay_phone
40
+ yonex_icon
41
+ cricketball
42
+ revolver
43
+ chimpanzee
44
+ crab
45
+ corn
46
+ baseball
47
+ rabbit
48
+ croquet_ball
49
+ artichoke
50
+ abacus
51
+ harp
52
+ bell
53
+ gas_tank
54
+ scissors
55
+ vase
56
+ upright_piano
57
+ typewriter
58
+ bittern
59
+ impala
60
+ tray
61
+ fire_hydrant
62
+ beer_bottle
63
+ sock
64
+ soup_bowl
65
+ spider
66
+ cherry
67
+ macaw
68
+ toilet_seat
69
+ fire_balloon
70
+ french_ball
71
+ fox_squirrel
72
+ volleyball
73
+ cornmeal
74
+ folding_chair
75
+ pubg_airdrop
76
+ beagle
77
+ skateboard
78
+ narcissus
79
+ whiptail
80
+ cup
81
+ arabian_camel
82
+ badger
83
+ stopwatch
84
+ ab_wheel
85
+ ox
86
+ lettuce
87
+ monocycle
88
+ redshank
89
+ vulture
90
+ whistle
91
+ smoothing_iron
92
+ mashed_potato
93
+ conveyor
94
+ yoga_pad
95
+ tow_truck
96
+ siamese_cat
97
+ cigar
98
+ white_stork
99
+ sniper_rifle
100
+ stretcher
101
+ tulip
102
+ handkerchief
103
+ basset
104
+ iceberg
105
+ gibbon
106
+ lacewing
107
+ thrush
108
+ cheetah
109
+ bighorn_sheep
110
+ espresso_maker
111
+ pretzel
112
+ english_setter
113
+ sandbar
114
+ cheese
115
+ daisy
116
+ arctic_fox
117
+ briard
118
+ colubus
119
+ balance_beam
120
+ coffeepot
121
+ soap_dispenser
122
+ yawl
123
+ consomme
124
+ parking_meter
125
+ cactus
126
+ turnstile
127
+ taro
128
+ fire_screen
129
+ digital_clock
130
+ rose
131
+ pomegranate
132
+ bee_eater
133
+ schooner
134
+ ski_mask
135
+ jay_bird
136
+ plaice
137
+ red_fox
138
+ syringe
139
+ camomile
140
+ pickelhaube
141
+ blenheim_spaniel
142
+ pear
143
+ parachute
144
+ common_newt
145
+ bowtie
146
+ cigarette
147
+ oscilloscope
148
+ laptop
149
+ african_crocodile
150
+ apron
151
+ coconut
152
+ sandal
153
+ kwanyin
154
+ lion
155
+ eel
156
+ balloon
157
+ crepe
158
+ armadillo
159
+ kazoo
160
+ lemon
161
+ spider_monkey
162
+ tape_player
163
+ ipod
164
+ bee
165
+ sea_cucumber
166
+ suitcase
167
+ television
168
+ pillow
169
+ banjo
170
+ rock_snake
171
+ partridge
172
+ platypus
173
+ lycaenid_butterfly
174
+ pinecone
175
+ conversion_plug
176
+ wolf
177
+ frying_pan
178
+ timber_wolf
179
+ bluetick
180
+ crayon
181
+ giant_schnauzer
182
+ orang
183
+ scarerow
184
+ kobe_logo
185
+ loguat
186
+ saxophone
187
+ ceiling_fan
188
+ cardoon
189
+ equestrian_helmet
190
+ louvre_pyramid
191
+ hotdog
192
+ ironing_board
193
+ razor
194
+ nagoya_castle
195
+ loggerhead_turtle
196
+ lipstick
197
+ cradle
198
+ strongbox
199
+ raven
200
+ kit_fox
201
+ albatross
202
+ flat-coated_retriever
203
+ beer_glass
204
+ ice_lolly
205
+ sungnyemun
206
+ totem_pole
207
+ vacuum
208
+ bolete
209
+ mango
210
+ ginger
211
+ weasel
212
+ cabbage
213
+ refrigerator
214
+ school_bus
215
+ hippo
216
+ tiger_cat
217
+ saltshaker
218
+ piano_keyboard
219
+ windsor_tie
220
+ sea_urchin
221
+ microsd
222
+ barbell
223
+ swim_ring
224
+ bulbul_bird
225
+ water_ouzel
226
+ ac_ground
227
+ sweatshirt
228
+ umbrella
229
+ hair_drier
230
+ hammerhead_shark
231
+ tomato
232
+ projector
233
+ cushion
234
+ dishwasher
235
+ three-toed_sloth
236
+ tiger_shark
237
+ har_gow
238
+ baby
239
+ thor's_hammer
240
+ nike_logo
data/suim.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" FSS-1000 few-shot semantic segmentation dataset """
2
+ import os
3
+ import glob
4
+
5
+ from torch.utils.data import Dataset
6
+ import torch.nn.functional as F
7
+ import torch
8
+ import PIL.Image as Image
9
+ import numpy as np
10
+
11
+
12
+ class DatasetSUIM(Dataset):
13
+ def __init__(self, datapath, fold, transform, split, shot, num_val=600):
14
+ self.split = split
15
+ self.benchmark = 'suim'
16
+ self.shot = shot
17
+ self.num_val = num_val
18
+
19
+ self.base_path = os.path.join(datapath)
20
+ self.img_path = os.path.join(self.base_path, 'images')
21
+ self.ann_path = os.path.join(self.base_path, 'masks')
22
+
23
+ self.categories = ['FV','HD','PF','RI','RO','SR','WR']
24
+
25
+ self.class_ids = range(len(self.categories))
26
+ self.img_metadata_classwise, self.num_images = self.build_img_metadata_classwise()
27
+
28
+ self.transform = transform
29
+
30
+ def __len__(self):
31
+ # if it is the target domain, then also test on entire dataset
32
+ return self.num_images if self.split !='val' else self.num_val
33
+
34
+ def __getitem__(self, idx):
35
+ query_name, support_names, class_sample = self.sample_episode(idx)
36
+ query_img, query_mask, support_imgs, support_masks = self.load_frame(query_name, support_names)
37
+
38
+ query_img = self.transform(query_img)
39
+ query_mask = F.interpolate(query_mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze()
40
+
41
+ support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs])
42
+
43
+ support_masks_tmp = []
44
+ for smask in support_masks:
45
+ smask = F.interpolate(smask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze()
46
+ support_masks_tmp.append(smask)
47
+ support_masks = torch.stack(support_masks_tmp)
48
+
49
+ batch = {'query_img': query_img,
50
+ 'query_mask': query_mask,
51
+ 'support_set': (support_imgs, support_masks),
52
+ 'support_classes': torch.tensor([class_sample]), # adapt to Nway
53
+
54
+ 'query_name': query_name, # REMOVE
55
+ 'support_imgs': support_imgs, # REMOVE
56
+ 'support_masks': support_masks, # REMOVE
57
+ 'support_names': support_names, # REMOVE
58
+ 'class_id': torch.tensor(class_sample)} # REMOVE
59
+
60
+ return batch
61
+
62
+
63
+ def load_frame(self, query_mask_path, support_mask_paths):
64
+ def maskpath_to_imgpath(maskpath):
65
+ filename, imgext = maskpath.split('/')[-1].split('.')[0], '.jpg'
66
+ return os.path.join(self.img_path, filename) + imgext
67
+
68
+ query_img = Image.open(maskpath_to_imgpath(query_mask_path)).convert('RGB')
69
+
70
+ support_imgs = [Image.open(maskpath_to_imgpath(s_mask_path)).convert('RGB') for s_mask_path in support_mask_paths]
71
+
72
+ query_mask = self.read_mask(query_mask_path)
73
+ support_masks = [self.read_mask(s_mask_path) for s_mask_path in support_mask_paths]
74
+
75
+ return query_img, query_mask, support_imgs, support_masks
76
+
77
+ def read_mask(self, img_name):
78
+ mask = torch.tensor(np.array(Image.open(img_name).convert('L')))
79
+ mask[mask < 128] = 0
80
+ mask[mask >= 128] = 1
81
+ return mask
82
+
83
+ def sample_episode(self, idx):
84
+ class_id = idx % len(self.class_ids)
85
+ class_sample = self.categories[class_id]
86
+
87
+ query_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
88
+ support_names = []
89
+ while True: # keep sampling support set if query == support
90
+ support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
91
+ if query_name != support_name: support_names.append(support_name)
92
+ if len(support_names) == self.shot: break
93
+
94
+ return query_name, support_names, class_id
95
+
96
+
97
+ # def build_img_metadata(self):
98
+ # img_metadata = []
99
+ # for cat in self.categories:
100
+ # os.path.join(self.base_path, cat)
101
+ # img_paths = sorted([path for path in glob.glob('%s/*' % os.path.join(self.base_path, cat, 'test', 'origin'))])
102
+ # for img_path in img_paths:
103
+ # if os.path.basename(img_path).split('.')[1] == 'jpg':
104
+ # img_metadata.append(img_path)
105
+ # return img_metadata
106
+
107
+ def build_img_metadata_classwise(self):
108
+ num_images=0
109
+ img_metadata_classwise = {}
110
+ for cat in self.categories:
111
+ img_metadata_classwise[cat] = []
112
+
113
+ for cat in self.categories:
114
+ mask_paths = sorted([path for path in glob.glob('%s/*' % os.path.join(self.base_path, 'masks', cat))])
115
+ for mask_path in mask_paths:
116
+ if self.read_mask(mask_path).count_nonzero() > 0: #no empty masks
117
+ img_metadata_classwise[cat] += [mask_path]
118
+ num_images += 1
119
+ return img_metadata_classwise, num_images
eval/evaluation.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Evaluate mask prediction """
2
+ import torch
3
+
4
+
5
+ class Evaluator:
6
+ r""" Computes intersection and union between prediction and ground-truth """
7
+ @classmethod
8
+ def initialize(cls):
9
+ cls.ignore_index = 255
10
+
11
+ @classmethod
12
+ def classify_prediction(cls, pred_mask, batch):
13
+ gt_mask = batch.get('query_mask')
14
+
15
+ # Apply ignore_index in PASCAL-5i masks (following evaluation scheme in PFE-Net (TPAMI 2020))
16
+ query_ignore_idx = batch.get('query_ignore_idx')
17
+ if query_ignore_idx is not None:
18
+ assert torch.logical_and(query_ignore_idx, gt_mask).sum() == 0
19
+ query_ignore_idx *= cls.ignore_index
20
+ gt_mask = gt_mask + query_ignore_idx
21
+ pred_mask[gt_mask == cls.ignore_index] = cls.ignore_index
22
+
23
+ # compute intersection and union of each episode in a batch
24
+ area_inter, area_pred, area_gt = [], [], []
25
+ for _pred_mask, _gt_mask in zip(pred_mask, gt_mask):
26
+ _inter = _pred_mask[_pred_mask == _gt_mask]
27
+ if _inter.size(0) == 0: # as torch.histc returns error if it gets empty tensor (pytorch 1.5.1)
28
+ _area_inter = torch.tensor([0, 0], device=_pred_mask.device)
29
+ else:
30
+ _area_inter = torch.histc(_inter, bins=2, min=0, max=1)
31
+ area_inter.append(_area_inter)
32
+ area_pred.append(torch.histc(_pred_mask, bins=2, min=0, max=1))
33
+ area_gt.append(torch.histc(_gt_mask, bins=2, min=0, max=1))
34
+ area_inter = torch.stack(area_inter).t()
35
+ area_pred = torch.stack(area_pred).t()
36
+ area_gt = torch.stack(area_gt).t()
37
+ area_union = area_pred + area_gt - area_inter
38
+
39
+ return area_inter, area_union
eval/logger.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Logging during training/testing """
2
+ import datetime
3
+ import logging
4
+ import os
5
+
6
+ from tensorboardX import SummaryWriter
7
+ import torch
8
+
9
+ class AverageMeter:
10
+ r""" Stores loss, evaluation results """
11
+ def __init__(self, dataset, device='cuda'):
12
+ self.benchmark = dataset.benchmark
13
+ if self.benchmark == 'pascal':
14
+ self.class_ids_interest = dataset.class_ids
15
+ self.class_ids_interest = torch.tensor(self.class_ids_interest).to(device)
16
+ self.nclass = 20
17
+ elif self.benchmark == 'fss':
18
+ self.class_ids_interest = dataset.class_ids
19
+ self.class_ids_interest = torch.tensor(self.class_ids_interest).to(device)
20
+ self.nclass = 1000
21
+ elif self.benchmark == 'deepglobe':
22
+ self.class_ids_interest = dataset.class_ids
23
+ self.class_ids_interest = torch.tensor(self.class_ids_interest).to(device)
24
+ self.nclass = 6
25
+ elif self.benchmark == 'isic':
26
+ self.class_ids_interest = dataset.class_ids
27
+ self.class_ids_interest = torch.tensor(self.class_ids_interest).to(device)
28
+ self.nclass = 3
29
+ elif self.benchmark == 'lung':
30
+ self.class_ids_interest = dataset.class_ids
31
+ self.class_ids_interest = torch.tensor(self.class_ids_interest).to(device)
32
+ self.nclass = 1
33
+ elif self.benchmark == 'suim':
34
+ self.class_ids_interest = dataset.class_ids
35
+ self.class_ids_interest = torch.tensor(self.class_ids_interest).to(device)
36
+ self.nclass = 7
37
+ else:
38
+ raise Exception('Unknown dataset: %s' % dataset)
39
+
40
+ self.intersection_buf = torch.zeros([2, self.nclass]).float().to(device)
41
+ self.union_buf = torch.zeros([2, self.nclass]).float().to(device)
42
+ self.ones = torch.ones_like(self.union_buf)
43
+ self.loss_buf = []
44
+
45
+ def update(self, inter_b, union_b, class_id, loss):
46
+ self.intersection_buf.index_add_(1, class_id, inter_b.float())
47
+ self.union_buf.index_add_(1, class_id, union_b.float())
48
+ if loss is None:
49
+ loss = torch.tensor(0.0)
50
+ self.loss_buf.append(loss)
51
+
52
+ def compute_iou(self):
53
+ iou = self.intersection_buf.float() / \
54
+ torch.max(torch.stack([self.union_buf, self.ones]), dim=0)[0]
55
+ iou = iou.index_select(1, self.class_ids_interest)
56
+ miou = iou[1].mean() * 100
57
+
58
+ fb_iou = (self.intersection_buf.index_select(1, self.class_ids_interest).sum(dim=1) /
59
+ self.union_buf.index_select(1, self.class_ids_interest).sum(dim=1)).mean() * 100
60
+
61
+ return miou, fb_iou
62
+
63
+ def write_result(self, split, epoch):
64
+ iou,fb_iou = self.compute_iou()
65
+
66
+ loss_buf = torch.stack(self.loss_buf)
67
+ msg = '\n*** %s ' % split
68
+ msg += '[@Epoch %02d] ' % epoch
69
+ msg += 'Avg L: %6.5f ' % loss_buf.mean()
70
+ msg += 'mIoU: %5.2f ' % iou
71
+ msg += 'FB-IoU: %5.2f ' % fb_iou
72
+
73
+ msg += '***\n'
74
+ Logger.info(msg)
75
+
76
+ def write_process(self, batch_idx, datalen, epoch, write_batch_idx=20):
77
+ if batch_idx % write_batch_idx == 0:
78
+ msg = '[Epoch: %02d] ' % epoch if epoch != -1 else ''
79
+ msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen)
80
+ iou,fb_iou = self.compute_iou()
81
+ if epoch != -1:
82
+ loss_buf = torch.stack(self.loss_buf)
83
+ msg += 'L: %6.5f ' % loss_buf[-1]
84
+ msg += 'Avg L: %6.5f ' % loss_buf.mean()
85
+ msg += 'mIoU: %5.2f | ' % iou
86
+ msg += 'FB-IoU: %5.2f' % fb_iou
87
+ Logger.info(msg)
88
+
89
+
90
+ class Logger:
91
+ r""" Writes evaluation results of training/testing """
92
+ @classmethod
93
+ def initialize(cls, args, training):
94
+ logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S')
95
+ logpath = args.logpath if training else args.logpath + '_TEST_' + logtime # changed lopath created for test
96
+ if logpath == '': logpath = logtime
97
+
98
+ cls.logpath = os.path.join('logs', logpath + '.log')
99
+ cls.benchmark = args.benchmark
100
+ print("logdir: ",cls.logpath)
101
+ os.makedirs(cls.logpath)
102
+
103
+ logging.basicConfig(filemode='w',
104
+ filename=os.path.join(cls.logpath, 'log.txt'),
105
+ level=logging.INFO,
106
+ format='%(message)s',
107
+ datefmt='%m-%d %H:%M:%S')
108
+
109
+ # Console log config
110
+ console = logging.StreamHandler()
111
+ console.setLevel(logging.INFO)
112
+ formatter = logging.Formatter('%(message)s')
113
+ console.setFormatter(formatter)
114
+ logging.getLogger('').addHandler(console)
115
+
116
+ # Tensorboard writer
117
+ cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs'))
118
+
119
+ # Log arguments
120
+ logging.info('\n:=========== Adapt Before Comparison - A New Perspective on Cross-Domain Few-Shot Segmentation ===========')
121
+ for arg_key in args.__dict__:
122
+ logging.info('| %20s: %-24s' % (arg_key, str(args.__dict__[arg_key])))
123
+ logging.info(':================================================\n')
124
+
125
+ @classmethod
126
+ def info(cls, msg):
127
+ r""" Writes log message to log.txt """
128
+ logging.info(msg)
129
+
130
+ @classmethod
131
+ def save_model_miou(cls, model, epoch, val_miou):
132
+ torch.save(model.state_dict(), os.path.join(cls.logpath, 'best_model.pt'))
133
+ cls.info('Model saved @%d w/ val. mIoU: %5.2f.\n' % (epoch, val_miou))
134
+
135
+ @classmethod
136
+ def log_params(cls, model):
137
+ backbone_param = 0
138
+ learner_param = 0
139
+ for k in model.state_dict().keys():
140
+ n_param = model.state_dict()[k].view(-1).size(0)
141
+ if k.split('.')[0] in 'backbone':
142
+ if k.split('.')[1] in ['classifier', 'fc']: # as fc layers are not used in HSNet
143
+ continue
144
+ backbone_param += n_param
145
+ else:
146
+ learner_param += n_param
147
+ Logger.info('Backbone # param.: %d' % backbone_param)
148
+ Logger.info('Learnable # param.: %d' % learner_param)
149
+ Logger.info('Total # param.: %d' % (backbone_param + learner_param))
main.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from core import runner
2
+ import torch
3
+ import argparse
4
+
5
+ def parse_opts():
6
+ r"""arguments"""
7
+ parser = argparse.ArgumentParser(description='Adapt Before Comparison - A New Perspective on Cross-Domain Few-Shot Segmentation')
8
+
9
+ # common
10
+ parser.add_argument('--benchmark', type=str, default='lung', choices=['fss', 'deepglobe', 'lung', 'isic', 'fss', 'lung'])
11
+ parser.add_argument('--datapath', type=str)
12
+ parser.add_argument('--nshot', type=int, default=1)
13
+
14
+ args = parser.parse_args()
15
+ return args
16
+
17
+
18
+ if __name__ == '__main__':
19
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
+ args = parse_opts()
21
+ print(args)
22
+ runner.args.benchmark = args.benchmark
23
+ runner.args.datapath = args.datapath
24
+ runner.args.nshot = args.nshot
25
+
26
+ dataloader = runner.makeDataloader()
27
+ config = runner.makeConfig()
28
+ feat_maker = runner.makeFeatureMaker(dataloader.dataset, config, device=device)
29
+ average_meter = runner.AverageMeterWrapper(dataloader, device)
30
+
31
+ for idx, batch in enumerate(dataloader):
32
+ sseval = runner.SingleSampleEval(batch, feat_maker)
33
+ sseval.forward()
34
+ sseval.calc_metrics()
35
+ average_meter.update(sseval)
36
+ average_meter.write(idx)
37
+ print('Result m|FB:', average_meter.average_meter.compute_iou())
utils/commonutils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Helper functions """
2
+ import random
3
+
4
+ import torch
5
+ import numpy as np
6
+
7
+
8
+ def fix_randseed(seed):
9
+ r""" Set random seeds for reproducibility """
10
+ if seed is None:
11
+ seed = int(random.random() * 1e5)
12
+ np.random.seed(seed)
13
+ torch.manual_seed(seed)
14
+ torch.cuda.manual_seed(seed)
15
+ torch.cuda.manual_seed_all(seed)
16
+ torch.backends.cudnn.benchmark = False
17
+ torch.backends.cudnn.deterministic = True
18
+
19
+
20
+ def mean(x):
21
+ return sum(x) / len(x) if len(x) > 0 else 0.0
22
+
23
+
24
+ def to_cuda(batch):
25
+ for key, value in batch.items():
26
+ if isinstance(value, torch.Tensor):
27
+ batch[key] = value.cuda()
28
+ return batch
29
+
30
+
31
+ def to_cpu(tensor):
32
+ return tensor.detach().clone().cpu()
utils/segutils.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ from torchvision import transforms
6
+ from PIL import Image, ImageDraw
7
+ import torch.nn.functional as F
8
+
9
+ norm = lambda t: (t - t.min()) / (t.max() - t.min())
10
+ denorm = lambda t, min_, max_: t * (max_ - min_) + min_
11
+
12
+ percentilerange = lambda t, perc: t.min() + perc * (t.max() - t.min())
13
+ midrange = lambda t: percentilerange(t, .5)
14
+
15
+ downsample_mask = lambda mask, H, W: F.interpolate(mask.unsqueeze(1), size=(H, W), mode='bilinear',
16
+ align_corners=False).squeeze(1)
17
+
18
+
19
+ # downsampled_mask: [bsz,vecs], vecs can be H*W for example
20
+ # s_feat_volume: [bsz,c,vecs]
21
+ # returns [bsz,c], [bsz,c,vecs]
22
+ def fg_bg_proto(sfeat_volume, downsampled_smask):
23
+ B, C, vecs = sfeat_volume.shape
24
+ reshaped_mask = downsampled_smask.expand(B, vecs).unsqueeze(1) # ->[B,1,vecs]
25
+
26
+ masked_fg = reshaped_mask * sfeat_volume
27
+ fg_proto = torch.sum(masked_fg, dim=-1) / (torch.sum(reshaped_mask, dim=-1) + 1e-8)
28
+
29
+ masked_bg = (1 - reshaped_mask) * sfeat_volume
30
+ bg_proto = torch.sum(masked_bg, dim=-1) / (torch.sum(1 - reshaped_mask, dim=-1) + 1e-8)
31
+ assert fg_proto.shape == (B, C), ":o"
32
+ return fg_proto, bg_proto
33
+
34
+
35
+ # intersection = lambda pred, target: (pred * target).float().sum()
36
+ # union = lambda pred, target: (pred + target).clamp(0, 1).float().sum()
37
+ #
38
+ #
39
+ # def iou(pred, target): # binary only, input bsz,h,w
40
+ # i, u = intersection(pred, target), union(pred, target)
41
+ # iou = (i + 1e-8) / (u + 1e-8)
42
+ # return iou.item()
43
+ #
44
+ #
45
+ # class SimpleAvgMeter:
46
+ # def __init__(self, n_classes, device=torch.device('cuda')):
47
+ # self.n_lasses = n_classes
48
+ # self.intersection_buf = torch.zeros(n_classes).to(device)
49
+ # self.union_buf = torch.zeros(n_classes).to(device)
50
+ #
51
+ # def update(self, pred, target, class_id):
52
+ # self.intersection_buf[class_id] += intersection(pred, target)
53
+ # self.union_buf[class_id] += union(pred, target)
54
+ #
55
+ # def IoU(self, class_id):
56
+ # return self.intersection_buf[class_id] / self.union_buf[class_id] * 100
57
+ #
58
+ # def cls_mIoU(self, class_ids):
59
+ # return (self.intersection_buf[class_ids] / self.union_buf[class_ids]).mean() * 100
60
+ #
61
+ # def compute_mIoU(self):
62
+ # noentry = self.union_buf == 0
63
+ # if noentry.sum() > 0: print("SimpleAvgMeter warning: ", noentry.sum(), "elements of", self.nclasses,
64
+ # "have no empty.")
65
+ # return self.cls_mIoU(~noentry)
66
+
67
+ # class KMeans():
68
+ # # expects input to be in shape [bsz, -1]
69
+ # def __init__(self, data, k=2, num_iterations=10):
70
+ # self.k = k
71
+ # self.device = data.device
72
+ # self.centroids = self._init_centroids(data)
73
+ #
74
+ # for _ in range(num_iterations):
75
+ # labels = self._assign_clusters(data)
76
+ # self._update_centroids(data, labels)
77
+ #
78
+ # self.labels = self._assign_clusters(data) # Final cluster assignment
79
+ #
80
+ # def _init_centroids(self, data):
81
+ # # Randomly initialize centroids
82
+ # centroids = []
83
+ # min_values = data.min(dim=1, keepdim=True).values
84
+ # range_values = (data.max(dim=1, keepdim=True).values - min_values)
85
+ #
86
+ # for _ in range(self.k):
87
+ # random_values = torch.rand((data.shape[0], 1)).to(self.device)
88
+ # centroids.append(min_values + random_values * range_values)
89
+ #
90
+ # return torch.cat(centroids, dim=1)
91
+ #
92
+ # def _assign_clusters(self, data):
93
+ # # Calculate distances between data points and centroids
94
+ # distances = torch.abs(data.unsqueeze(2) - self.centroids) # Expand data tensor to calculate distances
95
+ # # Determine the closest centroid for each data point
96
+ # labels = torch.argmin(distances, dim=2)
97
+ # # Sort labels so that the largest mean data point has the highest label
98
+ # cluster_means = [data[labels == k].mean() for k in range(self.k)]
99
+ # sorted_labels = {k: rank for rank, k in enumerate(sorted(range(self.k), key=lambda k: cluster_means[k]))}
100
+ # labels = torch.tensor([sorted_labels[label.item()] for label in labels.flatten()]).reshape_as(labels).to(
101
+ # self.device)
102
+ #
103
+ # return labels
104
+ #
105
+ # def _update_centroids(self, data, labels):
106
+ # # Calculate new centroids as the mean of the data points closest to each centroid
107
+ # mask = torch.nn.functional.one_hot(labels, num_classes=self.k).to(torch.float32)
108
+ # summed_data = torch.bmm(mask.transpose(1, 2), data.unsqueeze(2)) # Sum data points per centroid
109
+ # self.centroids = summed_data.squeeze() / mask.sum(dim=1, keepdim=True) # Normalize to get the mean
110
+ #
111
+ # def compute_thresholds(self):
112
+ # # Flatten the centroids along the middle dimension
113
+ # flat_centroids = self.centroids.view(self.centroids.size(0), -1)
114
+ #
115
+ # # Sort the flattened centroids
116
+ # sorted_centroids, _ = torch.sort(flat_centroids, dim=1)
117
+ #
118
+ # # Compute the midpoints between consecutive centroids
119
+ # thresholds = (sorted_centroids[:, :-1] + sorted_centroids[:, 1:]) / 2.0
120
+ #
121
+ # return thresholds
122
+ #
123
+ # def inference(self, data):
124
+ # # Assign data points to the nearest centroid
125
+ # return self._assign_clusters(data)
126
+
127
+ # def iterative_triclass_thresholding(image, max_iterations=100, tolerance=25):
128
+ # # Ensure image is grayscale
129
+ # if len(image.shape) == 3:
130
+ # image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
131
+ #
132
+ # # Initialize iteration parameters
133
+ # TBD_region = image.copy()
134
+ # iteration = 0
135
+ # prev_threshold = 0
136
+ #
137
+ # while iteration < max_iterations:
138
+ # iteration += 1
139
+ #
140
+ # # Step 1: Apply Otsu's thresholding on the TBD region
141
+ # current_threshold, _ = cv2.threshold(TBD_region, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
142
+ #
143
+ # # Check stopping criteria
144
+ # if abs(current_threshold - prev_threshold) < tolerance:
145
+ # break
146
+ # prev_threshold = current_threshold
147
+ #
148
+ # # Step 2: Calculate means for upper and lower regions
149
+ # upper_region = TBD_region[TBD_region > current_threshold]
150
+ # lower_region = TBD_region[TBD_region <= current_threshold]
151
+ #
152
+ # if len(upper_region) == 0 or len(lower_region) == 0:
153
+ # break # No further division possible
154
+ #
155
+ # mean_upper = np.mean(upper_region)
156
+ # mean_lower = np.mean(lower_region)
157
+ #
158
+ # # Step 3: Update temporary foreground, background, and TBD regions
159
+ # TBD_region[(TBD_region > mean_upper)] = 255 # Temporary foreground F
160
+ # TBD_region[(TBD_region < mean_lower)] = 0 # Temporary background B
161
+ #
162
+ # # Extracting the new TBD region (between mean_lower and mean_upper)
163
+ # mask = (TBD_region > mean_lower) & (TBD_region < mean_upper)
164
+ # TBD_region = TBD_region[mask] # Apply mask to extract region
165
+ #
166
+ # # Final classification after convergence or max iterations
167
+ # final_foreground = (image > current_threshold).astype(np.uint8) * 255
168
+ # final_background = (image <= current_threshold).astype(np.uint8) * 255
169
+ #
170
+ # return current_threshold, final_foreground
171
+
172
+ def otsus(batched_tensor_image, drop_least=0.05, mode='ordinary'):
173
+ bsz = batched_tensor_image.size(0)
174
+ binary_tensors = []
175
+ thresholds = []
176
+
177
+ for i in range(bsz):
178
+ # Convert the tensor to numpy array
179
+ numpy_image = batched_tensor_image[i].cpu().numpy()
180
+
181
+ # Rescale to [0, 255] and convert to uint8 type for OpenCV compatibility
182
+ npmin, npmax = numpy_image.min(), numpy_image.max()
183
+ numpy_image = (norm(numpy_image) * 255).astype(np.uint8)
184
+
185
+ # Drop values that are in the lowest percentiles
186
+ truncated_vals = numpy_image[numpy_image >= int(255 * drop_least)]
187
+
188
+ # Apply Otsu's thresholding
189
+ if mode == 'via_triclass':
190
+ thresh_value, _ = iterative_triclass_thresholding(truncated_vals)
191
+ else:
192
+ thresh_value, _ = cv2.threshold(truncated_vals, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
193
+
194
+ # Apply the computed threshold on the original image
195
+ binary_image = (numpy_image > thresh_value).astype(np.uint8) * 255
196
+
197
+ # Convert the result back to a tensor and append to the list
198
+ binary_tensors.append(torch.from_numpy(binary_image).float() / 255)
199
+
200
+ thresholds.append(torch.tensor(denorm(thresh_value / 255, npmin, npmax)) \
201
+ .to(batched_tensor_image.device, dtype=batched_tensor_image.dtype))
202
+
203
+ # Convert list of tensors back to a single batched tensor
204
+ binary_tensor_batch = torch.stack(binary_tensors, dim=0)
205
+ thresh_batch = torch.stack(thresholds, dim=0)
206
+ return thresh_batch, binary_tensor_batch
207
+
208
+
209
+ def iterative_otsus(probab_mask, s_mask, maxiters=5, mode='ordinary',
210
+ debug=False): # verify that it works correctly when batch_size >1
211
+ it = 1
212
+ otsuthresh = 0
213
+ assert probab_mask.min() >= 0 and probab_mask.max() <= 1, 'you should pass probabilites'
214
+ while True:
215
+ clipped = torch.where(probab_mask < otsuthresh, 0, probab_mask)
216
+ otsuthresh, newmask = otsus(clipped.detach(), drop_least=.02, mode=mode)
217
+ if otsuthresh >= s_mask.mean():
218
+ return otsuthresh.to(probab_mask.device), newmask.to(probab_mask.device)
219
+ if it >= maxiters:
220
+ if debug:
221
+ print('reached maxiter:', it, 'with thresh', otsuthresh.item(), \
222
+ 'removed', int(((clipped == 0).sum() / clipped.numel()).item() * 10000) / 100, \
223
+ '% at lower and and new min,max is', clipped[clipped > 0].min().item(), clipped.max().item())
224
+ display(pilImageRow(norm(probab_mask[0]), s_mask[0], maxwidth=300))
225
+ return s_mask.mean(), (probab_mask > s_mask.mean()).float() # otsuthresh
226
+ it += 1
227
+
228
+
229
+ # def upgrade_scipy():
230
+ # os.system('!pip install - -upgrade scipy')
231
+ #
232
+ #
233
+ # def slicRGB(q_img, n_segments=50, compactness=10., sigma=1, mask=None, debug=False):
234
+ # import skimage.segmentation as skseg
235
+ #
236
+ # rgb_labels = skseg.slic(q_img, n_segments=n_segments, compactness=compactness, sigma=sigma, mask=mask,
237
+ # enforce_connectivity=True)
238
+ #
239
+ # if debug:
240
+ # plt.imshow(skseg.mark_boundaries(q_img, rgb_labels))
241
+ # plt.show()
242
+ #
243
+ # return rgb_labels
244
+ #
245
+ #
246
+ #
247
+ # def slicRGBP(q_img, fg_pred, n_segments=30, compactness=0.1, sigma=1, mask=None, debug=False):
248
+ # import skimage.segmentation as skseg
249
+ #
250
+ # def concat_rgb_pred(rgbimg, pred):
251
+ # h, w = rgbimg.shape[:2]
252
+ # return np.concatenate((rgbimg, pred.reshape(h, w, 1)), axis=-1)
253
+ #
254
+ # rgbp_img = concat_rgb_pred(q_img, fg_pred)
255
+ # rgbp_labels = skseg.slic(rgbp_img, n_segments=n_segments, compactness=compactness, mask=mask, sigma=sigma,
256
+ # enforce_connectivity=True)
257
+ #
258
+ # if debug:
259
+ # rgb_labels = skseg.slic(q_img, n_segments=n_segments, compactness=10., sigma=sigma, mask=mask,
260
+ # enforce_connectivity=True)
261
+ # pred_labels = skseg.slic(fg_pred, n_segments=n_segments, compactness=compactness, sigma=sigma, mask=mask,
262
+ # channel_axis=None, enforce_connectivity=True)
263
+ #
264
+ # rows, cols = 1, 3
265
+ # fig, ax = plt.subplots(rows, cols, figsize=(10, 10), sharex=True, sharey=True)
266
+ # ax[0].imshow(skseg.mark_boundaries(q_img, rgbp_labels))
267
+ # ax[1].imshow(skseg.mark_boundaries(q_img, rgb_labels))
268
+ # ax[2].imshow(skseg.mark_boundaries(q_img, pred_labels))
269
+ # plt.show()
270
+ #
271
+ # return rgbp_labels
272
+ #
273
+ #
274
+ # def calc_cluster_means(label_id_map, fg_prob):
275
+ # fg_pred_clustered = np.zeros_like(fg_prob)
276
+ # label_ids = np.unique(label_id_map)
277
+ # for lab_id in label_ids:
278
+ # cluster = fg_prob[label_id_map == lab_id]
279
+ # fg_pred_clustered[label_id_map == lab_id] = cluster.mean()
280
+ # return fg_pred_clustered
281
+
282
+
283
+ def install_pydensecrf():
284
+ os.system('pip install git+https://github.com/lucasb-eyer/pydensecrf.git')
285
+
286
+
287
+ class CRF:
288
+ def __init__(self, gaussian_stdxy=(3, 3), gaussian_compat=3,
289
+ bilateral_stdxy=(80, 80), bilateral_compat=10, stdrgb=(13, 13, 13)):
290
+ self.gaussian_stdxy = gaussian_stdxy
291
+ self.gaussian_compat = gaussian_compat
292
+ self.bilateral_stdxy = bilateral_stdxy
293
+ self.bilateral_compat = bilateral_compat
294
+ self.stdrgb = stdrgb
295
+ self.iters = 5
296
+ self.debug = False
297
+
298
+ def refine(self, image_tensor, fg_probs, soft_thresh=None, T=1):
299
+
300
+ """
301
+ Refine segmentation using DenseCRF.
302
+
303
+ Args:
304
+ - image_tensor (tensor): Original image, shape [1, 3, H, W].
305
+ - fg_probs (tensor): Fg probabilities from the network, shape [1, H, W]
306
+ - soft_thresh: The preferred threshold for fg_probs for segmenting into binary prediction mask
307
+ - T: a temperature for softmax/sigmoid
308
+
309
+ Returns:
310
+ - Refined segmentation mask, shape [1, H, W].
311
+ """
312
+ try:
313
+ import pydensecrf.densecrf as dcrf
314
+ from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral
315
+ except ImportError as e:
316
+ print("pydensecrf not found. Installing...")
317
+ install_pydensecrf() # Ensure this function installs pydensecrf and handles any potential errors during installation.
318
+
319
+ # After installation, retry importing. This is placed inside the except block to avoid repeating the import statements.
320
+ try:
321
+ import pydensecrf.densecrf as dcrf
322
+ from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral
323
+ except ImportError as e:
324
+ print("Failed to import after installation. Please check the installation of pydensecrf.")
325
+ raise # This will raise the last exception that was handled by the except block
326
+
327
+ # We find the segmentation threshold that splits fg-bg
328
+ if soft_thresh is None:
329
+ soft_thresh, _ = otsus(fg_probs)
330
+ image_tensor, fg_probs, soft_thresh = image_tensor.cpu(), fg_probs.cpu(), soft_thresh.cpu()
331
+ # Then we presume at this threshold the probability should be 0.5
332
+ # probability 0 should stay 0, 1 should stay 1
333
+ # sigmoid=lambda x: 1/(1 + np.exp(-x))
334
+ fg_probs = torch.sigmoid(T * (fg_probs - soft_thresh))
335
+ probs = torch.stack([1 - fg_probs, fg_probs], dim=1) # crf expects both classes as input
336
+ if self.debug:
337
+ print('softthresh', soft_thresh)
338
+ print('fg_probs min max', fg_probs.min(), fg_probs.max())
339
+ # C: Number of classes
340
+ bsz, C, H, W = probs.shape
341
+
342
+ refined_masks = []
343
+ image_numpy = np.ascontiguousarray( \
344
+ (255 * image_tensor.permute(0, 2, 3, 1)).numpy().astype(np.uint8))
345
+ probs_numpy = probs.numpy()
346
+ for (image, prob) in zip(image_numpy, probs_numpy):
347
+ # Unary potentials
348
+ unary = np.ascontiguousarray(unary_from_softmax(prob))
349
+ d = dcrf.DenseCRF2D(W, H, C)
350
+ d.setUnaryEnergy(unary)
351
+
352
+ # Add pairwise potentials
353
+ d.addPairwiseGaussian(sxy=self.gaussian_stdxy, compat=self.gaussian_compat)
354
+ d.addPairwiseBilateral(sxy=self.bilateral_stdxy, srgb=self.stdrgb,
355
+ rgbim=image, compat=self.bilateral_compat)
356
+
357
+ # Perform inference
358
+ Q = d.inference(self.iters)
359
+ if self.debug:
360
+ print('Q:', np.array(Q).shape, np.array(Q)[0].mean(), np.array(Q).mean())
361
+ result = np.reshape(Q, (2, H, W)) # np.argmax(Q, axis=0).reshape((H, W))
362
+ refined_masks.append(result)
363
+
364
+ return torch.from_numpy(np.stack(refined_masks, axis=0))
365
+
366
+ # def iterrefine(self, iters, image_tensor, fg_probs, soft_thresh=None, T=1):
367
+ # q1 = fg_probs
368
+ # for iter in range(iters):
369
+ # print(q1.shape)
370
+ # q1 = self.refine(image_tensor, q1, soft_thresh=None, T=1)[:,1]
371
+ # return q1
372
+
373
+ def iterrefine(self, iters, q_img, fg_probs, thresh_fn, debug=False):
374
+ pred = fg_probs.unsqueeze(1).expand(1, 2, *fg_probs.shape[-2:])
375
+ for it in range(iters):
376
+ thresh = thresh_fn(pred[:, 1])[0]
377
+
378
+ if debug and i % 10 == 0:
379
+ print('thresh', thresh)
380
+ display(to_pil(pred[0, 1]))
381
+
382
+ pred = self.refine(q_img, pred[:, 1], soft_thresh=thresh)
383
+ return pred
384
+
385
+
386
+ #
387
+ # class Subplot:
388
+ # def __init__(self):
389
+ # self.vertical_lines = []
390
+ # self.histograms = []
391
+ # self.gaussian_curves = []
392
+ # self.colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
393
+ # self.title = ''
394
+ #
395
+ # class Element:
396
+ # def __init__(self, x=None, y=None, label=''):
397
+ # if x is not None:
398
+ # self.x = Subplot.to_np(x)
399
+ # if y is not None:
400
+ # self.y = Subplot.to_np(y)
401
+ #
402
+ # self.label = label
403
+ #
404
+ # @staticmethod
405
+ # def to_np(t):
406
+ # return t.detach().cpu().numpy()
407
+ #
408
+ # def add_vertical(self, x, label=''):
409
+ # self.vertical_lines.append(Subplot.Element(x=x, label=label))
410
+ # return self
411
+ #
412
+ # def add_histogram(self, samples, label=''):
413
+ # self.histograms.append(Subplot.Element(x=samples, label=label))
414
+ # return self
415
+ #
416
+ # def add_gaussian(self, gaussian):
417
+ # samples, mu, var = gaussian.samples, gaussian.mean, gaussian.covs
418
+ # # Generate a range of x values
419
+ # x_values = np.linspace(samples.min(), samples.max(), 100)
420
+ # x_values = np.linspace(samples.min(), samples.max(), 100)
421
+ #
422
+ # # Compute Gaussian values for these x values
423
+ # gaussian1_values = gaussian.gaussian_pdf(x_values, mu[0].item(), var[0].item())
424
+ # gaussian2_values = gaussian.gaussian_pdf(x_values, mu[1].item(), var[1].item())
425
+ # self.gaussian_curves.append(Subplot.Element(x_values, gaussian1_values))
426
+ # self.gaussian_curves.append(Subplot.Element(x_values, gaussian2_values))
427
+ # return self
428
+ #
429
+ #
430
+ # class PredHistos2():
431
+ # def __init__(self, n_cols=1):
432
+ # self.fig, self.axes = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, 4))
433
+ # self.n_cols = n_cols
434
+ # if n_cols == 1:
435
+ # self.builder = Subplot()
436
+ # self.subplots = [Subplot() for x in range(n_cols)]
437
+ # self.alpha = 0.5
438
+ # self.bins = 200
439
+ #
440
+ # def reload(self, n_cols=1):
441
+ # self.fig, self.axes = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, 4))
442
+ #
443
+ # def aggr(self, ax, sub):
444
+ # for hist, col in zip(sub.histograms, sub.colors):
445
+ # ax.hist(hist.x, self.bins, density=True, color=col, alpha=self.alpha, label=hist.label)
446
+ # for vline, col in zip(sub.vertical_lines, sub.colors):
447
+ # ax.axvline(x=vline.x, color=col, label=vline.label, linestyle='--')
448
+ # for gaussian, col in zip(sub.gaussian_curves, sub.colors):
449
+ # ax.plot(gaussian.x, gaussian.y, gaussian.label, col)
450
+ # ax.legend()
451
+ #
452
+ # def plot(self, name=''):
453
+ #
454
+ # if self.n_cols == 1:
455
+ # self.aggr(plt, self.builder)
456
+ # else:
457
+ # for ax, sub in zip(self.axes, self.subplots):
458
+ # self.aggr(ax, sub)
459
+ # ax.set_title(sub.title)
460
+ #
461
+ # plt.legend()
462
+ # plt.title(name)
463
+ # plt.show()
464
+ #
465
+ #
466
+ # from sklearn.mixture import GaussianMixture
467
+ # import scipy.optimize as opt
468
+ # from scipy.optimize import fsolve
469
+ #
470
+ #
471
+ # class GMM:
472
+ # def __init__(self, q_pred_coarse, name='gaussian', n_components=2):
473
+ # samples = q_pred_coarse.detach().cpu().numpy()
474
+ # self.samples = samples.reshape(-1, 1)
475
+ #
476
+ # # Fit a mixture of 2 Gaussians using EM
477
+ # gmm = GaussianMixture(n_components)
478
+ # gmm.fit(samples)
479
+ # self.means = gmm.means_.flatten()
480
+ # self.covs = gmm.covariances_.flatten()
481
+ # self.weights = gmm.weights_
482
+ # self.label = name
483
+ #
484
+ # def intersect(self):
485
+ # # Use fsolve to find the intersection
486
+ # gaussian_intersect, = fsolve(difference, self.means.mean(), args=(
487
+ # self.means[0].item(), self.covs[0].item(), self.means[1].item(), self.means[1].item()))
488
+ # return gaussian_intersect
489
+ #
490
+ #
491
+ # class PredHistoSNS():
492
+ # def __init__(self, n_cols=1):
493
+ # import seaborn as sns
494
+ # sns.set_theme(style="whitegrid") # Set the Seaborn theme. You can change the style as needed.
495
+ # self.fig, self.axes = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, 4))
496
+ # self.n_cols = n_cols
497
+ # if n_cols == 1:
498
+ # self.axes = [self.axes] # Wrap the single axis in a list to simplify the loop logic later.
499
+ # self.builder = Subplot() # This is assuming Subplot is a properly defined class.
500
+ # self.subplots = [Subplot() for _ in range(n_cols)] # Use underscore for unused loop variable.
501
+ # self.alpha = 0.5
502
+ # self.bins = 200
503
+ #
504
+ # def reload(self, n_cols=1):
505
+ # self.fig, self.axes = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, 4))
506
+ #
507
+ # def aggr(self, ax, sub):
508
+ # import seaborn as sns
509
+ # for hist, col in zip(sub.histograms, sub.colors):
510
+ # sns.histplot(hist.x, bins=self.bins, kde=False, color=col, ax=ax, alpha=self.alpha, label=hist.label)
511
+ # for vline, col in zip(sub.vertical_lines, sub.colors):
512
+ # ax.axvline(x=vline.x, color=col, label=vline.label, linestyle='--')
513
+ # for gaussian, col in zip(sub.gaussian_curves, sub.colors):
514
+ # sns.lineplot(x=gaussian.x, y=gaussian.y, label=gaussian.label, color=col, ax=ax)
515
+ # ax.legend()
516
+ #
517
+ # def plot(self, name=''):
518
+ #
519
+ # if self.n_cols == 1:
520
+ # self.aggr(self.axes[0], self.builder)
521
+ # else:
522
+ # for ax, sub in zip(self.axes, self.subplots):
523
+ # self.aggr(ax, sub)
524
+ # ax.set_title(sub.title)
525
+ #
526
+ # plt.show()
527
+ #
528
+ #
529
+ # def overlay_mask(image, mask, color=[255, 0, 0], alpha=0.2):
530
+ # """
531
+ # Apply an overlay of a binary mask onto an image using a specified color.
532
+ #
533
+ # :param image: A PyTorch tensor of the image (C x H x W) with pixel values in [0, 1].
534
+ # :param mask: A PyTorch tensor of the mask (H x W) with binary values (0 or 1).
535
+ # :param color: A list of 3 elements representing the RGB values of the overlay color.
536
+ # :param alpha: A float representing the transparency of the overlay (0 to 1).
537
+ # :return: An overlayed image tensor.
538
+ # """
539
+ # # Ensure the mask is binary
540
+ # mask = (mask > 0).float()
541
+ #
542
+ # # Create an RGB version of the mask
543
+ # mask_rgb = torch.tensor(color).view(3, 1, 1) / 255.0 # Normalize the color vector
544
+ # mask_rgb = mask_rgb * mask
545
+ #
546
+ # # Overlay the mask onto the image
547
+ # overlayed_image = (1 - alpha) * image + alpha * mask_rgb
548
+ #
549
+ # # Ensure the resulting tensor values are between 0 and 1
550
+ # overlayed_image = torch.clamp(overlayed_image, 0, 1)
551
+ #
552
+ # return overlayed_image
553
+ #
554
+ #
555
+ # import pandas as pd
556
+
557
+ # to_pil = lambda t: transforms.ToPILImage()(t) if t.shape[-1] > 4 else transforms.ToPILImage()(t.permute(2, 0, 1))
558
+ #
559
+ #
560
+ # def pilImageRow(*imgs, maxwidth=800, bordercolor=0x000000):
561
+ # imgs = [to_pil(im.float()) for im in imgs]
562
+ # dst = Image.new('RGB', (sum(im.width for im in imgs), imgs[0].height))
563
+ # for i, im in enumerate(imgs):
564
+ # loc = [x0, y0, x1, y1] = [i * im.width, 0, (i + 1) * im.width, im.height]
565
+ # dst.paste(im, (x0, y0))
566
+ # ImageDraw.Draw(dst).rectangle(loc, width=2, outline=bordercolor)
567
+ # factorToBig = dst.width / maxwidth
568
+ # dst = dst.resize((int(dst.width / factorToBig), int(dst.height / factorToBig)))
569
+ # return dst
570
+ #
571
+ #
572
+ # def tensor_table(**kwargs):
573
+ # tensor_overview = {}
574
+ # for name, tensor in kwargs.items():
575
+ # if callable(tensor):
576
+ # print(name, [tensor(t) for _, t in kwargs.items() if isinstance(t, torch.Tensor)])
577
+ # else:
578
+ # tensor_overview[name] = {
579
+ # 'min': tensor.min().item(),
580
+ # 'max': tensor.max().item(),
581
+ # 'shape': tensor.shape,
582
+ # }
583
+ # return pd.DataFrame.from_dict(tensor_overview, orient='index')
584
+