File size: 1,626 Bytes
7ef7abb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | from collections import defaultdict
import numpy as np
import torch
class Averager:
def __init__(self):
self.reset()
# noinspection PyAttributeOutsideInit
def reset(self):
self.total = {}
self.counter = {}
def update(self, stats):
for key, value in stats.items():
if key not in self.total:
if isinstance(value, torch.Tensor):
self.total[key] = value.sum()
self.counter[key] = value.numel()
elif isinstance(value, np.ndarray):
self.total[key] = value.sum()
self.counter[key] = value.size
else:
self.total[key] = value
self.counter[key] = 1
else:
if isinstance(value, torch.Tensor):
self.total[key] = self.total[key] + value.sum()
self.counter[key] = self.counter[key] + value.numel()
elif isinstance(value, np.ndarray):
self.total[key] = self.total[key] + value.sum()
self.counter[key] = self.counter[key] + value.size
else:
self.total[key] = self.total[key] + value
self.counter[key] = self.counter[key] + 1
def average(self):
averaged_stats = {
key: (tot / self.counter[key]).item() if isinstance(tot, torch.Tensor) else tot / self.counter[key] for key, tot in self.total.items()
}
self.reset()
return averaged_stats
|