import torch import torch.nn.functional as F from TEED.utils.AF.Fsmish import smish as Fsmish def bdcn_loss2(inputs, targets, l_weight=1.1): # bdcn loss modified in DexiNed targets = targets.long() mask = targets.float() num_positive = torch.sum((mask > 0.0).float()).float() # >0.1 num_negative = torch.sum((mask <= 0.0).float()).float() # <= 0.1 mask[mask > 0.] = 1.0 * num_negative / (num_positive + num_negative) #0.1 mask[mask <= 0.] = 1.1 * num_positive / (num_positive + num_negative) # before mask[mask <= 0.1] inputs= torch.sigmoid(inputs) cost = torch.nn.BCELoss(mask, reduction='none')(inputs, targets.float()) cost = torch.sum(cost.float().mean((1, 2, 3))) # before sum return l_weight*cost # ------------ cats losses ---------- def bdrloss(prediction, label, radius,device='cpu'): ''' The boundary tracing loss that handles the confusing pixels. ''' filt = torch.ones(1, 1, 2*radius+1, 2*radius+1) filt.requires_grad = False filt = filt.to(device) bdr_pred = prediction * label pred_bdr_sum = label * F.conv2d(bdr_pred, filt, bias=None, stride=1, padding=radius) texture_mask = F.conv2d(label.float(), filt, bias=None, stride=1, padding=radius) mask = (texture_mask != 0).float() mask[label == 1] = 0 pred_texture_sum = F.conv2d(prediction * (1-label) * mask, filt, bias=None, stride=1, padding=radius) softmax_map = torch.clamp(pred_bdr_sum / (pred_texture_sum + pred_bdr_sum + 1e-10), 1e-10, 1 - 1e-10) cost = -label * torch.log(softmax_map) cost[label == 0] = 0 return torch.sum(cost.float().mean((1, 2, 3))) def textureloss(prediction, label, mask_radius, device='cpu'): ''' The texture suppression loss that smooths the texture regions. ''' filt1 = torch.ones(1, 1, 3, 3) filt1.requires_grad = False filt1 = filt1.to(device) filt2 = torch.ones(1, 1, 2*mask_radius+1, 2*mask_radius+1) filt2.requires_grad = False filt2 = filt2.to(device) pred_sums = F.conv2d(prediction.float(), filt1, bias=None, stride=1, padding=1) label_sums = F.conv2d(label.float(), filt2, bias=None, stride=1, padding=mask_radius) mask = 1 - torch.gt(label_sums, 0).float() loss = -torch.log(torch.clamp(1-pred_sums/9, 1e-10, 1-1e-10)) loss[mask == 0] = 0 return torch.sum(loss.float().mean((1, 2, 3))) def cats_loss(prediction, label, l_weight=[0.,0.], device='cpu'): # tracingLoss tex_factor,bdr_factor = l_weight balanced_w = 1.1 label = label.float() prediction = prediction.float() with torch.no_grad(): mask = label.clone() num_positive = torch.sum((mask == 1).float()).float() num_negative = torch.sum((mask == 0).float()).float() beta = num_negative / (num_positive + num_negative) mask[mask == 1] = beta mask[mask == 0] = balanced_w * (1 - beta) mask[mask == 2] = 0 prediction = torch.sigmoid(prediction) cost = torch.nn.functional.binary_cross_entropy( prediction.float(), label.float(), weight=mask, reduction='none') cost = torch.sum(cost.float().mean((1, 2, 3))) # by me label_w = (label != 0).float() textcost = textureloss(prediction.float(), label_w.float(), mask_radius=4, device=device) bdrcost = bdrloss(prediction.float(), label_w.float(), radius=4, device=device) return cost + bdr_factor * bdrcost + tex_factor * textcost