| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class DepthLoss(nn.Module): |
| def __init__(self, type='l1'): |
| super(DepthLoss, self).__init__() |
| self.type = type |
|
|
|
|
| def forward(self, depth_pred, depth_gt, mask=None): |
| if (depth_gt < 0).sum() > 0: |
| |
| return torch.tensor(0.0).to(depth_pred.device) |
| if mask is not None: |
| mask_d = (depth_gt > 0).float() |
|
|
| mask = mask * mask_d |
|
|
| mask_sum = mask.sum() + 1e-5 |
| depth_error = (depth_pred - depth_gt) * mask |
| depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device), |
| reduction='sum') / mask_sum |
| else: |
| depth_error = depth_pred - depth_gt |
| depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device), |
| reduction='mean') |
| return depth_loss |
|
|
| def forward(self, depth_pred, depth_gt, mask=None): |
| if mask is not None: |
| mask_d = (depth_gt > 0).float() |
|
|
| mask = mask * mask_d |
|
|
| mask_sum = mask.sum() + 1e-5 |
| depth_error = (depth_pred - depth_gt) * mask |
| depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device), |
| reduction='sum') / mask_sum |
| else: |
| depth_error = depth_pred - depth_gt |
| depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device), |
| reduction='mean') |
| return depth_loss |
|
|
| class DepthSmoothLoss(nn.Module): |
| def __init__(self): |
| super(DepthSmoothLoss, self).__init__() |
|
|
| def forward(self, disp, img, mask): |
| """ |
| Computes the smoothness loss for a disparity image |
| The color image is used for edge-aware smoothness |
| :param disp: [B, 1, H, W] |
| :param img: [B, 1, H, W] |
| :param mask: [B, 1, H, W] |
| :return: |
| """ |
| grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:]) |
| grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :]) |
|
|
| grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True) |
| grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True) |
|
|
| grad_disp_x *= torch.exp(-grad_img_x) |
| grad_disp_y *= torch.exp(-grad_img_y) |
|
|
| grad_disp = (grad_disp_x * mask[:, :, :, :-1]).mean() + (grad_disp_y * mask[:, :, :-1, :]).mean() |
|
|
| return grad_disp |
|
|