Spaces:
Build error
Build error
| # code from: https://github.com/chaneyddtt/Coarse-to-fine-3D-Animal/blob/main/util/loss_sdf.py | |
| import torch | |
| import numpy as np | |
| from scipy.ndimage import distance_transform_edt as distance | |
| from skimage import segmentation as skimage_seg | |
| import matplotlib.pyplot as plt | |
| def dice_loss(score, target): | |
| # implemented from paper https://arxiv.org/pdf/1606.04797.pdf | |
| target = target.float() | |
| smooth = 1e-5 | |
| intersect = torch.sum(score * target) | |
| y_sum = torch.sum(target * target) | |
| z_sum = torch.sum(score * score) | |
| loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) | |
| loss = 1 - loss | |
| return loss | |
| class tversky_loss(torch.nn.Module): | |
| # implemented from https://arxiv.org/pdf/1706.05721.pdf | |
| def __init__(self, alpha, beta): | |
| ''' | |
| Args: | |
| alpha: coefficient for false positive prediction | |
| beta: coefficient for false negtive prediction | |
| ''' | |
| super(tversky_loss, self).__init__() | |
| self.alpha = alpha | |
| self.beta = beta | |
| def __call__(self, score, target): | |
| target = target.float() | |
| smooth = 1e-5 | |
| tp = torch.sum(score * target) | |
| fn = torch.sum(target * (1 - score)) | |
| fp = torch.sum((1-target) * score) | |
| loss = (tp + smooth) / (tp + self.alpha * fp + self.beta * fn + smooth) | |
| loss = 1 - loss | |
| return loss | |
| def compute_sdf1_1(img_gt, out_shape): | |
| """ | |
| compute the normalized signed distance map of binary mask | |
| input: segmentation, shape = (batch_size, x, y, z) | |
| output: the Signed Distance Map (SDM) | |
| sdf(x) = 0; x in segmentation boundary | |
| -inf|x-y|; x in segmentation | |
| +inf|x-y|; x out of segmentation | |
| normalize sdf to [-1, 1] | |
| """ | |
| img_gt = img_gt.astype(np.uint8) | |
| normalized_sdf = np.zeros(out_shape) | |
| for b in range(out_shape[0]): # batch size | |
| # ignore background | |
| for c in range(1, out_shape[1]): | |
| posmask = img_gt[b] | |
| negmask = 1-posmask | |
| posdis = distance(posmask) | |
| negdis = distance(negmask) | |
| boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8) | |
| sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) - (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis)) | |
| sdf[boundary==1] = 0 | |
| normalized_sdf[b][c] = sdf | |
| assert np.min(sdf) == -1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis)) | |
| assert np.max(sdf) == 1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis)) | |
| return normalized_sdf | |
| def compute_sdf(img_gt, out_shape): | |
| """ | |
| compute the signed distance map of binary mask | |
| input: segmentation, shape = (batch_size, x, y, z) | |
| output: the Signed Distance Map (SDM) | |
| sdf(x) = 0; x in segmentation boundary | |
| -inf|x-y|; x in segmentation | |
| +inf|x-y|; x out of segmentation | |
| """ | |
| img_gt = img_gt.astype(np.uint8) | |
| gt_sdf = np.zeros(out_shape) | |
| debug = False | |
| for b in range(out_shape[0]): # batch size | |
| for c in range(0, out_shape[1]): | |
| posmask = img_gt[b] | |
| negmask = 1-posmask | |
| posdis = distance(posmask) | |
| negdis = distance(negmask) | |
| boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8) | |
| sdf = negdis - posdis | |
| sdf[boundary==1] = 0 | |
| gt_sdf[b][c] = sdf | |
| if debug: | |
| plt.figure() | |
| plt.subplot(1, 2, 1), plt.imshow(img_gt[b, 0, :, :]), plt.colorbar() | |
| plt.subplot(1, 2, 2), plt.imshow(gt_sdf[b, 0, :, :]), plt.colorbar() | |
| plt.show() | |
| return gt_sdf | |
| def boundary_loss(output, gt): | |
| """ | |
| compute boundary loss for binary segmentation | |
| input: outputs_soft: softmax results, shape=(b,2,x,y,z) | |
| gt_sdf: sdf of ground truth (can be original or normalized sdf); shape=(b,2,x,y,z) | |
| output: boundary_loss; sclar | |
| adopted from http://proceedings.mlr.press/v102/kervadec19a/kervadec19a.pdf | |
| """ | |
| multipled = torch.einsum('bcxy, bcxy->bcxy', output, gt) | |
| bd_loss = multipled.mean() | |
| return bd_loss |