File size: 2,182 Bytes
c6dfc69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import numpy
import torch


class AverageMeter:
    def __init__(self, *keys):
        self.__data = dict()
        for k in keys:
            self.__data[k] = [0.0, 0]

    def add(self, dict):
        for k, v in dict.items():
            self.__data[k][0] += v
            self.__data[k][1] += 1

    def get(self, *keys):
        if len(keys) == 1:
            return self.__data[keys[0]][0] / self.__data[keys[0]][1]
        else:
            v_list = [self.__data[k][0] / self.__data[k][1] for k in keys]
            return tuple(v_list)

    def get_entire_dict_for_ddp_calculation(self):
        return self.__data

    def pop(self, key=None):
        if key is None:
            for k in self.__data.keys():
                self.__data[k] = [0.0, 0]
        else:
            v = self.get(key)
            self.__data[key] = [0.0, 0]
            return v


class ForegroundIoU(AverageMeter):
    def __init__(self):
        super(ForegroundIoU, self).__init__('foreground_iou')

    def calculate_iou(self, pred, target, eps=1e-7, get_entire_list=False):
        r"""
            param (both hard mask):
                pred: size [N x H x W], type: int
                target: size [N x H x W], type: int
            output:
                iou: size [1] (size_average=True) or [N] (size_average=False)
        """
        assert len(pred.shape) == 3 and pred.shape == target.shape, 'shape mismatch.'
        assert pred.dtype is torch.long and target.dtype is torch.long, 'type mismatch.'

        N = pred.size(0)
        num_pixels = pred.size(-1) * pred.size(-2)
        no_obj_flag = (target.sum(2).sum(1) == 0)

        inter = (pred * target).sum(2).sum(1)
        union = torch.max(pred, target).sum(2).sum(1)

        inter_no_obj = ((1 - target) * (1 - pred)).sum(2).sum(1)
        inter[no_obj_flag] = inter_no_obj[no_obj_flag]
        union[no_obj_flag] = num_pixels

        iou = torch.sum(inter / (union+eps)) / N

        self.add({'foreground_iou': iou})
        return self.get('foreground_iou') if not get_entire_list else self.get_entire_dict_for_ddp_calculation()

    def reset(self,):
        super(ForegroundIoU, self).__init__('foreground_iou')