BrainFM / Trainer /models /criterion.py
peirong26's picture
Upload 187 files
2571f24 verified
"""
Criterion modules.
"""
import numpy as np
import torch
import torch.nn as nn
from Trainer.models.losses import GradientLoss, SmoothnessLoss, HessianLoss, gaussian_loss, laplace_loss, l1_loss
from utils.misc import viewVolume
uncertainty_loss = {'gaussian': gaussian_loss, 'laplace': laplace_loss}
class SetCriterion(nn.Module):
"""
This class computes the loss for BrainID.
"""
def __init__(self, gen_args, train_args, weight_dict, loss_names, device):
""" Create the criterion.
Parameters:
args: general exp cfg
weight_dict: dict containing as key the names of the losses and as values their
relative weight.
loss_names: list of all the losses to be applied. See get_loss for list of
available loss_names.
"""
super(SetCriterion, self).__init__()
self.gen_args = gen_args
self.train_args = train_args
self.weight_dict = weight_dict
self.loss_names = loss_names
self.mse = nn.MSELoss()
self.loss_regression_type = train_args.losses.uncertainty if train_args.losses.uncertainty is not None else 'l1'
self.loss_regression = uncertainty_loss[train_args.losses.uncertainty] if train_args.losses.uncertainty is not None else l1_loss
self.grad = GradientLoss('l1')
self.smoothness = SmoothnessLoss('l2')
self.hessian = HessianLoss('l2')
self.bflog_loss = nn.L1Loss() if train_args.losses.bias_field_log_type == 'l1' else self.mse
if 'contrastive' in self.loss_names:
self.temp_alpha = train_args.contrastive_temperatures.alpha
self.temp_beta = train_args.contrastive_temperatures.beta
self.temp_gamma = train_args.contrastive_temperatures.gamma
# initialize weights # NOTE all = 1 for now
weights_brainseg = torch.ones(gen_args.n_labels).to(device)
weights_brainseg[gen_args.label_list_segmentation_with_csf==77] = train_args.relative_weight_lesions # give (more) importance to lesions
weights_brainseg = weights_brainseg / torch.sum(weights_brainseg)
self.weights_ce = weights_brainseg[None, :, None, None, None]
self.weights_dice = weights_brainseg[None, :]
# archived
#self.csf_ind = torch.tensor(np.where(np.array(gen_args.label_list_segmentation)==24)[0][0])
#self.csf_v = torch.tensor(np.concatenate([np.arange(0, self.csf_ind), np.arange(self.csf_ind+1, gen_args.n_labels)]))
self.loss_map = {
'seg_ce': self.loss_seg_ce,
'seg_dice': self.loss_seg_dice,
'pathol_ce': self.loss_pathol_ce,
'pathol_dice': self.loss_pathol_dice,
'implicit_pathol_ce': self.loss_implicit_pathol_ce,
'implicit_pathol_dice': self.loss_implicit_pathol_dice,
'implicit_aux_pathol_ce': self.loss_implicit_aux_pathol_ce,
'implicit_aux_pathol_dice': self.loss_implicit_aux_pathol_dice,
'T1': self.loss_T1,
'T1_grad': self.loss_T1_grad,
'T2': self.loss_T2,
'T2_grad': self.loss_T2_grad,
'FLAIR': self.loss_FLAIR,
'FLAIR_grad': self.loss_FLAIR_grad,
'CT': self.loss_CT,
'CT_grad': self.loss_CT_grad,
'SR': self.loss_SR,
'SR_grad': self.loss_SR_grad,
"age": self.loss_age,
"distance": self.loss_distance,
"registration": self.loss_registration,
"registration_grad": self.loss_registration_grad,
"registration_hessian": self.loss_registration_hessian,
"registration_smooth": self.loss_registration_smooth,
"bias_field_log": self.loss_bias_field_log,
'contrastive': self.loss_feat_contrastive,
"surface": self.loss_surface, # TODO
#'supervised_seg': self.loss_supervised_seg, # archived
}
def loss_feat_contrastive(self, outputs, *kwargs):
"""
outputs: [feat1, feat2]
feat shape: (b, feat_dim, s, r, c)
"""
feat1, feat2 = outputs[0]['feat'][-1], outputs[1]['feat'][-1]
num = torch.sum(torch.exp(feat1 * feat2 / self.temp_alpha), dim = 1)
den = torch.zeros_like(feat1[:, 0])
for i in range(feat1.shape[1]):
den1 = torch.exp(feat1[:, i] ** 2 / self.temp_beta)
den2 = torch.exp((torch.sum(feat1[:, i][:, None] * feat1, dim = 1) - feat1[:, i] ** 2) / self.temp_gamma)
den += den1 + den2
loss_contrastive = torch.mean(- torch.log(num / den))
return {'loss_contrastive': loss_contrastive}
def loss_seg_ce(self, outputs, targets, *kwargs):
"""
Cross entropy of segmentation
"""
loss_seg_ce = torch.mean(-torch.sum(torch.log(torch.clamp(outputs['segmentation'], min=1e-5)) * self.weights_ce * targets['segmentation'], dim=1))
return {'loss_seg_ce': loss_seg_ce}
def loss_seg_dice(self, outputs, targets, *kwargs):
"""
Dice of segmentation
"""
loss_seg_dice = torch.sum(self.weights_dice * (1.0 - 2.0 * ((outputs['segmentation'] * targets['segmentation']).sum(dim=[2, 3, 4]))
/ torch.clamp((outputs['segmentation'] + targets['segmentation']).sum(dim=[2, 3, 4]), min=1e-5)))
return {'loss_seg_dice': loss_seg_dice}
def loss_implicit_pathol_ce(self, outputs, targets, samples, *kwargs):
"""
Cross entropy of pathology segmentation
"""
if 'implicit_pathol_pred' in outputs:
#loss_implicit_pathol_ce = torch.mean(-torch.sum(torch.log(torch.clamp(outputs['implicit_pathol_pred'], min=1e-5)) * self.weights_ce * samples['pathol'], dim=1))
loss_implicit_pathol_ce = torch.mean(-torch.sum(torch.log(torch.clamp(outputs['implicit_pathol_pred'], min=1e-5)) * outputs['implicit_pathol_orig'], dim=1))
else: # no GT image exists
loss_implicit_pathol_ce = 0.
return {'loss_implicit_pathol_ce': loss_implicit_pathol_ce}
def loss_implicit_pathol_dice(self, outputs, targets, samples, *kwargs):
"""
Dice of pathology segmentation
"""
if 'implicit_pathol_pred' in outputs:
#loss_implicit_pathol_dice = torch.sum(self.weights_dice * (1.0 - 2.0 * ((outputs['implicit_pathol_pred'] * samples['pathol']).sum(dim=[2, 3, 4]))
# / torch.clamp((outputs['implicit_pathol_pred'] + samples['pathol']).sum(dim=[2, 3, 4]), min=1e-5)))
loss_implicit_pathol_dice = torch.sum((1.0 - 2.0 * ((outputs['implicit_pathol_pred'] * outputs['implicit_pathol_orig']).sum(dim=[2, 3, 4]))
/ torch.clamp((outputs['implicit_pathol_pred'] + outputs['implicit_pathol_orig']).sum(dim=[2, 3, 4]), min=1e-5)))
else:
loss_implicit_pathol_dice = 0.
return {'loss_implicit_pathol_dice': loss_implicit_pathol_dice}
def loss_implicit_aux_pathol_ce(self, outputs, targets, samples):
"""
Cross entropy of pathology segmentation
"""
if 'implicit_aux_pathol_pred' in outputs:
#loss_implicit_aux_pathol_ce = torch.mean(-torch.sum(torch.log(torch.clamp(outputs['implicit_aux_pathol_pred'], min=1e-5)) * self.weights_ce * samples['pathol'], dim=1))
loss_implicit_aux_pathol_ce = torch.mean(-torch.sum(torch.log(torch.clamp(outputs['implicit_aux_pathol_pred'], min=1e-5)) * self.weights_ce * outputs['implicit_aux_pathol_orig'], dim=1))
else:
loss_implicit_aux_pathol_ce = 0.
return {'loss_implicit_aux_pathol_ce': loss_implicit_aux_pathol_ce}
def loss_implicit_aux_pathol_dice(self, outputs, targets, samples):
"""
Dice of pathology segmentation
"""
if 'implicit_aux_pathol_pred' in outputs:
#loss_implicit_aux_pathol_dice = torch.sum(self.weights_dice * (1.0 - 2.0 * ((outputs['implicit_aux_pathol_pred'] * samples['pathol']).sum(dim=[2, 3, 4]))
# / torch.clamp((outputs['implicit_aux_pathol_pred'] + samples['pathol']).sum(dim=[2, 3, 4]), min=1e-5)))
loss_implicit_aux_pathol_dice = torch.sum(self.weights_dice * (1.0 - 2.0 * ((outputs['implicit_aux_pathol_pred'] * outputs['implicit_aux_pathol_orig']).sum(dim=[2, 3, 4]))
/ torch.clamp((outputs['implicit_aux_pathol_pred'] + outputs['implicit_aux_pathol_orig']).sum(dim=[2, 3, 4]), min=1e-5)))
else:
loss_implicit_aux_pathol_dice = 0.
return {'loss_implicit_aux_pathol_dice': loss_implicit_aux_pathol_dice}
def loss_surface(self, outputs, targets, *kwargs):
return {'loss_surface': self.loss_image(outputs['surface'], targets['surface'])}
def loss_distance(self, outputs, targets, *kwargs):
return {'loss_distance': self.loss_image(outputs['distance'], targets['distance'])}
def loss_registration(self, outputs, targets, *kwargs):
return {'loss_registration': self.loss_image(outputs['registration'], targets['registration'])}
def loss_registration_grad(self, outputs, targets, *kwargs):
return {'loss_registration_grad': self.loss_image_grad(outputs['registration'], targets['registration'])}
def loss_registration_smooth(self, outputs, *kwargs):
return {'loss_registration_smooth': self.smoothness(outputs['registration'])}
def loss_registration_hessian(self, outputs, *kwargs):
return {'loss_registration_hessian': self.hessian(outputs['registration'])}
def loss_pathol_ce(self, outputs, targets, samples):
"""
Cross entropy of pathology segmentation
"""
if 'pathology' in outputs and outputs['pathology'].shape == targets['pathology'].shape:
loss_pathol_ce = torch.mean(-torch.sum(torch.log(torch.clamp(outputs['pathology'], min=1e-5)) * targets['pathology'], dim=1))
else:
loss_pathol_ce = 0.
return {'loss_pathol_ce': loss_pathol_ce}
def loss_pathol_dice(self, outputs, targets, samples):
"""
Dice of pathology segmentation
"""
if 'pathology' in outputs and outputs['pathology'].shape == targets['pathology'].shape:
loss_pathol_dice = torch.sum((1.0 - 2.0 * ((outputs['pathology'] * targets['pathology']).sum(dim=[2, 3, 4]))
/ torch.clamp((outputs['pathology'] + targets['pathology']).sum(dim=[2, 3, 4]), min=1e-5)))
else:
loss_pathol_dice = 0.
return {'loss_pathol_dice': loss_pathol_dice}
def loss_T1(self, outputs, targets, *kwargs):
#weights = 1. - targets['pathology'] if targets['pathology'].shape == targets['T1'].shape else 1.
weights = 1. - targets['T1_DM'] if 'T1_DM' in targets else 1.
#weights = 1.
return {'loss_T1': self.loss_image(outputs['T1'], targets['T1'], outputs['T1_sigma'] if 'T1_sigma' in outputs else None, weights = weights)}
def loss_T1_grad(self, outputs, targets, *kwargs):
#weights = 1. - targets['pathology'] if targets['pathology'].shape == targets['T1'].shape else 1.
weights = 1. - targets['T1_DM'] if 'T1_DM' in targets else 1.
#weights = 1.
return {'loss_T1_grad': self.loss_image_grad(outputs['T1'], targets['T1'], weights)}
def loss_T2(self, outputs, targets, *kwargs):
#weights = 1. - targets['pathology'] if targets['pathology'].shape == targets['T2'].shape else 1.
weights = 1. - targets['T2_DM'] if 'T2_DM' in targets else 1.
#weights = 1.
return {'loss_T2': self.loss_image(outputs['T2'], targets['T2'], outputs['T2_sigma'] if 'T2_sigma' in outputs else None, weights)}
def loss_T2_grad(self, outputs, targets, *kwargs):
#weights = 1. - targets['pathology'] if targets['pathology'].shape == targets['T2'].shape else 1.
weights = 1. - targets['T2_DM'] if 'T2_DM' in targets else 1.
#weights = 1.
return {'loss_T2_grad': self.loss_image_grad(outputs['T2'], targets['T2'], weights)}
def loss_FLAIR(self, outputs, targets, *kwargs):
#weights = 1. - targets['pathology'] if targets['pathology'].shape == targets['FLAIR'].shape else 1.
weights = 1. - targets['FLAIR_DM'] if 'FLAIR_DM' in targets else 1.
#weights = 1.
return {'loss_FLAIR': self.loss_image(outputs['FLAIR'], targets['FLAIR'], outputs['FLAIR_sigma'] if 'FLAIR_sigma' in outputs else None, weights)}
def loss_FLAIR_grad(self, outputs, targets, *kwargs):
#weights = 1. - targets['pathology'] if targets['pathology'].shape == targets['FLAIR'].shape else 1.
weights = 1. - targets['FLAIR_DM'] if 'FLAIR_DM' in targets else 1.
#weights = 1.
return {'loss_FLAIR_grad': self.loss_image_grad(outputs['FLAIR'], targets['FLAIR'], weights)}
def loss_CT(self, outputs, targets, *kwargs):
#weights = 1. - targets['pathology'] if targets['pathology'].shape == targets['CT'].shape else 1.
weights = 1. - targets['CT_DM'] if 'CT_DM' in targets else 1.
#weights = 1.
return {'loss_CT': self.loss_image(outputs['CT'], targets['CT'], outputs['CT_sigma'] if 'CT_sigma' in outputs else None, weights)}
def loss_CT_grad(self, outputs, targets, *kwargs):
#weights = 1. - targets['pathology'] if targets['pathology'].shape == targets['CT'].shape else 1.
weights = 1. - targets['CT_DM'] if 'CT_DM' in targets else 1.
#weights = 1.
return {'loss_CT_grad': self.loss_image_grad(outputs['CT'], targets['CT'], weights)}
def loss_SR(self, outputs, targets, samples):
loss_SR = self.loss_image(outputs['high_res_residual'], samples['high_res_residual'])
return {'loss_SR': loss_SR}
def loss_SR_grad(self, outputs, targets, samples):
loss_SR_grad = self.loss_image_grad(outputs['high_res_residual'], samples['high_res_residual'])
return {'loss_SR_grad': loss_SR_grad}
def loss_bias_field_log(self, outputs, targets, samples):
if 'bias_field_log' in samples:
bf_soft_mask = 1. - targets['segmentation'][:, 0]
loss_bias_field_log = self.bflog_loss(outputs['bias_field_log'] * bf_soft_mask, samples['bias_field_log'] * bf_soft_mask)
else:
loss_bias_field_log = 0.
return {'loss_bias_field_log': loss_bias_field_log}
def loss_age(self, outputs, targets, *kwargs):
loss_age = abs(outputs['age'] - targets['age'])
#print(outputs['age'].item(), outputs['age'].shape, targets['age'].item(), targets['age'].shape)
return {'loss_age': loss_age}
def loss_image(self, output, target, output_sigma = None, weights = 1., *kwargs):
if output.shape == target.shape:
if output_sigma:
loss_image = self.loss_regression(output, output_sigma, target)
else:
loss_image = self.loss_regression(output, target, weights)
else:
loss_image = 0.
return loss_image
def loss_image_grad(self, output, target, weights = 1., *kwargs):
return self.grad(output, target, weights) if output.shape == target.shape else 0.
def loss_supervised_seg(self, outputs, targets, *kwargs):
"""
Supervised segmentation differences (for dataset_name == synth)
"""
onehot_withoutcsf = targets['segmentation'].clone()
onehot_withoutcsf = onehot_withoutcsf[:, self.csf_v, ...]
onehot_withoutcsf[:, 0, :, :, :] = onehot_withoutcsf[:, 0, :, :, :] + targets['segmentation'][:, self.csf_ind, :, :, :]
loss_supervised_seg = torch.sum(self.weights_dice_sup * (1.0 - 2.0 * ((outputs['supervised_seg'] * onehot_withoutcsf).sum(dim=[2, 3, 4]))
/ torch.clamp((outputs['supervised_seg'] + onehot_withoutcsf).sum(dim=[2, 3, 4]), min=1e-5)))
return {'loss_supervised_seg': loss_supervised_seg}
def get_loss(self, loss_name, outputs, targets, *kwargs):
assert loss_name in self.loss_map, f'do you really want to compute {loss_name} loss?'
return self.loss_map[loss_name](outputs, targets, *kwargs)
def forward(self, outputs, targets, *kwargs):
""" This performs the loss computation.
Parameters:
outputs: dict of tensors, see the output specification of the model for the format
targets: list of dicts, such that len(targets) == batch_size.
The expected keys in each dict depends on the losses applied,
see each loss' doc
"""
# Compute all the requested losses
losses = {}
for loss_name in self.loss_names:
losses.update(self.get_loss(loss_name, outputs, targets, *kwargs))
return losses
class SetMultiCriterion(SetCriterion):
"""
This class computes the loss for BrainID with a list of results as inputs.
"""
def __init__(self, gen_args, train_args, weight_dict, loss_names, device):
""" Create the criterion.
Parameters:
args: general exp cfg
weight_dict: dict containing as key the names of the losses and as values their
relative weight.
loss_names: list of all the losses to be applied. See get_loss for list of
available loss_names.
"""
super(SetMultiCriterion, self).__init__(gen_args, train_args, weight_dict, loss_names, device)
self.all_samples = gen_args.generator.all_samples
def get_loss(self, loss_name, outputs_list, targets, samples_list):
assert loss_name in self.loss_map, f'do you really want to compute {loss_name} loss?'
total_loss = 0.
for i_sample, outputs in enumerate(outputs_list):
total_loss += self.loss_map[loss_name](outputs, targets, samples_list[i_sample])['loss_' + loss_name]
return {'loss_' + loss_name: total_loss / self.all_samples}
def forward(self, outputs_list, targets, samples_list):
""" This performs the loss computation.
Parameters:
outputs: dict of tensors, see the output specification of the model for the format
targets: list of dicts, such that len(targets) == batch_size.
The expected keys in each dict depends on the losses applied,
see each loss' doc
"""
# Compute all the requested losses
losses = {}
for loss_name in self.loss_names:
losses.update(self.get_loss(loss_name, outputs_list, targets, samples_list))
return losses