File size: 2,867 Bytes
3cc53ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
Metrics for segmentation evaluation.
"""
import numpy as np


def compute_iou(pred, gt):
    """Intersection over Union (Jaccard Index)."""
    pred = pred.astype(bool)
    gt = gt.astype(bool)
    intersection = np.logical_and(pred, gt).sum()
    union = np.logical_or(pred, gt).sum()
    if union == 0:
        return 1.0 if intersection == 0 else 0.0
    return float(intersection / union)


def compute_dice(pred, gt):
    """Dice score / F1 score."""
    pred = pred.astype(bool)
    gt = gt.astype(bool)
    intersection = np.logical_and(pred, gt).sum()
    total = pred.sum() + gt.sum()
    if total == 0:
        return 1.0
    return float(2.0 * intersection / total)


def compute_pixel_accuracy(pred, gt):
    """Pixel accuracy."""
    pred = pred.astype(bool)
    gt = gt.astype(bool)
    correct = (pred == gt).sum()
    total = pred.size
    return float(correct / total)


def compute_precision_recall(pred, gt):
    """Precision and recall."""
    pred = pred.astype(bool)
    gt = gt.astype(bool)
    tp = np.logical_and(pred, gt).sum()
    fp = np.logical_and(pred, ~gt).sum()
    fn = np.logical_and(~pred, gt).sum()

    precision = float(tp / (tp + fp)) if (tp + fp) > 0 else 0.0
    recall = float(tp / (tp + fn)) if (tp + fn) > 0 else 0.0
    return precision, recall


def compute_fpr_fnr(pred, gt):
    """False positive rate and false negative rate."""
    pred = pred.astype(bool)
    gt = gt.astype(bool)
    tn = np.logical_and(~pred, ~gt).sum()
    fp = np.logical_and(pred, ~gt).sum()
    fn = np.logical_and(~pred, gt).sum()
    tp = np.logical_and(pred, gt).sum()

    fpr = float(fp / (fp + tn)) if (fp + tn) > 0 else 0.0
    fnr = float(fn / (fn + tp)) if (fn + tp) > 0 else 0.0
    return fpr, fnr


def compute_all_metrics(pred_batch, gt_batch):
    """
    pred_batch: (N, H, W) binary/float array
    gt_batch: (N, H, W) binary/float array
    """
    pred_batch = pred_batch > 0.5
    gt_batch = gt_batch > 0.5

    ious = []
    dices = []
    accs = []
    precisions = []
    recalls = []
    fprs = []
    fnrs = []

    N = pred_batch.shape[0]
    for i in range(N):
        pred = pred_batch[i]
        gt = gt_batch[i]

        ious.append(compute_iou(pred, gt))
        dices.append(compute_dice(pred, gt))
        accs.append(compute_pixel_accuracy(pred, gt))
        p, r = compute_precision_recall(pred, gt)
        precisions.append(p)
        recalls.append(r)
        fpr, fnr = compute_fpr_fnr(pred, gt)
        fprs.append(fpr)
        fnrs.append(fnr)

    return {
        "iou": float(np.mean(ious)),
        "miou": float(np.mean(ious)),
        "dice_score": float(np.mean(dices)),
        "pixel_accuracy": float(np.mean(accs)),
        "precision": float(np.mean(precisions)),
        "recall": float(np.mean(recalls)),
        "fpr": float(np.mean(fprs)),
        "fnr": float(np.mean(fnrs)),
    }