File size: 3,428 Bytes
c6dfc69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import numpy


class BinaryMIoU(object):
    def __init__(self, ignore_index):
        self.num_classes = 2
        self.ignore_index = ignore_index
        self.inter, self.union = 0, 0
        self.correct, self.label = 0, 0
        self.iou = numpy.array([0 for _ in range(self.num_classes)])
        self.acc = 0.0

    def get_metric_results(self, curr_correct_, curr_label_, curr_inter_, curr_union_):
        # calculates the overall miou and acc
        self.correct = self.correct + curr_correct_
        self.label = self.label + curr_label_
        self.inter = self.inter + curr_inter_
        self.union = self.union + curr_union_
        self.acc = 1.0 * self.correct / (numpy.spacing(1) + self.label)
        self.iou = 1.0 * self.inter / (numpy.spacing(1) + self.union)
        return numpy.round(self.iou, 4), numpy.round(self.acc, 4)
        # if class_list is None:
        #     return numpy.round(self.iou.mean().item(), 4), \
        #         numpy.round(self.acc, 4)
        # else:
        #     return numpy.round(self.iou[class_list].mean().item(), 4), \
        #         numpy.round(self.acc, 4)

    @staticmethod
    def get_current_image_results(curr_correct_, curr_label_, curr_inter_, curr_union_):
        curr_acc = 1.0 * curr_correct_ / (numpy.spacing(1) + curr_label_)
        curr_iou = 1.0 * curr_inter_ / (numpy.spacing(1) + curr_union_)
        return curr_iou, curr_acc

    def __call__(self, x, y):
        curr_correct, curr_label, curr_inter, curr_union = self.calculate_current_sample(x, y)
        return (self.get_metric_results(curr_correct, curr_label, curr_inter, curr_union),
                self.get_current_image_results(curr_correct, curr_label, curr_inter, curr_union))

    def calculate_current_sample(self, output, target):
        # output => BxCxHxW (logits)
        # target => Bx1xHxW
        target[target == self.ignore_index] = -1
        correct, labeled = self.batch_pix_accuracy(output.data, target)
        inter, union = self.batch_intersection_union(output.data, target, self.num_classes)
        return [numpy.round(correct, 5), numpy.round(labeled, 5), numpy.round(inter, 5), numpy.round(union, 5)]

    @ staticmethod
    def batch_pix_accuracy(predict, target):
        # _, predict = torch.max(output, 1)

        predict = predict.int() + 1
        target = target.int() + 1

        pixel_labeled = (target > 0).sum()
        pixel_correct = ((predict == target) * (target > 0)).sum()
        assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled"
        return pixel_correct.cpu().numpy(), pixel_labeled.cpu().numpy()

    @ staticmethod
    def batch_intersection_union(predict, target, num_class):
        # _, predict = torch.max(output, 1)
        predict = predict + 1
        target = target + 1

        predict = predict * (target > 0).long()
        intersection = predict * (predict == target).long()

        area_inter = torch.histc(intersection.float(), bins=num_class, max=num_class, min=1)
        area_pred = torch.histc(predict.float(), bins=num_class, max=num_class, min=1)
        area_lab = torch.histc(target.float(), bins=num_class, max=num_class, min=1)
        area_union = area_pred + area_lab - area_inter
        assert (area_inter <= area_union).all(), "Intersection area should be smaller than Union area"
        return area_inter.cpu().numpy(), area_union.cpu().numpy()