File size: 9,228 Bytes
747451d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
# /*---------------------------------------------------------------------------------------------
#  * Copyright (c) 2022 STMicroelectronics.
#  * All rights reserved.
#  * This software is licensed under terms that can be found in the LICENSE file in
#  * the root directory of this software component.
#  * If no LICENSE file comes with this software, it is provided AS-IS.
#  *--------------------------------------------------------------------------------------------*/

import os
from munch import DefaultMunch
from omegaconf import DictConfig
import numpy as np
import tensorflow as tf

from common.utils import LRTensorBoard, collect_callback_args
from common.training import lr_schedulers
from object_detection.tf.src.utils import calculate_objdet_metrics


class _ObjectDetectionMetrics(tf.keras.callbacks.Callback):

    def __init__(self, num_classes, iou_threshold, name=None):
        super().__init__()
        self.num_classes = num_classes
        self.iou_threshold = iou_threshold
        self.name = name

    def on_epoch_end(self, epoch, logs=None):

        # Calculate the average metrics for the data
        # collected by the test_step function
        groundtruths, detections = self.model.get_metrics_data()
        mpre, mrec, mAP = calculate_objdet_metrics(groundtruths, detections, self.iou_threshold, averages_only=True)
        
        # Add the metrics to the callbacks logs dictionary
        logs['val_mpre'] = mpre
        logs['val_mrec'] = mrec
        logs['val_map'] = mAP

        # Reset the metrics data collected by the test_step function
        self.model.reset_metrics_data()

    def on_test_batch_end(self, batch, logs=None):
        self.model.set_tmp_metrics_data(batch)


class MultiResCallback(tf.keras.callbacks.Callback):

    def __init__(self, image_sizes, period, name=None):
        super().__init__()
        self.resolutions = image_sizes
        self.period = period

    def on_train_batch_begin(self, batch, logs=None):
        res = self.resolutions[((batch-1)//self.period)%len(self.resolutions)]
        self.model.set_resolution(res)


def _get_callback_monitoring(args: DictConfig, callback_name: str = None, message: str = None) -> tuple:

    if not args.monitor:
        monitor = "val_loss"
    else:
        monitor = args.monitor
        valid = ("val_loss", "val_mpre", "val_mrec", "val_map")
        if monitor not in valid:
            raise ValueError(f"\nSupported values for the `monitor` argument of the '{callback_name}' "
                             f"callback are {valid}. Received {monitor}{message}")
    if not args.mode:
        mode = "auto"
    else:
        mode = args.mode
        valid = ("min", "max", "auto")
        if mode not in valid:
            raise ValueError("\nSupported values for the `mode` argument of the "
                            f"'{callback_name}' callback are {valid}. Received {mode}{message}")
    if mode == "auto":
        mode = "min" if monitor == "val_loss" else "max"

    if monitor in ("val_mpre", "val_mrec", "val_map") and mode == "min":
        print(f"WARNING: The `mode` argument of the '{callback_name}' callback is set to 'min'.",
              f"The '{monitor}' metrics will be minimized.")

    return monitor, mode


def _get_best_model_callback(cfg: DictConfig, 
                            saved_models_dir: str = None,
                            message: str = None) -> tf.keras.callbacks.Callback:

    monitor = "val_loss"
    mode = "min"
    if "ModelCheckpoint" in cfg:
        args = cfg["ModelCheckpoint"]

        for k in args.keys():
            if k not in ("monitor", "mode"):
                raise ValueError("\nThe 'ModelCheckpoint' callback is built-in. Only the "
                                f"`monitor` and `mode` arguments can be specified.{message}")
        monitor, mode = _get_callback_monitoring(args, callback_name="ModelCheckpoint", message=message)

    callback = tf.keras.callbacks.ModelCheckpoint(
                    filepath=os.path.join(saved_models_dir, "best_weights.weights.h5"),
                    save_best_only=True,
                    save_weights_only=True,
                    monitor=monitor,
                    mode=mode)
    return callback


def get_callbacks(cfg: DictConfig,
                  num_classes=None,
                  iou_eval_threshold=None,
                  image_sizes=None,
                  period=None,
                  saved_models_dir: str = None,
                  log_dir: str = None, 
                  metrics_dir: str = None) -> list[tf.keras.callbacks.Callback]:
                  
    """
    This function creates the list of Keras callbacks to be passed to 
    the fit() function including:
      - the Model Zoo built-in callbacks that can't be redefined by the
        user (ModelCheckpoint, TensorBoard, CSVLogger).
      - the callbacks specified in the config file that can be either Keras
        callbacks or custom callbacks (learning rate schedulers).

    For each callback, the attributes and their values used in the config
    file are used to create a string that is the callback instantiation as
    it would be written in a Python script. Then, the string is evaluated.
    If the evaluation succeeds, the callback object is returned. If it fails,
    an error is thrown with a message saying that the name and/or arguments
    of the callback are incorrect.

    The function also checks that there is only one learning rate scheduler
    in the list of callbacks.

    Args:
        callbacks_dict (DictConfig): dictionary containing the 'training.callbacks'
                                     section of the configuration.
        output_dir (str): directory root of the tree where the output files are saved.
        logs_dir (str): directory the TensorBoard logs and training metrics are saved.
        saved_models_dir (str): directory where the trained models are saved.

    Returns:
        List[tf.keras.callbacks.Callback]: A list of Keras callbacks to pass
                                           to the fit() function.
    """

    # Make sure every callback has an arguments dictionary
    for name in cfg.keys():
        if not cfg[name]:
            cfg[name] = DefaultMunch.fromDict({})

    message = "\nPlease check the 'training.callbacks' section of your configuration file."
    lr_scheduler_names = lr_schedulers.get_scheduler_names()
    num_lr_schedulers = 0

    callback_list = []
    
    # The callback that calculates the object detection
    # metrics must be first in the list of callbacks.
    callback_list.append(_ObjectDetectionMetrics(num_classes, iou_eval_threshold))
    if image_sizes is not None and period is not None:
        callback_list.append(MultiResCallback(image_sizes,period))

    for name, args in cfg.items():
        # The ModelCheckpoint callback is handled separately.
        if name == "ModelCheckpoint":
            continue

        if name in ("TensorBoard", "CSVLogger"):
            raise ValueError(f"\nThe `{name}` callback is built-in and can't be redefined.{message}")

        if name in  ("ReduceLROnPlateau", "EarlyStopping"):
            monitor, mode = _get_callback_monitoring(args, callback_name=name, message=message)
            args.monitor = monitor
            args.mode = mode
    
        if name in lr_scheduler_names:
            text = "lr_schedulers.{}".format(name)
        else:
            text = "tf.keras.callbacks.{}".format(name)

        # Add the arguments to the callback string
        # and evaluate it to get the callback object
        text += collect_callback_args(name, args, message=message)
        try:
            callback = eval(text)
        except:
            raise ValueError("\nThe callback name `{}` is unknown, or its arguments are incomplete "
                             "or invalid\nReceived: {}{}".format(name, text, message))
        callback_list.append(callback)

        # Count LR schedulers
        if name in ["ReduceLROnPlateau", "LearningRateScheduler"] + lr_scheduler_names:
            num_lr_schedulers += 1
    
    # Check that there is only one scheduler
    if num_lr_schedulers > 1:
        raise ValueError(f"\nFound more than one learning rate scheduler{message}")

    # Add the Keras callback that saves the best model obtained so far
    callback = _get_best_model_callback(cfg, saved_models_dir=saved_models_dir, message=message)
    callback_list.append(callback)

    # Add the Keras callback that saves the model at the end of the epoch
    callback = tf.keras.callbacks.ModelCheckpoint(
                    filepath=os.path.join(saved_models_dir, "last_weights.weights.h5"),
                    save_best_only=False,
                    save_weights_only=True,
                    save_freq='epoch')
    callback_list.append(callback)

    # Add the TensorBoard callback
    callback = LRTensorBoard(log_dir=log_dir)
    callback_list.append(callback)

    # Add the CVSLogger callback (must be last in the list 
    # of callbacks to make sure it records the learning rate)
    callback = tf.keras.callbacks.CSVLogger(os.path.join(metrics_dir, "train_metrics.csv"))
    callback_list.append(callback)

    return callback_list