| | """ |
| | Criterion modules. |
| | """ |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| |
|
| | from Trainer.models.losses import GradientLoss, SmoothnessLoss, HessianLoss, gaussian_loss, laplace_loss, l1_loss |
| | from utils.misc import viewVolume |
| |
|
| | uncertainty_loss = {'gaussian': gaussian_loss, 'laplace': laplace_loss} |
| |
|
| |
|
| | class SetCriterion(nn.Module): |
| | """ |
| | This class computes the loss for BrainID. |
| | """ |
| | def __init__(self, gen_args, train_args, weight_dict, loss_names, device): |
| | """ Create the criterion. |
| | Parameters: |
| | args: general exp cfg |
| | weight_dict: dict containing as key the names of the losses and as values their |
| | relative weight. |
| | loss_names: list of all the losses to be applied. See get_loss for list of |
| | available loss_names. |
| | """ |
| | super(SetCriterion, self).__init__() |
| | self.gen_args = gen_args |
| | self.train_args = train_args |
| | self.weight_dict = weight_dict |
| | self.loss_names = loss_names |
| | |
| | self.mse = nn.MSELoss() |
| |
|
| | self.loss_regression_type = train_args.losses.uncertainty if train_args.losses.uncertainty is not None else 'l1' |
| | self.loss_regression = uncertainty_loss[train_args.losses.uncertainty] if train_args.losses.uncertainty is not None else l1_loss |
| | |
| | self.grad = GradientLoss('l1') |
| | self.smoothness = SmoothnessLoss('l2') |
| | self.hessian = HessianLoss('l2') |
| |
|
| | self.bflog_loss = nn.L1Loss() if train_args.losses.bias_field_log_type == 'l1' else self.mse |
| |
|
| | if 'contrastive' in self.loss_names: |
| | self.temp_alpha = train_args.contrastive_temperatures.alpha |
| | self.temp_beta = train_args.contrastive_temperatures.beta |
| | self.temp_gamma = train_args.contrastive_temperatures.gamma |
| | |
| | |
| | weights_brainseg = torch.ones(gen_args.n_labels).to(device) |
| | weights_brainseg[gen_args.label_list_segmentation_with_csf==77] = train_args.relative_weight_lesions |
| | weights_brainseg = weights_brainseg / torch.sum(weights_brainseg) |
| |
|
| | self.weights_ce = weights_brainseg[None, :, None, None, None] |
| | self.weights_dice = weights_brainseg[None, :] |
| |
|
| | |
| | |
| | |
| |
|
| | self.loss_map = { |
| | 'seg_ce': self.loss_seg_ce, |
| | 'seg_dice': self.loss_seg_dice, |
| | 'pathol_ce': self.loss_pathol_ce, |
| | 'pathol_dice': self.loss_pathol_dice, |
| | 'implicit_pathol_ce': self.loss_implicit_pathol_ce, |
| | 'implicit_pathol_dice': self.loss_implicit_pathol_dice, |
| | 'implicit_aux_pathol_ce': self.loss_implicit_aux_pathol_ce, |
| | 'implicit_aux_pathol_dice': self.loss_implicit_aux_pathol_dice, |
| |
|
| | 'T1': self.loss_T1, |
| | 'T1_grad': self.loss_T1_grad, |
| | 'T2': self.loss_T2, |
| | 'T2_grad': self.loss_T2_grad, |
| | 'FLAIR': self.loss_FLAIR, |
| | 'FLAIR_grad': self.loss_FLAIR_grad, |
| | 'CT': self.loss_CT, |
| | 'CT_grad': self.loss_CT_grad, |
| | 'SR': self.loss_SR, |
| | 'SR_grad': self.loss_SR_grad, |
| |
|
| | "age": self.loss_age, |
| | "distance": self.loss_distance, |
| | "registration": self.loss_registration, |
| | "registration_grad": self.loss_registration_grad, |
| | "registration_hessian": self.loss_registration_hessian, |
| | "registration_smooth": self.loss_registration_smooth, |
| | "bias_field_log": self.loss_bias_field_log, |
| | 'contrastive': self.loss_feat_contrastive, |
| |
|
| | "surface": self.loss_surface, |
| | |
| | } |
| |
|
| | def loss_feat_contrastive(self, outputs, *kwargs): |
| | """ |
| | outputs: [feat1, feat2] |
| | feat shape: (b, feat_dim, s, r, c) |
| | """ |
| | feat1, feat2 = outputs[0]['feat'][-1], outputs[1]['feat'][-1] |
| | num = torch.sum(torch.exp(feat1 * feat2 / self.temp_alpha), dim = 1) |
| | den = torch.zeros_like(feat1[:, 0]) |
| | for i in range(feat1.shape[1]): |
| | den1 = torch.exp(feat1[:, i] ** 2 / self.temp_beta) |
| | den2 = torch.exp((torch.sum(feat1[:, i][:, None] * feat1, dim = 1) - feat1[:, i] ** 2) / self.temp_gamma) |
| | den += den1 + den2 |
| | loss_contrastive = torch.mean(- torch.log(num / den)) |
| | return {'loss_contrastive': loss_contrastive} |
| |
|
| | def loss_seg_ce(self, outputs, targets, *kwargs): |
| | """ |
| | Cross entropy of segmentation |
| | """ |
| | loss_seg_ce = torch.mean(-torch.sum(torch.log(torch.clamp(outputs['segmentation'], min=1e-5)) * self.weights_ce * targets['segmentation'], dim=1)) |
| | return {'loss_seg_ce': loss_seg_ce} |
| |
|
| | def loss_seg_dice(self, outputs, targets, *kwargs): |
| | """ |
| | Dice of segmentation |
| | """ |
| | loss_seg_dice = torch.sum(self.weights_dice * (1.0 - 2.0 * ((outputs['segmentation'] * targets['segmentation']).sum(dim=[2, 3, 4])) |
| | / torch.clamp((outputs['segmentation'] + targets['segmentation']).sum(dim=[2, 3, 4]), min=1e-5))) |
| | return {'loss_seg_dice': loss_seg_dice} |
| | |
| | def loss_implicit_pathol_ce(self, outputs, targets, samples, *kwargs): |
| | """ |
| | Cross entropy of pathology segmentation |
| | """ |
| | if 'implicit_pathol_pred' in outputs: |
| | |
| | loss_implicit_pathol_ce = torch.mean(-torch.sum(torch.log(torch.clamp(outputs['implicit_pathol_pred'], min=1e-5)) * outputs['implicit_pathol_orig'], dim=1)) |
| | else: |
| | loss_implicit_pathol_ce = 0. |
| | return {'loss_implicit_pathol_ce': loss_implicit_pathol_ce} |
| | |
| | def loss_implicit_pathol_dice(self, outputs, targets, samples, *kwargs): |
| | """ |
| | Dice of pathology segmentation |
| | """ |
| | if 'implicit_pathol_pred' in outputs: |
| | |
| | |
| | loss_implicit_pathol_dice = torch.sum((1.0 - 2.0 * ((outputs['implicit_pathol_pred'] * outputs['implicit_pathol_orig']).sum(dim=[2, 3, 4])) |
| | / torch.clamp((outputs['implicit_pathol_pred'] + outputs['implicit_pathol_orig']).sum(dim=[2, 3, 4]), min=1e-5))) |
| | else: |
| | loss_implicit_pathol_dice = 0. |
| | return {'loss_implicit_pathol_dice': loss_implicit_pathol_dice} |
| |
|
| |
|
| | def loss_implicit_aux_pathol_ce(self, outputs, targets, samples): |
| | """ |
| | Cross entropy of pathology segmentation |
| | """ |
| | if 'implicit_aux_pathol_pred' in outputs: |
| | |
| | loss_implicit_aux_pathol_ce = torch.mean(-torch.sum(torch.log(torch.clamp(outputs['implicit_aux_pathol_pred'], min=1e-5)) * self.weights_ce * outputs['implicit_aux_pathol_orig'], dim=1)) |
| | else: |
| | loss_implicit_aux_pathol_ce = 0. |
| | return {'loss_implicit_aux_pathol_ce': loss_implicit_aux_pathol_ce} |
| | |
| | def loss_implicit_aux_pathol_dice(self, outputs, targets, samples): |
| | """ |
| | Dice of pathology segmentation |
| | """ |
| | if 'implicit_aux_pathol_pred' in outputs: |
| | |
| | |
| | loss_implicit_aux_pathol_dice = torch.sum(self.weights_dice * (1.0 - 2.0 * ((outputs['implicit_aux_pathol_pred'] * outputs['implicit_aux_pathol_orig']).sum(dim=[2, 3, 4])) |
| | / torch.clamp((outputs['implicit_aux_pathol_pred'] + outputs['implicit_aux_pathol_orig']).sum(dim=[2, 3, 4]), min=1e-5))) |
| | else: |
| | loss_implicit_aux_pathol_dice = 0. |
| | return {'loss_implicit_aux_pathol_dice': loss_implicit_aux_pathol_dice} |
| |
|
| | def loss_surface(self, outputs, targets, *kwargs): |
| | return {'loss_surface': self.loss_image(outputs['surface'], targets['surface'])} |
| | |
| | def loss_distance(self, outputs, targets, *kwargs): |
| | return {'loss_distance': self.loss_image(outputs['distance'], targets['distance'])} |
| | |
| | def loss_registration(self, outputs, targets, *kwargs): |
| | return {'loss_registration': self.loss_image(outputs['registration'], targets['registration'])} |
| | |
| | def loss_registration_grad(self, outputs, targets, *kwargs): |
| | return {'loss_registration_grad': self.loss_image_grad(outputs['registration'], targets['registration'])} |
| | |
| | def loss_registration_smooth(self, outputs, *kwargs): |
| | return {'loss_registration_smooth': self.smoothness(outputs['registration'])} |
| | |
| | def loss_registration_hessian(self, outputs, *kwargs): |
| | return {'loss_registration_hessian': self.hessian(outputs['registration'])} |
| | |
| | def loss_pathol_ce(self, outputs, targets, samples): |
| | """ |
| | Cross entropy of pathology segmentation |
| | """ |
| | if 'pathology' in outputs and outputs['pathology'].shape == targets['pathology'].shape: |
| | loss_pathol_ce = torch.mean(-torch.sum(torch.log(torch.clamp(outputs['pathology'], min=1e-5)) * targets['pathology'], dim=1)) |
| | else: |
| | loss_pathol_ce = 0. |
| | return {'loss_pathol_ce': loss_pathol_ce} |
| | |
| | def loss_pathol_dice(self, outputs, targets, samples): |
| | """ |
| | Dice of pathology segmentation |
| | """ |
| | if 'pathology' in outputs and outputs['pathology'].shape == targets['pathology'].shape: |
| | loss_pathol_dice = torch.sum((1.0 - 2.0 * ((outputs['pathology'] * targets['pathology']).sum(dim=[2, 3, 4])) |
| | / torch.clamp((outputs['pathology'] + targets['pathology']).sum(dim=[2, 3, 4]), min=1e-5))) |
| | else: |
| | loss_pathol_dice = 0. |
| | return {'loss_pathol_dice': loss_pathol_dice} |
| | |
| |
|
| | def loss_T1(self, outputs, targets, *kwargs): |
| | |
| | weights = 1. - targets['T1_DM'] if 'T1_DM' in targets else 1. |
| | |
| | return {'loss_T1': self.loss_image(outputs['T1'], targets['T1'], outputs['T1_sigma'] if 'T1_sigma' in outputs else None, weights = weights)} |
| | def loss_T1_grad(self, outputs, targets, *kwargs): |
| | |
| | weights = 1. - targets['T1_DM'] if 'T1_DM' in targets else 1. |
| | |
| | return {'loss_T1_grad': self.loss_image_grad(outputs['T1'], targets['T1'], weights)} |
| | |
| | def loss_T2(self, outputs, targets, *kwargs): |
| | |
| | weights = 1. - targets['T2_DM'] if 'T2_DM' in targets else 1. |
| | |
| | return {'loss_T2': self.loss_image(outputs['T2'], targets['T2'], outputs['T2_sigma'] if 'T2_sigma' in outputs else None, weights)} |
| | def loss_T2_grad(self, outputs, targets, *kwargs): |
| | |
| | weights = 1. - targets['T2_DM'] if 'T2_DM' in targets else 1. |
| | |
| | return {'loss_T2_grad': self.loss_image_grad(outputs['T2'], targets['T2'], weights)} |
| | |
| | def loss_FLAIR(self, outputs, targets, *kwargs): |
| | |
| | weights = 1. - targets['FLAIR_DM'] if 'FLAIR_DM' in targets else 1. |
| | |
| | return {'loss_FLAIR': self.loss_image(outputs['FLAIR'], targets['FLAIR'], outputs['FLAIR_sigma'] if 'FLAIR_sigma' in outputs else None, weights)} |
| | def loss_FLAIR_grad(self, outputs, targets, *kwargs): |
| | |
| | weights = 1. - targets['FLAIR_DM'] if 'FLAIR_DM' in targets else 1. |
| | |
| | return {'loss_FLAIR_grad': self.loss_image_grad(outputs['FLAIR'], targets['FLAIR'], weights)} |
| | |
| | def loss_CT(self, outputs, targets, *kwargs): |
| | |
| | weights = 1. - targets['CT_DM'] if 'CT_DM' in targets else 1. |
| | |
| | return {'loss_CT': self.loss_image(outputs['CT'], targets['CT'], outputs['CT_sigma'] if 'CT_sigma' in outputs else None, weights)} |
| | def loss_CT_grad(self, outputs, targets, *kwargs): |
| | |
| | weights = 1. - targets['CT_DM'] if 'CT_DM' in targets else 1. |
| | |
| | return {'loss_CT_grad': self.loss_image_grad(outputs['CT'], targets['CT'], weights)} |
| | |
| | def loss_SR(self, outputs, targets, samples): |
| | loss_SR = self.loss_image(outputs['high_res_residual'], samples['high_res_residual']) |
| | return {'loss_SR': loss_SR} |
| | |
| | def loss_SR_grad(self, outputs, targets, samples): |
| | loss_SR_grad = self.loss_image_grad(outputs['high_res_residual'], samples['high_res_residual']) |
| | return {'loss_SR_grad': loss_SR_grad} |
| | |
| | def loss_bias_field_log(self, outputs, targets, samples): |
| | if 'bias_field_log' in samples: |
| | bf_soft_mask = 1. - targets['segmentation'][:, 0] |
| | loss_bias_field_log = self.bflog_loss(outputs['bias_field_log'] * bf_soft_mask, samples['bias_field_log'] * bf_soft_mask) |
| | else: |
| | loss_bias_field_log = 0. |
| | return {'loss_bias_field_log': loss_bias_field_log} |
| | |
| | |
| | def loss_age(self, outputs, targets, *kwargs): |
| | loss_age = abs(outputs['age'] - targets['age']) |
| | |
| | return {'loss_age': loss_age} |
| | |
| |
|
| | def loss_image(self, output, target, output_sigma = None, weights = 1., *kwargs): |
| | if output.shape == target.shape: |
| | if output_sigma: |
| | loss_image = self.loss_regression(output, output_sigma, target) |
| | else: |
| | loss_image = self.loss_regression(output, target, weights) |
| | else: |
| | loss_image = 0. |
| | return loss_image |
| | |
| | def loss_image_grad(self, output, target, weights = 1., *kwargs): |
| | return self.grad(output, target, weights) if output.shape == target.shape else 0. |
| |
|
| | |
| | def loss_supervised_seg(self, outputs, targets, *kwargs): |
| | """ |
| | Supervised segmentation differences (for dataset_name == synth) |
| | """ |
| | onehot_withoutcsf = targets['segmentation'].clone() |
| | onehot_withoutcsf = onehot_withoutcsf[:, self.csf_v, ...] |
| | onehot_withoutcsf[:, 0, :, :, :] = onehot_withoutcsf[:, 0, :, :, :] + targets['segmentation'][:, self.csf_ind, :, :, :] |
| |
|
| | loss_supervised_seg = torch.sum(self.weights_dice_sup * (1.0 - 2.0 * ((outputs['supervised_seg'] * onehot_withoutcsf).sum(dim=[2, 3, 4])) |
| | / torch.clamp((outputs['supervised_seg'] + onehot_withoutcsf).sum(dim=[2, 3, 4]), min=1e-5))) |
| |
|
| | return {'loss_supervised_seg': loss_supervised_seg} |
| |
|
| | def get_loss(self, loss_name, outputs, targets, *kwargs): |
| | assert loss_name in self.loss_map, f'do you really want to compute {loss_name} loss?' |
| | return self.loss_map[loss_name](outputs, targets, *kwargs) |
| |
|
| | def forward(self, outputs, targets, *kwargs): |
| | """ This performs the loss computation. |
| | Parameters: |
| | outputs: dict of tensors, see the output specification of the model for the format |
| | targets: list of dicts, such that len(targets) == batch_size. |
| | The expected keys in each dict depends on the losses applied, |
| | see each loss' doc |
| | """ |
| | |
| | losses = {} |
| | for loss_name in self.loss_names: |
| | losses.update(self.get_loss(loss_name, outputs, targets, *kwargs)) |
| | return losses |
| | |
| |
|
| |
|
| | class SetMultiCriterion(SetCriterion): |
| | """ |
| | This class computes the loss for BrainID with a list of results as inputs. |
| | """ |
| | def __init__(self, gen_args, train_args, weight_dict, loss_names, device): |
| | """ Create the criterion. |
| | Parameters: |
| | args: general exp cfg |
| | weight_dict: dict containing as key the names of the losses and as values their |
| | relative weight. |
| | loss_names: list of all the losses to be applied. See get_loss for list of |
| | available loss_names. |
| | """ |
| | super(SetMultiCriterion, self).__init__(gen_args, train_args, weight_dict, loss_names, device) |
| | self.all_samples = gen_args.generator.all_samples |
| |
|
| | def get_loss(self, loss_name, outputs_list, targets, samples_list): |
| | assert loss_name in self.loss_map, f'do you really want to compute {loss_name} loss?' |
| | total_loss = 0. |
| | for i_sample, outputs in enumerate(outputs_list): |
| | total_loss += self.loss_map[loss_name](outputs, targets, samples_list[i_sample])['loss_' + loss_name] |
| | return {'loss_' + loss_name: total_loss / self.all_samples} |
| | |
| | def forward(self, outputs_list, targets, samples_list): |
| | """ This performs the loss computation. |
| | Parameters: |
| | outputs: dict of tensors, see the output specification of the model for the format |
| | targets: list of dicts, such that len(targets) == batch_size. |
| | The expected keys in each dict depends on the losses applied, |
| | see each loss' doc |
| | """ |
| | |
| | losses = {} |
| | for loss_name in self.loss_names: |
| | losses.update(self.get_loss(loss_name, outputs_list, targets, samples_list)) |
| | return losses |
| |
|
| |
|