| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| from munch import DefaultMunch |
| from omegaconf import DictConfig |
| import numpy as np |
| import tensorflow as tf |
|
|
| from common.utils import LRTensorBoard, collect_callback_args |
| from common.training import lr_schedulers |
| from object_detection.tf.src.utils import calculate_objdet_metrics |
|
|
|
|
| class _ObjectDetectionMetrics(tf.keras.callbacks.Callback): |
|
|
| def __init__(self, num_classes, iou_threshold, name=None): |
| super().__init__() |
| self.num_classes = num_classes |
| self.iou_threshold = iou_threshold |
| self.name = name |
|
|
| def on_epoch_end(self, epoch, logs=None): |
|
|
| |
| |
| groundtruths, detections = self.model.get_metrics_data() |
| mpre, mrec, mAP = calculate_objdet_metrics(groundtruths, detections, self.iou_threshold, averages_only=True) |
| |
| |
| logs['val_mpre'] = mpre |
| logs['val_mrec'] = mrec |
| logs['val_map'] = mAP |
|
|
| |
| self.model.reset_metrics_data() |
|
|
| def on_test_batch_end(self, batch, logs=None): |
| self.model.set_tmp_metrics_data(batch) |
|
|
|
|
| class MultiResCallback(tf.keras.callbacks.Callback): |
|
|
| def __init__(self, image_sizes, period, name=None): |
| super().__init__() |
| self.resolutions = image_sizes |
| self.period = period |
|
|
| def on_train_batch_begin(self, batch, logs=None): |
| res = self.resolutions[((batch-1)//self.period)%len(self.resolutions)] |
| self.model.set_resolution(res) |
|
|
|
|
| def _get_callback_monitoring(args: DictConfig, callback_name: str = None, message: str = None) -> tuple: |
|
|
| if not args.monitor: |
| monitor = "val_loss" |
| else: |
| monitor = args.monitor |
| valid = ("val_loss", "val_mpre", "val_mrec", "val_map") |
| if monitor not in valid: |
| raise ValueError(f"\nSupported values for the `monitor` argument of the '{callback_name}' " |
| f"callback are {valid}. Received {monitor}{message}") |
| if not args.mode: |
| mode = "auto" |
| else: |
| mode = args.mode |
| valid = ("min", "max", "auto") |
| if mode not in valid: |
| raise ValueError("\nSupported values for the `mode` argument of the " |
| f"'{callback_name}' callback are {valid}. Received {mode}{message}") |
| if mode == "auto": |
| mode = "min" if monitor == "val_loss" else "max" |
|
|
| if monitor in ("val_mpre", "val_mrec", "val_map") and mode == "min": |
| print(f"WARNING: The `mode` argument of the '{callback_name}' callback is set to 'min'.", |
| f"The '{monitor}' metrics will be minimized.") |
|
|
| return monitor, mode |
|
|
|
|
| def _get_best_model_callback(cfg: DictConfig, |
| saved_models_dir: str = None, |
| message: str = None) -> tf.keras.callbacks.Callback: |
|
|
| monitor = "val_loss" |
| mode = "min" |
| if "ModelCheckpoint" in cfg: |
| args = cfg["ModelCheckpoint"] |
|
|
| for k in args.keys(): |
| if k not in ("monitor", "mode"): |
| raise ValueError("\nThe 'ModelCheckpoint' callback is built-in. Only the " |
| f"`monitor` and `mode` arguments can be specified.{message}") |
| monitor, mode = _get_callback_monitoring(args, callback_name="ModelCheckpoint", message=message) |
|
|
| callback = tf.keras.callbacks.ModelCheckpoint( |
| filepath=os.path.join(saved_models_dir, "best_weights.weights.h5"), |
| save_best_only=True, |
| save_weights_only=True, |
| monitor=monitor, |
| mode=mode) |
| return callback |
|
|
|
|
| def get_callbacks(cfg: DictConfig, |
| num_classes=None, |
| iou_eval_threshold=None, |
| image_sizes=None, |
| period=None, |
| saved_models_dir: str = None, |
| log_dir: str = None, |
| metrics_dir: str = None) -> list[tf.keras.callbacks.Callback]: |
| |
| """ |
| This function creates the list of Keras callbacks to be passed to |
| the fit() function including: |
| - the Model Zoo built-in callbacks that can't be redefined by the |
| user (ModelCheckpoint, TensorBoard, CSVLogger). |
| - the callbacks specified in the config file that can be either Keras |
| callbacks or custom callbacks (learning rate schedulers). |
| |
| For each callback, the attributes and their values used in the config |
| file are used to create a string that is the callback instantiation as |
| it would be written in a Python script. Then, the string is evaluated. |
| If the evaluation succeeds, the callback object is returned. If it fails, |
| an error is thrown with a message saying that the name and/or arguments |
| of the callback are incorrect. |
| |
| The function also checks that there is only one learning rate scheduler |
| in the list of callbacks. |
| |
| Args: |
| callbacks_dict (DictConfig): dictionary containing the 'training.callbacks' |
| section of the configuration. |
| output_dir (str): directory root of the tree where the output files are saved. |
| logs_dir (str): directory the TensorBoard logs and training metrics are saved. |
| saved_models_dir (str): directory where the trained models are saved. |
| |
| Returns: |
| List[tf.keras.callbacks.Callback]: A list of Keras callbacks to pass |
| to the fit() function. |
| """ |
|
|
| |
| for name in cfg.keys(): |
| if not cfg[name]: |
| cfg[name] = DefaultMunch.fromDict({}) |
|
|
| message = "\nPlease check the 'training.callbacks' section of your configuration file." |
| lr_scheduler_names = lr_schedulers.get_scheduler_names() |
| num_lr_schedulers = 0 |
|
|
| callback_list = [] |
| |
| |
| |
| callback_list.append(_ObjectDetectionMetrics(num_classes, iou_eval_threshold)) |
| if image_sizes is not None and period is not None: |
| callback_list.append(MultiResCallback(image_sizes,period)) |
|
|
| for name, args in cfg.items(): |
| |
| if name == "ModelCheckpoint": |
| continue |
|
|
| if name in ("TensorBoard", "CSVLogger"): |
| raise ValueError(f"\nThe `{name}` callback is built-in and can't be redefined.{message}") |
|
|
| if name in ("ReduceLROnPlateau", "EarlyStopping"): |
| monitor, mode = _get_callback_monitoring(args, callback_name=name, message=message) |
| args.monitor = monitor |
| args.mode = mode |
| |
| if name in lr_scheduler_names: |
| text = "lr_schedulers.{}".format(name) |
| else: |
| text = "tf.keras.callbacks.{}".format(name) |
|
|
| |
| |
| text += collect_callback_args(name, args, message=message) |
| try: |
| callback = eval(text) |
| except: |
| raise ValueError("\nThe callback name `{}` is unknown, or its arguments are incomplete " |
| "or invalid\nReceived: {}{}".format(name, text, message)) |
| callback_list.append(callback) |
|
|
| |
| if name in ["ReduceLROnPlateau", "LearningRateScheduler"] + lr_scheduler_names: |
| num_lr_schedulers += 1 |
| |
| |
| if num_lr_schedulers > 1: |
| raise ValueError(f"\nFound more than one learning rate scheduler{message}") |
|
|
| |
| callback = _get_best_model_callback(cfg, saved_models_dir=saved_models_dir, message=message) |
| callback_list.append(callback) |
|
|
| |
| callback = tf.keras.callbacks.ModelCheckpoint( |
| filepath=os.path.join(saved_models_dir, "last_weights.weights.h5"), |
| save_best_only=False, |
| save_weights_only=True, |
| save_freq='epoch') |
| callback_list.append(callback) |
|
|
| |
| callback = LRTensorBoard(log_dir=log_dir) |
| callback_list.append(callback) |
|
|
| |
| |
| callback = tf.keras.callbacks.CSVLogger(os.path.join(metrics_dir, "train_metrics.csv")) |
| callback_list.append(callback) |
|
|
| return callback_list |
|
|
|
|