File size: 6,918 Bytes
322161a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
from data.dataset import FSSDataset
from core.backbone import Backbone
from eval.logger import Logger, AverageMeter
from eval.evaluation import Evaluator
from utils import commonutils as utils
import utils.segutils as segutils
import core.contrastivehead as ctrutils
import core.denseaffinity as dautils
import torch
class args:
backbone = 'resnet50'
logpath = '/kaggle/working/logs'
nworker = 0
bsz = 1
benchmark='' #e.g. deepglobe,isic,etc.
datapath='' #path to the selected dataset
fold = 0
nshot = 1
class SingleSampleEval:
def __init__(self, batch, feat_maker, debug=False):
self.damat_comp = dautils.DAMatComparison()
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.batch = batch
self.feat_maker = feat_maker
self.debug = debug
self.thresh_method = 'pred_mean'
def taskAdapt(self, detach=True):
b = self.batch
if self.device.type == 'cuda': b = utils.to_cuda(b)
self.q_img, self.s_img, self.s_mask, self.class_id = b['query_img'], b['support_imgs'], b['support_masks'], b[
'class_id'].item()
self.task_adapted = self.feat_maker.taskAdapt(self.q_img, self.s_img, self.s_mask, self.class_id)
def compare_feats(self):
if self.task_adapted is None:
print("error, do task adaption first")
return None
self.logit_mask = self.damat_comp.forward(self.task_adapted[0], self.task_adapted[1], self.s_mask)
return self.logit_mask
def threshold(self, method=None):
if self.logit_mask is None:
print("error, calculate logit mask first (do forward pass)")
if method is None:
method = self.thresh_method
self.thresh = calcthresh(self.logit_mask, self.s_mask, method)
self.pred_mask = (self.logit_mask > self.thresh).float()
return self.thresh, self.pred_mask
def apply_crf(self):
return apply_crf(self.q_img, self.logit_mask, thresh_fn(self.thresh_method))
# this method calls above components sequentially
def forward(self):
self.taskAdapt()
self.logit_mask = self.compare_feats()
self.thresh, self.pred_mask = self.threshold()
return self.logit_mask, self.pred_mask
def calc_metrics(self):
# assert torch.logical_or(self.logit_mask<0, self.logit_mask>1).sum()==0, display(tensor_table(logit_mask=self.logit_mask))
self.area_inter, self.area_union = Evaluator.classify_prediction(self.pred_mask, self.batch)
self.fgratio_pred = self.pred_mask.float().mean()
self.fgratio_gt = self.batch['query_mask'].float().mean()
return self.area_inter[1] / self.area_union[1] # fg-iou
def plots(self):
display(pilImageRow(norm(self.logit_mask[0]), (self.logit_mask[0] > self.thresh).float(), self.pred_mask,
self.batch['query_mask'][:1], norm(self.q_img[0]), norm(self.s_img[0, 0])))
display(segutils.tensor_table(probs=self.logit_mask))
print('s_mask.mean, pred_mask.mean, thresh:', self.s_mask.mean().item(), self.logit_mask.mean().item(),
self.thresh.item())
class AverageMeterWrapper:
def __init__(self, dataloader, device='cpu', initlogger=True):
if initlogger: Logger.initialize(args, training=False)
self.average_meter = AverageMeter(dataloader.dataset, device)
self.device=device
self.dataloader = dataloader
self.write_batch_idx = 50
def update(self, sseval):
self.average_meter.update(sseval.area_inter, sseval.area_union, torch.tensor(sseval.class_id).to(self.device), loss=None)
def update_manual(self, area_inter, area_union, class_id):
if isinstance(class_id, int): class_id = torch.tensor(class_id).to(self.device)
self.average_meter.update(area_inter, area_union, class_id, loss=None)
def write(self, i):
self.average_meter.write_process(i, len(self.dataloader), 0, self.write_batch_idx)
def makeDataloader():
FSSDataset.initialize(img_size=400, datapath=args.datapath)
dataloader = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'test', args.nshot)
return dataloader
def makeConfig():
config = ctrutils.ContrastiveConfig()
config.fitting.protoloss = False
config.fitting.o_t_contr_proto_loss = True
config.fitting.selfattentionloss = False
config.fitting.keepvarloss = True
config.fitting.symmetricloss = False
config.fitting.q_nceloss = True
config.fitting.s_nceloss = True
config.fitting.num_epochs = 25
config.fitting.lr = 1e-2
config.fitting.debug = False
config.model.out_channels = 64
config.featext.fit_every_episode = False
config.aug.blurkernelsize = [1]
config.aug.n_transformed_imgs = 2
config.aug.maxjitter = 0.0
config.aug.maxangle = 0
config.aug.maxscale = 1
config.aug.maxshear = 20
config.aug.apply_affine = True
config.aug.debug = False
return config
def makeFeatureMaker(dataset, config, device='cpu', randseed=2, feat_extr_method=None):
utils.fix_randseed(randseed)
if feat_extr_method is None:
feat_extr_method = Backbone(args.backbone).to(device).extract_feats
feat_maker = ctrutils.FeatureMaker(feat_extr_method, dataset.class_ids, config)
utils.fix_randseed(randseed)
feat_maker.norm_bb_feats = False
return feat_maker
def apply_crf(rgb_img, fg_pred, thresh_fn,iterations=5): #5 on deployment, 1 on support-aug test for speedup
crf = segutils.CRF(gaussian_stdxy=(1,1), gaussian_compat=2,
bilateral_stdxy=(35,35), bilateral_compat=1, stdrgb=(13,13,13))
q = crf.iterrefine(iterations, rgb_img, fg_pred, thresh_fn)
return q.argmax(1)
def calcthresh(fused_pred, s_masks, method='otsus'):
if method=='iterotsus':
thresh = segutils.iterative_otsus(fused_pred,s_masks,maxiters=5)[0]
return thresh
elif method=='1iterotsus':
thresh = segutils.iterative_otsus(fused_pred,s_masks,maxiters=1)[0]
return thresh
elif method=='otsus':
thresh = segutils.otsus(fused_pred)[0]
return thresh
# elif method=='via_triclass':
# thresh = segutils.otsus(fused_pred, mode='via_triclass')[0]
elif method=='pred_mean':
otsu_thresh = segutils.otsus(fused_pred)[0]
thresh = torch.max(otsu_thresh, fused_pred.mean())
# elif method=='3kmeans':
# k3 = segutils.KMeans(fused_pred.float().view(1,-1), k=3)
# thresh = k3.compute_thresholds()[0][-1]
return thresh
def thresh_fn(method):
def inner(fused_pred, s_masks=None):
return calcthresh(fused_pred, s_masks, method)
return inner |