File size: 3,333 Bytes
f43af3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from collections import defaultdict

import numpy as np

from easy_tpp.utils.log_utils import default_logger as logger


class MetricsHelper:
    MAXIMIZE = 'maximize'
    MINIMIZE = 'minimize'
    _registry_center = defaultdict(tuple)

    @staticmethod
    def get_metric_function(name):
        if name in MetricsHelper._registry_center:
            return MetricsHelper._registry_center[name][0]
        else:
            logger.warn(f'Metric is not found: {name}')
            return None

    @staticmethod
    def get_metric_direction(name):
        if name in MetricsHelper._registry_center:
            return MetricsHelper._registry_center[name][1]
        else:
            return None

    @staticmethod
    def get_all_registered_metric():
        return MetricsHelper._registry_center.values

    @staticmethod
    def register(name, direction, overwrite=True):
        registry_center = MetricsHelper._registry_center

        def _add_metric_to_registry(func):
            if name in registry_center:
                if overwrite:
                    registry_center[name] = (func, direction)
                else:
                    logger.warn(f'The metric {name} is already registered, and cannot be overwritten!')
            else:
                registry_center[name] = (func, direction)
            return func

        return _add_metric_to_registry

    @staticmethod
    def metrics_dict_to_str(metrics_dict):
        """ Convert metrics to a string to show in console  """
        eval_info = ''
        for k, v in metrics_dict.items():
            eval_info += '{0} is {1}, '.format(k, v)

        return eval_info[:-2]

    @staticmethod
    def get_metrics_callback_from_names(metric_names):
        """ Metrics function callbacks    """
        metric_functions = []
        metric_names_ = []
        for name in metric_names:
            metric = MetricsHelper.get_metric_function(name)
            if metric is not None:
                metric_functions.append(metric)
                metric_names_.append(name)

        def metrics(preds, labels, **kwargs):
            """ call metrics functions """
            res = dict()
            for metric_name, metric_func in zip(metric_names_, metric_functions):
                res[metric_name.lower()] = metric_func(preds, labels, **kwargs)
            return res

        return metrics


class MetricsTracker:
    """Track and record the metrics.
    """

    def __init__(self):
        self.current_best = {
            'loglike': np.finfo(float).min,
            'distance': np.finfo(float).max
        }
        self.episode_best = 'NeverUpdated'

    def update_best(self, key, value, epoch):
        """Update the recorder for the best metrics.

        Args:
            key (str): metrics key.
            value (float): metrics value.
            epoch (int): num of epoch.

        Raises:
            NotImplementedError: for keys other than 'loglike'.

        Returns:
            bool: whether the recorder has been updated.
        """
        updated = False
        if key == 'loglike':
            if value > self.current_best[key]:
                updated = True
                self.current_best[key] = value
                self.episode_best = epoch
        else:
            raise NotImplementedError

        return updated