File size: 1,352 Bytes
cb0ad2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
from collections import defaultdict
from .comm import distributed, all_gather


def format_dict(res_dict):
    res_strs = []
    for key, val in res_dict.items():
        res_strs.append('%s: %s' % (key, val))
    return ', '.join(res_strs)


class Counter:
    def __init__(self, cache_nums=1000):
        self.cache_nums = cache_nums
        self.reset()

    def update(self, metric):
        for key, val in metric.items():
            if isinstance(val, torch.Tensor):
                val = val.item()
            self.metric_dict[key].append(val)
            if self.cache_nums is not None:
                self.metric_dict[key] = self.metric_dict[key][-1*self.cache_nums:]

    def reset(self):
        self.metric_dict = defaultdict(list)

    def _sync(self):
        metric_dicts = all_gather(self.metric_dict)
        total_metric_dict = defaultdict(list)
        for metric_dict in metric_dicts:
            for key, val in metric_dict.items():
                total_metric_dict[key].extend(val)
        return total_metric_dict

    def format_mean(self, sync=True):
        if sync and distributed():
            metric_dict = self._sync()
        else:
            metric_dict = self.metric_dict
        res_dict = {key: '%.4f' % (sum(val)/len(val)) for key, val in metric_dict.items()}
        return format_dict(res_dict)