File size: 4,383 Bytes
bd096d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# -*- coding: utf-8 -*-

# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2023 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: mica@tue.mpg.de


import os

import numpy as np
from loguru import logger


class BestModel:
    def __init__(self, trainer):
        self.average = np.Inf
        self.weighted_average = np.Inf
        self.smoothed_average = np.Inf
        self.smoothed_weighted_average = np.Inf
        self.running_average = np.Inf
        self.running_weighted_average = np.Inf
        self.now_mean = None

        self.trainer = trainer
        self.counter = None

        self.N = trainer.cfg.running_average

        os.makedirs(os.path.join(self.trainer.cfg.output_dir, 'best_models'), exist_ok=True)

    def state_dict(self):
        return {
            'average': self.average,
            'smoothed_average': self.smoothed_average,
            'running_average': self.running_average,
            'now_mean': self.now_mean,
            'counter': self.counter,
        }

    def load_state_dict(self, dict):
        self.average = dict['average']
        self.smoothed_average = dict['smoothed_average']
        self.running_average = dict['running_average']
        self.now_mean = dict['now_mean']
        self.counter = dict['counter']

        logger.info(f'[BEST] Best score weighted average: '
                    f'NoW mean: {self.now_mean:.6f} | '
                    f'average: {self.average:.6f} | '
                    f'smoothed average: {self.running_average:.6f}')

    def __call__(self, weighted_average, average):
        if self.counter is None:
            self.counter = 1
            self.average = average
            self.weighted_average = weighted_average
            self.running_weighted_average = weighted_average
            self.running_average = average

            return weighted_average, average

        if weighted_average < self.weighted_average:
            delta = self.weighted_average - weighted_average
            self.weighted_average = weighted_average
            logger.info(f'[BEST] (Average weighted)   {self.trainer.global_step} | {delta:.6f} improvement and value: {self.weighted_average:.6f}')
            self.trainer.save_checkpoint(os.path.join(self.trainer.cfg.output_dir, 'best_models', f'best_model_0.tar'))

        if average < self.average:
            delta = self.average - average
            self.average = average
            logger.info(f'[BEST] (Average)   {self.trainer.global_step} | {delta:.6f} improvement and value: {self.average:.6f}')
            self.trainer.save_checkpoint(os.path.join(self.trainer.cfg.output_dir, 'best_models', f'best_model_1.tar'))

        n = self.N

        self.running_average = self.running_average * ((n - 1) / n) + (average / n)
        if self.running_average < self.smoothed_average:
            delta = self.smoothed_average - self.running_average
            self.smoothed_average = self.running_average
            logger.info(f'[BEST] (Average Smoothed) {self.trainer.global_step} | {delta:.6f} improvement and value: {self.smoothed_average:.6f} | counter: {self.counter} | window: {n}')
            self.trainer.save_checkpoint(os.path.join(self.trainer.cfg.output_dir, 'best_models', f'best_model_3.tar'))

        self.counter += 1

        return self.running_weighted_average, self.running_average

    def now(self, median, mean, std):
        if self.now_mean is None:
            self.now_mean = mean
            return

        if mean < self.now_mean:
            delta = self.now_mean - mean
            self.now_mean = mean
            logger.info(f'[BEST] (NoW)   {self.trainer.global_step} | {delta:.6f} improvement and mean: {self.now_mean:.6f} std: {std} median: {median}')
            self.trainer.save_checkpoint(os.path.join(self.trainer.cfg.output_dir, 'best_models', f'best_model_now.tar'))