| import torch | |
| import torch.nn as nn | |
| class ConfidenceLoss(nn.Module): | |
| """ | |
| confidence loss. | |
| """ | |
| def __init__(self, loss_weight=1, data_type=['stereo', 'lidar', 'denselidar'], **kwargs): | |
| super(ConfidenceLoss, self).__init__() | |
| self.loss_weight = loss_weight | |
| self.data_type = data_type | |
| self.eps = 1e-6 | |
| def forward(self, prediction, target, confidence, mask=None, **kwargs): | |
| conf_mask = torch.abs(target - prediction) < target | |
| conf_mask = conf_mask & mask | |
| gt_confidence = (1 - torch.abs((prediction - target) / target)) * conf_mask | |
| loss = torch.sum(torch.abs(confidence - gt_confidence) * conf_mask) / (torch.sum(conf_mask) + self.eps) | |
| if torch.isnan(loss).item() | torch.isinf(loss).item(): | |
| loss = 0 * torch.sum(confidence) | |
| print(f'ConfidenceLoss NAN error, {loss}') | |
| return loss * self.loss_weight |