osu_mapper2 / osuT5 /utils /log_utils.py
Tiger14n's picture
Upload folder using huggingface_hub
7ef7abb verified
from collections import defaultdict
import numpy as np
import torch
class Averager:
def __init__(self):
self.reset()
# noinspection PyAttributeOutsideInit
def reset(self):
self.total = {}
self.counter = {}
def update(self, stats):
for key, value in stats.items():
if key not in self.total:
if isinstance(value, torch.Tensor):
self.total[key] = value.sum()
self.counter[key] = value.numel()
elif isinstance(value, np.ndarray):
self.total[key] = value.sum()
self.counter[key] = value.size
else:
self.total[key] = value
self.counter[key] = 1
else:
if isinstance(value, torch.Tensor):
self.total[key] = self.total[key] + value.sum()
self.counter[key] = self.counter[key] + value.numel()
elif isinstance(value, np.ndarray):
self.total[key] = self.total[key] + value.sum()
self.counter[key] = self.counter[key] + value.size
else:
self.total[key] = self.total[key] + value
self.counter[key] = self.counter[key] + 1
def average(self):
averaged_stats = {
key: (tot / self.counter[key]).item() if isinstance(tot, torch.Tensor) else tot / self.counter[key] for key, tot in self.total.items()
}
self.reset()
return averaged_stats