File size: 6,171 Bytes
322161a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
r""" Logging during training/testing """
import datetime
import logging
import os
from tensorboardX import SummaryWriter
import torch
class AverageMeter:
r""" Stores loss, evaluation results """
def __init__(self, dataset, device='cuda'):
self.benchmark = dataset.benchmark
if self.benchmark == 'pascal':
self.class_ids_interest = dataset.class_ids
self.class_ids_interest = torch.tensor(self.class_ids_interest).to(device)
self.nclass = 20
elif self.benchmark == 'fss':
self.class_ids_interest = dataset.class_ids
self.class_ids_interest = torch.tensor(self.class_ids_interest).to(device)
self.nclass = 1000
elif self.benchmark == 'deepglobe':
self.class_ids_interest = dataset.class_ids
self.class_ids_interest = torch.tensor(self.class_ids_interest).to(device)
self.nclass = 6
elif self.benchmark == 'isic':
self.class_ids_interest = dataset.class_ids
self.class_ids_interest = torch.tensor(self.class_ids_interest).to(device)
self.nclass = 3
elif self.benchmark == 'lung':
self.class_ids_interest = dataset.class_ids
self.class_ids_interest = torch.tensor(self.class_ids_interest).to(device)
self.nclass = 1
elif self.benchmark == 'suim':
self.class_ids_interest = dataset.class_ids
self.class_ids_interest = torch.tensor(self.class_ids_interest).to(device)
self.nclass = 7
else:
raise Exception('Unknown dataset: %s' % dataset)
self.intersection_buf = torch.zeros([2, self.nclass]).float().to(device)
self.union_buf = torch.zeros([2, self.nclass]).float().to(device)
self.ones = torch.ones_like(self.union_buf)
self.loss_buf = []
def update(self, inter_b, union_b, class_id, loss):
self.intersection_buf.index_add_(1, class_id, inter_b.float())
self.union_buf.index_add_(1, class_id, union_b.float())
if loss is None:
loss = torch.tensor(0.0)
self.loss_buf.append(loss)
def compute_iou(self):
iou = self.intersection_buf.float() / \
torch.max(torch.stack([self.union_buf, self.ones]), dim=0)[0]
iou = iou.index_select(1, self.class_ids_interest)
miou = iou[1].mean() * 100
fb_iou = (self.intersection_buf.index_select(1, self.class_ids_interest).sum(dim=1) /
self.union_buf.index_select(1, self.class_ids_interest).sum(dim=1)).mean() * 100
return miou, fb_iou
def write_result(self, split, epoch):
iou,fb_iou = self.compute_iou()
loss_buf = torch.stack(self.loss_buf)
msg = '\n*** %s ' % split
msg += '[@Epoch %02d] ' % epoch
msg += 'Avg L: %6.5f ' % loss_buf.mean()
msg += 'mIoU: %5.2f ' % iou
msg += 'FB-IoU: %5.2f ' % fb_iou
msg += '***\n'
Logger.info(msg)
def write_process(self, batch_idx, datalen, epoch, write_batch_idx=20):
if batch_idx % write_batch_idx == 0:
msg = '[Epoch: %02d] ' % epoch if epoch != -1 else ''
msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen)
iou,fb_iou = self.compute_iou()
if epoch != -1:
loss_buf = torch.stack(self.loss_buf)
msg += 'L: %6.5f ' % loss_buf[-1]
msg += 'Avg L: %6.5f ' % loss_buf.mean()
msg += 'mIoU: %5.2f | ' % iou
msg += 'FB-IoU: %5.2f' % fb_iou
Logger.info(msg)
class Logger:
r""" Writes evaluation results of training/testing """
@classmethod
def initialize(cls, args, training):
logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S')
logpath = args.logpath if training else args.logpath + '_TEST_' + logtime # changed lopath created for test
if logpath == '': logpath = logtime
cls.logpath = os.path.join('logs', logpath + '.log')
cls.benchmark = args.benchmark
print("logdir: ",cls.logpath)
os.makedirs(cls.logpath)
logging.basicConfig(filemode='w',
filename=os.path.join(cls.logpath, 'log.txt'),
level=logging.INFO,
format='%(message)s',
datefmt='%m-%d %H:%M:%S')
# Console log config
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)
# Tensorboard writer
cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs'))
# Log arguments
logging.info('\n:=========== Adapt Before Comparison - A New Perspective on Cross-Domain Few-Shot Segmentation ===========')
for arg_key in args.__dict__:
logging.info('| %20s: %-24s' % (arg_key, str(args.__dict__[arg_key])))
logging.info(':================================================\n')
@classmethod
def info(cls, msg):
r""" Writes log message to log.txt """
logging.info(msg)
@classmethod
def save_model_miou(cls, model, epoch, val_miou):
torch.save(model.state_dict(), os.path.join(cls.logpath, 'best_model.pt'))
cls.info('Model saved @%d w/ val. mIoU: %5.2f.\n' % (epoch, val_miou))
@classmethod
def log_params(cls, model):
backbone_param = 0
learner_param = 0
for k in model.state_dict().keys():
n_param = model.state_dict()[k].view(-1).size(0)
if k.split('.')[0] in 'backbone':
if k.split('.')[1] in ['classifier', 'fc']: # as fc layers are not used in HSNet
continue
backbone_param += n_param
else:
learner_param += n_param
Logger.info('Backbone # param.: %d' % backbone_param)
Logger.info('Learnable # param.: %d' % learner_param)
Logger.info('Total # param.: %d' % (backbone_param + learner_param))
|