File size: 6,918 Bytes
322161a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from data.dataset import FSSDataset
from core.backbone import Backbone
from eval.logger import Logger, AverageMeter
from eval.evaluation import Evaluator
from utils import commonutils as utils
import utils.segutils as segutils
import core.contrastivehead as ctrutils
import core.denseaffinity as dautils
import torch

class args:
    backbone = 'resnet50'
    logpath = '/kaggle/working/logs'
    nworker = 0
    bsz = 1
    benchmark='' #e.g. deepglobe,isic,etc.
    datapath='' #path to the selected dataset
    fold = 0
    nshot = 1

class SingleSampleEval:
    def __init__(self, batch, feat_maker, debug=False):
        self.damat_comp = dautils.DAMatComparison()
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.batch = batch
        self.feat_maker = feat_maker
        self.debug = debug
        self.thresh_method = 'pred_mean'

    def taskAdapt(self, detach=True):
        b = self.batch
        if self.device.type == 'cuda': b = utils.to_cuda(b)
        self.q_img, self.s_img, self.s_mask, self.class_id = b['query_img'], b['support_imgs'], b['support_masks'], b[
            'class_id'].item()
        self.task_adapted = self.feat_maker.taskAdapt(self.q_img, self.s_img, self.s_mask, self.class_id)

    def compare_feats(self):
        if self.task_adapted is None:
            print("error, do task adaption first")
            return None
        self.logit_mask = self.damat_comp.forward(self.task_adapted[0], self.task_adapted[1], self.s_mask)
        return self.logit_mask

    def threshold(self, method=None):
        if self.logit_mask is None:
            print("error, calculate logit mask first (do forward pass)")
        if method is None:
            method = self.thresh_method
        self.thresh = calcthresh(self.logit_mask, self.s_mask, method)
        self.pred_mask = (self.logit_mask > self.thresh).float()
        return self.thresh, self.pred_mask

    def apply_crf(self):
        return apply_crf(self.q_img, self.logit_mask, thresh_fn(self.thresh_method))

    # this method calls above components sequentially
    def forward(self):
        self.taskAdapt()

        self.logit_mask = self.compare_feats()

        self.thresh, self.pred_mask = self.threshold()

        return self.logit_mask, self.pred_mask

    def calc_metrics(self):
        # assert torch.logical_or(self.logit_mask<0, self.logit_mask>1).sum()==0, display(tensor_table(logit_mask=self.logit_mask))
        self.area_inter, self.area_union = Evaluator.classify_prediction(self.pred_mask, self.batch)
        self.fgratio_pred = self.pred_mask.float().mean()
        self.fgratio_gt = self.batch['query_mask'].float().mean()
        return self.area_inter[1] / self.area_union[1]  # fg-iou

    def plots(self):
        display(pilImageRow(norm(self.logit_mask[0]), (self.logit_mask[0] > self.thresh).float(), self.pred_mask,
                            self.batch['query_mask'][:1], norm(self.q_img[0]), norm(self.s_img[0, 0])))
        display(segutils.tensor_table(probs=self.logit_mask))

        print('s_mask.mean, pred_mask.mean, thresh:', self.s_mask.mean().item(), self.logit_mask.mean().item(),
              self.thresh.item())

class AverageMeterWrapper:
    def __init__(self, dataloader, device='cpu', initlogger=True):
        if initlogger: Logger.initialize(args, training=False)
        self.average_meter = AverageMeter(dataloader.dataset, device)
        self.device=device
        self.dataloader = dataloader
        self.write_batch_idx = 50
    def update(self, sseval):
        self.average_meter.update(sseval.area_inter, sseval.area_union, torch.tensor(sseval.class_id).to(self.device), loss=None)
    def update_manual(self, area_inter, area_union, class_id):
        if isinstance(class_id, int): class_id = torch.tensor(class_id).to(self.device)
        self.average_meter.update(area_inter, area_union, class_id, loss=None)
    def write(self, i):
        self.average_meter.write_process(i, len(self.dataloader), 0, self.write_batch_idx)

def makeDataloader():

    FSSDataset.initialize(img_size=400, datapath=args.datapath)
    dataloader = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'test', args.nshot)
    return dataloader


def makeConfig():
    config = ctrutils.ContrastiveConfig()
    config.fitting.protoloss = False
    config.fitting.o_t_contr_proto_loss = True
    config.fitting.selfattentionloss = False
    config.fitting.keepvarloss = True
    config.fitting.symmetricloss = False
    config.fitting.q_nceloss = True
    config.fitting.s_nceloss = True
    config.fitting.num_epochs = 25
    config.fitting.lr = 1e-2
    config.fitting.debug = False
    config.model.out_channels = 64
    config.featext.fit_every_episode = False
    config.aug.blurkernelsize = [1]
    config.aug.n_transformed_imgs = 2
    config.aug.maxjitter = 0.0
    config.aug.maxangle = 0
    config.aug.maxscale = 1
    config.aug.maxshear = 20
    config.aug.apply_affine = True
    config.aug.debug = False
    return config


def makeFeatureMaker(dataset, config, device='cpu', randseed=2, feat_extr_method=None):
    utils.fix_randseed(randseed)
    if feat_extr_method is None:
        feat_extr_method = Backbone(args.backbone).to(device).extract_feats
    feat_maker = ctrutils.FeatureMaker(feat_extr_method, dataset.class_ids, config)
    utils.fix_randseed(randseed)
    feat_maker.norm_bb_feats = False
    return feat_maker
def apply_crf(rgb_img, fg_pred, thresh_fn,iterations=5): #5 on deployment, 1 on support-aug test for speedup
    crf = segutils.CRF(gaussian_stdxy=(1,1), gaussian_compat=2,
                 bilateral_stdxy=(35,35), bilateral_compat=1, stdrgb=(13,13,13))
    q = crf.iterrefine(iterations, rgb_img, fg_pred, thresh_fn)
    return q.argmax(1)

def calcthresh(fused_pred, s_masks, method='otsus'):
    if method=='iterotsus':
        thresh = segutils.iterative_otsus(fused_pred,s_masks,maxiters=5)[0]
        return thresh
    elif method=='1iterotsus':
        thresh = segutils.iterative_otsus(fused_pred,s_masks,maxiters=1)[0]
        return thresh
    elif method=='otsus':
        thresh = segutils.otsus(fused_pred)[0]
        return thresh
    # elif method=='via_triclass':
    #     thresh = segutils.otsus(fused_pred, mode='via_triclass')[0]
    elif method=='pred_mean':
        otsu_thresh = segutils.otsus(fused_pred)[0]
        thresh = torch.max(otsu_thresh, fused_pred.mean())
    # elif method=='3kmeans':
    #     k3 = segutils.KMeans(fused_pred.float().view(1,-1), k=3)
    #     thresh = k3.compute_thresholds()[0][-1]
    return thresh

def thresh_fn(method):
    def inner(fused_pred, s_masks=None):
        return calcthresh(fused_pred, s_masks, method)
    return inner