add submission code
Browse files- README.md +31 -4
- core/backbone.py +63 -0
- core/contrastivehead.py +582 -0
- core/denseaffinity.py +93 -0
- core/runner.py +165 -0
- data/coco.py +111 -0
- data/dataset.py +52 -0
- data/deepglobe.py +119 -0
- data/fss.py +114 -0
- data/isic.py +113 -0
- data/lung.py +116 -0
- data/pascal.py +148 -0
- data/splits/fss/test.txt +240 -0
- data/splits/fss/trn.txt +520 -0
- data/splits/fss/val.txt +240 -0
- data/suim.py +119 -0
- eval/evaluation.py +39 -0
- eval/logger.py +149 -0
- main.py +37 -0
- utils/commonutils.py +32 -0
- utils/segutils.py +584 -0
README.md
CHANGED
|
@@ -1,7 +1,34 @@
|
|
| 1 |
-
# Cross-Domain Few-Shot Segmentation
|
| 2 |
|
| 3 |
-
|
| 4 |
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapt Before Comparision - A New Perspective on Cross-Domain Few-Shot Segmentation
|
| 2 |
|
| 3 |
+
Code for the Reproducing the Paper
|
| 4 |
|
| 5 |
+
## Preparing Data
|
| 6 |
+
Because we follow PATNet and RtD, please refer to their work for prepration of the following datasets:
|
| 7 |
+
- Deepglobe (PAT)
|
| 8 |
+
- ISIC (PAT)
|
| 9 |
+
- Chest X-Ray (Lung) (PAT)
|
| 10 |
+
- FSS-1000 (PAT)
|
| 11 |
+
- SUIM (RtD)
|
| 12 |
|
| 13 |
+
You do not need to get all datasets. Just prepare the one you want to test our method with.
|
| 14 |
+
|
| 15 |
+
## Python package prerequisites
|
| 16 |
+
1. torch
|
| 17 |
+
2. torchvision
|
| 18 |
+
3. cv2
|
| 19 |
+
4. numpy
|
| 20 |
+
5. for others, follow the console output
|
| 21 |
+
|
| 22 |
+
## Run it
|
| 23 |
+
Call
|
| 24 |
+
`python main.py --benchmark {} --datapath {} --nshot {}`
|
| 25 |
+
|
| 26 |
+
for example
|
| 27 |
+
`python main.py --benchmark deepglobe --datapath ./datasets/deepglobe/ --nshot 1`
|
| 28 |
+
|
| 29 |
+
Available `benchmark` strings: deepglobe,isic,lung,fss,suim
|
| 30 |
+
Easiest to prepare should be Lung or FSS.
|
| 31 |
+
|
| 32 |
+
Default is quick-infer mode.
|
| 33 |
+
To change this, set `config.featext.fit_every_episode = True` in the main file.
|
| 34 |
+
You can change all other parameters likewise, check the available parameters in runner.makeConfig.
|
core/backbone.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import reduce
|
| 2 |
+
from operator import add
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torchvision.models import resnet
|
| 7 |
+
|
| 8 |
+
class Backbone(nn.Module):
|
| 9 |
+
|
| 10 |
+
def __init__(self, typestr):
|
| 11 |
+
super(Backbone, self).__init__()
|
| 12 |
+
|
| 13 |
+
self.backbone = typestr
|
| 14 |
+
|
| 15 |
+
# feature extractor initialization
|
| 16 |
+
if typestr == 'resnet50':
|
| 17 |
+
self.feature_extractor = resnet.resnet50(weights=resnet.ResNet50_Weights.DEFAULT)
|
| 18 |
+
self.feat_channels = [256, 512, 1024, 2048]
|
| 19 |
+
self.nlayers = [3, 4, 6, 3]
|
| 20 |
+
self.feat_ids = list(range(0, 17))
|
| 21 |
+
else:
|
| 22 |
+
raise Exception('Unavailable backbone: %s' % typestr)
|
| 23 |
+
self.feature_extractor.eval()
|
| 24 |
+
|
| 25 |
+
# define model
|
| 26 |
+
self.lids = reduce(add, [[i + 1] * x for i, x in enumerate(self.nlayers)])
|
| 27 |
+
self.stack_ids = torch.tensor(self.lids).bincount()[-4:].cumsum(dim=0)
|
| 28 |
+
|
| 29 |
+
self.cross_entropy_loss = nn.CrossEntropyLoss()
|
| 30 |
+
|
| 31 |
+
def extract_feats(self, img):
|
| 32 |
+
r""" Extract input image features """
|
| 33 |
+
feats = []
|
| 34 |
+
bottleneck_ids = reduce(add, list(map(lambda x: list(range(x)), self.nlayers)))
|
| 35 |
+
# Layer 0
|
| 36 |
+
feat = self.feature_extractor.conv1.forward(img)
|
| 37 |
+
feat = self.feature_extractor.bn1.forward(feat)
|
| 38 |
+
feat = self.feature_extractor.relu.forward(feat)
|
| 39 |
+
feat = self.feature_extractor.maxpool.forward(feat)
|
| 40 |
+
|
| 41 |
+
# Layer 1-4
|
| 42 |
+
for hid, (bid, lid) in enumerate(zip(bottleneck_ids, self.lids)):
|
| 43 |
+
res = feat
|
| 44 |
+
feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].conv1.forward(feat)
|
| 45 |
+
feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].bn1.forward(feat)
|
| 46 |
+
feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
|
| 47 |
+
feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].conv2.forward(feat)
|
| 48 |
+
feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].bn2.forward(feat)
|
| 49 |
+
feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
|
| 50 |
+
feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].conv3.forward(feat)
|
| 51 |
+
feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].bn3.forward(feat)
|
| 52 |
+
|
| 53 |
+
if bid == 0:
|
| 54 |
+
res = self.feature_extractor.__getattr__('layer%d' % lid)[bid].downsample.forward(res)
|
| 55 |
+
|
| 56 |
+
feat += res
|
| 57 |
+
|
| 58 |
+
if hid + 1 in self.feat_ids:
|
| 59 |
+
feats.append(feat.clone())
|
| 60 |
+
|
| 61 |
+
feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
|
| 62 |
+
|
| 63 |
+
return feats
|
core/contrastivehead.py
ADDED
|
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn.functional as F
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from utils import segutils
|
| 5 |
+
import core.denseaffinity as dautils
|
| 6 |
+
|
| 7 |
+
identity_mapping = lambda x, *args, **kwargs: x
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ContrastiveConfig:
|
| 11 |
+
def __init__(self, config=None):
|
| 12 |
+
# Define the internal dictionary with default settings.
|
| 13 |
+
if config is None:
|
| 14 |
+
self._data = {
|
| 15 |
+
'aug': {
|
| 16 |
+
'n_transformed_imgs': 2,
|
| 17 |
+
'blurkernelsize': [1], # chooses one of this kernel sizes
|
| 18 |
+
'maxjitter': 0.0,
|
| 19 |
+
'maxangle': 0, # rotation
|
| 20 |
+
# 'translate': (0,0), # BE CAREFUL WITH TRANSLATE - if you apply it on the feature volume that has smaller spatial dims correspondences break
|
| 21 |
+
'maxscale': 1.0, # 1.0 = No scaling
|
| 22 |
+
'maxshear': 20,
|
| 23 |
+
'randomhflip': False,
|
| 24 |
+
'apply_affine': True,
|
| 25 |
+
'debug': False
|
| 26 |
+
},
|
| 27 |
+
'model': {
|
| 28 |
+
'out_channels': 64,
|
| 29 |
+
'kernel_size': 1,
|
| 30 |
+
'prepend_relu': False,
|
| 31 |
+
'append_normalize': False,
|
| 32 |
+
'debug': False
|
| 33 |
+
},
|
| 34 |
+
'fitting': {
|
| 35 |
+
'lr': 1e-2,
|
| 36 |
+
'optimizer': torch.optim.SGD,
|
| 37 |
+
'num_epochs': 25,
|
| 38 |
+
'nce': {
|
| 39 |
+
'temperature': 0.5,
|
| 40 |
+
'debug': False
|
| 41 |
+
},
|
| 42 |
+
'normalize_after_fwd_pass': True,
|
| 43 |
+
'q_nceloss': True,
|
| 44 |
+
's_nceloss': True,
|
| 45 |
+
'protoloss': False,
|
| 46 |
+
'keepvarloss': True,
|
| 47 |
+
'symmetricloss': False,
|
| 48 |
+
'selfattentionloss': False,
|
| 49 |
+
'o_t_contr_proto_loss': True,
|
| 50 |
+
'debug': False
|
| 51 |
+
},
|
| 52 |
+
'featext': {
|
| 53 |
+
'l0': 3, # the first resnet bottleneck id to consider (0,1,2,3,4,5...15)
|
| 54 |
+
'fit_every_episode': False
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
else:
|
| 58 |
+
self._data = config
|
| 59 |
+
|
| 60 |
+
def __getattr__(self, key):
|
| 61 |
+
# Try to get '_data' without causing a recursive call to __getattr__
|
| 62 |
+
_data = super().__getattribute__('_data') if '_data' in self.__dict__ else None
|
| 63 |
+
|
| 64 |
+
if _data is not None and key in _data:
|
| 65 |
+
if isinstance(_data[key], dict):
|
| 66 |
+
return ContrastiveConfig(_data[key])
|
| 67 |
+
return _data[key]
|
| 68 |
+
|
| 69 |
+
# If we're here, it means the key was not found in the data,
|
| 70 |
+
# so we let Python raise the appropriate AttributeError.
|
| 71 |
+
raise AttributeError(f"No setting named {key}")
|
| 72 |
+
|
| 73 |
+
def __setattr__(self, key, value):
|
| 74 |
+
# Prevent overwriting of the '_data' attribute by normal means
|
| 75 |
+
if key == '_data':
|
| 76 |
+
super().__setattr__(key, value)
|
| 77 |
+
else:
|
| 78 |
+
# Try to get '_data' without causing a recursive call to __getattr__
|
| 79 |
+
_data = super().__getattribute__('_data') if '_data' in self.__dict__ else None
|
| 80 |
+
|
| 81 |
+
if _data is not None:
|
| 82 |
+
_data[key] = value
|
| 83 |
+
else:
|
| 84 |
+
# This situation should not normally occur, handle appropriately (e.g., log an error, raise exception)
|
| 85 |
+
raise AttributeError("Unexpected")
|
| 86 |
+
|
| 87 |
+
# Optional: Representation for better debugging.
|
| 88 |
+
def __repr__(self):
|
| 89 |
+
return str(self._data)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def dense_info_nce_loss(original_features, transformed_features, config_nce):
|
| 93 |
+
B, C, H, W = transformed_features.shape
|
| 94 |
+
o_features = original_features.expand(B, C, H, W).permute(0, 2, 3, 1).view(B, H * W, C)
|
| 95 |
+
t_features = transformed_features.permute(0, 2, 3, 1).view(B, H * W, C)
|
| 96 |
+
|
| 97 |
+
# Calculate dot product between original and transformed feature vectors for positive pairs
|
| 98 |
+
positive_logits = torch.einsum('bik,bik->bi', o_features, t_features) / config_nce.temperature
|
| 99 |
+
|
| 100 |
+
# Calculate dot product between original features and all other transformed features for negative pairs
|
| 101 |
+
all_logits = torch.einsum('bik,bjk->bij', o_features, t_features) / config_nce.temperature
|
| 102 |
+
|
| 103 |
+
if config_nce.debug: print('pos/neg:', positive_logits.mean().detach(), all_logits.mean().detach())
|
| 104 |
+
|
| 105 |
+
# Using the log-sum-exp trick
|
| 106 |
+
max_logits = torch.max(all_logits, dim=-1, keepdim=True).values
|
| 107 |
+
log_sum_exp = max_logits + torch.log(torch.sum(torch.exp(all_logits - max_logits), dim=-1, keepdim=True))
|
| 108 |
+
|
| 109 |
+
# Compute InfoNCE loss
|
| 110 |
+
loss = - (positive_logits - log_sum_exp.squeeze())
|
| 111 |
+
return loss.mean() # [B=k*aug] or [B=k] -> scalar
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def ssim(a, b):
|
| 115 |
+
return torch.nn.CosineSimilarity()(a, b)
|
| 116 |
+
|
| 117 |
+
def augwise_proto(feat_vol, mask, k, aug):
|
| 118 |
+
k, aug, c, h, w = k, aug, *feat_vol.shape[-3:]
|
| 119 |
+
feature_vectors_augwise = torch.cat(feat_vol.view(k, aug, c, h * w).unbind(0), dim=-1)
|
| 120 |
+
mask_augwise = torch.cat(segutils.downsample_mask(mask, h, w).view(k, aug, h * w).unbind(0), dim=-1)
|
| 121 |
+
assert feature_vectors_augwise.shape == (aug, c, k * h * w) and mask_augwise.shape == (
|
| 122 |
+
aug, k * h * w), "of transformed"
|
| 123 |
+
|
| 124 |
+
fg_proto, bg_proto = segutils.fg_bg_proto(feature_vectors_augwise, mask_augwise)
|
| 125 |
+
assert fg_proto.shape == bg_proto.shape == (aug, c)
|
| 126 |
+
|
| 127 |
+
return fg_proto, bg_proto
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def calc_q_pred_coarse_nodetach(qft, sft, s_mask, l0=3):
|
| 131 |
+
bsz, c, hq, wq = qft.shape
|
| 132 |
+
hs, ws = sft.shape[-2:]
|
| 133 |
+
|
| 134 |
+
sft_row = torch.cat(sft.unbind(1), -1) # bsz,k,c,h,w -> bsz,c,h,w*k
|
| 135 |
+
smasks_downsampled = [segutils.downsample_mask(m, hs, ws) for m in s_mask.unbind(1)]
|
| 136 |
+
smask_row = torch.cat(smasks_downsampled, -1)
|
| 137 |
+
|
| 138 |
+
damat = dautils.buildDenseAffinityMat(qft, sft_row)
|
| 139 |
+
filtered = dautils.filterDenseAffinityMap(damat, smask_row)
|
| 140 |
+
q_pred_coarse = filtered.view(bsz, hq, wq)
|
| 141 |
+
return q_pred_coarse
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# input k*aug,c,h,w
|
| 145 |
+
def self_attention_loss(f_base, f_transformed, mask_base, mask_transformed, k, aug):
|
| 146 |
+
c, h, w = f_base.shape[-3:]
|
| 147 |
+
pseudoquery = torch.cat(f_base.view(k, aug, c, h, w).unbind(0), -1) # shape aug,c,h,w*k
|
| 148 |
+
pseudoquerymask = torch.cat(mask_base.view(k, aug, h, w).unbind(0), -1) # shape aug,h,w*k
|
| 149 |
+
pseudosupport = f_transformed.view(k, aug, c, h, w).transpose(0, 1) # shape bsz,k,c,h,w
|
| 150 |
+
pseudosupportmask = mask_transformed.view(k, aug, h, w).transpose(0, 1) # shape bsz,k,h,w
|
| 151 |
+
# display(segutils.tensor_table(q=pseudoquery, s=pseudosupport, m=pseudosupportmask))
|
| 152 |
+
pred_map = calc_q_pred_coarse_nodetach(pseudoquery, pseudosupport, pseudosupportmask, l0=0)
|
| 153 |
+
|
| 154 |
+
loss = torch.nn.BCELoss()(pred_map.float(), pseudoquerymask.float())
|
| 155 |
+
return loss.mean()
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# features of base, transformed: [b,c,h,w]
|
| 159 |
+
# if base features are aligned with transformed features, pass both same
|
| 160 |
+
def ctrstive_prototype_loss(base, transformed, mask_base, mask_transformed, k, aug):
|
| 161 |
+
assert transformed.shape == base.shape, ".."
|
| 162 |
+
b, c, h, w = base.shape
|
| 163 |
+
assert b == k * aug, 'provide correct k and aug such that dim0=k*aug'
|
| 164 |
+
assert mask_base.shape == mask_transformed.shape == (b, h, w), ".."
|
| 165 |
+
fg_proto_o, bg_proto_o = augwise_proto(base, mask_base, k, aug)
|
| 166 |
+
fg_proto_t, bg_proto_t = augwise_proto(transformed, mask_transformed, k, aug)
|
| 167 |
+
# i: fg, b: bg
|
| 168 |
+
# p_b_i, p_b_j = segutils.fg_bg_proto(base.view(b,c,h*w), mask_base.view(b,h*w))
|
| 169 |
+
# p_t_i, p_t_j = segutils.fg_bg_proto(transformed.view(b,c,h*w), mask_transformed.view(b,h*w))
|
| 170 |
+
enumer = torch.exp(
|
| 171 |
+
ssim(fg_proto_o, fg_proto_t)) # 5vs5 (augvsaug), but in 5-shot: 25vs25, no, you want also augvsaug
|
| 172 |
+
denom = torch.exp(ssim(fg_proto_o, fg_proto_t)) + torch.exp(ssim(fg_proto_o, bg_proto_t))
|
| 173 |
+
assert enumer.shape == denom.shape == torch.Size([aug]), 'you want to calculate one prototype for each augmentation'
|
| 174 |
+
loss = -torch.log(enumer / denom) # [bsz]
|
| 175 |
+
return loss.mean()
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def opposite_proto_sim_in_aug(transformed_features, mapped_s_masks, k, aug):
|
| 179 |
+
fg_proto_t, bg_proto_t = augwise_proto(transformed_features, mapped_s_masks, k, aug)
|
| 180 |
+
fg_bg_sim_t = ssim(fg_proto_t, bg_proto_t)
|
| 181 |
+
return fg_bg_sim_t.mean()
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def proto_align_val_measure(original_features, transformed_features, mapped_s_masks, k, aug):
|
| 185 |
+
fg_proto_o, _ = augwise_proto(original_features, mapped_s_masks, k, aug)
|
| 186 |
+
fg_proto_t, _ = augwise_proto(transformed_features, mapped_s_masks, k, aug)
|
| 187 |
+
fg_proto_sim = ssim(fg_proto_o, fg_proto_t)
|
| 188 |
+
return fg_proto_sim.mean()
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def atest():
|
| 192 |
+
k, aug, c, h, w = 2, 5, 8, 20, 20
|
| 193 |
+
f_base = torch.rand(k * aug, c, h, w).float()
|
| 194 |
+
f_base.requires_grad = True
|
| 195 |
+
f_transformed = torch.rand(k * aug, c, h, w).float()
|
| 196 |
+
mask_base = torch.randint(0, 2, (k * aug, h, w)).float()
|
| 197 |
+
mask_transformed = torch.randint(0, 2, (k * aug, h, w)).float()
|
| 198 |
+
|
| 199 |
+
return self_attention_loss(f_base, f_transformed, mask_base, mask_transformed, k, aug)
|
| 200 |
+
|
| 201 |
+
def keep_var_loss(original_features, transformed_features):
|
| 202 |
+
meandiff = original_features.mean((-2, -1)) - transformed_features.mean((-2, -1))
|
| 203 |
+
vardiff = original_features.var((-2, -1)) - transformed_features.var((-2, -1))
|
| 204 |
+
keepvarloss = torch.abs(meandiff).mean() + torch.abs(
|
| 205 |
+
vardiff).mean() # [k*aug,c] -> [scalar] or [aug,c] -> [scalar]
|
| 206 |
+
return keepvarloss
|
| 207 |
+
|
| 208 |
+
class ContrastiveFeatureTransformer(nn.Module):
|
| 209 |
+
def __init__(self, in_channels, config_model):
|
| 210 |
+
super(ContrastiveFeatureTransformer, self).__init__()
|
| 211 |
+
|
| 212 |
+
out_channels, kernel_size = config_model.out_channels, config_model.kernel_size
|
| 213 |
+
# Add a convolutional layer and a batch normalization layer for learning
|
| 214 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2)
|
| 215 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
| 216 |
+
self.linear = nn.Conv2d(out_channels, out_channels, 1)
|
| 217 |
+
|
| 218 |
+
self.prepend_relu = config_model.prepend_relu
|
| 219 |
+
self.append_normalize = config_model.append_normalize
|
| 220 |
+
self.debug = config_model.debug
|
| 221 |
+
|
| 222 |
+
def forward(self, x):
|
| 223 |
+
if self.prepend_relu:
|
| 224 |
+
x = nn.ReLU()(x)
|
| 225 |
+
x = self.conv(x)
|
| 226 |
+
x = self.bn(x)
|
| 227 |
+
x = nn.ReLU()(x)
|
| 228 |
+
x = self.linear(x)
|
| 229 |
+
if self.append_normalize:
|
| 230 |
+
x = F.normalize(x, p=2, dim=1)
|
| 231 |
+
return x
|
| 232 |
+
|
| 233 |
+
# fits the model for one semantic class, therefore does not work with batches
|
| 234 |
+
# mapped_qfeat_vol, aug_qfeat_vols: [aug,c,h,w]
|
| 235 |
+
# mapped_sfeat_vol, aug_sfeat_vols: [k*aug,c,h,w]
|
| 236 |
+
# augmented_smasks: [k*aug,h,w]
|
| 237 |
+
def fit(self, mapped_qfeat_vol, aug_qfeat_vols, mapped_sfeat_vol, aug_sfeat_vols, augmented_smasks, config_fit):
|
| 238 |
+
f_norm = F.normalize if config_fit.normalize_after_fwd_pass else identity_mapping
|
| 239 |
+
optimizer = config_fit.optimizer(self.parameters(), lr=config_fit.lr)
|
| 240 |
+
for epoch in range(config_fit.num_epochs):
|
| 241 |
+
# Pass original and transformed image batches through the model
|
| 242 |
+
|
| 243 |
+
# Q
|
| 244 |
+
original_features = f_norm(self(mapped_qfeat_vol), p=2, dim=1) # fwd pass non-augmented
|
| 245 |
+
transformed_features = f_norm(self(aug_qfeat_vols), p=2, dim=1) # fwd pass augmented
|
| 246 |
+
|
| 247 |
+
qloss = dense_info_nce_loss(original_features, transformed_features,
|
| 248 |
+
config_fit.nce) if config_fit.q_nceloss else 0
|
| 249 |
+
if config_fit.keepvarloss: # 1. idea: Let query and support have the same feature distribution (mean/var per channel)
|
| 250 |
+
qloss += keep_var_loss(original_features, transformed_features)
|
| 251 |
+
# S
|
| 252 |
+
original_features = f_norm(self(mapped_sfeat_vol), p=2, dim=1) # fwd pass non-augmented
|
| 253 |
+
transformed_features = f_norm(self(aug_sfeat_vols), p=2, dim=1) # fwd pass augmented
|
| 254 |
+
|
| 255 |
+
sloss = dense_info_nce_loss(original_features, transformed_features,
|
| 256 |
+
config_fit.nce) if config_fit.s_nceloss else 0
|
| 257 |
+
if config_fit.keepvarloss:
|
| 258 |
+
sloss += keep_var_loss(original_features, transformed_features)
|
| 259 |
+
|
| 260 |
+
# 2. class-aware loss: opposite classes should get opposite features
|
| 261 |
+
# for prototype calculation, we want only one prototype per class
|
| 262 |
+
# so we average over features of entire k
|
| 263 |
+
# but calculate prototype for each augmentation individually [k*aug,c,h,w]->[aug,c,k*h*w]->[aug,c]
|
| 264 |
+
kaug, c, h, w = transformed_features.shape
|
| 265 |
+
aug = aug_qfeat_vols.shape[0]
|
| 266 |
+
k = kaug // aug
|
| 267 |
+
if config_fit.protoloss:
|
| 268 |
+
assert not config_fit.o_t_contr_proto_loss, 'only one of the proto losses should be used'
|
| 269 |
+
opposite_proto_sim = opposite_proto_sim_in_aug(transformed_features, augmented_smasks, k, aug)
|
| 270 |
+
if config_fit.debug and (epoch == config_fit.num_epochs - 1 or epoch == 0): print(
|
| 271 |
+
'proto-sim intER-class transf<->transf', opposite_proto_sim.item())
|
| 272 |
+
proto_loss = opposite_proto_sim
|
| 273 |
+
elif config_fit.selfattentionloss:
|
| 274 |
+
proto_loss = self_attention_loss(original_features, transformed_features, augmented_smasks,
|
| 275 |
+
augmented_smasks, k, aug)
|
| 276 |
+
if config_fit.debug and (epoch == config_fit.num_epochs - 1 or epoch == 0): print(
|
| 277 |
+
'self-att non-transf<->transformed bce', proto_loss.item())
|
| 278 |
+
elif config_fit.o_t_contr_proto_loss:
|
| 279 |
+
o_t_contr_proto_loss = ctrstive_prototype_loss(original_features, transformed_features,
|
| 280 |
+
augmented_smasks, augmented_smasks, k, aug)
|
| 281 |
+
if config_fit.debug and (epoch == config_fit.num_epochs - 1 or epoch == 0): print(
|
| 282 |
+
'proto-contr non-transf<->transformed', o_t_contr_proto_loss.item())
|
| 283 |
+
proto_loss = o_t_contr_proto_loss
|
| 284 |
+
else:
|
| 285 |
+
proto_loss = 0
|
| 286 |
+
|
| 287 |
+
if config_fit.debug and (epoch == config_fit.num_epochs - 1 or epoch == 0):
|
| 288 |
+
proto_align_val = proto_align_val_measure(original_features, transformed_features, augmented_smasks, k,
|
| 289 |
+
aug)
|
| 290 |
+
print('proto-sim intRA-class non-transf<->transformed (for validation)', proto_align_val.item())
|
| 291 |
+
|
| 292 |
+
# 3. do not let only one image fit well - regularization
|
| 293 |
+
q_s_loss_diff = torch.abs(qloss - sloss) if config_fit.symmetricloss else 0
|
| 294 |
+
|
| 295 |
+
# Aggregate loss
|
| 296 |
+
loss = qloss + sloss + q_s_loss_diff + proto_loss
|
| 297 |
+
assert loss.isfinite().all(), f"invalid contrastive loss:{loss}"
|
| 298 |
+
|
| 299 |
+
# Backpropagation and optimization
|
| 300 |
+
if config_fit.debug and (epoch == config_fit.num_epochs - 1 or epoch == 0):
|
| 301 |
+
def gradient_magnitude(loss_term):
|
| 302 |
+
optimizer.zero_grad()
|
| 303 |
+
loss_term.backward(retain_graph=True)
|
| 304 |
+
magn = torch.abs(self.conv.weight.grad.mean()) + torch.abs(self.linear.weight.grad.mean())
|
| 305 |
+
return magn
|
| 306 |
+
|
| 307 |
+
q_loss_grad_magnitude = gradient_magnitude(qloss)
|
| 308 |
+
s_loss_grad_magnitude = gradient_magnitude(sloss)
|
| 309 |
+
proto_loss_grad_magnitude = gradient_magnitude(proto_loss)
|
| 310 |
+
q_s_loss_diff_grad_magnitude = gradient_magnitude(q_s_loss_diff)
|
| 311 |
+
display(segutils.tensor_table(q_loss_grad_magnitude=q_loss_grad_magnitude,
|
| 312 |
+
s_loss_grad_magnitude=s_loss_grad_magnitude,
|
| 313 |
+
proto_loss_grad_magnitude=proto_loss_grad_magnitude,
|
| 314 |
+
q_s_loss_diff_grad_magnitude=q_s_loss_diff_grad_magnitude))
|
| 315 |
+
|
| 316 |
+
optimizer.zero_grad()
|
| 317 |
+
loss.backward()
|
| 318 |
+
optimizer.step()
|
| 319 |
+
|
| 320 |
+
if config_fit.debug and epoch % 10 == 0: print('loss', loss.detach())
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
import numpy as np
|
| 324 |
+
import torch.nn.functional as F
|
| 325 |
+
from torchvision.transforms.functional import affine
|
| 326 |
+
from torchvision.transforms import GaussianBlur, ColorJitter
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class AffineProxy:
|
| 330 |
+
def __init__(self, angle, translate, scale, shear):
|
| 331 |
+
self.affine_params = {
|
| 332 |
+
'angle': angle,
|
| 333 |
+
'translate': translate,
|
| 334 |
+
'scale': scale,
|
| 335 |
+
'shear': shear
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
def apply(self, img):
|
| 339 |
+
return affine(img, angle=self.affine_params['angle'], translate=self.affine_params['translate'],
|
| 340 |
+
scale=self.affine_params['scale'], shear=self.affine_params['shear'])
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
# def affine_proxy(angle, translate, scale, shear):
|
| 344 |
+
# def inner(img):
|
| 345 |
+
# return affine(img, angle=angle, translate=translate, scale=scale, shear=shear)
|
| 346 |
+
|
| 347 |
+
# return inner
|
| 348 |
+
|
| 349 |
+
class Augmen:
|
| 350 |
+
def __init__(self, config_aug):
|
| 351 |
+
self.config = config_aug
|
| 352 |
+
self.blurs, self.jitters, self.affines = self.setup_augmentations()
|
| 353 |
+
|
| 354 |
+
def copy_construct(self, blurs, jitters, affines, config_aug):
|
| 355 |
+
self.config = config_aug
|
| 356 |
+
self.blurs, self.jitters, self.affines = blurs, jitters, affines
|
| 357 |
+
|
| 358 |
+
def setup_augmentations(self):
|
| 359 |
+
blurkernelsize = self.config.blurkernelsize
|
| 360 |
+
maxjitter = self.config.maxjitter
|
| 361 |
+
|
| 362 |
+
maxangle = self.config.maxangle
|
| 363 |
+
translate = (0, 0)
|
| 364 |
+
maxscale = self.config.maxscale
|
| 365 |
+
maxshear = self.config.maxshear
|
| 366 |
+
|
| 367 |
+
blurs = []
|
| 368 |
+
jitters = []
|
| 369 |
+
affine_trans = []
|
| 370 |
+
for i in range(self.config.n_transformed_imgs):
|
| 371 |
+
# Randomize kernel size for GaussianBlur
|
| 372 |
+
kernel_size = np.random.choice(torch.tensor(blurkernelsize), (1,)).item()
|
| 373 |
+
blur = GaussianBlur(kernel_size)
|
| 374 |
+
blurs.append(blur)
|
| 375 |
+
|
| 376 |
+
# Randomize values for ColorJitter
|
| 377 |
+
brightness_val = torch.rand(1).item() * maxjitter # up to <maxjitter> change
|
| 378 |
+
contrast_val = torch.rand(1).item() * maxjitter
|
| 379 |
+
saturation_val = torch.rand(1).item() * maxjitter
|
| 380 |
+
jitter = ColorJitter(brightness=brightness_val, contrast=contrast_val, saturation=saturation_val)
|
| 381 |
+
jitters.append(jitter)
|
| 382 |
+
|
| 383 |
+
# Random values for each iteration
|
| 384 |
+
angle = torch.randint(-maxangle, maxangle + 1, (1,)).item()
|
| 385 |
+
shear = [torch.randint(-maxshear, maxshear + 1, (1,)).item() for _ in range(2)]
|
| 386 |
+
scale = torch.rand(1).item() * (1 - maxscale) + maxscale
|
| 387 |
+
affine_trans.append(AffineProxy(angle=angle, translate=translate, scale=scale, shear=shear))
|
| 388 |
+
|
| 389 |
+
return (blurs, jitters, affine_trans) # tuple of lists
|
| 390 |
+
|
| 391 |
+
def augment(self, original_image, orignal_mask):
|
| 392 |
+
transformed_imgs = []
|
| 393 |
+
transformed_masks = []
|
| 394 |
+
for blur, jitter, affine_trans in zip(self.blurs, self.jitters, self.affines):
|
| 395 |
+
# Apply non-geometric transformations
|
| 396 |
+
t_img = blur(original_image)
|
| 397 |
+
t_img = jitter(t_img)
|
| 398 |
+
t_mask = orignal_mask.clone()
|
| 399 |
+
|
| 400 |
+
if self.config.apply_affine:
|
| 401 |
+
t_img = affine_trans.apply(t_img)
|
| 402 |
+
t_mask = affine_trans.apply(t_mask)
|
| 403 |
+
|
| 404 |
+
transformed_imgs.append(t_img)
|
| 405 |
+
transformed_masks.append(t_mask)
|
| 406 |
+
return torch.stack(transformed_imgs, dim=1), torch.stack(transformed_masks, dim=1)
|
| 407 |
+
|
| 408 |
+
# [bsz,ch,h,w] -> [bsz,aug,ch,h,w], where aug is the number of augmentated images
|
| 409 |
+
def applyAffines(self, feat_vol):
|
| 410 |
+
return torch.stack([trans.apply(feat_vol) for trans in self.affines], dim=1)
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
class CTrBuilder:
|
| 414 |
+
# call init 1st, pass all config parameters (initatiate a ContrastiveConfig class in your code)
|
| 415 |
+
def __init__(self, config, augmentator=None):
|
| 416 |
+
if augmentator is None:
|
| 417 |
+
augmentator = Augmen(config.aug)
|
| 418 |
+
self.augmentator = augmentator
|
| 419 |
+
|
| 420 |
+
self.augimgs = self.AugImgStack(augmentator)
|
| 421 |
+
|
| 422 |
+
self.hasfit = False
|
| 423 |
+
self.config = config
|
| 424 |
+
|
| 425 |
+
class AugImgStack():
|
| 426 |
+
def __init__(self, augmentator):
|
| 427 |
+
self.augmentator = augmentator
|
| 428 |
+
self.q, self.s, self.s_mask = None, None, None
|
| 429 |
+
|
| 430 |
+
def init(self, s_img):
|
| 431 |
+
# c is color channels here, not feature channels
|
| 432 |
+
bsz, k, aug, c, h, w = *s_img.shape[:2], self.augmentator.config.n_transformed_imgs, *s_img.shape[-3:]
|
| 433 |
+
self.q = torch.empty(bsz, aug, c, h, w).to(s_img.device)
|
| 434 |
+
self.s = torch.empty(bsz, k, aug, c, h, w).to(s_img.device)
|
| 435 |
+
self.s_mask = torch.empty(bsz, k, aug, h, w).to(s_img.device)
|
| 436 |
+
|
| 437 |
+
def show(self):
|
| 438 |
+
bsz_, k_, aug_ = self.s.shape[:3]
|
| 439 |
+
for b in range(bsz_):
|
| 440 |
+
display('aug x queries', segutils.pilImageRow(*[segutils.norm(img) for img in self.q[b]]))
|
| 441 |
+
for k in range(k_):
|
| 442 |
+
print('k=', k, ' aug x (s, smask):')
|
| 443 |
+
display(segutils.pilImageRow(*[segutils.norm(img) for img in self.s[b, k]]))
|
| 444 |
+
display(segutils.pilImageRow(*self.s_mask[b, k]))
|
| 445 |
+
|
| 446 |
+
def showAugmented(self):
|
| 447 |
+
self.augimgs.show()
|
| 448 |
+
|
| 449 |
+
# 2nd call makeAugmented
|
| 450 |
+
def makeAugmented(self, q_img, s_img, s_mask):
|
| 451 |
+
# 2. Augmentation
|
| 452 |
+
# 2.1 Apply transformations to images
|
| 453 |
+
self.augimgs.init(s_img)
|
| 454 |
+
self.augimgs.q, _ = self.augmentator.augment(q_img, s_mask)
|
| 455 |
+
|
| 456 |
+
for k in range(s_img.shape[1]):
|
| 457 |
+
s_aug_imgs, s_aug_masks = self.augmentator.augment(s_img[:, k], s_mask[:, k])
|
| 458 |
+
self.augimgs.s[:, k] = s_aug_imgs
|
| 459 |
+
self.augimgs.s_mask[:, k] = s_aug_masks
|
| 460 |
+
if self.config.aug.debug: self.augimgs.show()
|
| 461 |
+
|
| 462 |
+
# 3rd call build_and_fit
|
| 463 |
+
def build_and_fit(self, q_feat, s_feat, q_feataug, s_feataug, s_maskaug=None):
|
| 464 |
+
if s_maskaug is None: s_maskaug = self.augimgs.s_mask
|
| 465 |
+
self.ctrs = self.buildContrastiveTransformers(q_feat, s_feat, q_feataug, s_feataug, s_maskaug)
|
| 466 |
+
self.hasfit = True
|
| 467 |
+
|
| 468 |
+
def buildContrastiveTransformers(self, qfeat_alllayers, sfeat_alllayers, query_feats_aug, support_feats_aug,
|
| 469 |
+
supp_aug_mask, s_mask=None):
|
| 470 |
+
contrastive_transformers = []
|
| 471 |
+
l0 = self.config.featext.l0
|
| 472 |
+
# [bsz,k,aug,h,w] -> [k*aug,h,w]
|
| 473 |
+
s_aug_mask = supp_aug_mask.view(-1, *supp_aug_mask.shape[-2:])
|
| 474 |
+
# iterate over feature layers
|
| 475 |
+
for (qfeat, sfeat, qfeataug, sfeataug) in zip(qfeat_alllayers[l0:], sfeat_alllayers[l0:], query_feats_aug[l0:],
|
| 476 |
+
support_feats_aug[l0:]):
|
| 477 |
+
bsz, k, aug, ch, h, w = sfeataug.shape
|
| 478 |
+
# we fit it for exactly one class, so use no batches
|
| 479 |
+
assert bsz == 1, "bsz should be 1"
|
| 480 |
+
assert supp_aug_mask.shape[1] == sfeat.shape[
|
| 481 |
+
1] == k, f'augmented support shot-dimension mismatch:{s_aug_mask.shape[1]=},{sfeat.shape[1]=},(bsz,k,aug,ch,h,w)={bsz, k, aug, ch, h, w}'
|
| 482 |
+
assert supp_aug_mask.shape[2] == qfeataug.shape[1] == aug, 'augmented shot-dimension mismatch'
|
| 483 |
+
# [bsz,c,h,w] -> [1,c,h,w]
|
| 484 |
+
qfeat = qfeat.view(-1, *qfeat.shape[-3:])
|
| 485 |
+
# [bsz,k,c,h,w] -> [k,c,h,w]
|
| 486 |
+
sfeat = sfeat.view(-1, *sfeat.shape[-3:])
|
| 487 |
+
# [bsz,aug,c,h,w] -> [aug,c,h,w]
|
| 488 |
+
qfeataug = qfeataug.view(-1, *qfeataug.shape[-3:])
|
| 489 |
+
# [bsz,k,aug,c,h,w] -> [k*aug,c,h,w]
|
| 490 |
+
sfeataug = sfeataug.view(-1, *qfeataug.shape[-3:])
|
| 491 |
+
|
| 492 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 493 |
+
contrastive_head = ContrastiveFeatureTransformer(in_channels=ch, config_model=self.config.model).to(device)
|
| 494 |
+
|
| 495 |
+
# 3. Feature volumes from untransformed image need to be geometrically mapped to allow for dense matching
|
| 496 |
+
mapped_qfeat = self.augmentator.applyAffines(qfeat)
|
| 497 |
+
assert mapped_qfeat.shape[1] == aug, "should be 1,aug,c,h,w"
|
| 498 |
+
mapped_qfeat = mapped_qfeat.view(-1, *qfeat.shape[-3:]) # ->[aug,c,h,w]
|
| 499 |
+
mapped_sfeat = self.augmentator.applyAffines(sfeat)
|
| 500 |
+
assert mapped_sfeat.shape[1] == aug and mapped_sfeat.shape[0] == k, "should be k,aug,c,h,w"
|
| 501 |
+
mapped_sfeat = mapped_sfeat.view(-1, *sfeat.shape[-3:]) # ->[k*aug,c,h,w]
|
| 502 |
+
|
| 503 |
+
contrastive_head.fit(mapped_qfeat, qfeataug, mapped_sfeat, sfeataug,
|
| 504 |
+
segutils.downsample_mask(s_aug_mask, h, w), self.config.fitting)
|
| 505 |
+
|
| 506 |
+
contrastive_transformers.append(contrastive_head)
|
| 507 |
+
# show how support image and its augmentations would produce a affinity map
|
| 508 |
+
if s_mask != None:
|
| 509 |
+
display(segutils.to_pil(segutils.norm(dautils.filterDenseAffinityMap(
|
| 510 |
+
dautils.buildDenseAffinityMat(contrastive_head(sfeat), contrastive_head(sfeataug[:1])),
|
| 511 |
+
segutils.downsample_mask(s_mask, h, w)).view(1, h, w))))
|
| 512 |
+
display(segutils.to_pil(segutils.norm(dautils.filterDenseAffinityMap(
|
| 513 |
+
dautils.buildDenseAffinityMat(contrastive_head(qfeat), contrastive_head(sfeat)),
|
| 514 |
+
segutils.downsample_mask(s_mask, h, w)).view(1, h, w))))
|
| 515 |
+
return contrastive_transformers
|
| 516 |
+
|
| 517 |
+
# You have fitted the contrastive transformers, now apply the transform and then pass to the downstream DCAMA
|
| 518 |
+
# you just need to append the empty layers you exluded ([:3]), they're also skipped in dcama
|
| 519 |
+
# Obtain the result of the contrastive head, which will be the new query and support feat representation
|
| 520 |
+
def getTaskAdaptedFeats(self, layerwise_feats):
|
| 521 |
+
if (self.ctrs == None): print("error: call buildContrastiveTransformers() first")
|
| 522 |
+
task_adapted_feats = []
|
| 523 |
+
|
| 524 |
+
for idx in range(len(layerwise_feats)):
|
| 525 |
+
if idx < self.config.featext.l0:
|
| 526 |
+
task_adapted_feats.append(None)
|
| 527 |
+
else:
|
| 528 |
+
input_shape = layerwise_feats[idx].shape
|
| 529 |
+
idxth_feat = layerwise_feats[idx].view(-1, *input_shape[-3:])
|
| 530 |
+
forward_pass_res = self.ctrs[idx - self.config.featext.l0](idxth_feat)
|
| 531 |
+
target_shape = *input_shape[:-3], *forward_pass_res.shape[
|
| 532 |
+
-3:] # borrow channel dim from result, but bsz,k dims from input
|
| 533 |
+
task_adapted_feats.append(forward_pass_res.view(target_shape))
|
| 534 |
+
|
| 535 |
+
return task_adapted_feats
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
class FeatureMaker:
|
| 539 |
+
def __init__(self, feat_extraction_method, class_ids, config=ContrastiveConfig()):
|
| 540 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 541 |
+
self.featextractor = feat_extraction_method
|
| 542 |
+
self.c_trs = {ctr: CTrBuilder(config) for ctr in class_ids}
|
| 543 |
+
self.norm_bb_feats = False
|
| 544 |
+
|
| 545 |
+
def extract_bb_feats(self, img):
|
| 546 |
+
with torch.no_grad():
|
| 547 |
+
return self.featextractor(img)
|
| 548 |
+
|
| 549 |
+
def create_and_fit(self, c_tr, q_img, s_img, s_mask, q_feat, s_feat):
|
| 550 |
+
print('doing contrastive')
|
| 551 |
+
c_tr.makeAugmented(q_img, s_img, s_mask)
|
| 552 |
+
|
| 553 |
+
bsz, k, c, h, w = s_img.shape
|
| 554 |
+
aug = c_tr.augmentator.config.n_transformed_imgs
|
| 555 |
+
# [bsz,aug,c,h,w]->[bsz*aug,c,h,w] squeeze for forward pass
|
| 556 |
+
q_feataug = self.extract_bb_feats(c_tr.augimgs.q.view(-1, c, h, w)) # returns layer-list
|
| 557 |
+
# then restore
|
| 558 |
+
q_feataug = [l.view(bsz, aug, *l.shape[1:]) for l in q_feataug]
|
| 559 |
+
# [bsz,k,aug,c,h,w]->[bsz*k*aug,c,h,w]->[bsz,k,aug,c,h,w]
|
| 560 |
+
s_feataug = self.extract_bb_feats(c_tr.augimgs.s.view(-1, c, h, w))
|
| 561 |
+
s_feataug = [l.view(bsz, k, aug, *l.shape[1:]) for l in s_feataug]
|
| 562 |
+
|
| 563 |
+
c_tr.build_and_fit(q_feat, s_feat, q_feataug, s_feataug)
|
| 564 |
+
|
| 565 |
+
def taskAdapt(self, q_img, s_img, s_mask, class_id):
|
| 566 |
+
ch_norm = lambda t: t / torch.linalg.norm(t, dim=1)
|
| 567 |
+
q_feat = self.extract_bb_feats(q_img)
|
| 568 |
+
bsz, k, c, h, w = s_img.shape
|
| 569 |
+
s_feat = self.extract_bb_feats(s_img.view(-1, c, h, w))
|
| 570 |
+
if self.norm_bb_feats:
|
| 571 |
+
q_feat = [ch_norm(l) for l in q_feat]
|
| 572 |
+
s_feat = [ch_norm(l) for l in q_feat]
|
| 573 |
+
s_feat = [l.view(bsz, k, *l.shape[1:]) for l in s_feat]
|
| 574 |
+
|
| 575 |
+
c_tr = self.c_trs[class_id] # select the relevant ctr for this class
|
| 576 |
+
|
| 577 |
+
if c_tr.hasfit is False or c_tr.config.featext.fit_every_episode: # create and fit a contrastive transformer if not existing yet
|
| 578 |
+
self.create_and_fit(c_tr, q_img, s_img, s_mask, q_feat, s_feat)
|
| 579 |
+
|
| 580 |
+
q_feat_t, s_feat_t = c_tr.getTaskAdaptedFeats(q_feat), c_tr.getTaskAdaptedFeats(
|
| 581 |
+
s_feat) # tocheck: do they require_grad here?
|
| 582 |
+
return q_feat_t, s_feat_t
|
core/denseaffinity.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import math
|
| 4 |
+
from utils import segutils
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def buildHyperCol(feat_pyram):
|
| 8 |
+
# concatenate along channel dim
|
| 9 |
+
# upsample spatial size to largest feat vol space available
|
| 10 |
+
target_size = feat_pyram[0].shape[-2:]
|
| 11 |
+
upsampled = []
|
| 12 |
+
for layer in feat_pyram:
|
| 13 |
+
# if idx < self.stack_ids[0]: continue
|
| 14 |
+
upsampled.append(F.interpolate(layer, size=target_size, mode='bilinear', align_corners=False))
|
| 15 |
+
return torch.cat(upsampled, dim=1)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# accepts both:
|
| 19 |
+
# s_feat_vol: [bsz,k,c,h,w]->[bsz,c,h,w*k]
|
| 20 |
+
# s_mask: [bsz,k,h,w]->[bsz,h,w*k]
|
| 21 |
+
def paste_supports_together(supports):
|
| 22 |
+
return torch.cat(supports.unbind(dim=1), dim=-1)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# Attention regular:
|
| 26 |
+
# 1. Dot product
|
| 27 |
+
# 2. Divide by square root of key length (#nchannels)
|
| 28 |
+
# 3. Softmax
|
| 29 |
+
# 4. Multiply with V (mask)
|
| 30 |
+
|
| 31 |
+
def buildDenseAffinityMat(qfeat_volume, sfeat_volume, softmax_arg2=True): # bsz,C,H,W
|
| 32 |
+
qfeat_volume, sfeat_volume = qfeat_volume.permute(0, 2, 3, 1), sfeat_volume.permute(0, 2, 3, 1)
|
| 33 |
+
bsz, H, Wq, C = qfeat_volume.shape
|
| 34 |
+
Ws = sfeat_volume.shape[2]
|
| 35 |
+
# [px,C][C,px]=[px,px]
|
| 36 |
+
dense_affinity_mat = torch.matmul(qfeat_volume.view(bsz, H * Wq, C),
|
| 37 |
+
sfeat_volume.view(bsz, H * Ws, C).transpose(1, 2))
|
| 38 |
+
if softmax_arg2 is False: return dense_affinity_mat
|
| 39 |
+
dense_affinity_mat_softmax = (dense_affinity_mat / math.sqrt(C)).softmax(
|
| 40 |
+
dim=-1) # each query pixel's affinities sum up to 1 over support pxls
|
| 41 |
+
return dense_affinity_mat_softmax
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# filter with support mask following DAM
|
| 45 |
+
def filterDenseAffinityMap(dense_affinity_mat, downsampled_smask):
|
| 46 |
+
# for each query pixel, aggregate all correlations where the support mask ==1
|
| 47 |
+
# [px,px][px,1]=[px,1]
|
| 48 |
+
bsz, HWq, HWs = dense_affinity_mat.shape
|
| 49 |
+
# let mean(V)=1 -> sum(V)=len(V) -> d_mask / mean(d_mask)
|
| 50 |
+
# downsampled_smask_norm = downsampled_smask / downsampled_smask.mean()
|
| 51 |
+
q_coarse = torch.matmul(dense_affinity_mat, downsampled_smask.view(bsz, HWs, 1))
|
| 52 |
+
return q_coarse.view(bsz, HWq)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def upsample(volume, h, w):
|
| 56 |
+
return F.interpolate(volume, size=(h, w), mode='bilinear', align_corners=False)
|
| 57 |
+
|
| 58 |
+
class DAMatComparison:
|
| 59 |
+
|
| 60 |
+
def algo_mean(self, q_pred_coarses_t, s_mask=None):
|
| 61 |
+
return q_pred_coarses_t.mean(1)
|
| 62 |
+
|
| 63 |
+
def calc_q_pred_coarses(self, q_feat_t, s_feat_t, s_mask, l0=3):
|
| 64 |
+
q_pred_coarses = []
|
| 65 |
+
h0, w0 = q_feat_t[l0].shape[-2:]
|
| 66 |
+
for (qft, sft) in zip(q_feat_t[l0:], s_feat_t[l0:]):
|
| 67 |
+
qft, sft = qft.detach(), sft.detach()
|
| 68 |
+
bsz, c, hq, wq = qft.shape
|
| 69 |
+
hs, ws = sft.shape[-2:]
|
| 70 |
+
|
| 71 |
+
sft_row = torch.cat(sft.unbind(1), -1) # bsz,k,c,h,w -> bsz,c,h,w*k
|
| 72 |
+
smasks_downsampled = [segutils.downsample_mask(m, hs, ws) for m in s_mask.unbind(1)]
|
| 73 |
+
smask_row = torch.cat(smasks_downsampled, -1)
|
| 74 |
+
|
| 75 |
+
damat = buildDenseAffinityMat(qft, sft_row)
|
| 76 |
+
filtered = filterDenseAffinityMap(damat, smask_row)
|
| 77 |
+
q_pred_coarse = upsample(filtered.view(bsz, 1, hq, wq), h0, w0).squeeze(1)
|
| 78 |
+
q_pred_coarses.append(q_pred_coarse)
|
| 79 |
+
return torch.stack(q_pred_coarses, dim=1)
|
| 80 |
+
|
| 81 |
+
def forward(self, q_feat_t, s_feat_t, s_mask, upsample=True, debug=False):
|
| 82 |
+
q_pred_coarses_t = self.calc_q_pred_coarses(q_feat_t, s_feat_t, s_mask)
|
| 83 |
+
|
| 84 |
+
if debug: display(segutils.pilImageRow(*q_pred_coarses_t.unbind(1), q_pred_coarses_t.mean(1)))
|
| 85 |
+
|
| 86 |
+
# select the algorithm
|
| 87 |
+
postprocessing_algorithm = self.algo_mean
|
| 88 |
+
# do the postprocessing
|
| 89 |
+
logit_mask = postprocessing_algorithm(q_pred_coarses_t, s_mask)
|
| 90 |
+
if upsample: # if query and support have different shape, then you must do upsampling yourself afterwards
|
| 91 |
+
logit_mask = segutils.downsample_mask(logit_mask, *s_mask.shape[-2:])
|
| 92 |
+
|
| 93 |
+
return logit_mask
|
core/runner.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from data.dataset import FSSDataset
|
| 2 |
+
from core.backbone import Backbone
|
| 3 |
+
from eval.logger import Logger, AverageMeter
|
| 4 |
+
from eval.evaluation import Evaluator
|
| 5 |
+
from utils import commonutils as utils
|
| 6 |
+
import utils.segutils as segutils
|
| 7 |
+
import core.contrastivehead as ctrutils
|
| 8 |
+
import core.denseaffinity as dautils
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
class args:
|
| 12 |
+
backbone = 'resnet50'
|
| 13 |
+
logpath = '/kaggle/working/logs'
|
| 14 |
+
nworker = 0
|
| 15 |
+
bsz = 1
|
| 16 |
+
benchmark='' #e.g. deepglobe,isic,etc.
|
| 17 |
+
datapath='' #path to the selected dataset
|
| 18 |
+
fold = 0
|
| 19 |
+
nshot = 1
|
| 20 |
+
|
| 21 |
+
class SingleSampleEval:
|
| 22 |
+
def __init__(self, batch, feat_maker, debug=False):
|
| 23 |
+
self.damat_comp = dautils.DAMatComparison()
|
| 24 |
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 25 |
+
self.batch = batch
|
| 26 |
+
self.feat_maker = feat_maker
|
| 27 |
+
self.debug = debug
|
| 28 |
+
self.thresh_method = 'pred_mean'
|
| 29 |
+
|
| 30 |
+
def taskAdapt(self, detach=True):
|
| 31 |
+
b = self.batch
|
| 32 |
+
if self.device.type == 'cuda': b = utils.to_cuda(b)
|
| 33 |
+
self.q_img, self.s_img, self.s_mask, self.class_id = b['query_img'], b['support_imgs'], b['support_masks'], b[
|
| 34 |
+
'class_id'].item()
|
| 35 |
+
self.task_adapted = self.feat_maker.taskAdapt(self.q_img, self.s_img, self.s_mask, self.class_id)
|
| 36 |
+
|
| 37 |
+
def compare_feats(self):
|
| 38 |
+
if self.task_adapted is None:
|
| 39 |
+
print("error, do task adaption first")
|
| 40 |
+
return None
|
| 41 |
+
self.logit_mask = self.damat_comp.forward(self.task_adapted[0], self.task_adapted[1], self.s_mask)
|
| 42 |
+
return self.logit_mask
|
| 43 |
+
|
| 44 |
+
def threshold(self, method=None):
|
| 45 |
+
if self.logit_mask is None:
|
| 46 |
+
print("error, calculate logit mask first (do forward pass)")
|
| 47 |
+
if method is None:
|
| 48 |
+
method = self.thresh_method
|
| 49 |
+
self.thresh = calcthresh(self.logit_mask, self.s_mask, method)
|
| 50 |
+
self.pred_mask = (self.logit_mask > self.thresh).float()
|
| 51 |
+
return self.thresh, self.pred_mask
|
| 52 |
+
|
| 53 |
+
def apply_crf(self):
|
| 54 |
+
return apply_crf(self.q_img, self.logit_mask, thresh_fn(self.thresh_method))
|
| 55 |
+
|
| 56 |
+
# this method calls above components sequentially
|
| 57 |
+
def forward(self):
|
| 58 |
+
self.taskAdapt()
|
| 59 |
+
|
| 60 |
+
self.logit_mask = self.compare_feats()
|
| 61 |
+
|
| 62 |
+
self.thresh, self.pred_mask = self.threshold()
|
| 63 |
+
|
| 64 |
+
return self.logit_mask, self.pred_mask
|
| 65 |
+
|
| 66 |
+
def calc_metrics(self):
|
| 67 |
+
# assert torch.logical_or(self.logit_mask<0, self.logit_mask>1).sum()==0, display(tensor_table(logit_mask=self.logit_mask))
|
| 68 |
+
self.area_inter, self.area_union = Evaluator.classify_prediction(self.pred_mask, self.batch)
|
| 69 |
+
self.fgratio_pred = self.pred_mask.float().mean()
|
| 70 |
+
self.fgratio_gt = self.batch['query_mask'].float().mean()
|
| 71 |
+
return self.area_inter[1] / self.area_union[1] # fg-iou
|
| 72 |
+
|
| 73 |
+
def plots(self):
|
| 74 |
+
display(pilImageRow(norm(self.logit_mask[0]), (self.logit_mask[0] > self.thresh).float(), self.pred_mask,
|
| 75 |
+
self.batch['query_mask'][:1], norm(self.q_img[0]), norm(self.s_img[0, 0])))
|
| 76 |
+
display(segutils.tensor_table(probs=self.logit_mask))
|
| 77 |
+
|
| 78 |
+
print('s_mask.mean, pred_mask.mean, thresh:', self.s_mask.mean().item(), self.logit_mask.mean().item(),
|
| 79 |
+
self.thresh.item())
|
| 80 |
+
|
| 81 |
+
class AverageMeterWrapper:
|
| 82 |
+
def __init__(self, dataloader, device='cpu', initlogger=True):
|
| 83 |
+
if initlogger: Logger.initialize(args, training=False)
|
| 84 |
+
self.average_meter = AverageMeter(dataloader.dataset, device)
|
| 85 |
+
self.device=device
|
| 86 |
+
self.dataloader = dataloader
|
| 87 |
+
self.write_batch_idx = 50
|
| 88 |
+
def update(self, sseval):
|
| 89 |
+
self.average_meter.update(sseval.area_inter, sseval.area_union, torch.tensor(sseval.class_id).to(self.device), loss=None)
|
| 90 |
+
def update_manual(self, area_inter, area_union, class_id):
|
| 91 |
+
if isinstance(class_id, int): class_id = torch.tensor(class_id).to(self.device)
|
| 92 |
+
self.average_meter.update(area_inter, area_union, class_id, loss=None)
|
| 93 |
+
def write(self, i):
|
| 94 |
+
self.average_meter.write_process(i, len(self.dataloader), 0, self.write_batch_idx)
|
| 95 |
+
|
| 96 |
+
def makeDataloader():
|
| 97 |
+
|
| 98 |
+
FSSDataset.initialize(img_size=400, datapath=args.datapath)
|
| 99 |
+
dataloader = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'test', args.nshot)
|
| 100 |
+
return dataloader
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def makeConfig():
|
| 104 |
+
config = ctrutils.ContrastiveConfig()
|
| 105 |
+
config.fitting.protoloss = False
|
| 106 |
+
config.fitting.o_t_contr_proto_loss = True
|
| 107 |
+
config.fitting.selfattentionloss = False
|
| 108 |
+
config.fitting.keepvarloss = True
|
| 109 |
+
config.fitting.symmetricloss = False
|
| 110 |
+
config.fitting.q_nceloss = True
|
| 111 |
+
config.fitting.s_nceloss = True
|
| 112 |
+
config.fitting.num_epochs = 25
|
| 113 |
+
config.fitting.lr = 1e-2
|
| 114 |
+
config.fitting.debug = False
|
| 115 |
+
config.model.out_channels = 64
|
| 116 |
+
config.featext.fit_every_episode = False
|
| 117 |
+
config.aug.blurkernelsize = [1]
|
| 118 |
+
config.aug.n_transformed_imgs = 2
|
| 119 |
+
config.aug.maxjitter = 0.0
|
| 120 |
+
config.aug.maxangle = 0
|
| 121 |
+
config.aug.maxscale = 1
|
| 122 |
+
config.aug.maxshear = 20
|
| 123 |
+
config.aug.apply_affine = True
|
| 124 |
+
config.aug.debug = False
|
| 125 |
+
return config
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def makeFeatureMaker(dataset, config, device='cpu', randseed=2, feat_extr_method=None):
|
| 129 |
+
utils.fix_randseed(randseed)
|
| 130 |
+
if feat_extr_method is None:
|
| 131 |
+
feat_extr_method = Backbone(args.backbone).to(device).extract_feats
|
| 132 |
+
feat_maker = ctrutils.FeatureMaker(feat_extr_method, dataset.class_ids, config)
|
| 133 |
+
utils.fix_randseed(randseed)
|
| 134 |
+
feat_maker.norm_bb_feats = False
|
| 135 |
+
return feat_maker
|
| 136 |
+
def apply_crf(rgb_img, fg_pred, thresh_fn,iterations=5): #5 on deployment, 1 on support-aug test for speedup
|
| 137 |
+
crf = segutils.CRF(gaussian_stdxy=(1,1), gaussian_compat=2,
|
| 138 |
+
bilateral_stdxy=(35,35), bilateral_compat=1, stdrgb=(13,13,13))
|
| 139 |
+
q = crf.iterrefine(iterations, rgb_img, fg_pred, thresh_fn)
|
| 140 |
+
return q.argmax(1)
|
| 141 |
+
|
| 142 |
+
def calcthresh(fused_pred, s_masks, method='otsus'):
|
| 143 |
+
if method=='iterotsus':
|
| 144 |
+
thresh = segutils.iterative_otsus(fused_pred,s_masks,maxiters=5)[0]
|
| 145 |
+
return thresh
|
| 146 |
+
elif method=='1iterotsus':
|
| 147 |
+
thresh = segutils.iterative_otsus(fused_pred,s_masks,maxiters=1)[0]
|
| 148 |
+
return thresh
|
| 149 |
+
elif method=='otsus':
|
| 150 |
+
thresh = segutils.otsus(fused_pred)[0]
|
| 151 |
+
return thresh
|
| 152 |
+
# elif method=='via_triclass':
|
| 153 |
+
# thresh = segutils.otsus(fused_pred, mode='via_triclass')[0]
|
| 154 |
+
elif method=='pred_mean':
|
| 155 |
+
otsu_thresh = segutils.otsus(fused_pred)[0]
|
| 156 |
+
thresh = torch.max(otsu_thresh, fused_pred.mean())
|
| 157 |
+
# elif method=='3kmeans':
|
| 158 |
+
# k3 = segutils.KMeans(fused_pred.float().view(1,-1), k=3)
|
| 159 |
+
# thresh = k3.compute_thresholds()[0][-1]
|
| 160 |
+
return thresh
|
| 161 |
+
|
| 162 |
+
def thresh_fn(method):
|
| 163 |
+
def inner(fused_pred, s_masks=None):
|
| 164 |
+
return calcthresh(fused_pred, s_masks, method)
|
| 165 |
+
return inner
|
data/coco.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r""" COCO-20i few-shot semantic segmentation dataset """
|
| 2 |
+
import os
|
| 3 |
+
import pickle
|
| 4 |
+
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torch
|
| 8 |
+
import PIL.Image as Image
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DatasetCOCO(Dataset):
|
| 13 |
+
def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize=False):
|
| 14 |
+
self.split = 'val' if split in ['val', 'test'] else 'trn'
|
| 15 |
+
self.fold = fold
|
| 16 |
+
self.nfolds = 4
|
| 17 |
+
self.nclass = 80
|
| 18 |
+
self.benchmark = 'coco'
|
| 19 |
+
self.shot = shot
|
| 20 |
+
self.split_coco = split if split == 'val2014' else 'train2014'
|
| 21 |
+
self.base_path = os.path.join(datapath, 'COCO2014')
|
| 22 |
+
self.transform = transform
|
| 23 |
+
self.use_original_imgsize = use_original_imgsize
|
| 24 |
+
|
| 25 |
+
self.class_ids = self.build_class_ids()
|
| 26 |
+
self.img_metadata_classwise = self.build_img_metadata_classwise()
|
| 27 |
+
self.img_metadata = self.build_img_metadata()
|
| 28 |
+
|
| 29 |
+
def __len__(self):
|
| 30 |
+
return len(self.img_metadata) if self.split == 'trn' else 1000
|
| 31 |
+
|
| 32 |
+
def __getitem__(self, idx):
|
| 33 |
+
# ignores idx during training & testing and perform uniform sampling over object classes to form an episode
|
| 34 |
+
# (due to the large size of the COCO dataset)
|
| 35 |
+
query_img, query_mask, support_imgs, support_masks, query_name, support_names, class_sample, org_qry_imsize = self.load_frame()
|
| 36 |
+
|
| 37 |
+
query_img = self.transform(query_img)
|
| 38 |
+
query_mask = query_mask.float()
|
| 39 |
+
if not self.use_original_imgsize:
|
| 40 |
+
query_mask = F.interpolate(query_mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze()
|
| 41 |
+
|
| 42 |
+
support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs])
|
| 43 |
+
for midx, smask in enumerate(support_masks):
|
| 44 |
+
support_masks[midx] = F.interpolate(smask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze()
|
| 45 |
+
support_masks = torch.stack(support_masks)
|
| 46 |
+
|
| 47 |
+
batch = {'query_img': query_img,
|
| 48 |
+
'query_mask': query_mask,
|
| 49 |
+
'query_name': query_name,
|
| 50 |
+
|
| 51 |
+
'org_query_imsize': org_qry_imsize,
|
| 52 |
+
|
| 53 |
+
'support_imgs': support_imgs,
|
| 54 |
+
'support_masks': support_masks,
|
| 55 |
+
'support_names': support_names,
|
| 56 |
+
'class_id': torch.tensor(class_sample)}
|
| 57 |
+
|
| 58 |
+
return batch
|
| 59 |
+
|
| 60 |
+
def build_class_ids(self):
|
| 61 |
+
nclass_trn = self.nclass // self.nfolds
|
| 62 |
+
class_ids_val = [self.fold + self.nfolds * v for v in range(nclass_trn)]
|
| 63 |
+
class_ids_trn = [x for x in range(self.nclass) if x not in class_ids_val]
|
| 64 |
+
class_ids = class_ids_trn if self.split == 'trn' else class_ids_val
|
| 65 |
+
|
| 66 |
+
return class_ids
|
| 67 |
+
|
| 68 |
+
def build_img_metadata_classwise(self):
|
| 69 |
+
with open('./data/splits/coco/%s/fold%d.pkl' % (self.split, self.fold), 'rb') as f:
|
| 70 |
+
img_metadata_classwise = pickle.load(f)
|
| 71 |
+
return img_metadata_classwise
|
| 72 |
+
|
| 73 |
+
def build_img_metadata(self):
|
| 74 |
+
img_metadata = []
|
| 75 |
+
for k in self.img_metadata_classwise.keys():
|
| 76 |
+
img_metadata += self.img_metadata_classwise[k]
|
| 77 |
+
return sorted(list(set(img_metadata)))
|
| 78 |
+
|
| 79 |
+
def read_mask(self, name):
|
| 80 |
+
mask_path = os.path.join(self.base_path, 'annotations', name)
|
| 81 |
+
mask = torch.tensor(np.array(Image.open(mask_path[:mask_path.index('.jpg')] + '.png')))
|
| 82 |
+
return mask
|
| 83 |
+
|
| 84 |
+
def load_frame(self):
|
| 85 |
+
class_sample = np.random.choice(self.class_ids, 1, replace=False)[0]
|
| 86 |
+
query_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
|
| 87 |
+
query_img = Image.open(os.path.join(self.base_path, query_name)).convert('RGB')
|
| 88 |
+
query_mask = self.read_mask(query_name)
|
| 89 |
+
|
| 90 |
+
org_qry_imsize = query_img.size
|
| 91 |
+
|
| 92 |
+
query_mask[query_mask != class_sample + 1] = 0
|
| 93 |
+
query_mask[query_mask == class_sample + 1] = 1
|
| 94 |
+
|
| 95 |
+
support_names = []
|
| 96 |
+
while True: # keep sampling support set if query == support
|
| 97 |
+
support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
|
| 98 |
+
if query_name != support_name: support_names.append(support_name)
|
| 99 |
+
if len(support_names) == self.shot: break
|
| 100 |
+
|
| 101 |
+
support_imgs = []
|
| 102 |
+
support_masks = []
|
| 103 |
+
for support_name in support_names:
|
| 104 |
+
support_imgs.append(Image.open(os.path.join(self.base_path, support_name)).convert('RGB'))
|
| 105 |
+
support_mask = self.read_mask(support_name)
|
| 106 |
+
support_mask[support_mask != class_sample + 1] = 0
|
| 107 |
+
support_mask[support_mask == class_sample + 1] = 1
|
| 108 |
+
support_masks.append(support_mask)
|
| 109 |
+
|
| 110 |
+
return query_img, query_mask, support_imgs, support_masks, query_name, support_names, class_sample, org_qry_imsize
|
| 111 |
+
|
data/dataset.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r""" Dataloader builder for few-shot semantic segmentation dataset """
|
| 2 |
+
from torch.utils.data.distributed import DistributedSampler as Sampler
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
|
| 6 |
+
from data.pascal import DatasetPASCAL
|
| 7 |
+
from data.coco import DatasetCOCO
|
| 8 |
+
from data.fss import DatasetFSS
|
| 9 |
+
from data.deepglobe import DatasetDeepglobe
|
| 10 |
+
from data.isic import DatasetISIC
|
| 11 |
+
from data.lung import DatasetLung
|
| 12 |
+
from data.fss import DatasetFSS
|
| 13 |
+
from data.suim import DatasetSUIM
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class FSSDataset:
|
| 17 |
+
|
| 18 |
+
@classmethod
|
| 19 |
+
def initialize(cls, img_size, datapath):
|
| 20 |
+
|
| 21 |
+
cls.datasets = {
|
| 22 |
+
'pascal': DatasetPASCAL,
|
| 23 |
+
'coco': DatasetCOCO,
|
| 24 |
+
'fss': DatasetFSS,
|
| 25 |
+
'deepglobe': DatasetDeepglobe,
|
| 26 |
+
'isic': DatasetISIC,
|
| 27 |
+
'lung': DatasetLung,
|
| 28 |
+
'suim': DatasetSUIM
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
cls.img_mean = [0.485, 0.456, 0.406]
|
| 32 |
+
cls.img_std = [0.229, 0.224, 0.225]
|
| 33 |
+
cls.datapath = datapath
|
| 34 |
+
|
| 35 |
+
cls.transform = transforms.Compose([transforms.Resize(size=(img_size, img_size)),
|
| 36 |
+
transforms.ToTensor(),
|
| 37 |
+
transforms.Normalize(cls.img_mean, cls.img_std)])
|
| 38 |
+
|
| 39 |
+
@classmethod
|
| 40 |
+
def build_dataloader(cls, benchmark, bsz, nworker, fold, split, shot=1):
|
| 41 |
+
nworker = nworker if split == 'trn' else 0
|
| 42 |
+
|
| 43 |
+
dataset = cls.datasets[benchmark](cls.datapath, fold=fold,
|
| 44 |
+
transform=cls.transform,
|
| 45 |
+
split=split, shot=shot)
|
| 46 |
+
# Force randomness during training for diverse episode combinations
|
| 47 |
+
# Freeze randomness during testing for reproducibility
|
| 48 |
+
#train_sampler = Sampler(dataset) if split == 'trn' else None
|
| 49 |
+
dataloader = DataLoader(dataset, batch_size=bsz, shuffle=split=='trn', num_workers=nworker,
|
| 50 |
+
pin_memory=True)
|
| 51 |
+
|
| 52 |
+
return dataloader
|
data/deepglobe.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r""" FSS-1000 few-shot semantic segmentation dataset """
|
| 2 |
+
import os
|
| 3 |
+
import glob
|
| 4 |
+
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torch
|
| 8 |
+
import PIL.Image as Image
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DatasetDeepglobe(Dataset):
|
| 13 |
+
def __init__(self, datapath, fold, transform, split, shot, num_val=600):
|
| 14 |
+
self.split = split
|
| 15 |
+
self.benchmark = 'deepglobe'
|
| 16 |
+
self.shot = shot
|
| 17 |
+
self.num_val = num_val
|
| 18 |
+
|
| 19 |
+
self.base_path = os.path.join(datapath)
|
| 20 |
+
self.to_annpath = lambda p: p.replace('jpg', 'png').replace('origin', 'groundtruth')
|
| 21 |
+
|
| 22 |
+
self.categories = ['1','2','3','4','5','6']
|
| 23 |
+
|
| 24 |
+
self.class_ids = range(0, 6)
|
| 25 |
+
self.img_metadata_classwise, self.num_images = self.build_img_metadata_classwise()
|
| 26 |
+
|
| 27 |
+
self.transform = transform
|
| 28 |
+
|
| 29 |
+
def __len__(self):
|
| 30 |
+
# if it is the target domain, then also test on entire dataset
|
| 31 |
+
return self.num_images if self.split !='val' else self.num_val
|
| 32 |
+
|
| 33 |
+
def __getitem__(self, idx):
|
| 34 |
+
query_name, support_names, class_sample = self.sample_episode(idx)
|
| 35 |
+
query_img, query_mask, support_imgs, support_masks = self.load_frame(query_name, support_names)
|
| 36 |
+
|
| 37 |
+
query_img = self.transform(query_img)
|
| 38 |
+
query_mask = F.interpolate(query_mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze()
|
| 39 |
+
|
| 40 |
+
support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs])
|
| 41 |
+
|
| 42 |
+
support_masks_tmp = []
|
| 43 |
+
for smask in support_masks:
|
| 44 |
+
smask = F.interpolate(smask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze()
|
| 45 |
+
support_masks_tmp.append(smask)
|
| 46 |
+
support_masks = torch.stack(support_masks_tmp)
|
| 47 |
+
|
| 48 |
+
batch = {'query_img': query_img,
|
| 49 |
+
'query_mask': query_mask,
|
| 50 |
+
'support_set': (support_imgs, support_masks),
|
| 51 |
+
'support_classes': torch.tensor([class_sample]), # adapt to Nway
|
| 52 |
+
|
| 53 |
+
'query_name': query_name, # REMOVE
|
| 54 |
+
'support_imgs': support_imgs, # REMOVE
|
| 55 |
+
'support_masks': support_masks, # REMOVE
|
| 56 |
+
'support_names': support_names, # REMOVE
|
| 57 |
+
'class_id': torch.tensor(class_sample)} # REMOVE
|
| 58 |
+
|
| 59 |
+
return batch
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def load_frame(self, query_name, support_names):
|
| 63 |
+
query_img = Image.open(query_name).convert('RGB')
|
| 64 |
+
support_imgs = [Image.open(name).convert('RGB') for name in support_names]
|
| 65 |
+
|
| 66 |
+
query_id = query_name.split('/')[-1].split('.')[0]
|
| 67 |
+
ann_path = os.path.join(self.base_path, query_name.split('/')[-4], 'test', 'groundtruth')
|
| 68 |
+
query_name = os.path.join(ann_path, query_id) + '.png'
|
| 69 |
+
support_ids = [name.split('/')[-1].split('.')[0] for name in support_names]
|
| 70 |
+
support_names = [os.path.join(ann_path, sid) + '.png' for name, sid in zip(support_names, support_ids)]
|
| 71 |
+
|
| 72 |
+
query_mask = self.read_mask(query_name)
|
| 73 |
+
support_masks = [self.read_mask(name) for name in support_names]
|
| 74 |
+
|
| 75 |
+
return query_img, query_mask, support_imgs, support_masks
|
| 76 |
+
|
| 77 |
+
def read_mask(self, img_name):
|
| 78 |
+
mask = torch.tensor(np.array(Image.open(img_name).convert('L')))
|
| 79 |
+
mask[mask < 128] = 0
|
| 80 |
+
mask[mask >= 128] = 1
|
| 81 |
+
return mask
|
| 82 |
+
|
| 83 |
+
def sample_episode(self, idx):
|
| 84 |
+
class_id = idx % len(self.class_ids)
|
| 85 |
+
class_sample = self.categories[class_id]
|
| 86 |
+
|
| 87 |
+
query_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
|
| 88 |
+
support_names = []
|
| 89 |
+
while True: # keep sampling support set if query == support
|
| 90 |
+
support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
|
| 91 |
+
if query_name != support_name: support_names.append(support_name)
|
| 92 |
+
if len(support_names) == self.shot: break
|
| 93 |
+
|
| 94 |
+
return query_name, support_names, class_id
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# def build_img_metadata(self):
|
| 98 |
+
# img_metadata = []
|
| 99 |
+
# for cat in self.categories:
|
| 100 |
+
# os.path.join(self.base_path, cat)
|
| 101 |
+
# img_paths = sorted([path for path in glob.glob('%s/*' % os.path.join(self.base_path, cat, 'test', 'origin'))])
|
| 102 |
+
# for img_path in img_paths:
|
| 103 |
+
# if os.path.basename(img_path).split('.')[1] == 'jpg':
|
| 104 |
+
# img_metadata.append(img_path)
|
| 105 |
+
# return img_metadata
|
| 106 |
+
|
| 107 |
+
def build_img_metadata_classwise(self):
|
| 108 |
+
num_images=0
|
| 109 |
+
img_metadata_classwise = {}
|
| 110 |
+
for cat in self.categories:
|
| 111 |
+
img_metadata_classwise[cat] = []
|
| 112 |
+
|
| 113 |
+
for cat in self.categories:
|
| 114 |
+
img_paths = sorted([path for path in glob.glob('%s/*' % os.path.join(self.base_path, cat, 'test', 'origin'))])
|
| 115 |
+
for img_path in img_paths:
|
| 116 |
+
if os.path.basename(img_path).split('.')[1] == 'jpg':
|
| 117 |
+
img_metadata_classwise[cat] += [img_path]
|
| 118 |
+
num_images += 1
|
| 119 |
+
return img_metadata_classwise, num_images
|
data/fss.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r""" FSS-1000 few-shot semantic segmentation dataset """
|
| 2 |
+
import os
|
| 3 |
+
import glob
|
| 4 |
+
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torch
|
| 8 |
+
import PIL.Image as Image
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DatasetFSS(Dataset):
|
| 13 |
+
def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize=False):
|
| 14 |
+
self.split = split
|
| 15 |
+
self.benchmark = 'fss'
|
| 16 |
+
self.shot = shot
|
| 17 |
+
|
| 18 |
+
self.base_path = os.path.join(datapath, 'FSS-1000')
|
| 19 |
+
|
| 20 |
+
# Given predefined test split, load randomly generated training/val splits:
|
| 21 |
+
# (reference regarding trn/val/test splits: https://github.com/HKUSTCV/FSS-1000/issues/7))
|
| 22 |
+
with open('./data/splits/fss/%s.txt' % split, 'r') as f:
|
| 23 |
+
self.categories = f.read().split('\n')[:-1]
|
| 24 |
+
self.categories = sorted(self.categories)
|
| 25 |
+
|
| 26 |
+
self.class_ids = self.build_class_ids()
|
| 27 |
+
self.img_metadata = self.build_img_metadata()
|
| 28 |
+
|
| 29 |
+
self.transform = transform
|
| 30 |
+
|
| 31 |
+
def __len__(self):
|
| 32 |
+
return len(self.img_metadata)
|
| 33 |
+
|
| 34 |
+
def __getitem__(self, idx):
|
| 35 |
+
query_name, support_names, class_sample = self.sample_episode(idx)
|
| 36 |
+
query_img, query_mask, support_imgs, support_masks = self.load_frame(query_name, support_names)
|
| 37 |
+
|
| 38 |
+
query_img = self.transform(query_img)
|
| 39 |
+
query_mask = F.interpolate(query_mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze()
|
| 40 |
+
|
| 41 |
+
support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs])
|
| 42 |
+
|
| 43 |
+
support_masks_tmp = []
|
| 44 |
+
for smask in support_masks:
|
| 45 |
+
smask = F.interpolate(smask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze()
|
| 46 |
+
support_masks_tmp.append(smask)
|
| 47 |
+
support_masks = torch.stack(support_masks_tmp)
|
| 48 |
+
|
| 49 |
+
batch = {'query_img': query_img,
|
| 50 |
+
'query_mask': query_mask,
|
| 51 |
+
'query_name': query_name,
|
| 52 |
+
|
| 53 |
+
'support_imgs': support_imgs,
|
| 54 |
+
'support_masks': support_masks,
|
| 55 |
+
'support_names': support_names,
|
| 56 |
+
|
| 57 |
+
'class_id': torch.tensor(class_sample)}
|
| 58 |
+
|
| 59 |
+
return batch
|
| 60 |
+
|
| 61 |
+
def load_frame(self, query_name, support_names):
|
| 62 |
+
query_img = Image.open(query_name).convert('RGB')
|
| 63 |
+
support_imgs = [Image.open(name).convert('RGB') for name in support_names]
|
| 64 |
+
|
| 65 |
+
query_id = query_name.split('/')[-1].split('.')[0]
|
| 66 |
+
query_name = os.path.join(os.path.dirname(query_name), query_id) + '.png'
|
| 67 |
+
support_ids = [name.split('/')[-1].split('.')[0] for name in support_names]
|
| 68 |
+
support_names = [os.path.join(os.path.dirname(name), sid) + '.png' for name, sid in zip(support_names, support_ids)]
|
| 69 |
+
|
| 70 |
+
query_mask = self.read_mask(query_name)
|
| 71 |
+
support_masks = [self.read_mask(name) for name in support_names]
|
| 72 |
+
|
| 73 |
+
return query_img, query_mask, support_imgs, support_masks
|
| 74 |
+
|
| 75 |
+
def read_mask(self, img_name):
|
| 76 |
+
mask = torch.tensor(np.array(Image.open(img_name).convert('L')))
|
| 77 |
+
mask[mask < 128] = 0
|
| 78 |
+
mask[mask >= 128] = 1
|
| 79 |
+
return mask
|
| 80 |
+
|
| 81 |
+
def sample_episode(self, idx):
|
| 82 |
+
query_name = self.img_metadata[idx]
|
| 83 |
+
class_sample = self.categories.index(query_name.split('/')[-2])
|
| 84 |
+
if self.split == 'val':
|
| 85 |
+
class_sample += 520
|
| 86 |
+
elif self.split == 'test':
|
| 87 |
+
class_sample += 760
|
| 88 |
+
|
| 89 |
+
support_names = []
|
| 90 |
+
while True: # keep sampling support set if query == support
|
| 91 |
+
support_name = np.random.choice(range(1, 11), 1, replace=False)[0]
|
| 92 |
+
support_name = os.path.join(os.path.dirname(query_name), str(support_name)) + '.jpg'
|
| 93 |
+
if query_name != support_name: support_names.append(support_name)
|
| 94 |
+
if len(support_names) == self.shot: break
|
| 95 |
+
|
| 96 |
+
return query_name, support_names, class_sample
|
| 97 |
+
|
| 98 |
+
def build_class_ids(self):
|
| 99 |
+
if self.split == 'trn':
|
| 100 |
+
class_ids = range(0, 520)
|
| 101 |
+
elif self.split == 'val':
|
| 102 |
+
class_ids = range(520, 760)
|
| 103 |
+
elif self.split == 'test':
|
| 104 |
+
class_ids = range(760, 1000)
|
| 105 |
+
return class_ids
|
| 106 |
+
|
| 107 |
+
def build_img_metadata(self):
|
| 108 |
+
img_metadata = []
|
| 109 |
+
for cat in self.categories:
|
| 110 |
+
img_paths = sorted([path for path in glob.glob('%s/*' % os.path.join(self.base_path, cat))])
|
| 111 |
+
for img_path in img_paths:
|
| 112 |
+
if os.path.basename(img_path).split('.')[1] == 'jpg':
|
| 113 |
+
img_metadata.append(img_path)
|
| 114 |
+
return img_metadata
|
data/isic.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r""" ISIC few-shot semantic segmentation dataset """
|
| 2 |
+
import os
|
| 3 |
+
import glob
|
| 4 |
+
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torch
|
| 8 |
+
import PIL.Image as Image
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DatasetISIC(Dataset):
|
| 13 |
+
def __init__(self, datapath, fold, transform, split, shot, num_val=600):
|
| 14 |
+
self.split = split
|
| 15 |
+
self.benchmark = 'isic'
|
| 16 |
+
self.shot = shot
|
| 17 |
+
self.num_val = num_val
|
| 18 |
+
|
| 19 |
+
self.base_path = os.path.join(datapath)
|
| 20 |
+
self.categories = ['1', '2', '3']
|
| 21 |
+
|
| 22 |
+
self.class_ids = range(0, 3)
|
| 23 |
+
self.img_metadata_classwise,self.num_images = self.build_img_metadata_classwise()
|
| 24 |
+
|
| 25 |
+
self.transform = transform
|
| 26 |
+
|
| 27 |
+
def __len__(self):
|
| 28 |
+
return self.num_images if self.split != 'val' else self.num_val
|
| 29 |
+
|
| 30 |
+
def __getitem__(self, idx):
|
| 31 |
+
query_name, support_names, class_sample = self.sample_episode(idx)
|
| 32 |
+
query_img, query_mask, support_imgs, support_masks = self.load_frame(query_name, support_names)
|
| 33 |
+
|
| 34 |
+
query_img = self.transform(query_img)
|
| 35 |
+
query_mask = F.interpolate(query_mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze()
|
| 36 |
+
|
| 37 |
+
support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs])
|
| 38 |
+
|
| 39 |
+
support_masks_tmp = []
|
| 40 |
+
for smask in support_masks:
|
| 41 |
+
smask = F.interpolate(smask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze()
|
| 42 |
+
support_masks_tmp.append(smask)
|
| 43 |
+
support_masks = torch.stack(support_masks_tmp)
|
| 44 |
+
|
| 45 |
+
batch = {'query_img': query_img,
|
| 46 |
+
'query_mask': query_mask,
|
| 47 |
+
'query_name': query_name,
|
| 48 |
+
|
| 49 |
+
'support_imgs': support_imgs,
|
| 50 |
+
'support_masks': support_masks,
|
| 51 |
+
'support_names': support_names,
|
| 52 |
+
|
| 53 |
+
'class_id': torch.tensor(class_sample)}
|
| 54 |
+
|
| 55 |
+
return batch
|
| 56 |
+
|
| 57 |
+
def load_frame(self, query_name, support_names):
|
| 58 |
+
query_img = Image.open(query_name).convert('RGB')
|
| 59 |
+
support_imgs = [Image.open(name).convert('RGB') for name in support_names]
|
| 60 |
+
|
| 61 |
+
query_id = query_name.split('/')[-1].split('.')[0]
|
| 62 |
+
ann_path = os.path.join(self.base_path, 'ISIC2018_Task1_Training_GroundTruth')
|
| 63 |
+
query_name = os.path.join(ann_path, query_id) + '_segmentation.png'
|
| 64 |
+
support_ids = [name.split('/')[-1].split('.')[0] for name in support_names]
|
| 65 |
+
support_names = [os.path.join(ann_path, sid) + '_segmentation.png' for name, sid in zip(support_names, support_ids)]
|
| 66 |
+
|
| 67 |
+
query_mask = self.read_mask(query_name)
|
| 68 |
+
support_masks = [self.read_mask(name) for name in support_names]
|
| 69 |
+
|
| 70 |
+
return query_img, query_mask, support_imgs, support_masks
|
| 71 |
+
|
| 72 |
+
def read_mask(self, img_name):
|
| 73 |
+
mask = torch.tensor(np.array(Image.open(img_name).convert('L')))
|
| 74 |
+
mask[mask < 128] = 0
|
| 75 |
+
mask[mask >= 128] = 1
|
| 76 |
+
return mask
|
| 77 |
+
|
| 78 |
+
def sample_episode(self, idx):
|
| 79 |
+
class_id = idx % len(self.class_ids)
|
| 80 |
+
class_sample = self.categories[class_id]
|
| 81 |
+
|
| 82 |
+
query_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
|
| 83 |
+
support_names = []
|
| 84 |
+
while True: # keep sampling support set if query == support
|
| 85 |
+
support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
|
| 86 |
+
if query_name != support_name: support_names.append(support_name)
|
| 87 |
+
if len(support_names) == self.shot: break
|
| 88 |
+
|
| 89 |
+
return query_name, support_names, class_id
|
| 90 |
+
|
| 91 |
+
def build_img_metadata(self):
|
| 92 |
+
img_metadata = []
|
| 93 |
+
for cat in self.categories:
|
| 94 |
+
os.path.join(self.base_path, cat)
|
| 95 |
+
img_paths = sorted([path for path in glob.glob('%s/*' % os.path.join(self.base_path, 'ISIC2018_Task1-2_Training_Input', cat))])
|
| 96 |
+
for img_path in img_paths:
|
| 97 |
+
if os.path.basename(img_path).split('.')[1] == 'jpg':
|
| 98 |
+
img_metadata.append(img_path)
|
| 99 |
+
return img_metadata
|
| 100 |
+
|
| 101 |
+
def build_img_metadata_classwise(self):
|
| 102 |
+
num_images=0
|
| 103 |
+
img_metadata_classwise = {}
|
| 104 |
+
for cat in self.categories:
|
| 105 |
+
img_metadata_classwise[cat] = []
|
| 106 |
+
|
| 107 |
+
for cat in self.categories:
|
| 108 |
+
img_paths = sorted([path for path in glob.glob('%s/*' % os.path.join(self.base_path, 'ISIC2018_Task1-2_Training_Input', cat))])
|
| 109 |
+
for img_path in img_paths:
|
| 110 |
+
if os.path.basename(img_path).split('.')[1] == 'jpg':
|
| 111 |
+
img_metadata_classwise[cat] += [img_path]
|
| 112 |
+
num_images += 1
|
| 113 |
+
return img_metadata_classwise, num_images
|
data/lung.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r""" Chest X-ray few-shot semantic segmentation dataset """
|
| 2 |
+
import os
|
| 3 |
+
import glob
|
| 4 |
+
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torch
|
| 8 |
+
import PIL.Image as Image
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DatasetLung(Dataset):
|
| 13 |
+
def __init__(self, datapath, fold, transform, split, shot=1, num_val=600):
|
| 14 |
+
self.benchmark = 'lung'
|
| 15 |
+
self.shot = shot
|
| 16 |
+
self.split = split
|
| 17 |
+
self.num_val = num_val
|
| 18 |
+
|
| 19 |
+
self.base_path = os.path.join(datapath)
|
| 20 |
+
self.img_path = os.path.join(self.base_path, 'CXR_png')
|
| 21 |
+
self.ann_path = os.path.join(self.base_path, 'masks')
|
| 22 |
+
|
| 23 |
+
self.categories = ['1']
|
| 24 |
+
|
| 25 |
+
self.class_ids = range(0, 1)
|
| 26 |
+
self.img_metadata_classwise, self.num_images = self.build_img_metadata_classwise()
|
| 27 |
+
|
| 28 |
+
self.transform = transform
|
| 29 |
+
|
| 30 |
+
def __len__(self):
|
| 31 |
+
return self.num_images if self.split != 'val' else self.num_val
|
| 32 |
+
|
| 33 |
+
def __getitem__(self, idx):
|
| 34 |
+
query_name, support_names, class_sample = self.sample_episode(idx)
|
| 35 |
+
query_img, query_mask, support_imgs, support_masks = self.load_frame(query_name, support_names)
|
| 36 |
+
|
| 37 |
+
query_img = self.transform(query_img)
|
| 38 |
+
query_mask = F.interpolate(query_mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze()
|
| 39 |
+
|
| 40 |
+
support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs])
|
| 41 |
+
|
| 42 |
+
support_masks_tmp = []
|
| 43 |
+
for smask in support_masks:
|
| 44 |
+
smask = F.interpolate(smask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze()
|
| 45 |
+
support_masks_tmp.append(smask)
|
| 46 |
+
support_masks = torch.stack(support_masks_tmp)
|
| 47 |
+
|
| 48 |
+
batch = {'query_img': query_img,
|
| 49 |
+
'query_mask': query_mask,
|
| 50 |
+
'query_name': query_name,
|
| 51 |
+
'support_imgs': support_imgs,
|
| 52 |
+
'support_masks': support_masks,
|
| 53 |
+
'class_id': torch.tensor(class_sample),
|
| 54 |
+
'support_names': support_names,
|
| 55 |
+
|
| 56 |
+
'support_set': [support_imgs, support_masks],
|
| 57 |
+
'support_classes': torch.tensor([class_sample])
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
return batch
|
| 61 |
+
|
| 62 |
+
def load_frame(self, query_name, support_names):
|
| 63 |
+
query_mask = self.read_mask(query_name)
|
| 64 |
+
support_masks = [self.read_mask(name) for name in support_names]
|
| 65 |
+
|
| 66 |
+
query_id = query_name[:-9] + '.png'
|
| 67 |
+
query_img = Image.open(os.path.join(self.img_path, os.path.basename(query_id))).convert('RGB')
|
| 68 |
+
|
| 69 |
+
support_ids = [os.path.basename(name)[:-9] + '.png' for name in support_names]
|
| 70 |
+
support_names = [os.path.join(self.img_path, sid) for sid in support_ids]
|
| 71 |
+
support_imgs = [Image.open(name).convert('RGB') for name in support_names]
|
| 72 |
+
|
| 73 |
+
return query_img, query_mask, support_imgs, support_masks
|
| 74 |
+
|
| 75 |
+
def read_mask(self, img_name):
|
| 76 |
+
mask = torch.tensor(np.array(Image.open(img_name).convert('L')))
|
| 77 |
+
mask[mask < 128] = 0
|
| 78 |
+
mask[mask >= 128] = 1
|
| 79 |
+
return mask
|
| 80 |
+
|
| 81 |
+
def sample_episode(self, idx):
|
| 82 |
+
class_id = idx % len(self.class_ids)
|
| 83 |
+
class_sample = self.categories[class_id]
|
| 84 |
+
|
| 85 |
+
query_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
|
| 86 |
+
support_names = []
|
| 87 |
+
while True: # keep sampling support set if query == support
|
| 88 |
+
support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
|
| 89 |
+
if query_name != support_name: support_names.append(support_name)
|
| 90 |
+
if len(support_names) == self.shot: break
|
| 91 |
+
|
| 92 |
+
return query_name, support_names, class_id
|
| 93 |
+
|
| 94 |
+
def build_img_metadata(self):
|
| 95 |
+
img_metadata = []
|
| 96 |
+
for cat in self.categories:
|
| 97 |
+
os.path.join(self.base_path, cat)
|
| 98 |
+
img_paths = sorted([path for path in glob.glob('%s/*' % os.path.join(self.img_path, cat))])
|
| 99 |
+
for img_path in img_paths:
|
| 100 |
+
if os.path.basename(img_path).split('.')[1] == 'png':
|
| 101 |
+
img_metadata.append(img_path)
|
| 102 |
+
return img_metadata
|
| 103 |
+
|
| 104 |
+
def build_img_metadata_classwise(self):
|
| 105 |
+
num_images=0
|
| 106 |
+
img_metadata_classwise = {}
|
| 107 |
+
for cat in self.categories:
|
| 108 |
+
img_metadata_classwise[cat] = []
|
| 109 |
+
|
| 110 |
+
for cat in self.categories:
|
| 111 |
+
img_paths = sorted([path for path in glob.glob('%s/*' % self.ann_path)])
|
| 112 |
+
for img_path in img_paths:
|
| 113 |
+
if os.path.basename(img_path).split('.')[1] == 'png':
|
| 114 |
+
img_metadata_classwise[cat] += [img_path]
|
| 115 |
+
num_images+=1
|
| 116 |
+
return img_metadata_classwise, num_images
|
data/pascal.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r""" PASCAL-5i few-shot semantic segmentation dataset """
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from torch.utils.data import Dataset
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import torch
|
| 7 |
+
import PIL.Image as Image
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DatasetPASCAL(Dataset):
|
| 12 |
+
def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize=False):
|
| 13 |
+
self.split = 'val' if split in ['val', 'test'] else 'trn'
|
| 14 |
+
self.fold = fold
|
| 15 |
+
self.nfolds = 4
|
| 16 |
+
self.nclass = 20
|
| 17 |
+
self.benchmark = 'pascal'
|
| 18 |
+
self.shot = shot
|
| 19 |
+
self.use_original_imgsize = use_original_imgsize
|
| 20 |
+
|
| 21 |
+
self.img_path = os.path.join(datapath, 'VOC2012/JPEGImages/')
|
| 22 |
+
self.ann_path = os.path.join(datapath, 'VOC2012/SegmentationClassAug/')
|
| 23 |
+
self.transform = transform
|
| 24 |
+
|
| 25 |
+
self.class_ids = self.build_class_ids()
|
| 26 |
+
self.img_metadata = self.build_img_metadata()
|
| 27 |
+
self.img_metadata_classwise = self.build_img_metadata_classwise()
|
| 28 |
+
|
| 29 |
+
def __len__(self):
|
| 30 |
+
return len(self.img_metadata) if self.split == 'trn' else 1000
|
| 31 |
+
|
| 32 |
+
def __getitem__(self, idx):
|
| 33 |
+
idx %= len(self.img_metadata) # for testing, as n_images < 1000
|
| 34 |
+
query_name, support_names, class_sample = self.sample_episode(idx)
|
| 35 |
+
query_img, query_cmask, support_imgs, support_cmasks, org_qry_imsize = self.load_frame(query_name, support_names)
|
| 36 |
+
|
| 37 |
+
query_img = self.transform(query_img)
|
| 38 |
+
if not self.use_original_imgsize:
|
| 39 |
+
query_cmask = F.interpolate(query_cmask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze()
|
| 40 |
+
query_mask, query_ignore_idx = self.extract_ignore_idx(query_cmask.float(), class_sample)
|
| 41 |
+
|
| 42 |
+
support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs])
|
| 43 |
+
|
| 44 |
+
support_masks = []
|
| 45 |
+
support_ignore_idxs = []
|
| 46 |
+
for scmask in support_cmasks:
|
| 47 |
+
scmask = F.interpolate(scmask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze()
|
| 48 |
+
support_mask, support_ignore_idx = self.extract_ignore_idx(scmask, class_sample)
|
| 49 |
+
support_masks.append(support_mask)
|
| 50 |
+
support_ignore_idxs.append(support_ignore_idx)
|
| 51 |
+
support_masks = torch.stack(support_masks)
|
| 52 |
+
support_ignore_idxs = torch.stack(support_ignore_idxs)
|
| 53 |
+
|
| 54 |
+
batch = {'query_img': query_img,
|
| 55 |
+
'query_mask': query_mask,
|
| 56 |
+
'query_name': query_name,
|
| 57 |
+
'query_ignore_idx': query_ignore_idx,
|
| 58 |
+
|
| 59 |
+
'org_query_imsize': org_qry_imsize,
|
| 60 |
+
|
| 61 |
+
'support_imgs': support_imgs,
|
| 62 |
+
'support_masks': support_masks,
|
| 63 |
+
'support_names': support_names,
|
| 64 |
+
'support_ignore_idxs': support_ignore_idxs,
|
| 65 |
+
|
| 66 |
+
'class_id': torch.tensor(class_sample)}
|
| 67 |
+
|
| 68 |
+
return batch
|
| 69 |
+
|
| 70 |
+
def extract_ignore_idx(self, mask, class_id):
|
| 71 |
+
boundary = (mask / 255).floor()
|
| 72 |
+
mask[mask != class_id + 1] = 0
|
| 73 |
+
mask[mask == class_id + 1] = 1
|
| 74 |
+
|
| 75 |
+
return mask, boundary
|
| 76 |
+
|
| 77 |
+
def load_frame(self, query_name, support_names):
|
| 78 |
+
query_img = self.read_img(query_name)
|
| 79 |
+
query_mask = self.read_mask(query_name)
|
| 80 |
+
support_imgs = [self.read_img(name) for name in support_names]
|
| 81 |
+
support_masks = [self.read_mask(name) for name in support_names]
|
| 82 |
+
|
| 83 |
+
org_qry_imsize = query_img.size
|
| 84 |
+
|
| 85 |
+
return query_img, query_mask, support_imgs, support_masks, org_qry_imsize
|
| 86 |
+
|
| 87 |
+
def read_mask(self, img_name):
|
| 88 |
+
r"""Return segmentation mask in PIL Image"""
|
| 89 |
+
mask = torch.tensor(np.array(Image.open(os.path.join(self.ann_path, img_name) + '.png')))
|
| 90 |
+
return mask
|
| 91 |
+
|
| 92 |
+
def read_img(self, img_name):
|
| 93 |
+
r"""Return RGB image in PIL Image"""
|
| 94 |
+
return Image.open(os.path.join(self.img_path, img_name) + '.jpg')
|
| 95 |
+
|
| 96 |
+
def sample_episode(self, idx):
|
| 97 |
+
query_name, class_sample = self.img_metadata[idx]
|
| 98 |
+
|
| 99 |
+
support_names = []
|
| 100 |
+
while True: # keep sampling support set if query == support
|
| 101 |
+
support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
|
| 102 |
+
if query_name != support_name: support_names.append(support_name)
|
| 103 |
+
if len(support_names) == self.shot: break
|
| 104 |
+
|
| 105 |
+
return query_name, support_names, class_sample
|
| 106 |
+
|
| 107 |
+
def build_class_ids(self):
|
| 108 |
+
nclass_trn = self.nclass // self.nfolds
|
| 109 |
+
class_ids_val = [self.fold * nclass_trn + i for i in range(nclass_trn)]
|
| 110 |
+
class_ids_trn = [x for x in range(self.nclass) if x not in class_ids_val]
|
| 111 |
+
|
| 112 |
+
if self.split == 'trn':
|
| 113 |
+
return class_ids_trn
|
| 114 |
+
else:
|
| 115 |
+
return class_ids_val
|
| 116 |
+
|
| 117 |
+
def build_img_metadata(self):
|
| 118 |
+
|
| 119 |
+
def read_metadata(split, fold_id):
|
| 120 |
+
fold_n_metadata = os.path.join('data/splits/pascal/%s/fold%d.txt' % (split, fold_id))
|
| 121 |
+
with open(fold_n_metadata, 'r') as f:
|
| 122 |
+
fold_n_metadata = f.read().split('\n')[:-1]
|
| 123 |
+
fold_n_metadata = [[data.split('__')[0], int(data.split('__')[1]) - 1] for data in fold_n_metadata]
|
| 124 |
+
return fold_n_metadata
|
| 125 |
+
|
| 126 |
+
img_metadata = []
|
| 127 |
+
if self.split == 'trn': # For training, read image-metadata of "the other" folds
|
| 128 |
+
for fold_id in range(self.nfolds):
|
| 129 |
+
if fold_id == self.fold: # Skip validation fold
|
| 130 |
+
continue
|
| 131 |
+
img_metadata += read_metadata(self.split, fold_id)
|
| 132 |
+
elif self.split == 'val': # For validation, read image-metadata of "current" fold
|
| 133 |
+
img_metadata = read_metadata(self.split, self.fold)
|
| 134 |
+
else:
|
| 135 |
+
raise Exception('Undefined split %s: ' % self.split)
|
| 136 |
+
|
| 137 |
+
print('Total (%s) images are : %d' % (self.split, len(img_metadata)))
|
| 138 |
+
|
| 139 |
+
return img_metadata
|
| 140 |
+
|
| 141 |
+
def build_img_metadata_classwise(self):
|
| 142 |
+
img_metadata_classwise = {}
|
| 143 |
+
for class_id in range(self.nclass):
|
| 144 |
+
img_metadata_classwise[class_id] = []
|
| 145 |
+
|
| 146 |
+
for img_name, img_class in self.img_metadata:
|
| 147 |
+
img_metadata_classwise[img_class] += [img_name]
|
| 148 |
+
return img_metadata_classwise
|
data/splits/fss/test.txt
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
bus
|
| 2 |
+
hotel_slipper
|
| 3 |
+
burj_al
|
| 4 |
+
reflex_camera
|
| 5 |
+
abe's_flyingfish
|
| 6 |
+
oiltank_car
|
| 7 |
+
doormat
|
| 8 |
+
fish_eagle
|
| 9 |
+
barber_shaver
|
| 10 |
+
motorbike
|
| 11 |
+
feather_clothes
|
| 12 |
+
wandering_albatross
|
| 13 |
+
rice_cooker
|
| 14 |
+
delta_wing
|
| 15 |
+
fish
|
| 16 |
+
nintendo_switch
|
| 17 |
+
bustard
|
| 18 |
+
diver
|
| 19 |
+
minicooper
|
| 20 |
+
cathedrale_paris
|
| 21 |
+
big_ben
|
| 22 |
+
combination_lock
|
| 23 |
+
villa_savoye
|
| 24 |
+
american_alligator
|
| 25 |
+
gym_ball
|
| 26 |
+
andean_condor
|
| 27 |
+
leggings
|
| 28 |
+
pyramid_cube
|
| 29 |
+
jet_aircraft
|
| 30 |
+
meatloaf
|
| 31 |
+
reel
|
| 32 |
+
swan
|
| 33 |
+
osprey
|
| 34 |
+
crt_screen
|
| 35 |
+
microscope
|
| 36 |
+
rubber_eraser
|
| 37 |
+
arrow
|
| 38 |
+
monkey
|
| 39 |
+
mitten
|
| 40 |
+
spiderman
|
| 41 |
+
parthenon
|
| 42 |
+
bat
|
| 43 |
+
chess_king
|
| 44 |
+
sulphur_butterfly
|
| 45 |
+
quail_egg
|
| 46 |
+
oriole
|
| 47 |
+
iron_man
|
| 48 |
+
wooden_boat
|
| 49 |
+
anise
|
| 50 |
+
steering_wheel
|
| 51 |
+
groenendael
|
| 52 |
+
dwarf_beans
|
| 53 |
+
pteropus
|
| 54 |
+
chalk_brush
|
| 55 |
+
bloodhound
|
| 56 |
+
moon
|
| 57 |
+
english_foxhound
|
| 58 |
+
boxing_gloves
|
| 59 |
+
peregine_falcon
|
| 60 |
+
pyraminx
|
| 61 |
+
cicada
|
| 62 |
+
screw
|
| 63 |
+
shower_curtain
|
| 64 |
+
tredmill
|
| 65 |
+
bulb
|
| 66 |
+
bell_pepper
|
| 67 |
+
lemur_catta
|
| 68 |
+
doughnut
|
| 69 |
+
twin_tower
|
| 70 |
+
astronaut
|
| 71 |
+
nintendo_3ds
|
| 72 |
+
fennel_bulb
|
| 73 |
+
indri
|
| 74 |
+
captain_america_shield
|
| 75 |
+
kunai
|
| 76 |
+
broom
|
| 77 |
+
iphone
|
| 78 |
+
earphone1
|
| 79 |
+
flying_squirrel
|
| 80 |
+
onion
|
| 81 |
+
vinyl
|
| 82 |
+
sydney_opera_house
|
| 83 |
+
oyster
|
| 84 |
+
harmonica
|
| 85 |
+
egg
|
| 86 |
+
breast_pump
|
| 87 |
+
guitar
|
| 88 |
+
potato_chips
|
| 89 |
+
tunnel
|
| 90 |
+
cuckoo
|
| 91 |
+
rubick_cube
|
| 92 |
+
plastic_bag
|
| 93 |
+
phonograph
|
| 94 |
+
net_surface_shoes
|
| 95 |
+
goldfinch
|
| 96 |
+
ipad
|
| 97 |
+
mite_predator
|
| 98 |
+
coffee_mug
|
| 99 |
+
golden_plover
|
| 100 |
+
f1_racing
|
| 101 |
+
lapwing
|
| 102 |
+
nintendo_gba
|
| 103 |
+
pizza
|
| 104 |
+
rally_car
|
| 105 |
+
drilling_platform
|
| 106 |
+
cd
|
| 107 |
+
fly
|
| 108 |
+
magpie_bird
|
| 109 |
+
leaf_fan
|
| 110 |
+
little_blue_heron
|
| 111 |
+
carriage
|
| 112 |
+
moist_proof_pad
|
| 113 |
+
flying_snakes
|
| 114 |
+
dart_target
|
| 115 |
+
warehouse_tray
|
| 116 |
+
nintendo_wiiu
|
| 117 |
+
chiffon_cake
|
| 118 |
+
bath_ball
|
| 119 |
+
manatee
|
| 120 |
+
cloud
|
| 121 |
+
marimba
|
| 122 |
+
eagle
|
| 123 |
+
ruler
|
| 124 |
+
soymilk_machine
|
| 125 |
+
sled
|
| 126 |
+
seagull
|
| 127 |
+
glider_flyingfish
|
| 128 |
+
doublebus
|
| 129 |
+
transport_helicopter
|
| 130 |
+
window_screen
|
| 131 |
+
truss_bridge
|
| 132 |
+
wasp
|
| 133 |
+
snowman
|
| 134 |
+
poached_egg
|
| 135 |
+
strawberry
|
| 136 |
+
spinach
|
| 137 |
+
earphone2
|
| 138 |
+
downy_pitch
|
| 139 |
+
taj_mahal
|
| 140 |
+
rocking_chair
|
| 141 |
+
cablestayed_bridge
|
| 142 |
+
sealion
|
| 143 |
+
banana_boat
|
| 144 |
+
pheasant
|
| 145 |
+
stone_lion
|
| 146 |
+
electronic_stove
|
| 147 |
+
fox
|
| 148 |
+
iguana
|
| 149 |
+
rugby_ball
|
| 150 |
+
hang_glider
|
| 151 |
+
water_buffalo
|
| 152 |
+
lotus
|
| 153 |
+
paper_plane
|
| 154 |
+
missile
|
| 155 |
+
flamingo
|
| 156 |
+
american_chamelon
|
| 157 |
+
kart
|
| 158 |
+
chinese_knot
|
| 159 |
+
cabbage_butterfly
|
| 160 |
+
key
|
| 161 |
+
church
|
| 162 |
+
tiltrotor
|
| 163 |
+
helicopter
|
| 164 |
+
french_fries
|
| 165 |
+
water_heater
|
| 166 |
+
snow_leopard
|
| 167 |
+
goblet
|
| 168 |
+
fan
|
| 169 |
+
snowplow
|
| 170 |
+
leafhopper
|
| 171 |
+
pspgo
|
| 172 |
+
black_bear
|
| 173 |
+
quail
|
| 174 |
+
condor
|
| 175 |
+
chandelier
|
| 176 |
+
hair_razor
|
| 177 |
+
white_wolf
|
| 178 |
+
toaster
|
| 179 |
+
pidan
|
| 180 |
+
pyramid
|
| 181 |
+
chicken_leg
|
| 182 |
+
letter_opener
|
| 183 |
+
apple_icon
|
| 184 |
+
porcupine
|
| 185 |
+
chicken
|
| 186 |
+
stingray
|
| 187 |
+
warplane
|
| 188 |
+
windmill
|
| 189 |
+
bamboo_slip
|
| 190 |
+
wig
|
| 191 |
+
flying_geckos
|
| 192 |
+
stonechat
|
| 193 |
+
haddock
|
| 194 |
+
australian_terrier
|
| 195 |
+
hover_board
|
| 196 |
+
siamang
|
| 197 |
+
canton_tower
|
| 198 |
+
santa_sledge
|
| 199 |
+
arch_bridge
|
| 200 |
+
curlew
|
| 201 |
+
sushi
|
| 202 |
+
beet_root
|
| 203 |
+
accordion
|
| 204 |
+
leaf_egg
|
| 205 |
+
stealth_aircraft
|
| 206 |
+
stork
|
| 207 |
+
bucket
|
| 208 |
+
hawk
|
| 209 |
+
chess_queen
|
| 210 |
+
ocarina
|
| 211 |
+
knife
|
| 212 |
+
whippet
|
| 213 |
+
cantilever_bridge
|
| 214 |
+
may_bug
|
| 215 |
+
wagtail
|
| 216 |
+
leather_shoes
|
| 217 |
+
wheelchair
|
| 218 |
+
shumai
|
| 219 |
+
speedboat
|
| 220 |
+
vacuum_cup
|
| 221 |
+
chess_knight
|
| 222 |
+
pumpkin_pie
|
| 223 |
+
wooden_spoon
|
| 224 |
+
bamboo_dragonfly
|
| 225 |
+
ganeva_chair
|
| 226 |
+
soap
|
| 227 |
+
clearwing_flyingfish
|
| 228 |
+
pencil_sharpener1
|
| 229 |
+
cricket
|
| 230 |
+
photocopier
|
| 231 |
+
nintendo_sp
|
| 232 |
+
samarra_mosque
|
| 233 |
+
clam
|
| 234 |
+
charge_battery
|
| 235 |
+
flying_frog
|
| 236 |
+
ferrari911
|
| 237 |
+
polo_shirt
|
| 238 |
+
echidna
|
| 239 |
+
coin
|
| 240 |
+
tower_pisa
|
data/splits/fss/trn.txt
ADDED
|
@@ -0,0 +1,520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fountain
|
| 2 |
+
taxi
|
| 3 |
+
assult_rifle
|
| 4 |
+
radio
|
| 5 |
+
comb
|
| 6 |
+
box_turtle
|
| 7 |
+
igloo
|
| 8 |
+
head_cabbage
|
| 9 |
+
cottontail
|
| 10 |
+
coho
|
| 11 |
+
ashtray
|
| 12 |
+
joystick
|
| 13 |
+
sleeping_bag
|
| 14 |
+
jackfruit
|
| 15 |
+
trailer_truck
|
| 16 |
+
shower_cap
|
| 17 |
+
ibex
|
| 18 |
+
kinguin
|
| 19 |
+
squirrel
|
| 20 |
+
ac_wall
|
| 21 |
+
sidewinder
|
| 22 |
+
remote_control
|
| 23 |
+
marshmallow
|
| 24 |
+
bolotie
|
| 25 |
+
polar_bear
|
| 26 |
+
rock_beauty
|
| 27 |
+
tokyo_tower
|
| 28 |
+
wafer
|
| 29 |
+
red_bayberry
|
| 30 |
+
electronic_toothbrush
|
| 31 |
+
hartebeest
|
| 32 |
+
cassette
|
| 33 |
+
oil_filter
|
| 34 |
+
bomb
|
| 35 |
+
walnut
|
| 36 |
+
toilet_tissue
|
| 37 |
+
memory_stick
|
| 38 |
+
wild_boar
|
| 39 |
+
cableways
|
| 40 |
+
chihuahua
|
| 41 |
+
envelope
|
| 42 |
+
bison
|
| 43 |
+
poker
|
| 44 |
+
pubg_lvl3helmet
|
| 45 |
+
indian_cobra
|
| 46 |
+
staffordshire
|
| 47 |
+
park_bench
|
| 48 |
+
wombat
|
| 49 |
+
black_grouse
|
| 50 |
+
submarine
|
| 51 |
+
washer
|
| 52 |
+
agama
|
| 53 |
+
coyote
|
| 54 |
+
feeder
|
| 55 |
+
sarong
|
| 56 |
+
buckingham_palace
|
| 57 |
+
frog
|
| 58 |
+
steam_locomotive
|
| 59 |
+
acorn
|
| 60 |
+
german_pointer
|
| 61 |
+
obelisk
|
| 62 |
+
polecat
|
| 63 |
+
black_swan
|
| 64 |
+
butterfly
|
| 65 |
+
mountain_tent
|
| 66 |
+
gorilla
|
| 67 |
+
sloth_bear
|
| 68 |
+
aubergine
|
| 69 |
+
stinkhorn
|
| 70 |
+
stole
|
| 71 |
+
owl
|
| 72 |
+
mooli
|
| 73 |
+
pool_table
|
| 74 |
+
collar
|
| 75 |
+
lhasa_apso
|
| 76 |
+
ambulance
|
| 77 |
+
spade
|
| 78 |
+
pufferfish
|
| 79 |
+
paint_brush
|
| 80 |
+
lark
|
| 81 |
+
golf_ball
|
| 82 |
+
hock
|
| 83 |
+
fork
|
| 84 |
+
drake
|
| 85 |
+
bee_house
|
| 86 |
+
mooncake
|
| 87 |
+
wok
|
| 88 |
+
cocacola
|
| 89 |
+
water_bike
|
| 90 |
+
ladder
|
| 91 |
+
psp
|
| 92 |
+
bassoon
|
| 93 |
+
bear
|
| 94 |
+
border_terrier
|
| 95 |
+
petri_dish
|
| 96 |
+
pill_bottle
|
| 97 |
+
aircraft_carrier
|
| 98 |
+
panther
|
| 99 |
+
canoe
|
| 100 |
+
baseball_player
|
| 101 |
+
turtle
|
| 102 |
+
espresso
|
| 103 |
+
throne
|
| 104 |
+
cornet
|
| 105 |
+
coucal
|
| 106 |
+
eletrical_switch
|
| 107 |
+
bra
|
| 108 |
+
snail
|
| 109 |
+
backpack
|
| 110 |
+
jacamar
|
| 111 |
+
scroll_brush
|
| 112 |
+
gliding_lizard
|
| 113 |
+
raft
|
| 114 |
+
pinwheel
|
| 115 |
+
grasshopper
|
| 116 |
+
green_mamba
|
| 117 |
+
eft_newt
|
| 118 |
+
computer_mouse
|
| 119 |
+
vine_snake
|
| 120 |
+
recreational_vehicle
|
| 121 |
+
llama
|
| 122 |
+
meerkat
|
| 123 |
+
chainsaw
|
| 124 |
+
ferret
|
| 125 |
+
garbage_can
|
| 126 |
+
kangaroo
|
| 127 |
+
litchi
|
| 128 |
+
carbonara
|
| 129 |
+
housefinch
|
| 130 |
+
modem
|
| 131 |
+
tebby_cat
|
| 132 |
+
thatch
|
| 133 |
+
face_powder
|
| 134 |
+
tomb
|
| 135 |
+
apple
|
| 136 |
+
ladybug
|
| 137 |
+
killer_whale
|
| 138 |
+
rocket
|
| 139 |
+
airship
|
| 140 |
+
surfboard
|
| 141 |
+
lesser_panda
|
| 142 |
+
jordan_logo
|
| 143 |
+
banana
|
| 144 |
+
nail_scissor
|
| 145 |
+
swab
|
| 146 |
+
perfume
|
| 147 |
+
punching_bag
|
| 148 |
+
victor_icon
|
| 149 |
+
waffle_iron
|
| 150 |
+
trimaran
|
| 151 |
+
garlic
|
| 152 |
+
flute
|
| 153 |
+
langur
|
| 154 |
+
starfish
|
| 155 |
+
parallel_bars
|
| 156 |
+
dandie_dinmont
|
| 157 |
+
cosmetic_brush
|
| 158 |
+
screwdriver
|
| 159 |
+
brick_card
|
| 160 |
+
balance_weight
|
| 161 |
+
hornet
|
| 162 |
+
carton
|
| 163 |
+
toothpaste
|
| 164 |
+
bracelet
|
| 165 |
+
egg_tart
|
| 166 |
+
pencil_sharpener2
|
| 167 |
+
swimming_glasses
|
| 168 |
+
howler_monkey
|
| 169 |
+
camel
|
| 170 |
+
dragonfly
|
| 171 |
+
lionfish
|
| 172 |
+
convertible
|
| 173 |
+
mule
|
| 174 |
+
usb
|
| 175 |
+
conch
|
| 176 |
+
papaya
|
| 177 |
+
garbage_truck
|
| 178 |
+
dingo
|
| 179 |
+
radiator
|
| 180 |
+
solar_dish
|
| 181 |
+
streetcar
|
| 182 |
+
trilobite
|
| 183 |
+
bouzouki
|
| 184 |
+
ringlet_butterfly
|
| 185 |
+
space_shuttle
|
| 186 |
+
waffle
|
| 187 |
+
american_staffordshire
|
| 188 |
+
violin
|
| 189 |
+
flowerpot
|
| 190 |
+
forklift
|
| 191 |
+
manx
|
| 192 |
+
sundial
|
| 193 |
+
snowmobile
|
| 194 |
+
chickadee_bird
|
| 195 |
+
ruffed_grouse
|
| 196 |
+
brick_tea
|
| 197 |
+
paddle
|
| 198 |
+
stove
|
| 199 |
+
carousel
|
| 200 |
+
spatula
|
| 201 |
+
beaker
|
| 202 |
+
gas_pump
|
| 203 |
+
lawn_mower
|
| 204 |
+
speaker
|
| 205 |
+
tank
|
| 206 |
+
tresher
|
| 207 |
+
kappa_logo
|
| 208 |
+
hare
|
| 209 |
+
tennis_racket
|
| 210 |
+
shopping_cart
|
| 211 |
+
thimble
|
| 212 |
+
tractor
|
| 213 |
+
anemone_fish
|
| 214 |
+
trolleybus
|
| 215 |
+
steak
|
| 216 |
+
capuchin
|
| 217 |
+
red_breasted_merganser
|
| 218 |
+
golden_retriever
|
| 219 |
+
light_tube
|
| 220 |
+
flatworm
|
| 221 |
+
melon_seed
|
| 222 |
+
digital_watch
|
| 223 |
+
jacko_lantern
|
| 224 |
+
brown_bear
|
| 225 |
+
cairn
|
| 226 |
+
mushroom
|
| 227 |
+
chalk
|
| 228 |
+
skull
|
| 229 |
+
stapler
|
| 230 |
+
potato
|
| 231 |
+
telescope
|
| 232 |
+
proboscis
|
| 233 |
+
microphone
|
| 234 |
+
torii
|
| 235 |
+
baseball_bat
|
| 236 |
+
dhole
|
| 237 |
+
excavator
|
| 238 |
+
fig
|
| 239 |
+
snake
|
| 240 |
+
bradypod
|
| 241 |
+
pepitas
|
| 242 |
+
prairie_chicken
|
| 243 |
+
scorpion
|
| 244 |
+
shotgun
|
| 245 |
+
bottle_cap
|
| 246 |
+
file_cabinet
|
| 247 |
+
grey_whale
|
| 248 |
+
one-armed_bandit
|
| 249 |
+
banded_gecko
|
| 250 |
+
flying_disc
|
| 251 |
+
croissant
|
| 252 |
+
toothbrush
|
| 253 |
+
miniskirt
|
| 254 |
+
pokermon_ball
|
| 255 |
+
gazelle
|
| 256 |
+
grey_fox
|
| 257 |
+
esport_chair
|
| 258 |
+
necklace
|
| 259 |
+
ptarmigan
|
| 260 |
+
watermelon
|
| 261 |
+
besom
|
| 262 |
+
pomelo
|
| 263 |
+
radio_telescope
|
| 264 |
+
studio_couch
|
| 265 |
+
black_stork
|
| 266 |
+
vestment
|
| 267 |
+
koala
|
| 268 |
+
brambling
|
| 269 |
+
muscle_car
|
| 270 |
+
window_shade
|
| 271 |
+
space_heater
|
| 272 |
+
sunglasses
|
| 273 |
+
motor_scooter
|
| 274 |
+
ladyfinger
|
| 275 |
+
pencil_box
|
| 276 |
+
titi_monkey
|
| 277 |
+
chicken_wings
|
| 278 |
+
mount_fuji
|
| 279 |
+
giant_panda
|
| 280 |
+
dart
|
| 281 |
+
fire_engine
|
| 282 |
+
running_shoe
|
| 283 |
+
dumbbell
|
| 284 |
+
donkey
|
| 285 |
+
loafer
|
| 286 |
+
hard_disk
|
| 287 |
+
globe
|
| 288 |
+
lifeboat
|
| 289 |
+
medical_kit
|
| 290 |
+
brain_coral
|
| 291 |
+
paper_towel
|
| 292 |
+
dugong
|
| 293 |
+
seatbelt
|
| 294 |
+
skunk
|
| 295 |
+
military_vest
|
| 296 |
+
cocktail_shaker
|
| 297 |
+
zucchini
|
| 298 |
+
quad_drone
|
| 299 |
+
ocicat
|
| 300 |
+
shih-tzu
|
| 301 |
+
teapot
|
| 302 |
+
tile_roof
|
| 303 |
+
cheese_burger
|
| 304 |
+
handshower
|
| 305 |
+
red_wolf
|
| 306 |
+
stop_sign
|
| 307 |
+
mouse
|
| 308 |
+
battery
|
| 309 |
+
adidas_logo2
|
| 310 |
+
earplug
|
| 311 |
+
hummingbird
|
| 312 |
+
brush_pen
|
| 313 |
+
pistachio
|
| 314 |
+
hamster
|
| 315 |
+
air_strip
|
| 316 |
+
indian_elephant
|
| 317 |
+
otter
|
| 318 |
+
cucumber
|
| 319 |
+
scabbard
|
| 320 |
+
hawthorn
|
| 321 |
+
bullet_train
|
| 322 |
+
leopard
|
| 323 |
+
whale
|
| 324 |
+
cream
|
| 325 |
+
chinese_date
|
| 326 |
+
jellyfish
|
| 327 |
+
lobster
|
| 328 |
+
skua
|
| 329 |
+
single_log
|
| 330 |
+
chicory
|
| 331 |
+
bagel
|
| 332 |
+
beacon
|
| 333 |
+
pingpong_racket
|
| 334 |
+
spoon
|
| 335 |
+
yurt
|
| 336 |
+
wallaby
|
| 337 |
+
egret
|
| 338 |
+
christmas_stocking
|
| 339 |
+
mcdonald_uncle
|
| 340 |
+
wrench
|
| 341 |
+
spark_plug
|
| 342 |
+
triceratops
|
| 343 |
+
wall_clock
|
| 344 |
+
jinrikisha
|
| 345 |
+
pickup
|
| 346 |
+
rhinoceros
|
| 347 |
+
swimming_trunk
|
| 348 |
+
band-aid
|
| 349 |
+
spotted_salamander
|
| 350 |
+
leeks
|
| 351 |
+
marmot
|
| 352 |
+
warthog
|
| 353 |
+
cello
|
| 354 |
+
stool
|
| 355 |
+
chest
|
| 356 |
+
toilet_plunger
|
| 357 |
+
wardrobe
|
| 358 |
+
cannon
|
| 359 |
+
adidas_logo1
|
| 360 |
+
drumstick
|
| 361 |
+
lady_slipper
|
| 362 |
+
puma_logo
|
| 363 |
+
great_wall
|
| 364 |
+
white_shark
|
| 365 |
+
witch_hat
|
| 366 |
+
vending_machine
|
| 367 |
+
wreck
|
| 368 |
+
chopsticks
|
| 369 |
+
garfish
|
| 370 |
+
african_elephant
|
| 371 |
+
children_slide
|
| 372 |
+
hornbill
|
| 373 |
+
zebra
|
| 374 |
+
boa_constrictor
|
| 375 |
+
armour
|
| 376 |
+
pineapple
|
| 377 |
+
angora
|
| 378 |
+
brick
|
| 379 |
+
car_wheel
|
| 380 |
+
wallet
|
| 381 |
+
boston_bull
|
| 382 |
+
hyena
|
| 383 |
+
lynx
|
| 384 |
+
crash_helmet
|
| 385 |
+
terrapin_turtle
|
| 386 |
+
persian_cat
|
| 387 |
+
shift_gear
|
| 388 |
+
cactus_ball
|
| 389 |
+
fur_coat
|
| 390 |
+
plate
|
| 391 |
+
pen
|
| 392 |
+
okra
|
| 393 |
+
mario
|
| 394 |
+
airedale
|
| 395 |
+
cowboy_hat
|
| 396 |
+
celery
|
| 397 |
+
macaque
|
| 398 |
+
candle
|
| 399 |
+
goose
|
| 400 |
+
raccoon
|
| 401 |
+
brasscica
|
| 402 |
+
almond
|
| 403 |
+
maotai_bottle
|
| 404 |
+
soccer_ball
|
| 405 |
+
sports_car
|
| 406 |
+
tobacco_pipe
|
| 407 |
+
water_polo
|
| 408 |
+
eggnog
|
| 409 |
+
hook
|
| 410 |
+
ostrich
|
| 411 |
+
patas
|
| 412 |
+
table_lamp
|
| 413 |
+
teddy
|
| 414 |
+
mongoose
|
| 415 |
+
spoonbill
|
| 416 |
+
redheart
|
| 417 |
+
crane
|
| 418 |
+
dinosaur
|
| 419 |
+
kitchen_knife
|
| 420 |
+
seal
|
| 421 |
+
baboon
|
| 422 |
+
golfcart
|
| 423 |
+
roller_coaster
|
| 424 |
+
avocado
|
| 425 |
+
birdhouse
|
| 426 |
+
yorkshire_terrier
|
| 427 |
+
saluki
|
| 428 |
+
basketball
|
| 429 |
+
buckler
|
| 430 |
+
harvester
|
| 431 |
+
afghan_hound
|
| 432 |
+
beam_bridge
|
| 433 |
+
guinea_pig
|
| 434 |
+
lorikeet
|
| 435 |
+
shakuhachi
|
| 436 |
+
motarboard
|
| 437 |
+
statue_liberty
|
| 438 |
+
police_car
|
| 439 |
+
sulphur_crested
|
| 440 |
+
gourd
|
| 441 |
+
sombrero
|
| 442 |
+
mailbox
|
| 443 |
+
adhensive_tape
|
| 444 |
+
night_snake
|
| 445 |
+
bushtit
|
| 446 |
+
mouthpiece
|
| 447 |
+
beaver
|
| 448 |
+
bathtub
|
| 449 |
+
printer
|
| 450 |
+
cumquat
|
| 451 |
+
orange
|
| 452 |
+
cleaver
|
| 453 |
+
quill_pen
|
| 454 |
+
panpipe
|
| 455 |
+
diamond
|
| 456 |
+
gypsy_moth
|
| 457 |
+
cauliflower
|
| 458 |
+
lampshade
|
| 459 |
+
cougar
|
| 460 |
+
traffic_light
|
| 461 |
+
briefcase
|
| 462 |
+
ballpoint
|
| 463 |
+
african_grey
|
| 464 |
+
kremlin
|
| 465 |
+
barometer
|
| 466 |
+
peacock
|
| 467 |
+
paper_crane
|
| 468 |
+
sunscreen
|
| 469 |
+
tofu
|
| 470 |
+
bedlington_terrier
|
| 471 |
+
snowball
|
| 472 |
+
carrot
|
| 473 |
+
tiger
|
| 474 |
+
mink
|
| 475 |
+
cristo_redentor
|
| 476 |
+
ladle
|
| 477 |
+
keyboard
|
| 478 |
+
maraca
|
| 479 |
+
monitor
|
| 480 |
+
water_snake
|
| 481 |
+
can_opener
|
| 482 |
+
mud_turtle
|
| 483 |
+
bald_eagle
|
| 484 |
+
carp
|
| 485 |
+
cn_tower
|
| 486 |
+
egyptian_cat
|
| 487 |
+
hen_of_the_woods
|
| 488 |
+
measuring_cup
|
| 489 |
+
roller_skate
|
| 490 |
+
kite
|
| 491 |
+
sandwich_cookies
|
| 492 |
+
sandwich
|
| 493 |
+
persimmon
|
| 494 |
+
chess_bishop
|
| 495 |
+
coffin
|
| 496 |
+
ruddy_turnstone
|
| 497 |
+
prayer_rug
|
| 498 |
+
rain_barrel
|
| 499 |
+
neck_brace
|
| 500 |
+
nematode
|
| 501 |
+
rosehip
|
| 502 |
+
dutch_oven
|
| 503 |
+
goldfish
|
| 504 |
+
blossom_card
|
| 505 |
+
dough
|
| 506 |
+
trench_coat
|
| 507 |
+
sponge
|
| 508 |
+
stupa
|
| 509 |
+
wash_basin
|
| 510 |
+
electric_fan
|
| 511 |
+
spring_scroll
|
| 512 |
+
potted_plant
|
| 513 |
+
sparrow
|
| 514 |
+
car_mirror
|
| 515 |
+
gecko
|
| 516 |
+
diaper
|
| 517 |
+
leatherback_turtle
|
| 518 |
+
strainer
|
| 519 |
+
guacamole
|
| 520 |
+
microwave
|
data/splits/fss/val.txt
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
handcuff
|
| 2 |
+
mortar
|
| 3 |
+
matchstick
|
| 4 |
+
wine_bottle
|
| 5 |
+
dowitcher
|
| 6 |
+
triumphal_arch
|
| 7 |
+
gyromitra
|
| 8 |
+
hatchet
|
| 9 |
+
airliner
|
| 10 |
+
broccoli
|
| 11 |
+
olive
|
| 12 |
+
pubg_lvl3backpack
|
| 13 |
+
calculator
|
| 14 |
+
toucan
|
| 15 |
+
shovel
|
| 16 |
+
sewing_machine
|
| 17 |
+
icecream
|
| 18 |
+
woodpecker
|
| 19 |
+
pig
|
| 20 |
+
relay_stick
|
| 21 |
+
mcdonald_sign
|
| 22 |
+
cpu
|
| 23 |
+
peanut
|
| 24 |
+
pumpkin
|
| 25 |
+
sturgeon
|
| 26 |
+
hammer
|
| 27 |
+
hami_melon
|
| 28 |
+
squirrel_monkey
|
| 29 |
+
shuriken
|
| 30 |
+
power_drill
|
| 31 |
+
pingpong_ball
|
| 32 |
+
crocodile
|
| 33 |
+
carambola
|
| 34 |
+
monarch_butterfly
|
| 35 |
+
drum
|
| 36 |
+
water_tower
|
| 37 |
+
panda
|
| 38 |
+
toilet_brush
|
| 39 |
+
pay_phone
|
| 40 |
+
yonex_icon
|
| 41 |
+
cricketball
|
| 42 |
+
revolver
|
| 43 |
+
chimpanzee
|
| 44 |
+
crab
|
| 45 |
+
corn
|
| 46 |
+
baseball
|
| 47 |
+
rabbit
|
| 48 |
+
croquet_ball
|
| 49 |
+
artichoke
|
| 50 |
+
abacus
|
| 51 |
+
harp
|
| 52 |
+
bell
|
| 53 |
+
gas_tank
|
| 54 |
+
scissors
|
| 55 |
+
vase
|
| 56 |
+
upright_piano
|
| 57 |
+
typewriter
|
| 58 |
+
bittern
|
| 59 |
+
impala
|
| 60 |
+
tray
|
| 61 |
+
fire_hydrant
|
| 62 |
+
beer_bottle
|
| 63 |
+
sock
|
| 64 |
+
soup_bowl
|
| 65 |
+
spider
|
| 66 |
+
cherry
|
| 67 |
+
macaw
|
| 68 |
+
toilet_seat
|
| 69 |
+
fire_balloon
|
| 70 |
+
french_ball
|
| 71 |
+
fox_squirrel
|
| 72 |
+
volleyball
|
| 73 |
+
cornmeal
|
| 74 |
+
folding_chair
|
| 75 |
+
pubg_airdrop
|
| 76 |
+
beagle
|
| 77 |
+
skateboard
|
| 78 |
+
narcissus
|
| 79 |
+
whiptail
|
| 80 |
+
cup
|
| 81 |
+
arabian_camel
|
| 82 |
+
badger
|
| 83 |
+
stopwatch
|
| 84 |
+
ab_wheel
|
| 85 |
+
ox
|
| 86 |
+
lettuce
|
| 87 |
+
monocycle
|
| 88 |
+
redshank
|
| 89 |
+
vulture
|
| 90 |
+
whistle
|
| 91 |
+
smoothing_iron
|
| 92 |
+
mashed_potato
|
| 93 |
+
conveyor
|
| 94 |
+
yoga_pad
|
| 95 |
+
tow_truck
|
| 96 |
+
siamese_cat
|
| 97 |
+
cigar
|
| 98 |
+
white_stork
|
| 99 |
+
sniper_rifle
|
| 100 |
+
stretcher
|
| 101 |
+
tulip
|
| 102 |
+
handkerchief
|
| 103 |
+
basset
|
| 104 |
+
iceberg
|
| 105 |
+
gibbon
|
| 106 |
+
lacewing
|
| 107 |
+
thrush
|
| 108 |
+
cheetah
|
| 109 |
+
bighorn_sheep
|
| 110 |
+
espresso_maker
|
| 111 |
+
pretzel
|
| 112 |
+
english_setter
|
| 113 |
+
sandbar
|
| 114 |
+
cheese
|
| 115 |
+
daisy
|
| 116 |
+
arctic_fox
|
| 117 |
+
briard
|
| 118 |
+
colubus
|
| 119 |
+
balance_beam
|
| 120 |
+
coffeepot
|
| 121 |
+
soap_dispenser
|
| 122 |
+
yawl
|
| 123 |
+
consomme
|
| 124 |
+
parking_meter
|
| 125 |
+
cactus
|
| 126 |
+
turnstile
|
| 127 |
+
taro
|
| 128 |
+
fire_screen
|
| 129 |
+
digital_clock
|
| 130 |
+
rose
|
| 131 |
+
pomegranate
|
| 132 |
+
bee_eater
|
| 133 |
+
schooner
|
| 134 |
+
ski_mask
|
| 135 |
+
jay_bird
|
| 136 |
+
plaice
|
| 137 |
+
red_fox
|
| 138 |
+
syringe
|
| 139 |
+
camomile
|
| 140 |
+
pickelhaube
|
| 141 |
+
blenheim_spaniel
|
| 142 |
+
pear
|
| 143 |
+
parachute
|
| 144 |
+
common_newt
|
| 145 |
+
bowtie
|
| 146 |
+
cigarette
|
| 147 |
+
oscilloscope
|
| 148 |
+
laptop
|
| 149 |
+
african_crocodile
|
| 150 |
+
apron
|
| 151 |
+
coconut
|
| 152 |
+
sandal
|
| 153 |
+
kwanyin
|
| 154 |
+
lion
|
| 155 |
+
eel
|
| 156 |
+
balloon
|
| 157 |
+
crepe
|
| 158 |
+
armadillo
|
| 159 |
+
kazoo
|
| 160 |
+
lemon
|
| 161 |
+
spider_monkey
|
| 162 |
+
tape_player
|
| 163 |
+
ipod
|
| 164 |
+
bee
|
| 165 |
+
sea_cucumber
|
| 166 |
+
suitcase
|
| 167 |
+
television
|
| 168 |
+
pillow
|
| 169 |
+
banjo
|
| 170 |
+
rock_snake
|
| 171 |
+
partridge
|
| 172 |
+
platypus
|
| 173 |
+
lycaenid_butterfly
|
| 174 |
+
pinecone
|
| 175 |
+
conversion_plug
|
| 176 |
+
wolf
|
| 177 |
+
frying_pan
|
| 178 |
+
timber_wolf
|
| 179 |
+
bluetick
|
| 180 |
+
crayon
|
| 181 |
+
giant_schnauzer
|
| 182 |
+
orang
|
| 183 |
+
scarerow
|
| 184 |
+
kobe_logo
|
| 185 |
+
loguat
|
| 186 |
+
saxophone
|
| 187 |
+
ceiling_fan
|
| 188 |
+
cardoon
|
| 189 |
+
equestrian_helmet
|
| 190 |
+
louvre_pyramid
|
| 191 |
+
hotdog
|
| 192 |
+
ironing_board
|
| 193 |
+
razor
|
| 194 |
+
nagoya_castle
|
| 195 |
+
loggerhead_turtle
|
| 196 |
+
lipstick
|
| 197 |
+
cradle
|
| 198 |
+
strongbox
|
| 199 |
+
raven
|
| 200 |
+
kit_fox
|
| 201 |
+
albatross
|
| 202 |
+
flat-coated_retriever
|
| 203 |
+
beer_glass
|
| 204 |
+
ice_lolly
|
| 205 |
+
sungnyemun
|
| 206 |
+
totem_pole
|
| 207 |
+
vacuum
|
| 208 |
+
bolete
|
| 209 |
+
mango
|
| 210 |
+
ginger
|
| 211 |
+
weasel
|
| 212 |
+
cabbage
|
| 213 |
+
refrigerator
|
| 214 |
+
school_bus
|
| 215 |
+
hippo
|
| 216 |
+
tiger_cat
|
| 217 |
+
saltshaker
|
| 218 |
+
piano_keyboard
|
| 219 |
+
windsor_tie
|
| 220 |
+
sea_urchin
|
| 221 |
+
microsd
|
| 222 |
+
barbell
|
| 223 |
+
swim_ring
|
| 224 |
+
bulbul_bird
|
| 225 |
+
water_ouzel
|
| 226 |
+
ac_ground
|
| 227 |
+
sweatshirt
|
| 228 |
+
umbrella
|
| 229 |
+
hair_drier
|
| 230 |
+
hammerhead_shark
|
| 231 |
+
tomato
|
| 232 |
+
projector
|
| 233 |
+
cushion
|
| 234 |
+
dishwasher
|
| 235 |
+
three-toed_sloth
|
| 236 |
+
tiger_shark
|
| 237 |
+
har_gow
|
| 238 |
+
baby
|
| 239 |
+
thor's_hammer
|
| 240 |
+
nike_logo
|
data/suim.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r""" FSS-1000 few-shot semantic segmentation dataset """
|
| 2 |
+
import os
|
| 3 |
+
import glob
|
| 4 |
+
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torch
|
| 8 |
+
import PIL.Image as Image
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DatasetSUIM(Dataset):
|
| 13 |
+
def __init__(self, datapath, fold, transform, split, shot, num_val=600):
|
| 14 |
+
self.split = split
|
| 15 |
+
self.benchmark = 'suim'
|
| 16 |
+
self.shot = shot
|
| 17 |
+
self.num_val = num_val
|
| 18 |
+
|
| 19 |
+
self.base_path = os.path.join(datapath)
|
| 20 |
+
self.img_path = os.path.join(self.base_path, 'images')
|
| 21 |
+
self.ann_path = os.path.join(self.base_path, 'masks')
|
| 22 |
+
|
| 23 |
+
self.categories = ['FV','HD','PF','RI','RO','SR','WR']
|
| 24 |
+
|
| 25 |
+
self.class_ids = range(len(self.categories))
|
| 26 |
+
self.img_metadata_classwise, self.num_images = self.build_img_metadata_classwise()
|
| 27 |
+
|
| 28 |
+
self.transform = transform
|
| 29 |
+
|
| 30 |
+
def __len__(self):
|
| 31 |
+
# if it is the target domain, then also test on entire dataset
|
| 32 |
+
return self.num_images if self.split !='val' else self.num_val
|
| 33 |
+
|
| 34 |
+
def __getitem__(self, idx):
|
| 35 |
+
query_name, support_names, class_sample = self.sample_episode(idx)
|
| 36 |
+
query_img, query_mask, support_imgs, support_masks = self.load_frame(query_name, support_names)
|
| 37 |
+
|
| 38 |
+
query_img = self.transform(query_img)
|
| 39 |
+
query_mask = F.interpolate(query_mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze()
|
| 40 |
+
|
| 41 |
+
support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs])
|
| 42 |
+
|
| 43 |
+
support_masks_tmp = []
|
| 44 |
+
for smask in support_masks:
|
| 45 |
+
smask = F.interpolate(smask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze()
|
| 46 |
+
support_masks_tmp.append(smask)
|
| 47 |
+
support_masks = torch.stack(support_masks_tmp)
|
| 48 |
+
|
| 49 |
+
batch = {'query_img': query_img,
|
| 50 |
+
'query_mask': query_mask,
|
| 51 |
+
'support_set': (support_imgs, support_masks),
|
| 52 |
+
'support_classes': torch.tensor([class_sample]), # adapt to Nway
|
| 53 |
+
|
| 54 |
+
'query_name': query_name, # REMOVE
|
| 55 |
+
'support_imgs': support_imgs, # REMOVE
|
| 56 |
+
'support_masks': support_masks, # REMOVE
|
| 57 |
+
'support_names': support_names, # REMOVE
|
| 58 |
+
'class_id': torch.tensor(class_sample)} # REMOVE
|
| 59 |
+
|
| 60 |
+
return batch
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def load_frame(self, query_mask_path, support_mask_paths):
|
| 64 |
+
def maskpath_to_imgpath(maskpath):
|
| 65 |
+
filename, imgext = maskpath.split('/')[-1].split('.')[0], '.jpg'
|
| 66 |
+
return os.path.join(self.img_path, filename) + imgext
|
| 67 |
+
|
| 68 |
+
query_img = Image.open(maskpath_to_imgpath(query_mask_path)).convert('RGB')
|
| 69 |
+
|
| 70 |
+
support_imgs = [Image.open(maskpath_to_imgpath(s_mask_path)).convert('RGB') for s_mask_path in support_mask_paths]
|
| 71 |
+
|
| 72 |
+
query_mask = self.read_mask(query_mask_path)
|
| 73 |
+
support_masks = [self.read_mask(s_mask_path) for s_mask_path in support_mask_paths]
|
| 74 |
+
|
| 75 |
+
return query_img, query_mask, support_imgs, support_masks
|
| 76 |
+
|
| 77 |
+
def read_mask(self, img_name):
|
| 78 |
+
mask = torch.tensor(np.array(Image.open(img_name).convert('L')))
|
| 79 |
+
mask[mask < 128] = 0
|
| 80 |
+
mask[mask >= 128] = 1
|
| 81 |
+
return mask
|
| 82 |
+
|
| 83 |
+
def sample_episode(self, idx):
|
| 84 |
+
class_id = idx % len(self.class_ids)
|
| 85 |
+
class_sample = self.categories[class_id]
|
| 86 |
+
|
| 87 |
+
query_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
|
| 88 |
+
support_names = []
|
| 89 |
+
while True: # keep sampling support set if query == support
|
| 90 |
+
support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
|
| 91 |
+
if query_name != support_name: support_names.append(support_name)
|
| 92 |
+
if len(support_names) == self.shot: break
|
| 93 |
+
|
| 94 |
+
return query_name, support_names, class_id
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# def build_img_metadata(self):
|
| 98 |
+
# img_metadata = []
|
| 99 |
+
# for cat in self.categories:
|
| 100 |
+
# os.path.join(self.base_path, cat)
|
| 101 |
+
# img_paths = sorted([path for path in glob.glob('%s/*' % os.path.join(self.base_path, cat, 'test', 'origin'))])
|
| 102 |
+
# for img_path in img_paths:
|
| 103 |
+
# if os.path.basename(img_path).split('.')[1] == 'jpg':
|
| 104 |
+
# img_metadata.append(img_path)
|
| 105 |
+
# return img_metadata
|
| 106 |
+
|
| 107 |
+
def build_img_metadata_classwise(self):
|
| 108 |
+
num_images=0
|
| 109 |
+
img_metadata_classwise = {}
|
| 110 |
+
for cat in self.categories:
|
| 111 |
+
img_metadata_classwise[cat] = []
|
| 112 |
+
|
| 113 |
+
for cat in self.categories:
|
| 114 |
+
mask_paths = sorted([path for path in glob.glob('%s/*' % os.path.join(self.base_path, 'masks', cat))])
|
| 115 |
+
for mask_path in mask_paths:
|
| 116 |
+
if self.read_mask(mask_path).count_nonzero() > 0: #no empty masks
|
| 117 |
+
img_metadata_classwise[cat] += [mask_path]
|
| 118 |
+
num_images += 1
|
| 119 |
+
return img_metadata_classwise, num_images
|
eval/evaluation.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r""" Evaluate mask prediction """
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Evaluator:
|
| 6 |
+
r""" Computes intersection and union between prediction and ground-truth """
|
| 7 |
+
@classmethod
|
| 8 |
+
def initialize(cls):
|
| 9 |
+
cls.ignore_index = 255
|
| 10 |
+
|
| 11 |
+
@classmethod
|
| 12 |
+
def classify_prediction(cls, pred_mask, batch):
|
| 13 |
+
gt_mask = batch.get('query_mask')
|
| 14 |
+
|
| 15 |
+
# Apply ignore_index in PASCAL-5i masks (following evaluation scheme in PFE-Net (TPAMI 2020))
|
| 16 |
+
query_ignore_idx = batch.get('query_ignore_idx')
|
| 17 |
+
if query_ignore_idx is not None:
|
| 18 |
+
assert torch.logical_and(query_ignore_idx, gt_mask).sum() == 0
|
| 19 |
+
query_ignore_idx *= cls.ignore_index
|
| 20 |
+
gt_mask = gt_mask + query_ignore_idx
|
| 21 |
+
pred_mask[gt_mask == cls.ignore_index] = cls.ignore_index
|
| 22 |
+
|
| 23 |
+
# compute intersection and union of each episode in a batch
|
| 24 |
+
area_inter, area_pred, area_gt = [], [], []
|
| 25 |
+
for _pred_mask, _gt_mask in zip(pred_mask, gt_mask):
|
| 26 |
+
_inter = _pred_mask[_pred_mask == _gt_mask]
|
| 27 |
+
if _inter.size(0) == 0: # as torch.histc returns error if it gets empty tensor (pytorch 1.5.1)
|
| 28 |
+
_area_inter = torch.tensor([0, 0], device=_pred_mask.device)
|
| 29 |
+
else:
|
| 30 |
+
_area_inter = torch.histc(_inter, bins=2, min=0, max=1)
|
| 31 |
+
area_inter.append(_area_inter)
|
| 32 |
+
area_pred.append(torch.histc(_pred_mask, bins=2, min=0, max=1))
|
| 33 |
+
area_gt.append(torch.histc(_gt_mask, bins=2, min=0, max=1))
|
| 34 |
+
area_inter = torch.stack(area_inter).t()
|
| 35 |
+
area_pred = torch.stack(area_pred).t()
|
| 36 |
+
area_gt = torch.stack(area_gt).t()
|
| 37 |
+
area_union = area_pred + area_gt - area_inter
|
| 38 |
+
|
| 39 |
+
return area_inter, area_union
|
eval/logger.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r""" Logging during training/testing """
|
| 2 |
+
import datetime
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
from tensorboardX import SummaryWriter
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
class AverageMeter:
|
| 10 |
+
r""" Stores loss, evaluation results """
|
| 11 |
+
def __init__(self, dataset, device='cuda'):
|
| 12 |
+
self.benchmark = dataset.benchmark
|
| 13 |
+
if self.benchmark == 'pascal':
|
| 14 |
+
self.class_ids_interest = dataset.class_ids
|
| 15 |
+
self.class_ids_interest = torch.tensor(self.class_ids_interest).to(device)
|
| 16 |
+
self.nclass = 20
|
| 17 |
+
elif self.benchmark == 'fss':
|
| 18 |
+
self.class_ids_interest = dataset.class_ids
|
| 19 |
+
self.class_ids_interest = torch.tensor(self.class_ids_interest).to(device)
|
| 20 |
+
self.nclass = 1000
|
| 21 |
+
elif self.benchmark == 'deepglobe':
|
| 22 |
+
self.class_ids_interest = dataset.class_ids
|
| 23 |
+
self.class_ids_interest = torch.tensor(self.class_ids_interest).to(device)
|
| 24 |
+
self.nclass = 6
|
| 25 |
+
elif self.benchmark == 'isic':
|
| 26 |
+
self.class_ids_interest = dataset.class_ids
|
| 27 |
+
self.class_ids_interest = torch.tensor(self.class_ids_interest).to(device)
|
| 28 |
+
self.nclass = 3
|
| 29 |
+
elif self.benchmark == 'lung':
|
| 30 |
+
self.class_ids_interest = dataset.class_ids
|
| 31 |
+
self.class_ids_interest = torch.tensor(self.class_ids_interest).to(device)
|
| 32 |
+
self.nclass = 1
|
| 33 |
+
elif self.benchmark == 'suim':
|
| 34 |
+
self.class_ids_interest = dataset.class_ids
|
| 35 |
+
self.class_ids_interest = torch.tensor(self.class_ids_interest).to(device)
|
| 36 |
+
self.nclass = 7
|
| 37 |
+
else:
|
| 38 |
+
raise Exception('Unknown dataset: %s' % dataset)
|
| 39 |
+
|
| 40 |
+
self.intersection_buf = torch.zeros([2, self.nclass]).float().to(device)
|
| 41 |
+
self.union_buf = torch.zeros([2, self.nclass]).float().to(device)
|
| 42 |
+
self.ones = torch.ones_like(self.union_buf)
|
| 43 |
+
self.loss_buf = []
|
| 44 |
+
|
| 45 |
+
def update(self, inter_b, union_b, class_id, loss):
|
| 46 |
+
self.intersection_buf.index_add_(1, class_id, inter_b.float())
|
| 47 |
+
self.union_buf.index_add_(1, class_id, union_b.float())
|
| 48 |
+
if loss is None:
|
| 49 |
+
loss = torch.tensor(0.0)
|
| 50 |
+
self.loss_buf.append(loss)
|
| 51 |
+
|
| 52 |
+
def compute_iou(self):
|
| 53 |
+
iou = self.intersection_buf.float() / \
|
| 54 |
+
torch.max(torch.stack([self.union_buf, self.ones]), dim=0)[0]
|
| 55 |
+
iou = iou.index_select(1, self.class_ids_interest)
|
| 56 |
+
miou = iou[1].mean() * 100
|
| 57 |
+
|
| 58 |
+
fb_iou = (self.intersection_buf.index_select(1, self.class_ids_interest).sum(dim=1) /
|
| 59 |
+
self.union_buf.index_select(1, self.class_ids_interest).sum(dim=1)).mean() * 100
|
| 60 |
+
|
| 61 |
+
return miou, fb_iou
|
| 62 |
+
|
| 63 |
+
def write_result(self, split, epoch):
|
| 64 |
+
iou,fb_iou = self.compute_iou()
|
| 65 |
+
|
| 66 |
+
loss_buf = torch.stack(self.loss_buf)
|
| 67 |
+
msg = '\n*** %s ' % split
|
| 68 |
+
msg += '[@Epoch %02d] ' % epoch
|
| 69 |
+
msg += 'Avg L: %6.5f ' % loss_buf.mean()
|
| 70 |
+
msg += 'mIoU: %5.2f ' % iou
|
| 71 |
+
msg += 'FB-IoU: %5.2f ' % fb_iou
|
| 72 |
+
|
| 73 |
+
msg += '***\n'
|
| 74 |
+
Logger.info(msg)
|
| 75 |
+
|
| 76 |
+
def write_process(self, batch_idx, datalen, epoch, write_batch_idx=20):
|
| 77 |
+
if batch_idx % write_batch_idx == 0:
|
| 78 |
+
msg = '[Epoch: %02d] ' % epoch if epoch != -1 else ''
|
| 79 |
+
msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen)
|
| 80 |
+
iou,fb_iou = self.compute_iou()
|
| 81 |
+
if epoch != -1:
|
| 82 |
+
loss_buf = torch.stack(self.loss_buf)
|
| 83 |
+
msg += 'L: %6.5f ' % loss_buf[-1]
|
| 84 |
+
msg += 'Avg L: %6.5f ' % loss_buf.mean()
|
| 85 |
+
msg += 'mIoU: %5.2f | ' % iou
|
| 86 |
+
msg += 'FB-IoU: %5.2f' % fb_iou
|
| 87 |
+
Logger.info(msg)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class Logger:
|
| 91 |
+
r""" Writes evaluation results of training/testing """
|
| 92 |
+
@classmethod
|
| 93 |
+
def initialize(cls, args, training):
|
| 94 |
+
logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S')
|
| 95 |
+
logpath = args.logpath if training else args.logpath + '_TEST_' + logtime # changed lopath created for test
|
| 96 |
+
if logpath == '': logpath = logtime
|
| 97 |
+
|
| 98 |
+
cls.logpath = os.path.join('logs', logpath + '.log')
|
| 99 |
+
cls.benchmark = args.benchmark
|
| 100 |
+
print("logdir: ",cls.logpath)
|
| 101 |
+
os.makedirs(cls.logpath)
|
| 102 |
+
|
| 103 |
+
logging.basicConfig(filemode='w',
|
| 104 |
+
filename=os.path.join(cls.logpath, 'log.txt'),
|
| 105 |
+
level=logging.INFO,
|
| 106 |
+
format='%(message)s',
|
| 107 |
+
datefmt='%m-%d %H:%M:%S')
|
| 108 |
+
|
| 109 |
+
# Console log config
|
| 110 |
+
console = logging.StreamHandler()
|
| 111 |
+
console.setLevel(logging.INFO)
|
| 112 |
+
formatter = logging.Formatter('%(message)s')
|
| 113 |
+
console.setFormatter(formatter)
|
| 114 |
+
logging.getLogger('').addHandler(console)
|
| 115 |
+
|
| 116 |
+
# Tensorboard writer
|
| 117 |
+
cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs'))
|
| 118 |
+
|
| 119 |
+
# Log arguments
|
| 120 |
+
logging.info('\n:=========== Adapt Before Comparison - A New Perspective on Cross-Domain Few-Shot Segmentation ===========')
|
| 121 |
+
for arg_key in args.__dict__:
|
| 122 |
+
logging.info('| %20s: %-24s' % (arg_key, str(args.__dict__[arg_key])))
|
| 123 |
+
logging.info(':================================================\n')
|
| 124 |
+
|
| 125 |
+
@classmethod
|
| 126 |
+
def info(cls, msg):
|
| 127 |
+
r""" Writes log message to log.txt """
|
| 128 |
+
logging.info(msg)
|
| 129 |
+
|
| 130 |
+
@classmethod
|
| 131 |
+
def save_model_miou(cls, model, epoch, val_miou):
|
| 132 |
+
torch.save(model.state_dict(), os.path.join(cls.logpath, 'best_model.pt'))
|
| 133 |
+
cls.info('Model saved @%d w/ val. mIoU: %5.2f.\n' % (epoch, val_miou))
|
| 134 |
+
|
| 135 |
+
@classmethod
|
| 136 |
+
def log_params(cls, model):
|
| 137 |
+
backbone_param = 0
|
| 138 |
+
learner_param = 0
|
| 139 |
+
for k in model.state_dict().keys():
|
| 140 |
+
n_param = model.state_dict()[k].view(-1).size(0)
|
| 141 |
+
if k.split('.')[0] in 'backbone':
|
| 142 |
+
if k.split('.')[1] in ['classifier', 'fc']: # as fc layers are not used in HSNet
|
| 143 |
+
continue
|
| 144 |
+
backbone_param += n_param
|
| 145 |
+
else:
|
| 146 |
+
learner_param += n_param
|
| 147 |
+
Logger.info('Backbone # param.: %d' % backbone_param)
|
| 148 |
+
Logger.info('Learnable # param.: %d' % learner_param)
|
| 149 |
+
Logger.info('Total # param.: %d' % (backbone_param + learner_param))
|
main.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from core import runner
|
| 2 |
+
import torch
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
def parse_opts():
|
| 6 |
+
r"""arguments"""
|
| 7 |
+
parser = argparse.ArgumentParser(description='Adapt Before Comparison - A New Perspective on Cross-Domain Few-Shot Segmentation')
|
| 8 |
+
|
| 9 |
+
# common
|
| 10 |
+
parser.add_argument('--benchmark', type=str, default='lung', choices=['fss', 'deepglobe', 'lung', 'isic', 'fss', 'lung'])
|
| 11 |
+
parser.add_argument('--datapath', type=str)
|
| 12 |
+
parser.add_argument('--nshot', type=int, default=1)
|
| 13 |
+
|
| 14 |
+
args = parser.parse_args()
|
| 15 |
+
return args
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
if __name__ == '__main__':
|
| 19 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 20 |
+
args = parse_opts()
|
| 21 |
+
print(args)
|
| 22 |
+
runner.args.benchmark = args.benchmark
|
| 23 |
+
runner.args.datapath = args.datapath
|
| 24 |
+
runner.args.nshot = args.nshot
|
| 25 |
+
|
| 26 |
+
dataloader = runner.makeDataloader()
|
| 27 |
+
config = runner.makeConfig()
|
| 28 |
+
feat_maker = runner.makeFeatureMaker(dataloader.dataset, config, device=device)
|
| 29 |
+
average_meter = runner.AverageMeterWrapper(dataloader, device)
|
| 30 |
+
|
| 31 |
+
for idx, batch in enumerate(dataloader):
|
| 32 |
+
sseval = runner.SingleSampleEval(batch, feat_maker)
|
| 33 |
+
sseval.forward()
|
| 34 |
+
sseval.calc_metrics()
|
| 35 |
+
average_meter.update(sseval)
|
| 36 |
+
average_meter.write(idx)
|
| 37 |
+
print('Result m|FB:', average_meter.average_meter.compute_iou())
|
utils/commonutils.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r""" Helper functions """
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def fix_randseed(seed):
|
| 9 |
+
r""" Set random seeds for reproducibility """
|
| 10 |
+
if seed is None:
|
| 11 |
+
seed = int(random.random() * 1e5)
|
| 12 |
+
np.random.seed(seed)
|
| 13 |
+
torch.manual_seed(seed)
|
| 14 |
+
torch.cuda.manual_seed(seed)
|
| 15 |
+
torch.cuda.manual_seed_all(seed)
|
| 16 |
+
torch.backends.cudnn.benchmark = False
|
| 17 |
+
torch.backends.cudnn.deterministic = True
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def mean(x):
|
| 21 |
+
return sum(x) / len(x) if len(x) > 0 else 0.0
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def to_cuda(batch):
|
| 25 |
+
for key, value in batch.items():
|
| 26 |
+
if isinstance(value, torch.Tensor):
|
| 27 |
+
batch[key] = value.cuda()
|
| 28 |
+
return batch
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def to_cpu(tensor):
|
| 32 |
+
return tensor.detach().clone().cpu()
|
utils/segutils.py
ADDED
|
@@ -0,0 +1,584 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
from PIL import Image, ImageDraw
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
norm = lambda t: (t - t.min()) / (t.max() - t.min())
|
| 10 |
+
denorm = lambda t, min_, max_: t * (max_ - min_) + min_
|
| 11 |
+
|
| 12 |
+
percentilerange = lambda t, perc: t.min() + perc * (t.max() - t.min())
|
| 13 |
+
midrange = lambda t: percentilerange(t, .5)
|
| 14 |
+
|
| 15 |
+
downsample_mask = lambda mask, H, W: F.interpolate(mask.unsqueeze(1), size=(H, W), mode='bilinear',
|
| 16 |
+
align_corners=False).squeeze(1)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# downsampled_mask: [bsz,vecs], vecs can be H*W for example
|
| 20 |
+
# s_feat_volume: [bsz,c,vecs]
|
| 21 |
+
# returns [bsz,c], [bsz,c,vecs]
|
| 22 |
+
def fg_bg_proto(sfeat_volume, downsampled_smask):
|
| 23 |
+
B, C, vecs = sfeat_volume.shape
|
| 24 |
+
reshaped_mask = downsampled_smask.expand(B, vecs).unsqueeze(1) # ->[B,1,vecs]
|
| 25 |
+
|
| 26 |
+
masked_fg = reshaped_mask * sfeat_volume
|
| 27 |
+
fg_proto = torch.sum(masked_fg, dim=-1) / (torch.sum(reshaped_mask, dim=-1) + 1e-8)
|
| 28 |
+
|
| 29 |
+
masked_bg = (1 - reshaped_mask) * sfeat_volume
|
| 30 |
+
bg_proto = torch.sum(masked_bg, dim=-1) / (torch.sum(1 - reshaped_mask, dim=-1) + 1e-8)
|
| 31 |
+
assert fg_proto.shape == (B, C), ":o"
|
| 32 |
+
return fg_proto, bg_proto
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# intersection = lambda pred, target: (pred * target).float().sum()
|
| 36 |
+
# union = lambda pred, target: (pred + target).clamp(0, 1).float().sum()
|
| 37 |
+
#
|
| 38 |
+
#
|
| 39 |
+
# def iou(pred, target): # binary only, input bsz,h,w
|
| 40 |
+
# i, u = intersection(pred, target), union(pred, target)
|
| 41 |
+
# iou = (i + 1e-8) / (u + 1e-8)
|
| 42 |
+
# return iou.item()
|
| 43 |
+
#
|
| 44 |
+
#
|
| 45 |
+
# class SimpleAvgMeter:
|
| 46 |
+
# def __init__(self, n_classes, device=torch.device('cuda')):
|
| 47 |
+
# self.n_lasses = n_classes
|
| 48 |
+
# self.intersection_buf = torch.zeros(n_classes).to(device)
|
| 49 |
+
# self.union_buf = torch.zeros(n_classes).to(device)
|
| 50 |
+
#
|
| 51 |
+
# def update(self, pred, target, class_id):
|
| 52 |
+
# self.intersection_buf[class_id] += intersection(pred, target)
|
| 53 |
+
# self.union_buf[class_id] += union(pred, target)
|
| 54 |
+
#
|
| 55 |
+
# def IoU(self, class_id):
|
| 56 |
+
# return self.intersection_buf[class_id] / self.union_buf[class_id] * 100
|
| 57 |
+
#
|
| 58 |
+
# def cls_mIoU(self, class_ids):
|
| 59 |
+
# return (self.intersection_buf[class_ids] / self.union_buf[class_ids]).mean() * 100
|
| 60 |
+
#
|
| 61 |
+
# def compute_mIoU(self):
|
| 62 |
+
# noentry = self.union_buf == 0
|
| 63 |
+
# if noentry.sum() > 0: print("SimpleAvgMeter warning: ", noentry.sum(), "elements of", self.nclasses,
|
| 64 |
+
# "have no empty.")
|
| 65 |
+
# return self.cls_mIoU(~noentry)
|
| 66 |
+
|
| 67 |
+
# class KMeans():
|
| 68 |
+
# # expects input to be in shape [bsz, -1]
|
| 69 |
+
# def __init__(self, data, k=2, num_iterations=10):
|
| 70 |
+
# self.k = k
|
| 71 |
+
# self.device = data.device
|
| 72 |
+
# self.centroids = self._init_centroids(data)
|
| 73 |
+
#
|
| 74 |
+
# for _ in range(num_iterations):
|
| 75 |
+
# labels = self._assign_clusters(data)
|
| 76 |
+
# self._update_centroids(data, labels)
|
| 77 |
+
#
|
| 78 |
+
# self.labels = self._assign_clusters(data) # Final cluster assignment
|
| 79 |
+
#
|
| 80 |
+
# def _init_centroids(self, data):
|
| 81 |
+
# # Randomly initialize centroids
|
| 82 |
+
# centroids = []
|
| 83 |
+
# min_values = data.min(dim=1, keepdim=True).values
|
| 84 |
+
# range_values = (data.max(dim=1, keepdim=True).values - min_values)
|
| 85 |
+
#
|
| 86 |
+
# for _ in range(self.k):
|
| 87 |
+
# random_values = torch.rand((data.shape[0], 1)).to(self.device)
|
| 88 |
+
# centroids.append(min_values + random_values * range_values)
|
| 89 |
+
#
|
| 90 |
+
# return torch.cat(centroids, dim=1)
|
| 91 |
+
#
|
| 92 |
+
# def _assign_clusters(self, data):
|
| 93 |
+
# # Calculate distances between data points and centroids
|
| 94 |
+
# distances = torch.abs(data.unsqueeze(2) - self.centroids) # Expand data tensor to calculate distances
|
| 95 |
+
# # Determine the closest centroid for each data point
|
| 96 |
+
# labels = torch.argmin(distances, dim=2)
|
| 97 |
+
# # Sort labels so that the largest mean data point has the highest label
|
| 98 |
+
# cluster_means = [data[labels == k].mean() for k in range(self.k)]
|
| 99 |
+
# sorted_labels = {k: rank for rank, k in enumerate(sorted(range(self.k), key=lambda k: cluster_means[k]))}
|
| 100 |
+
# labels = torch.tensor([sorted_labels[label.item()] for label in labels.flatten()]).reshape_as(labels).to(
|
| 101 |
+
# self.device)
|
| 102 |
+
#
|
| 103 |
+
# return labels
|
| 104 |
+
#
|
| 105 |
+
# def _update_centroids(self, data, labels):
|
| 106 |
+
# # Calculate new centroids as the mean of the data points closest to each centroid
|
| 107 |
+
# mask = torch.nn.functional.one_hot(labels, num_classes=self.k).to(torch.float32)
|
| 108 |
+
# summed_data = torch.bmm(mask.transpose(1, 2), data.unsqueeze(2)) # Sum data points per centroid
|
| 109 |
+
# self.centroids = summed_data.squeeze() / mask.sum(dim=1, keepdim=True) # Normalize to get the mean
|
| 110 |
+
#
|
| 111 |
+
# def compute_thresholds(self):
|
| 112 |
+
# # Flatten the centroids along the middle dimension
|
| 113 |
+
# flat_centroids = self.centroids.view(self.centroids.size(0), -1)
|
| 114 |
+
#
|
| 115 |
+
# # Sort the flattened centroids
|
| 116 |
+
# sorted_centroids, _ = torch.sort(flat_centroids, dim=1)
|
| 117 |
+
#
|
| 118 |
+
# # Compute the midpoints between consecutive centroids
|
| 119 |
+
# thresholds = (sorted_centroids[:, :-1] + sorted_centroids[:, 1:]) / 2.0
|
| 120 |
+
#
|
| 121 |
+
# return thresholds
|
| 122 |
+
#
|
| 123 |
+
# def inference(self, data):
|
| 124 |
+
# # Assign data points to the nearest centroid
|
| 125 |
+
# return self._assign_clusters(data)
|
| 126 |
+
|
| 127 |
+
# def iterative_triclass_thresholding(image, max_iterations=100, tolerance=25):
|
| 128 |
+
# # Ensure image is grayscale
|
| 129 |
+
# if len(image.shape) == 3:
|
| 130 |
+
# image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
| 131 |
+
#
|
| 132 |
+
# # Initialize iteration parameters
|
| 133 |
+
# TBD_region = image.copy()
|
| 134 |
+
# iteration = 0
|
| 135 |
+
# prev_threshold = 0
|
| 136 |
+
#
|
| 137 |
+
# while iteration < max_iterations:
|
| 138 |
+
# iteration += 1
|
| 139 |
+
#
|
| 140 |
+
# # Step 1: Apply Otsu's thresholding on the TBD region
|
| 141 |
+
# current_threshold, _ = cv2.threshold(TBD_region, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
| 142 |
+
#
|
| 143 |
+
# # Check stopping criteria
|
| 144 |
+
# if abs(current_threshold - prev_threshold) < tolerance:
|
| 145 |
+
# break
|
| 146 |
+
# prev_threshold = current_threshold
|
| 147 |
+
#
|
| 148 |
+
# # Step 2: Calculate means for upper and lower regions
|
| 149 |
+
# upper_region = TBD_region[TBD_region > current_threshold]
|
| 150 |
+
# lower_region = TBD_region[TBD_region <= current_threshold]
|
| 151 |
+
#
|
| 152 |
+
# if len(upper_region) == 0 or len(lower_region) == 0:
|
| 153 |
+
# break # No further division possible
|
| 154 |
+
#
|
| 155 |
+
# mean_upper = np.mean(upper_region)
|
| 156 |
+
# mean_lower = np.mean(lower_region)
|
| 157 |
+
#
|
| 158 |
+
# # Step 3: Update temporary foreground, background, and TBD regions
|
| 159 |
+
# TBD_region[(TBD_region > mean_upper)] = 255 # Temporary foreground F
|
| 160 |
+
# TBD_region[(TBD_region < mean_lower)] = 0 # Temporary background B
|
| 161 |
+
#
|
| 162 |
+
# # Extracting the new TBD region (between mean_lower and mean_upper)
|
| 163 |
+
# mask = (TBD_region > mean_lower) & (TBD_region < mean_upper)
|
| 164 |
+
# TBD_region = TBD_region[mask] # Apply mask to extract region
|
| 165 |
+
#
|
| 166 |
+
# # Final classification after convergence or max iterations
|
| 167 |
+
# final_foreground = (image > current_threshold).astype(np.uint8) * 255
|
| 168 |
+
# final_background = (image <= current_threshold).astype(np.uint8) * 255
|
| 169 |
+
#
|
| 170 |
+
# return current_threshold, final_foreground
|
| 171 |
+
|
| 172 |
+
def otsus(batched_tensor_image, drop_least=0.05, mode='ordinary'):
|
| 173 |
+
bsz = batched_tensor_image.size(0)
|
| 174 |
+
binary_tensors = []
|
| 175 |
+
thresholds = []
|
| 176 |
+
|
| 177 |
+
for i in range(bsz):
|
| 178 |
+
# Convert the tensor to numpy array
|
| 179 |
+
numpy_image = batched_tensor_image[i].cpu().numpy()
|
| 180 |
+
|
| 181 |
+
# Rescale to [0, 255] and convert to uint8 type for OpenCV compatibility
|
| 182 |
+
npmin, npmax = numpy_image.min(), numpy_image.max()
|
| 183 |
+
numpy_image = (norm(numpy_image) * 255).astype(np.uint8)
|
| 184 |
+
|
| 185 |
+
# Drop values that are in the lowest percentiles
|
| 186 |
+
truncated_vals = numpy_image[numpy_image >= int(255 * drop_least)]
|
| 187 |
+
|
| 188 |
+
# Apply Otsu's thresholding
|
| 189 |
+
if mode == 'via_triclass':
|
| 190 |
+
thresh_value, _ = iterative_triclass_thresholding(truncated_vals)
|
| 191 |
+
else:
|
| 192 |
+
thresh_value, _ = cv2.threshold(truncated_vals, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
| 193 |
+
|
| 194 |
+
# Apply the computed threshold on the original image
|
| 195 |
+
binary_image = (numpy_image > thresh_value).astype(np.uint8) * 255
|
| 196 |
+
|
| 197 |
+
# Convert the result back to a tensor and append to the list
|
| 198 |
+
binary_tensors.append(torch.from_numpy(binary_image).float() / 255)
|
| 199 |
+
|
| 200 |
+
thresholds.append(torch.tensor(denorm(thresh_value / 255, npmin, npmax)) \
|
| 201 |
+
.to(batched_tensor_image.device, dtype=batched_tensor_image.dtype))
|
| 202 |
+
|
| 203 |
+
# Convert list of tensors back to a single batched tensor
|
| 204 |
+
binary_tensor_batch = torch.stack(binary_tensors, dim=0)
|
| 205 |
+
thresh_batch = torch.stack(thresholds, dim=0)
|
| 206 |
+
return thresh_batch, binary_tensor_batch
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def iterative_otsus(probab_mask, s_mask, maxiters=5, mode='ordinary',
|
| 210 |
+
debug=False): # verify that it works correctly when batch_size >1
|
| 211 |
+
it = 1
|
| 212 |
+
otsuthresh = 0
|
| 213 |
+
assert probab_mask.min() >= 0 and probab_mask.max() <= 1, 'you should pass probabilites'
|
| 214 |
+
while True:
|
| 215 |
+
clipped = torch.where(probab_mask < otsuthresh, 0, probab_mask)
|
| 216 |
+
otsuthresh, newmask = otsus(clipped.detach(), drop_least=.02, mode=mode)
|
| 217 |
+
if otsuthresh >= s_mask.mean():
|
| 218 |
+
return otsuthresh.to(probab_mask.device), newmask.to(probab_mask.device)
|
| 219 |
+
if it >= maxiters:
|
| 220 |
+
if debug:
|
| 221 |
+
print('reached maxiter:', it, 'with thresh', otsuthresh.item(), \
|
| 222 |
+
'removed', int(((clipped == 0).sum() / clipped.numel()).item() * 10000) / 100, \
|
| 223 |
+
'% at lower and and new min,max is', clipped[clipped > 0].min().item(), clipped.max().item())
|
| 224 |
+
display(pilImageRow(norm(probab_mask[0]), s_mask[0], maxwidth=300))
|
| 225 |
+
return s_mask.mean(), (probab_mask > s_mask.mean()).float() # otsuthresh
|
| 226 |
+
it += 1
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
# def upgrade_scipy():
|
| 230 |
+
# os.system('!pip install - -upgrade scipy')
|
| 231 |
+
#
|
| 232 |
+
#
|
| 233 |
+
# def slicRGB(q_img, n_segments=50, compactness=10., sigma=1, mask=None, debug=False):
|
| 234 |
+
# import skimage.segmentation as skseg
|
| 235 |
+
#
|
| 236 |
+
# rgb_labels = skseg.slic(q_img, n_segments=n_segments, compactness=compactness, sigma=sigma, mask=mask,
|
| 237 |
+
# enforce_connectivity=True)
|
| 238 |
+
#
|
| 239 |
+
# if debug:
|
| 240 |
+
# plt.imshow(skseg.mark_boundaries(q_img, rgb_labels))
|
| 241 |
+
# plt.show()
|
| 242 |
+
#
|
| 243 |
+
# return rgb_labels
|
| 244 |
+
#
|
| 245 |
+
#
|
| 246 |
+
#
|
| 247 |
+
# def slicRGBP(q_img, fg_pred, n_segments=30, compactness=0.1, sigma=1, mask=None, debug=False):
|
| 248 |
+
# import skimage.segmentation as skseg
|
| 249 |
+
#
|
| 250 |
+
# def concat_rgb_pred(rgbimg, pred):
|
| 251 |
+
# h, w = rgbimg.shape[:2]
|
| 252 |
+
# return np.concatenate((rgbimg, pred.reshape(h, w, 1)), axis=-1)
|
| 253 |
+
#
|
| 254 |
+
# rgbp_img = concat_rgb_pred(q_img, fg_pred)
|
| 255 |
+
# rgbp_labels = skseg.slic(rgbp_img, n_segments=n_segments, compactness=compactness, mask=mask, sigma=sigma,
|
| 256 |
+
# enforce_connectivity=True)
|
| 257 |
+
#
|
| 258 |
+
# if debug:
|
| 259 |
+
# rgb_labels = skseg.slic(q_img, n_segments=n_segments, compactness=10., sigma=sigma, mask=mask,
|
| 260 |
+
# enforce_connectivity=True)
|
| 261 |
+
# pred_labels = skseg.slic(fg_pred, n_segments=n_segments, compactness=compactness, sigma=sigma, mask=mask,
|
| 262 |
+
# channel_axis=None, enforce_connectivity=True)
|
| 263 |
+
#
|
| 264 |
+
# rows, cols = 1, 3
|
| 265 |
+
# fig, ax = plt.subplots(rows, cols, figsize=(10, 10), sharex=True, sharey=True)
|
| 266 |
+
# ax[0].imshow(skseg.mark_boundaries(q_img, rgbp_labels))
|
| 267 |
+
# ax[1].imshow(skseg.mark_boundaries(q_img, rgb_labels))
|
| 268 |
+
# ax[2].imshow(skseg.mark_boundaries(q_img, pred_labels))
|
| 269 |
+
# plt.show()
|
| 270 |
+
#
|
| 271 |
+
# return rgbp_labels
|
| 272 |
+
#
|
| 273 |
+
#
|
| 274 |
+
# def calc_cluster_means(label_id_map, fg_prob):
|
| 275 |
+
# fg_pred_clustered = np.zeros_like(fg_prob)
|
| 276 |
+
# label_ids = np.unique(label_id_map)
|
| 277 |
+
# for lab_id in label_ids:
|
| 278 |
+
# cluster = fg_prob[label_id_map == lab_id]
|
| 279 |
+
# fg_pred_clustered[label_id_map == lab_id] = cluster.mean()
|
| 280 |
+
# return fg_pred_clustered
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def install_pydensecrf():
|
| 284 |
+
os.system('pip install git+https://github.com/lucasb-eyer/pydensecrf.git')
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class CRF:
|
| 288 |
+
def __init__(self, gaussian_stdxy=(3, 3), gaussian_compat=3,
|
| 289 |
+
bilateral_stdxy=(80, 80), bilateral_compat=10, stdrgb=(13, 13, 13)):
|
| 290 |
+
self.gaussian_stdxy = gaussian_stdxy
|
| 291 |
+
self.gaussian_compat = gaussian_compat
|
| 292 |
+
self.bilateral_stdxy = bilateral_stdxy
|
| 293 |
+
self.bilateral_compat = bilateral_compat
|
| 294 |
+
self.stdrgb = stdrgb
|
| 295 |
+
self.iters = 5
|
| 296 |
+
self.debug = False
|
| 297 |
+
|
| 298 |
+
def refine(self, image_tensor, fg_probs, soft_thresh=None, T=1):
|
| 299 |
+
|
| 300 |
+
"""
|
| 301 |
+
Refine segmentation using DenseCRF.
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
- image_tensor (tensor): Original image, shape [1, 3, H, W].
|
| 305 |
+
- fg_probs (tensor): Fg probabilities from the network, shape [1, H, W]
|
| 306 |
+
- soft_thresh: The preferred threshold for fg_probs for segmenting into binary prediction mask
|
| 307 |
+
- T: a temperature for softmax/sigmoid
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
- Refined segmentation mask, shape [1, H, W].
|
| 311 |
+
"""
|
| 312 |
+
try:
|
| 313 |
+
import pydensecrf.densecrf as dcrf
|
| 314 |
+
from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral
|
| 315 |
+
except ImportError as e:
|
| 316 |
+
print("pydensecrf not found. Installing...")
|
| 317 |
+
install_pydensecrf() # Ensure this function installs pydensecrf and handles any potential errors during installation.
|
| 318 |
+
|
| 319 |
+
# After installation, retry importing. This is placed inside the except block to avoid repeating the import statements.
|
| 320 |
+
try:
|
| 321 |
+
import pydensecrf.densecrf as dcrf
|
| 322 |
+
from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral
|
| 323 |
+
except ImportError as e:
|
| 324 |
+
print("Failed to import after installation. Please check the installation of pydensecrf.")
|
| 325 |
+
raise # This will raise the last exception that was handled by the except block
|
| 326 |
+
|
| 327 |
+
# We find the segmentation threshold that splits fg-bg
|
| 328 |
+
if soft_thresh is None:
|
| 329 |
+
soft_thresh, _ = otsus(fg_probs)
|
| 330 |
+
image_tensor, fg_probs, soft_thresh = image_tensor.cpu(), fg_probs.cpu(), soft_thresh.cpu()
|
| 331 |
+
# Then we presume at this threshold the probability should be 0.5
|
| 332 |
+
# probability 0 should stay 0, 1 should stay 1
|
| 333 |
+
# sigmoid=lambda x: 1/(1 + np.exp(-x))
|
| 334 |
+
fg_probs = torch.sigmoid(T * (fg_probs - soft_thresh))
|
| 335 |
+
probs = torch.stack([1 - fg_probs, fg_probs], dim=1) # crf expects both classes as input
|
| 336 |
+
if self.debug:
|
| 337 |
+
print('softthresh', soft_thresh)
|
| 338 |
+
print('fg_probs min max', fg_probs.min(), fg_probs.max())
|
| 339 |
+
# C: Number of classes
|
| 340 |
+
bsz, C, H, W = probs.shape
|
| 341 |
+
|
| 342 |
+
refined_masks = []
|
| 343 |
+
image_numpy = np.ascontiguousarray( \
|
| 344 |
+
(255 * image_tensor.permute(0, 2, 3, 1)).numpy().astype(np.uint8))
|
| 345 |
+
probs_numpy = probs.numpy()
|
| 346 |
+
for (image, prob) in zip(image_numpy, probs_numpy):
|
| 347 |
+
# Unary potentials
|
| 348 |
+
unary = np.ascontiguousarray(unary_from_softmax(prob))
|
| 349 |
+
d = dcrf.DenseCRF2D(W, H, C)
|
| 350 |
+
d.setUnaryEnergy(unary)
|
| 351 |
+
|
| 352 |
+
# Add pairwise potentials
|
| 353 |
+
d.addPairwiseGaussian(sxy=self.gaussian_stdxy, compat=self.gaussian_compat)
|
| 354 |
+
d.addPairwiseBilateral(sxy=self.bilateral_stdxy, srgb=self.stdrgb,
|
| 355 |
+
rgbim=image, compat=self.bilateral_compat)
|
| 356 |
+
|
| 357 |
+
# Perform inference
|
| 358 |
+
Q = d.inference(self.iters)
|
| 359 |
+
if self.debug:
|
| 360 |
+
print('Q:', np.array(Q).shape, np.array(Q)[0].mean(), np.array(Q).mean())
|
| 361 |
+
result = np.reshape(Q, (2, H, W)) # np.argmax(Q, axis=0).reshape((H, W))
|
| 362 |
+
refined_masks.append(result)
|
| 363 |
+
|
| 364 |
+
return torch.from_numpy(np.stack(refined_masks, axis=0))
|
| 365 |
+
|
| 366 |
+
# def iterrefine(self, iters, image_tensor, fg_probs, soft_thresh=None, T=1):
|
| 367 |
+
# q1 = fg_probs
|
| 368 |
+
# for iter in range(iters):
|
| 369 |
+
# print(q1.shape)
|
| 370 |
+
# q1 = self.refine(image_tensor, q1, soft_thresh=None, T=1)[:,1]
|
| 371 |
+
# return q1
|
| 372 |
+
|
| 373 |
+
def iterrefine(self, iters, q_img, fg_probs, thresh_fn, debug=False):
|
| 374 |
+
pred = fg_probs.unsqueeze(1).expand(1, 2, *fg_probs.shape[-2:])
|
| 375 |
+
for it in range(iters):
|
| 376 |
+
thresh = thresh_fn(pred[:, 1])[0]
|
| 377 |
+
|
| 378 |
+
if debug and i % 10 == 0:
|
| 379 |
+
print('thresh', thresh)
|
| 380 |
+
display(to_pil(pred[0, 1]))
|
| 381 |
+
|
| 382 |
+
pred = self.refine(q_img, pred[:, 1], soft_thresh=thresh)
|
| 383 |
+
return pred
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
#
|
| 387 |
+
# class Subplot:
|
| 388 |
+
# def __init__(self):
|
| 389 |
+
# self.vertical_lines = []
|
| 390 |
+
# self.histograms = []
|
| 391 |
+
# self.gaussian_curves = []
|
| 392 |
+
# self.colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
|
| 393 |
+
# self.title = ''
|
| 394 |
+
#
|
| 395 |
+
# class Element:
|
| 396 |
+
# def __init__(self, x=None, y=None, label=''):
|
| 397 |
+
# if x is not None:
|
| 398 |
+
# self.x = Subplot.to_np(x)
|
| 399 |
+
# if y is not None:
|
| 400 |
+
# self.y = Subplot.to_np(y)
|
| 401 |
+
#
|
| 402 |
+
# self.label = label
|
| 403 |
+
#
|
| 404 |
+
# @staticmethod
|
| 405 |
+
# def to_np(t):
|
| 406 |
+
# return t.detach().cpu().numpy()
|
| 407 |
+
#
|
| 408 |
+
# def add_vertical(self, x, label=''):
|
| 409 |
+
# self.vertical_lines.append(Subplot.Element(x=x, label=label))
|
| 410 |
+
# return self
|
| 411 |
+
#
|
| 412 |
+
# def add_histogram(self, samples, label=''):
|
| 413 |
+
# self.histograms.append(Subplot.Element(x=samples, label=label))
|
| 414 |
+
# return self
|
| 415 |
+
#
|
| 416 |
+
# def add_gaussian(self, gaussian):
|
| 417 |
+
# samples, mu, var = gaussian.samples, gaussian.mean, gaussian.covs
|
| 418 |
+
# # Generate a range of x values
|
| 419 |
+
# x_values = np.linspace(samples.min(), samples.max(), 100)
|
| 420 |
+
# x_values = np.linspace(samples.min(), samples.max(), 100)
|
| 421 |
+
#
|
| 422 |
+
# # Compute Gaussian values for these x values
|
| 423 |
+
# gaussian1_values = gaussian.gaussian_pdf(x_values, mu[0].item(), var[0].item())
|
| 424 |
+
# gaussian2_values = gaussian.gaussian_pdf(x_values, mu[1].item(), var[1].item())
|
| 425 |
+
# self.gaussian_curves.append(Subplot.Element(x_values, gaussian1_values))
|
| 426 |
+
# self.gaussian_curves.append(Subplot.Element(x_values, gaussian2_values))
|
| 427 |
+
# return self
|
| 428 |
+
#
|
| 429 |
+
#
|
| 430 |
+
# class PredHistos2():
|
| 431 |
+
# def __init__(self, n_cols=1):
|
| 432 |
+
# self.fig, self.axes = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, 4))
|
| 433 |
+
# self.n_cols = n_cols
|
| 434 |
+
# if n_cols == 1:
|
| 435 |
+
# self.builder = Subplot()
|
| 436 |
+
# self.subplots = [Subplot() for x in range(n_cols)]
|
| 437 |
+
# self.alpha = 0.5
|
| 438 |
+
# self.bins = 200
|
| 439 |
+
#
|
| 440 |
+
# def reload(self, n_cols=1):
|
| 441 |
+
# self.fig, self.axes = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, 4))
|
| 442 |
+
#
|
| 443 |
+
# def aggr(self, ax, sub):
|
| 444 |
+
# for hist, col in zip(sub.histograms, sub.colors):
|
| 445 |
+
# ax.hist(hist.x, self.bins, density=True, color=col, alpha=self.alpha, label=hist.label)
|
| 446 |
+
# for vline, col in zip(sub.vertical_lines, sub.colors):
|
| 447 |
+
# ax.axvline(x=vline.x, color=col, label=vline.label, linestyle='--')
|
| 448 |
+
# for gaussian, col in zip(sub.gaussian_curves, sub.colors):
|
| 449 |
+
# ax.plot(gaussian.x, gaussian.y, gaussian.label, col)
|
| 450 |
+
# ax.legend()
|
| 451 |
+
#
|
| 452 |
+
# def plot(self, name=''):
|
| 453 |
+
#
|
| 454 |
+
# if self.n_cols == 1:
|
| 455 |
+
# self.aggr(plt, self.builder)
|
| 456 |
+
# else:
|
| 457 |
+
# for ax, sub in zip(self.axes, self.subplots):
|
| 458 |
+
# self.aggr(ax, sub)
|
| 459 |
+
# ax.set_title(sub.title)
|
| 460 |
+
#
|
| 461 |
+
# plt.legend()
|
| 462 |
+
# plt.title(name)
|
| 463 |
+
# plt.show()
|
| 464 |
+
#
|
| 465 |
+
#
|
| 466 |
+
# from sklearn.mixture import GaussianMixture
|
| 467 |
+
# import scipy.optimize as opt
|
| 468 |
+
# from scipy.optimize import fsolve
|
| 469 |
+
#
|
| 470 |
+
#
|
| 471 |
+
# class GMM:
|
| 472 |
+
# def __init__(self, q_pred_coarse, name='gaussian', n_components=2):
|
| 473 |
+
# samples = q_pred_coarse.detach().cpu().numpy()
|
| 474 |
+
# self.samples = samples.reshape(-1, 1)
|
| 475 |
+
#
|
| 476 |
+
# # Fit a mixture of 2 Gaussians using EM
|
| 477 |
+
# gmm = GaussianMixture(n_components)
|
| 478 |
+
# gmm.fit(samples)
|
| 479 |
+
# self.means = gmm.means_.flatten()
|
| 480 |
+
# self.covs = gmm.covariances_.flatten()
|
| 481 |
+
# self.weights = gmm.weights_
|
| 482 |
+
# self.label = name
|
| 483 |
+
#
|
| 484 |
+
# def intersect(self):
|
| 485 |
+
# # Use fsolve to find the intersection
|
| 486 |
+
# gaussian_intersect, = fsolve(difference, self.means.mean(), args=(
|
| 487 |
+
# self.means[0].item(), self.covs[0].item(), self.means[1].item(), self.means[1].item()))
|
| 488 |
+
# return gaussian_intersect
|
| 489 |
+
#
|
| 490 |
+
#
|
| 491 |
+
# class PredHistoSNS():
|
| 492 |
+
# def __init__(self, n_cols=1):
|
| 493 |
+
# import seaborn as sns
|
| 494 |
+
# sns.set_theme(style="whitegrid") # Set the Seaborn theme. You can change the style as needed.
|
| 495 |
+
# self.fig, self.axes = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, 4))
|
| 496 |
+
# self.n_cols = n_cols
|
| 497 |
+
# if n_cols == 1:
|
| 498 |
+
# self.axes = [self.axes] # Wrap the single axis in a list to simplify the loop logic later.
|
| 499 |
+
# self.builder = Subplot() # This is assuming Subplot is a properly defined class.
|
| 500 |
+
# self.subplots = [Subplot() for _ in range(n_cols)] # Use underscore for unused loop variable.
|
| 501 |
+
# self.alpha = 0.5
|
| 502 |
+
# self.bins = 200
|
| 503 |
+
#
|
| 504 |
+
# def reload(self, n_cols=1):
|
| 505 |
+
# self.fig, self.axes = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, 4))
|
| 506 |
+
#
|
| 507 |
+
# def aggr(self, ax, sub):
|
| 508 |
+
# import seaborn as sns
|
| 509 |
+
# for hist, col in zip(sub.histograms, sub.colors):
|
| 510 |
+
# sns.histplot(hist.x, bins=self.bins, kde=False, color=col, ax=ax, alpha=self.alpha, label=hist.label)
|
| 511 |
+
# for vline, col in zip(sub.vertical_lines, sub.colors):
|
| 512 |
+
# ax.axvline(x=vline.x, color=col, label=vline.label, linestyle='--')
|
| 513 |
+
# for gaussian, col in zip(sub.gaussian_curves, sub.colors):
|
| 514 |
+
# sns.lineplot(x=gaussian.x, y=gaussian.y, label=gaussian.label, color=col, ax=ax)
|
| 515 |
+
# ax.legend()
|
| 516 |
+
#
|
| 517 |
+
# def plot(self, name=''):
|
| 518 |
+
#
|
| 519 |
+
# if self.n_cols == 1:
|
| 520 |
+
# self.aggr(self.axes[0], self.builder)
|
| 521 |
+
# else:
|
| 522 |
+
# for ax, sub in zip(self.axes, self.subplots):
|
| 523 |
+
# self.aggr(ax, sub)
|
| 524 |
+
# ax.set_title(sub.title)
|
| 525 |
+
#
|
| 526 |
+
# plt.show()
|
| 527 |
+
#
|
| 528 |
+
#
|
| 529 |
+
# def overlay_mask(image, mask, color=[255, 0, 0], alpha=0.2):
|
| 530 |
+
# """
|
| 531 |
+
# Apply an overlay of a binary mask onto an image using a specified color.
|
| 532 |
+
#
|
| 533 |
+
# :param image: A PyTorch tensor of the image (C x H x W) with pixel values in [0, 1].
|
| 534 |
+
# :param mask: A PyTorch tensor of the mask (H x W) with binary values (0 or 1).
|
| 535 |
+
# :param color: A list of 3 elements representing the RGB values of the overlay color.
|
| 536 |
+
# :param alpha: A float representing the transparency of the overlay (0 to 1).
|
| 537 |
+
# :return: An overlayed image tensor.
|
| 538 |
+
# """
|
| 539 |
+
# # Ensure the mask is binary
|
| 540 |
+
# mask = (mask > 0).float()
|
| 541 |
+
#
|
| 542 |
+
# # Create an RGB version of the mask
|
| 543 |
+
# mask_rgb = torch.tensor(color).view(3, 1, 1) / 255.0 # Normalize the color vector
|
| 544 |
+
# mask_rgb = mask_rgb * mask
|
| 545 |
+
#
|
| 546 |
+
# # Overlay the mask onto the image
|
| 547 |
+
# overlayed_image = (1 - alpha) * image + alpha * mask_rgb
|
| 548 |
+
#
|
| 549 |
+
# # Ensure the resulting tensor values are between 0 and 1
|
| 550 |
+
# overlayed_image = torch.clamp(overlayed_image, 0, 1)
|
| 551 |
+
#
|
| 552 |
+
# return overlayed_image
|
| 553 |
+
#
|
| 554 |
+
#
|
| 555 |
+
# import pandas as pd
|
| 556 |
+
|
| 557 |
+
# to_pil = lambda t: transforms.ToPILImage()(t) if t.shape[-1] > 4 else transforms.ToPILImage()(t.permute(2, 0, 1))
|
| 558 |
+
#
|
| 559 |
+
#
|
| 560 |
+
# def pilImageRow(*imgs, maxwidth=800, bordercolor=0x000000):
|
| 561 |
+
# imgs = [to_pil(im.float()) for im in imgs]
|
| 562 |
+
# dst = Image.new('RGB', (sum(im.width for im in imgs), imgs[0].height))
|
| 563 |
+
# for i, im in enumerate(imgs):
|
| 564 |
+
# loc = [x0, y0, x1, y1] = [i * im.width, 0, (i + 1) * im.width, im.height]
|
| 565 |
+
# dst.paste(im, (x0, y0))
|
| 566 |
+
# ImageDraw.Draw(dst).rectangle(loc, width=2, outline=bordercolor)
|
| 567 |
+
# factorToBig = dst.width / maxwidth
|
| 568 |
+
# dst = dst.resize((int(dst.width / factorToBig), int(dst.height / factorToBig)))
|
| 569 |
+
# return dst
|
| 570 |
+
#
|
| 571 |
+
#
|
| 572 |
+
# def tensor_table(**kwargs):
|
| 573 |
+
# tensor_overview = {}
|
| 574 |
+
# for name, tensor in kwargs.items():
|
| 575 |
+
# if callable(tensor):
|
| 576 |
+
# print(name, [tensor(t) for _, t in kwargs.items() if isinstance(t, torch.Tensor)])
|
| 577 |
+
# else:
|
| 578 |
+
# tensor_overview[name] = {
|
| 579 |
+
# 'min': tensor.min().item(),
|
| 580 |
+
# 'max': tensor.max().item(),
|
| 581 |
+
# 'shape': tensor.shape,
|
| 582 |
+
# }
|
| 583 |
+
# return pd.DataFrame.from_dict(tensor_overview, orient='index')
|
| 584 |
+
|