File size: 1,762 Bytes
c6dfc69 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | import numpy
import torch
class AverageMeter:
def __init__(self, *keys):
self.__data = dict()
for k in keys:
self.__data[k] = [0.0, 0]
def add(self, dict):
for k, v in dict.items():
self.__data[k][0] += v
self.__data[k][1] += 1
def get(self, *keys):
if len(keys) == 1:
return self.__data[keys[0]][0] / self.__data[keys[0]][1]
else:
v_list = [self.__data[k][0] / self.__data[k][1] for k in keys]
return tuple(v_list)
def get_entire_dict_for_ddp_calculation(self):
return self.__data
def pop(self, key=None):
if key is None:
for k in self.__data.keys():
self.__data[k] = [0.0, 0]
else:
v = self.get(key)
self.__data[key] = [0.0, 0]
return v
class ForegroundS(AverageMeter):
def __init__(self):
super(ForegroundS, self).__init__('foreground_p', 'foreground_n')
def metric_s_for_null(self, pred, get_entire_list=False):
NF, bsz, H, W = pred.shape
pred = pred.view(NF * bsz, H, W)
assert len(pred.shape) == 3
N = pred.size(0)
num_pixels = pred.view(-1).shape[0]
temp_pred = torch.sigmoid(pred)
pred = (temp_pred > 0.5).int()
x = torch.sum(pred.view(-1))
s = torch.sqrt(x / num_pixels)
self.add({'foreground_p': x})
self.add({'foreground_n': num_pixels})
# self.add({'foreground_s': s})
return self.get('foreground_p')/self.get('foreground_n') if not get_entire_list else self.get_entire_dict_for_ddp_calculation()
def reset(self, ):
super(ForegroundS, self).__init__('foreground_p', 'foreground_n')
|