Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
| # holder of all proprietary rights on this computer program. | |
| # You can only use this computer program if you have closed | |
| # a license agreement with MPG or you get the right to use the computer | |
| # program from someone who is authorized to grant you that right. | |
| # Any use of the computer program without a valid license is prohibited and | |
| # liable to prosecution. | |
| # | |
| # Copyright©2023 Max-Planck-Gesellschaft zur Förderung | |
| # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
| # for Intelligent Systems. All rights reserved. | |
| # | |
| # Contact: mica@tue.mpg.de | |
| import os | |
| import numpy as np | |
| from loguru import logger | |
| class BestModel: | |
| def __init__(self, trainer): | |
| self.average = np.Inf | |
| self.weighted_average = np.Inf | |
| self.smoothed_average = np.Inf | |
| self.smoothed_weighted_average = np.Inf | |
| self.running_average = np.Inf | |
| self.running_weighted_average = np.Inf | |
| self.now_mean = None | |
| self.trainer = trainer | |
| self.counter = None | |
| self.N = trainer.cfg.running_average | |
| os.makedirs(os.path.join(self.trainer.cfg.output_dir, 'best_models'), exist_ok=True) | |
| def state_dict(self): | |
| return { | |
| 'average': self.average, | |
| 'smoothed_average': self.smoothed_average, | |
| 'running_average': self.running_average, | |
| 'now_mean': self.now_mean, | |
| 'counter': self.counter, | |
| } | |
| def load_state_dict(self, dict): | |
| self.average = dict['average'] | |
| self.smoothed_average = dict['smoothed_average'] | |
| self.running_average = dict['running_average'] | |
| self.now_mean = dict['now_mean'] | |
| self.counter = dict['counter'] | |
| logger.info(f'[BEST] Best score weighted average: ' | |
| f'NoW mean: {self.now_mean:.6f} | ' | |
| f'average: {self.average:.6f} | ' | |
| f'smoothed average: {self.running_average:.6f}') | |
| def __call__(self, weighted_average, average): | |
| if self.counter is None: | |
| self.counter = 1 | |
| self.average = average | |
| self.weighted_average = weighted_average | |
| self.running_weighted_average = weighted_average | |
| self.running_average = average | |
| return weighted_average, average | |
| if weighted_average < self.weighted_average: | |
| delta = self.weighted_average - weighted_average | |
| self.weighted_average = weighted_average | |
| logger.info(f'[BEST] (Average weighted) {self.trainer.global_step} | {delta:.6f} improvement and value: {self.weighted_average:.6f}') | |
| self.trainer.save_checkpoint(os.path.join(self.trainer.cfg.output_dir, 'best_models', f'best_model_0.tar')) | |
| if average < self.average: | |
| delta = self.average - average | |
| self.average = average | |
| logger.info(f'[BEST] (Average) {self.trainer.global_step} | {delta:.6f} improvement and value: {self.average:.6f}') | |
| self.trainer.save_checkpoint(os.path.join(self.trainer.cfg.output_dir, 'best_models', f'best_model_1.tar')) | |
| n = self.N | |
| self.running_average = self.running_average * ((n - 1) / n) + (average / n) | |
| if self.running_average < self.smoothed_average: | |
| delta = self.smoothed_average - self.running_average | |
| self.smoothed_average = self.running_average | |
| logger.info(f'[BEST] (Average Smoothed) {self.trainer.global_step} | {delta:.6f} improvement and value: {self.smoothed_average:.6f} | counter: {self.counter} | window: {n}') | |
| self.trainer.save_checkpoint(os.path.join(self.trainer.cfg.output_dir, 'best_models', f'best_model_3.tar')) | |
| self.counter += 1 | |
| return self.running_weighted_average, self.running_average | |
| def now(self, median, mean, std): | |
| if self.now_mean is None: | |
| self.now_mean = mean | |
| return | |
| if mean < self.now_mean: | |
| delta = self.now_mean - mean | |
| self.now_mean = mean | |
| logger.info(f'[BEST] (NoW) {self.trainer.global_step} | {delta:.6f} improvement and mean: {self.now_mean:.6f} std: {std} median: {median}') | |
| self.trainer.save_checkpoint(os.path.join(self.trainer.cfg.output_dir, 'best_models', f'best_model_now.tar')) | |