|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Created in September 2022 |
|
|
@author: fabrizio.guillaro |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import time |
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
import matplotlib |
|
|
matplotlib.use('agg') |
|
|
import matplotlib.pyplot as plt |
|
|
import itertools |
|
|
|
|
|
|
|
|
|
|
|
def adjust_learning_rate(optimizer, base_lr, max_iters, cur_iters, power=0.9): |
|
|
lr = base_lr*((1-float(cur_iters)/max_iters)**(power)) |
|
|
for i, param_group in enumerate(optimizer.param_groups): |
|
|
param_group['lr'] = lr |
|
|
return lr |
|
|
|
|
|
|
|
|
class FullModel(nn.Module): |
|
|
""" |
|
|
Distribute the loss on multi-gpu to reduce the memory cost in the main gpu. |
|
|
""" |
|
|
def __init__(self, model, config=None): |
|
|
super(FullModel, self).__init__() |
|
|
self.model = model |
|
|
self.model_name = config.MODEL.NAME |
|
|
self.cfg = config |
|
|
self.losses = config.LOSS.LOSSES |
|
|
self.loss_loc, self.loss_conf, self.loss_det = get_criterion(config) |
|
|
|
|
|
def forward(self, labels=None, rgbs=None): |
|
|
outputs, conf, det, npp = self.model(rgbs) |
|
|
final_loss = 0 |
|
|
for (l,w,_) in self.losses: |
|
|
if l == 'LOC': |
|
|
loss = self.loss_loc(outputs, labels) |
|
|
elif l == 'CONF': |
|
|
loss = self.loss_conf(outputs, labels, conf) |
|
|
elif l == 'DET': |
|
|
loss = self.loss_det(det, labels) |
|
|
|
|
|
loss = torch.unsqueeze(loss, 0) |
|
|
final_loss += w * loss |
|
|
|
|
|
return final_loss, outputs, conf, det |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_model(config): |
|
|
if config.MODEL.NAME == 'detconfcmx': |
|
|
from lib.models.cmx.builder_np_conf import EncoderDecoder as detconfcmx |
|
|
return detconfcmx(cfg=config) |
|
|
else: |
|
|
raise NotImplementedError("Model not implemented") |
|
|
|
|
|
|
|
|
def get_criterion(config): |
|
|
ignore_label = config.TRAIN.IGNORE_LABEL |
|
|
smooth = config.LOSS.SMOOTH |
|
|
weight = torch.FloatTensor(config.DATASET.CLASS_WEIGHTS) |
|
|
|
|
|
losses = config.LOSS.LOSSES |
|
|
detection = config.MODEL.EXTRA.DETECTION |
|
|
|
|
|
criterion_loc, criterion_conf, criterion_det = None, None, None |
|
|
|
|
|
for (l,_,criterion) in losses: |
|
|
assert l in ['LOC', 'CONF', 'DET'] |
|
|
|
|
|
|
|
|
if l == 'LOC': |
|
|
if criterion == 'dice': |
|
|
from lib.core.criterion import DiceLoss |
|
|
criterion_loc = DiceLoss(ignore_label=ignore_label, smooth=smooth).cuda() |
|
|
elif criterion == 'binary_dice': |
|
|
from lib.core.criterion import BinaryDiceLoss |
|
|
criterion_loc = BinaryDiceLoss(ignore_label=ignore_label, smooth=smooth).cuda() |
|
|
elif criterion == 'cross_entropy': |
|
|
from lib.core.criterion import CrossEntropy |
|
|
criterion_loc = CrossEntropy(ignore_label=ignore_label, weight=weight).cuda() |
|
|
elif criterion == 'dice_entropy': |
|
|
from lib.core.criterion import DiceEntropyLoss |
|
|
criterion_loc = DiceEntropyLoss(ignore_label=ignore_label, weight=weight, smooth=smooth).cuda() |
|
|
else: |
|
|
raise ValueError('Localization loss not implemented') |
|
|
|
|
|
|
|
|
elif l == 'CONF': |
|
|
if criterion == 'mse': |
|
|
from lib.core.criterion_conf import MSE |
|
|
criterion_conf = MSE().cuda() |
|
|
else: |
|
|
raise ValueError('Confidence loss not implemented') |
|
|
|
|
|
|
|
|
elif l == 'DET': |
|
|
if detection is not None and not detection == 'none': |
|
|
if criterion == 'cross_entropy': |
|
|
from lib.core.criterion_det import CrossEntropy |
|
|
criterion_det = CrossEntropy().cuda() |
|
|
else: |
|
|
raise ValueError('Detection loss not implemented') |
|
|
|
|
|
return criterion_loc, criterion_conf, criterion_det |
|
|
|
|
|
|
|
|
|
|
|
def get_optimizer(model, config): |
|
|
if 'cmx' in config.MODEL.NAME: |
|
|
from lib.models.cmx.init_func import group_weight |
|
|
params_list = [] |
|
|
params_list = group_weight(params_list, model, nn.BatchNorm2d, config.TRAIN.LR) |
|
|
else: |
|
|
params_list = [{'params': filter(lambda p: p.requires_grad, model.parameters()), 'lr': config.TRAIN.LR}] |
|
|
|
|
|
if config.TRAIN.OPTIMIZER == 'sgd': |
|
|
optimizer = torch.optim.SGD(params_list, |
|
|
lr = config.TRAIN.LR, |
|
|
momentum = config.TRAIN.MOMENTUM, |
|
|
weight_decay = config.TRAIN.WD, |
|
|
nesterov = config.TRAIN.NESTEROV) |
|
|
elif config.TRAIN.OPTIMIZER == 'adam': |
|
|
optimizer = torch.optim.Adam(params_list, |
|
|
lr = config.TRAIN.LR, |
|
|
betas = (0.9, 0.999), |
|
|
weight_decay = config.TRAIN.WD) |
|
|
else: |
|
|
raise ValueError('Optimizer not implemented') |
|
|
|
|
|
return optimizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AverageMeter(object): |
|
|
"""Computes and stores the average and current value""" |
|
|
|
|
|
def __init__(self): |
|
|
self.initialized = False |
|
|
self.val = None |
|
|
self.avg = None |
|
|
self.sum = None |
|
|
self.count = None |
|
|
|
|
|
def initialize(self, val, weight): |
|
|
self.val = val |
|
|
self.avg = val |
|
|
self.sum = val * weight |
|
|
self.count = weight |
|
|
self.initialized = True |
|
|
|
|
|
def update(self, val, weight=1): |
|
|
if not self.initialized: |
|
|
self.initialize(val, weight) |
|
|
else: |
|
|
self.add(val, weight) |
|
|
|
|
|
def add(self, val, weight): |
|
|
self.val = val |
|
|
self.sum += val * weight |
|
|
self.count += weight |
|
|
self.avg = self.sum / self.count |
|
|
|
|
|
def value(self): |
|
|
return self.val |
|
|
|
|
|
def average(self): |
|
|
return self.avg |
|
|
|
|
|
|
|
|
|
|
|
def create_logger(cfg, cfg_name, phase='train'): |
|
|
root_output_dir = Path(cfg.OUTPUT_DIR) |
|
|
|
|
|
if not root_output_dir.exists(): |
|
|
print('=> creating {}'.format(root_output_dir)) |
|
|
root_output_dir.mkdir() |
|
|
|
|
|
model = cfg.MODEL.NAME |
|
|
final_output_dir = root_output_dir / cfg_name |
|
|
|
|
|
print('=> creating {}'.format(final_output_dir)) |
|
|
final_output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
time_str = time.strftime('%Y-%m-%d-%H-%M') |
|
|
log_file = '{}_{}_{}.log'.format(cfg_name.replace('/','_'), time_str, phase) |
|
|
final_log_file = final_output_dir / log_file |
|
|
head = '%(asctime)-15s %(message)s' |
|
|
logging.basicConfig(filename=str(final_log_file), format=head) |
|
|
logger = logging.getLogger() |
|
|
logger.setLevel(logging.INFO) |
|
|
console = logging.StreamHandler() |
|
|
logging.getLogger('').addHandler(console) |
|
|
|
|
|
tensorboard_log_dir = Path(cfg.LOG_DIR) / model / (cfg_name + '_' + time_str) |
|
|
return logger, str(final_output_dir), str(tensorboard_log_dir) |
|
|
|
|
|
|
|
|
|
|
|
def get_confusion_matrix(label, pred, size, num_class, ignore=-1): |
|
|
""" |
|
|
Calcute the confusion matrix by given label and pred |
|
|
""" |
|
|
output = pred.cpu().numpy().transpose(0, 2, 3, 1) |
|
|
seg_pred = np.asarray(np.argmax(output, axis=3), dtype=np.uint8) |
|
|
seg_gt = np.asarray( |
|
|
label.cpu().numpy()[:, :size[-2], :size[-1]], dtype=np.int) |
|
|
|
|
|
ignore_index = seg_gt != ignore |
|
|
seg_gt = seg_gt[ignore_index] |
|
|
seg_pred = seg_pred[ignore_index] |
|
|
|
|
|
index = (seg_gt * num_class + seg_pred).astype('int32') |
|
|
label_count = np.bincount(index) |
|
|
confusion_matrix = np.zeros((num_class, num_class)) |
|
|
|
|
|
for i_label in range(num_class): |
|
|
for i_pred in range(num_class): |
|
|
cur_index = i_label * num_class + i_pred |
|
|
if cur_index < len(label_count): |
|
|
confusion_matrix[i_label, |
|
|
i_pred] = label_count[cur_index] |
|
|
return confusion_matrix |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_confusion_matrix_1ch(label, confid, size, num_class, ignore=-1): |
|
|
""" |
|
|
Calcute the confusion matrix by given label and pred |
|
|
""" |
|
|
|
|
|
|
|
|
output = confid.squeeze(dim=1).cpu().numpy() |
|
|
|
|
|
|
|
|
seg_pred = np.asarray(output>0, dtype=np.uint8) |
|
|
seg_gt = np.asarray( |
|
|
label.cpu().numpy()[:, :size[-2], :size[-1]], dtype=np.int) |
|
|
|
|
|
ignore_index = seg_gt != ignore |
|
|
seg_gt = seg_gt[ignore_index] |
|
|
seg_pred = seg_pred[ignore_index] |
|
|
|
|
|
index = (seg_gt * num_class + seg_pred).astype('int32') |
|
|
label_count = np.bincount(index) |
|
|
confusion_matrix = np.zeros((num_class, num_class)) |
|
|
|
|
|
for i_label in range(num_class): |
|
|
for i_pred in range(num_class): |
|
|
cur_index = i_label * num_class + i_pred |
|
|
if cur_index < len(label_count): |
|
|
confusion_matrix[i_label, |
|
|
i_pred] = label_count[cur_index] |
|
|
return confusion_matrix |
|
|
|
|
|
|
|
|
def plot_confusion_matrix(confusion_matrix): |
|
|
|
|
|
fig = plt.figure(figsize=(3, 3), dpi=200, facecolor='w', edgecolor='k') |
|
|
ax = fig.add_subplot(1, 1, 1) |
|
|
im = ax.imshow(confusion_matrix, cmap='bwr') |
|
|
|
|
|
ax.set_xlabel('Predicted', fontsize=10) |
|
|
ax.set_xticks([0,1]) |
|
|
ax.xaxis.set_label_position('bottom') |
|
|
ax.xaxis.tick_bottom() |
|
|
|
|
|
ax.set_ylabel('True Label', fontsize=10) |
|
|
ax.set_yticks([0,1]) |
|
|
ax.yaxis.set_label_position('left') |
|
|
ax.yaxis.tick_left() |
|
|
|
|
|
for i, j in itertools.product(range(2), range(2)): |
|
|
ax.text(j, i, format(confusion_matrix[i, j], '.3e') if confusion_matrix[i,j]!=0 else '.', horizontalalignment="center", fontsize=10, verticalalignment='center', color= "black") |
|
|
|
|
|
fig.set_tight_layout(True) |
|
|
fig.colorbar(im,fraction=0.046, pad=0.04) |
|
|
|
|
|
fig.canvas.draw() |
|
|
canvas = fig.canvas.tostring_rgb() |
|
|
ncols, nrows = fig.canvas.get_width_height() |
|
|
cm = np.frombuffer(canvas, dtype=np.uint8).reshape(nrows, ncols, 3).transpose(2, 0, 1) |
|
|
plt.close(fig) |
|
|
return cm |
|
|
|