Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| import monai.metrics as metrics | |
| from common.constants import DIST_MATRIX_PATH | |
| DIST_MATRIX = np.load(DIST_MATRIX_PATH) | |
| def metric(mask, pred, back=True): | |
| iou = metrics.compute_meaniou(pred, mask, back, False) | |
| iou = iou.mean() | |
| return iou | |
| def precision_recall_f1score(gt, pred): | |
| """ | |
| Compute precision, recall, and f1 | |
| """ | |
| # gt = gt.numpy() | |
| # pred = pred.numpy() | |
| precision = torch.zeros(gt.shape[0]) | |
| recall = torch.zeros(gt.shape[0]) | |
| f1 = torch.zeros(gt.shape[0]) | |
| for b in range(gt.shape[0]): | |
| tp_num = gt[b, pred[b, :] >= 0.5].sum() | |
| precision_denominator = (pred[b, :] >= 0.5).sum() | |
| recall_denominator = (gt[b, :]).sum() | |
| precision_ = tp_num / precision_denominator | |
| recall_ = tp_num / recall_denominator | |
| if precision_denominator == 0: # if no pred | |
| precision_ = 1. | |
| recall_ = 0. | |
| f1_ = 0. | |
| elif recall_denominator == 0: # if no GT | |
| precision_ = 0. | |
| recall_ = 1. | |
| f1_ = 0. | |
| elif (precision_ + recall_) <= 1e-10: # to avoid precision issues | |
| precision_= 0. | |
| recall_= 0. | |
| f1_ = 0. | |
| else: | |
| f1_ = 2 * precision_ * recall_ / (precision_ + recall_) | |
| precision[b] = precision_ | |
| recall[b] = recall_ | |
| f1[b] = f1_ | |
| # return precision, recall, f1 | |
| return precision, recall, f1 | |
| def acc_precision_recall_f1score(gt, pred): | |
| """ | |
| Compute acc, precision, recall, and f1 | |
| """ | |
| # gt = gt.numpy() | |
| # pred = pred.numpy() | |
| acc = torch.zeros(gt.shape[0]) | |
| precision = torch.zeros(gt.shape[0]) | |
| recall = torch.zeros(gt.shape[0]) | |
| f1 = torch.zeros(gt.shape[0]) | |
| for b in range(gt.shape[0]): | |
| tp_num = gt[b, pred[b, :] >= 0.5].sum() | |
| precision_denominator = (pred[b, :] >= 0.5).sum() | |
| recall_denominator = (gt[b, :]).sum() | |
| tn_num = gt.shape[-1] - precision_denominator - recall_denominator + tp_num | |
| acc_ = (tp_num + tn_num) / gt.shape[-1] | |
| precision_ = tp_num / (precision_denominator + 1e-10) | |
| recall_ = tp_num / (recall_denominator + 1e-10) | |
| f1_ = 2 * precision_ * recall_ / (precision_ + recall_ + 1e-10) | |
| acc[b] = acc_ | |
| precision[b] = precision_ | |
| recall[b] = recall_ | |
| # return precision, recall, f1 | |
| return acc, precision, recall, f1 | |
| def det_error_metric(pred, gt): | |
| gt = gt.detach().cpu() | |
| pred = pred.detach().cpu() | |
| dist_matrix = torch.tensor(DIST_MATRIX) | |
| false_positive_dist = torch.zeros(gt.shape[0]) | |
| false_negative_dist = torch.zeros(gt.shape[0]) | |
| for b in range(gt.shape[0]): | |
| gt_columns = dist_matrix[:, gt[b, :]==1] if any(gt[b, :]==1) else dist_matrix | |
| error_matrix = gt_columns[pred[b, :] >= 0.5, :] if any(pred[b, :] >= 0.5) else gt_columns | |
| false_positive_dist_ = error_matrix.min(dim=1)[0].mean() | |
| false_negative_dist_ = error_matrix.min(dim=0)[0].mean() | |
| false_positive_dist[b] = false_positive_dist_ | |
| false_negative_dist[b] = false_negative_dist_ | |
| return false_positive_dist, false_negative_dist |