File size: 4,547 Bytes
840ef2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import numpy as np

def calcuate_confusion_matrix(num_class:int, gt:torch.tensor, pred:torch.tensor):
    gt_vector = gt.flatten()
    pred_vector = pred.flatten()
    mask = (gt_vector >= 0) & (gt_vector < num_class)
    cm = torch.bincount(num_class * gt_vector[mask].to(dtype=int) + pred_vector[mask], minlength=num_class ** 2).reshape(num_class, num_class)
    return cm

class segmengtion_metric(object):
    def __init__(self, num_class:int, device:str):
        self.num_class = num_class
        self.device = device
        self.confusion_matrix = torch.zeros((self.num_class, self.num_class)).to(self.device)


    def clear(self):
        self.confusion_matrix = torch.zeros((self.num_class, self.num_class)).to(self.device)


    def update_confusion_matrix(self, gt, pred):
        cm = calcuate_confusion_matrix(self.num_class, gt, pred)
        self.confusion_matrix += cm

    def get_matrix_per_batch(self, gt, pred):
        confusion_matrix = calcuate_confusion_matrix(self.num_class, gt, pred)

        tp = torch.diag(confusion_matrix)

        sum_a1 = torch.sum(confusion_matrix, dim=1)

        sum_a0 = torch.sum(confusion_matrix, dim=0)

        acc = tp.sum() / (confusion_matrix.sum() + torch.finfo(type=torch.float32).eps)
        recall = tp / (sum_a1 + torch.finfo(type=torch.float32).eps)
        precision = tp / (sum_a0 + torch.finfo(type=torch.float32).eps)
        f1 = (2 * recall * precision) / (recall + precision + torch.finfo(type=torch.float32).eps)
        iou = tp / (sum_a1 + sum_a0 - tp + torch.finfo(type=torch.float32).eps)

        cls_precision = dict(zip(['pre_class[{}]'.format(i) for i in range(self.num_class)], precision))
        cls_recall = dict(zip(['rec_class[{}]'.format(i) for i in range(self.num_class)], recall))
        cls_f1 = dict(zip(['f1_class[{}]'.format(i) for i in range(self.num_class)], f1))
        cls_iou = dict(zip(['iou_class[{}]'.format(i) for i in range(self.num_class)], iou))

        mean_precision = precision[precision != 0].mean()
        mean_recall = recall[recall != 0].mean()
        mean_iou = iou[iou != 0].mean()
        mean_f1 = f1[f1 != 0].mean()

        score_dict_batch = {'acc': acc, 'mean_pre': mean_precision, 'mean_rec': mean_recall, 'mIoU': mean_iou, 'mF1': mean_f1}
        score_dict_batch.update(cls_precision)
        score_dict_batch.update(cls_recall)
        score_dict_batch.update(cls_iou)
        score_dict_batch.update(cls_f1)

        return score_dict_batch

    def get_metric_dict_per_epoch(self):

        tp = torch.diag(self.confusion_matrix)

        sum_a1 = torch.sum(self.confusion_matrix, dim=1)

        sum_a0 = torch.sum(self.confusion_matrix, dim=0)

        acc = tp.sum() / (self.confusion_matrix.sum() + torch.finfo(type=torch.float32).eps)

        recall = tp / (sum_a1 + torch.finfo(type=torch.float32).eps)

        precision = tp / (sum_a0 + torch.finfo(type=torch.float32).eps)

        f1 = (2 * recall * precision) / (recall + precision + torch.finfo(type=torch.float32).eps)

        iou = tp / (sum_a1 + sum_a0 - tp + torch.finfo(type=torch.float32).eps)

        cls_precision = dict(zip(['Precision_Class[{}]'.format(i) for i in range(self.num_class)], precision))
        cls_recall = dict(zip(['Recall_Class[{}]'.format(i) for i in range(self.num_class)], recall))
        cls_iou = dict(zip(['IoU_Class[{}]'.format(i) for i in range(self.num_class)], iou))
        cls_f1 = dict(zip(['F1_Class[{}]'.format(i) for i in range(self.num_class)], f1))

        mean_precision = precision.mean()
        mean_recall = recall.mean()
        mean_iou = iou.mean()
        mean_f1 = f1.mean()
        score_dict_epoch = {'Accuracy': acc, 'mean_Precision': mean_precision, 'mean_Recall': mean_recall,
                            'mIoU': mean_iou, 'mF1': mean_f1}

        score_dict_epoch.update(cls_precision)
        score_dict_epoch.update(cls_recall)
        score_dict_epoch.update(cls_iou)
        score_dict_epoch.update(cls_f1)
        return score_dict_epoch






if __name__=="__main__":
    gt_label = torch.tensor([[0, 1, 2, 3, 1],
                         [1, 2, 2, 3, 4]])

    pre_label = torch.tensor([[0, 1, 2, 3, 1],
                          [5, 1, 2, 1, 4]])

    num_class = 6
    metric = segmengtion_metric(6, 'cuda:0')
    res = metric.get_matrix_per_batch(gt_label, pre_label)
    res1 = metric.get_metric_dict_per_epoch()
    print(res)
    print(res1)