File size: 1,626 Bytes
7ef7abb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from collections import defaultdict

import numpy as np
import torch


class Averager:
    def __init__(self):
        self.reset()

    # noinspection PyAttributeOutsideInit
    def reset(self):
        self.total = {}
        self.counter = {}

    def update(self, stats):
        for key, value in stats.items():
            if key not in self.total:
                if isinstance(value, torch.Tensor):
                    self.total[key] = value.sum()
                    self.counter[key] = value.numel()
                elif isinstance(value, np.ndarray):
                    self.total[key] = value.sum()
                    self.counter[key] = value.size
                else:
                    self.total[key] = value
                    self.counter[key] = 1
            else:
                if isinstance(value, torch.Tensor):
                    self.total[key] = self.total[key] + value.sum()
                    self.counter[key] = self.counter[key] + value.numel()
                elif isinstance(value, np.ndarray):
                    self.total[key] = self.total[key] + value.sum()
                    self.counter[key] = self.counter[key] + value.size
                else:
                    self.total[key] = self.total[key] + value
                    self.counter[key] = self.counter[key] + 1

    def average(self):
        averaged_stats = {
            key: (tot / self.counter[key]).item() if isinstance(tot, torch.Tensor) else tot / self.counter[key] for key, tot in self.total.items()
        }
        self.reset()

        return averaged_stats