| """ |
| Metrics for segmentation evaluation. |
| """ |
| import numpy as np |
|
|
|
|
| def compute_iou(pred, gt): |
| """Intersection over Union (Jaccard Index).""" |
| pred = pred.astype(bool) |
| gt = gt.astype(bool) |
| intersection = np.logical_and(pred, gt).sum() |
| union = np.logical_or(pred, gt).sum() |
| if union == 0: |
| return 1.0 if intersection == 0 else 0.0 |
| return float(intersection / union) |
|
|
|
|
| def compute_dice(pred, gt): |
| """Dice score / F1 score.""" |
| pred = pred.astype(bool) |
| gt = gt.astype(bool) |
| intersection = np.logical_and(pred, gt).sum() |
| total = pred.sum() + gt.sum() |
| if total == 0: |
| return 1.0 |
| return float(2.0 * intersection / total) |
|
|
|
|
| def compute_pixel_accuracy(pred, gt): |
| """Pixel accuracy.""" |
| pred = pred.astype(bool) |
| gt = gt.astype(bool) |
| correct = (pred == gt).sum() |
| total = pred.size |
| return float(correct / total) |
|
|
|
|
| def compute_precision_recall(pred, gt): |
| """Precision and recall.""" |
| pred = pred.astype(bool) |
| gt = gt.astype(bool) |
| tp = np.logical_and(pred, gt).sum() |
| fp = np.logical_and(pred, ~gt).sum() |
| fn = np.logical_and(~pred, gt).sum() |
|
|
| precision = float(tp / (tp + fp)) if (tp + fp) > 0 else 0.0 |
| recall = float(tp / (tp + fn)) if (tp + fn) > 0 else 0.0 |
| return precision, recall |
|
|
|
|
| def compute_fpr_fnr(pred, gt): |
| """False positive rate and false negative rate.""" |
| pred = pred.astype(bool) |
| gt = gt.astype(bool) |
| tn = np.logical_and(~pred, ~gt).sum() |
| fp = np.logical_and(pred, ~gt).sum() |
| fn = np.logical_and(~pred, gt).sum() |
| tp = np.logical_and(pred, gt).sum() |
|
|
| fpr = float(fp / (fp + tn)) if (fp + tn) > 0 else 0.0 |
| fnr = float(fn / (fn + tp)) if (fn + tp) > 0 else 0.0 |
| return fpr, fnr |
|
|
|
|
| def compute_all_metrics(pred_batch, gt_batch): |
| """ |
| pred_batch: (N, H, W) binary/float array |
| gt_batch: (N, H, W) binary/float array |
| """ |
| pred_batch = pred_batch > 0.5 |
| gt_batch = gt_batch > 0.5 |
|
|
| ious = [] |
| dices = [] |
| accs = [] |
| precisions = [] |
| recalls = [] |
| fprs = [] |
| fnrs = [] |
|
|
| N = pred_batch.shape[0] |
| for i in range(N): |
| pred = pred_batch[i] |
| gt = gt_batch[i] |
|
|
| ious.append(compute_iou(pred, gt)) |
| dices.append(compute_dice(pred, gt)) |
| accs.append(compute_pixel_accuracy(pred, gt)) |
| p, r = compute_precision_recall(pred, gt) |
| precisions.append(p) |
| recalls.append(r) |
| fpr, fnr = compute_fpr_fnr(pred, gt) |
| fprs.append(fpr) |
| fnrs.append(fnr) |
|
|
| return { |
| "iou": float(np.mean(ious)), |
| "miou": float(np.mean(ious)), |
| "dice_score": float(np.mean(dices)), |
| "pixel_accuracy": float(np.mean(accs)), |
| "precision": float(np.mean(precisions)), |
| "recall": float(np.mean(recalls)), |
| "fpr": float(np.mean(fprs)), |
| "fnr": float(np.mean(fnrs)), |
| } |
|
|