alexnasa's picture
Upload 82 files
bd096d2 verified
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2023 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: mica@tue.mpg.de
import os
import numpy as np
from loguru import logger
class BestModel:
def __init__(self, trainer):
self.average = np.Inf
self.weighted_average = np.Inf
self.smoothed_average = np.Inf
self.smoothed_weighted_average = np.Inf
self.running_average = np.Inf
self.running_weighted_average = np.Inf
self.now_mean = None
self.trainer = trainer
self.counter = None
self.N = trainer.cfg.running_average
os.makedirs(os.path.join(self.trainer.cfg.output_dir, 'best_models'), exist_ok=True)
def state_dict(self):
return {
'average': self.average,
'smoothed_average': self.smoothed_average,
'running_average': self.running_average,
'now_mean': self.now_mean,
'counter': self.counter,
}
def load_state_dict(self, dict):
self.average = dict['average']
self.smoothed_average = dict['smoothed_average']
self.running_average = dict['running_average']
self.now_mean = dict['now_mean']
self.counter = dict['counter']
logger.info(f'[BEST] Best score weighted average: '
f'NoW mean: {self.now_mean:.6f} | '
f'average: {self.average:.6f} | '
f'smoothed average: {self.running_average:.6f}')
def __call__(self, weighted_average, average):
if self.counter is None:
self.counter = 1
self.average = average
self.weighted_average = weighted_average
self.running_weighted_average = weighted_average
self.running_average = average
return weighted_average, average
if weighted_average < self.weighted_average:
delta = self.weighted_average - weighted_average
self.weighted_average = weighted_average
logger.info(f'[BEST] (Average weighted) {self.trainer.global_step} | {delta:.6f} improvement and value: {self.weighted_average:.6f}')
self.trainer.save_checkpoint(os.path.join(self.trainer.cfg.output_dir, 'best_models', f'best_model_0.tar'))
if average < self.average:
delta = self.average - average
self.average = average
logger.info(f'[BEST] (Average) {self.trainer.global_step} | {delta:.6f} improvement and value: {self.average:.6f}')
self.trainer.save_checkpoint(os.path.join(self.trainer.cfg.output_dir, 'best_models', f'best_model_1.tar'))
n = self.N
self.running_average = self.running_average * ((n - 1) / n) + (average / n)
if self.running_average < self.smoothed_average:
delta = self.smoothed_average - self.running_average
self.smoothed_average = self.running_average
logger.info(f'[BEST] (Average Smoothed) {self.trainer.global_step} | {delta:.6f} improvement and value: {self.smoothed_average:.6f} | counter: {self.counter} | window: {n}')
self.trainer.save_checkpoint(os.path.join(self.trainer.cfg.output_dir, 'best_models', f'best_model_3.tar'))
self.counter += 1
return self.running_weighted_average, self.running_average
def now(self, median, mean, std):
if self.now_mean is None:
self.now_mean = mean
return
if mean < self.now_mean:
delta = self.now_mean - mean
self.now_mean = mean
logger.info(f'[BEST] (NoW) {self.trainer.global_step} | {delta:.6f} improvement and mean: {self.now_mean:.6f} std: {std} median: {median}')
self.trainer.save_checkpoint(os.path.join(self.trainer.cfg.output_dir, 'best_models', f'best_model_now.tar'))