training_sem / libs /utils /counter.py
kai-2054's picture
Initial commit: add code
cb0ad2d
raw
history blame
1.35 kB
import torch
from collections import defaultdict
from .comm import distributed, all_gather
def format_dict(res_dict):
res_strs = []
for key, val in res_dict.items():
res_strs.append('%s: %s' % (key, val))
return ', '.join(res_strs)
class Counter:
def __init__(self, cache_nums=1000):
self.cache_nums = cache_nums
self.reset()
def update(self, metric):
for key, val in metric.items():
if isinstance(val, torch.Tensor):
val = val.item()
self.metric_dict[key].append(val)
if self.cache_nums is not None:
self.metric_dict[key] = self.metric_dict[key][-1*self.cache_nums:]
def reset(self):
self.metric_dict = defaultdict(list)
def _sync(self):
metric_dicts = all_gather(self.metric_dict)
total_metric_dict = defaultdict(list)
for metric_dict in metric_dicts:
for key, val in metric_dict.items():
total_metric_dict[key].extend(val)
return total_metric_dict
def format_mean(self, sync=True):
if sync and distributed():
metric_dict = self._sync()
else:
metric_dict = self.metric_dict
res_dict = {key: '%.4f' % (sum(val)/len(val)) for key, val in metric_dict.items()}
return format_dict(res_dict)