Spaces:
Sleeping
Sleeping
File size: 5,122 Bytes
c70812a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
r""" Dense Cross-Query-and-Support Attention Weighted Mask Aggregation for Few-Shot Segmentation """
import torch.nn as nn
import torch
from model.DCAMA import DCAMA
from common.logger import Logger, AverageMeter
from common.vis import Visualizer
from common.evaluation import Evaluator
from common.config import parse_opts
from common import utils
from data.dataset import FSSDataset
import cv2
import numpy as np
import os
# from gpu_mem_track import MemTracker
# gpu_tracker = MemTracker()
def test(model, dataloader, nshot):
r""" Test """
# Freeze randomness during testing for reproducibility
utils.fix_randseed(0)
average_meter = AverageMeter(dataloader.dataset)
for idx, batch in enumerate(dataloader):
# 1. forward pass
nshot = batch['support_imgs'].size(1)
## TODO:
batch = utils.to_cuda(batch)
# gpu_tracker.track()
pred_mask, simi, simi_map = model.module.predict_mask_nshot(batch, nshot=nshot)
# gpu_tracker.track()
torch.cuda.synchronize()
assert pred_mask.size() == batch['query_mask'].size()
# 2. Evaluate prediction
area_inter, area_union = Evaluator.classify_prediction(pred_mask.clone(), batch)
## TODO:
iou = area_inter[1] / area_union[1]
'''
cv2.imwrite('debug/query.png', cv2.imread("/home/bkdongxianchi/MY_MOT/TWL/data/COCO2014/{}".format(batch['query_name'][0])))
cv2.imwrite('debug/query_mask.png', (batch['query_mask'][0] * 255).detach().cpu().numpy().astype(np.uint8))
cv2.imwrite('debug/support_{:.3}.png'.format(iou.item()), cv2.imread('/home/bkdongxianchi/MY_MOT/TWL/data/COCO2014/{}'.format(batch['support_names'][0][0])))
cv2.imwrite('debug/support_mask_{:.3}.png'.format(iou.item()), (batch['support_masks'][0][0] * 255).detach().cpu().numpy().astype(np.uint8))
simi_map = simi_map - simi_map.min()
simi_map = (simi_map / simi_map.max() * 255).detach().cpu().numpy().astype(np.uint8)
cv2.imwrite('debug/simi_map_{:.3}.png'.format(iou.item()), simi_map)
if os.path.exists('debug/stats.txt'):
with open('debug/stats.txt', "a") as f:
f.write("{} {}\n".format(simi.item(), iou.item()))
else:
with open('debug/stats.txt', 'w') as f:
f.write('{} {}\n'.format(simi.item(), iou.item()))
'''
average_meter.update(area_inter, area_union, batch['class_id'], loss=None)
average_meter.write_process(idx, len(dataloader), epoch=-1, write_batch_idx=1)
# Visualize predictions
if Visualizer.visualize:
Visualizer.visualize_prediction_batch(batch['support_imgs'], batch['support_masks'],
batch['query_img'], batch['query_mask'],
pred_mask, batch['class_id'], idx,
iou_b=area_inter[1].float() / area_union[1].float())
# Write evaluation results
average_meter.write_result('Test', 0)
miou, fb_iou = average_meter.compute_iou()
return miou, fb_iou
if __name__ == '__main__':
# Arguments parsing
args = parse_opts()
Logger.initialize(args, training=False)
# Model initialization
model = DCAMA(args.backbone, args.feature_extractor_path, args.use_original_imgsize)
model.eval()
# Device setup
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Logger.info('# available GPUs: %d' % torch.cuda.device_count())
model = nn.DataParallel(model)
model.to(device)
# Load trained model
if args.load == '': raise Exception('Pretrained model not specified.')
params = model.state_dict()
state_dict = torch.load(args.load)
if 'state_dict' in state_dict.keys():
state_dict = state_dict['state_dict']
state_dict2 = {}
for k, v in state_dict.items():
if 'scorer' not in k:
state_dict2[k] = v
state_dict = state_dict2
for k1, k2 in zip(list(state_dict.keys()), params.keys()):
state_dict[k2] = state_dict.pop(k1)
try:
model.load_state_dict(state_dict, strict=True)
except:
for k in params.keys():
if k not in state_dict.keys():
state_dict[k] = params[k]
model.load_state_dict(state_dict, strict=True)
# Helper classes (for testing) initialization
Evaluator.initialize()
Visualizer.initialize(args.visualize, args.vispath)
# Dataset initialization
FSSDataset.initialize(img_size=384, datapath=args.datapath, use_original_imgsize=args.use_original_imgsize)
dataloader_test = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'test', args.nshot)
# Test
with torch.no_grad():
test_miou, test_fb_iou = test(model, dataloader_test, args.nshot)
Logger.info('Fold %d mIoU: %5.2f \t FB-IoU: %5.2f' % (args.fold, test_miou.item(), test_fb_iou.item()))
Logger.info('==================== Finished Testing ====================')
|