| import torch | |
| import torch.nn as nn | |
| class CoordLoss(nn.Module): | |
| def __init__(self): | |
| super(CoordLoss, self).__init__() | |
| def forward(self, coord_out, coord_gt, valid, is_3D=None): | |
| loss = torch.abs(coord_out - coord_gt) * valid | |
| if is_3D is not None: | |
| loss_z = loss[:,:,2:] * is_3D[:,None,None].float() | |
| loss = torch.cat((loss[:,:,:2], loss_z),2) | |
| return loss | |
| class ParamLoss(nn.Module): | |
| def __init__(self): | |
| super(ParamLoss, self).__init__() | |
| def forward(self, param_out, param_gt, valid): | |
| loss = torch.abs(param_out - param_gt) * valid | |
| return loss | |
| class CELoss(nn.Module): | |
| def __init__(self): | |
| super(CELoss, self).__init__() | |
| self.ce_loss = nn.CrossEntropyLoss(reduction='none') | |
| def forward(self, out, gt_index): | |
| loss = self.ce_loss(out, gt_index) | |
| return loss | |