| import torch | |
| from collections import defaultdict | |
| from .comm import distributed, all_gather | |
| def format_dict(res_dict): | |
| res_strs = [] | |
| for key, val in res_dict.items(): | |
| res_strs.append('%s: %s' % (key, val)) | |
| return ', '.join(res_strs) | |
| class Counter: | |
| def __init__(self, cache_nums=1000): | |
| self.cache_nums = cache_nums | |
| self.reset() | |
| def update(self, metric): | |
| for key, val in metric.items(): | |
| if isinstance(val, torch.Tensor): | |
| val = val.item() | |
| self.metric_dict[key].append(val) | |
| if self.cache_nums is not None: | |
| self.metric_dict[key] = self.metric_dict[key][-1*self.cache_nums:] | |
| def reset(self): | |
| self.metric_dict = defaultdict(list) | |
| def _sync(self): | |
| metric_dicts = all_gather(self.metric_dict) | |
| total_metric_dict = defaultdict(list) | |
| for metric_dict in metric_dicts: | |
| for key, val in metric_dict.items(): | |
| total_metric_dict[key].extend(val) | |
| return total_metric_dict | |
| def format_mean(self, sync=True): | |
| if sync and distributed(): | |
| metric_dict = self._sync() | |
| else: | |
| metric_dict = self.metric_dict | |
| res_dict = {key: '%.4f' % (sum(val)/len(val)) for key, val in metric_dict.items()} | |
| return format_dict(res_dict) | |