| | import torch |
| | from torch import nn |
| | import torch.nn.functional as F |
| | from utils.geometry_utils import edge_acc |
| |
|
| |
|
| | class CornerCriterion(nn.Module): |
| | def __init__(self, image_size): |
| | super().__init__() |
| | self.loss_rate = 9 |
| |
|
| | def forward(self, outputs_s1, targets, gauss_targets, epoch=0): |
| | |
| | preds_s1 = (outputs_s1 >= 0.5).float() |
| | pos_target_ids = torch.where(targets == 1) |
| | correct = (preds_s1[pos_target_ids] == targets[pos_target_ids]).float().sum() |
| | recall_s1 = correct / len(pos_target_ids[0]) |
| |
|
| | rate = self.loss_rate |
| |
|
| | loss_weight = (gauss_targets > 0.5).float() * rate + 1 |
| | loss_s1 = F.binary_cross_entropy(outputs_s1, gauss_targets, weight=loss_weight, reduction='none') |
| | loss_s1 = loss_s1.sum(-1).sum(-1).mean() |
| |
|
| | return loss_s1, recall_s1 |
| |
|
| |
|
| | class EdgeCriterion(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.edge_loss = nn.CrossEntropyLoss(weight=torch.tensor([0.33, 1.0]).cuda(), reduction='none') |
| |
|
| | def forward(self, logits_s1, logits_s2_hybrid, logits_s2_rel, s2_ids, s2_edge_mask, edge_labels, edge_lengths, |
| | edge_mask, s2_gt_values): |
| | |
| | s1_losses = self.edge_loss(logits_s1, edge_labels) |
| | s1_losses[torch.where(edge_mask == True)] = 0 |
| | s1_losses = s1_losses[torch.where(s1_losses > 0)].sum() / edge_mask.shape[0] |
| | gt_values = torch.ones_like(edge_mask).long() * 2 |
| | s1_acc = edge_acc(logits_s1, edge_labels, edge_lengths, gt_values) |
| |
|
| | |
| | s2_labels = torch.gather(edge_labels, 1, s2_ids) |
| |
|
| | |
| | s2_losses_hybrid = self.edge_loss(logits_s2_hybrid, s2_labels) |
| | s2_losses_hybrid[torch.where((s2_edge_mask == True) | (s2_gt_values != 2))] = 0 |
| | |
| | s2_losses_hybrid = s2_losses_hybrid[torch.where(s2_losses_hybrid > 0)].sum() / s2_edge_mask.shape[0] |
| | s2_edge_lengths = (s2_edge_mask == 0).sum(dim=-1) |
| | |
| | s2_acc_hybrid = edge_acc(logits_s2_hybrid, s2_labels, s2_edge_lengths, s2_gt_values) |
| |
|
| | |
| | s2_losses_rel = self.edge_loss(logits_s2_rel, s2_labels) |
| | s2_losses_rel[torch.where((s2_edge_mask == True) | (s2_gt_values != 2))] = 0 |
| | |
| | s2_losses_rel = s2_losses_rel[torch.where(s2_losses_rel > 0)].sum() / s2_edge_mask.shape[0] |
| | s2_edge_lengths = (s2_edge_mask == 0).sum(dim=-1) |
| | |
| | s2_acc_rel = edge_acc(logits_s2_rel, s2_labels, s2_edge_lengths, s2_gt_values) |
| |
|
| | return s1_losses, s1_acc, s2_losses_hybrid, s2_acc_hybrid, s2_losses_rel, s2_acc_rel |
| |
|