FBAGSTM's picture
STM32 AI Experimentation Hub
747451d
# /*---------------------------------------------------------------------------------------------
# * Copyright (c) 2022 STMicroelectronics.
# * All rights reserved.
# * This software is licensed under terms that can be found in the LICENSE file in
# * the root directory of this software component.
# * If no LICENSE file comes with this software, it is provided AS-IS.
# *--------------------------------------------------------------------------------------------*/
import os
from munch import DefaultMunch
from omegaconf import DictConfig
import numpy as np
import tensorflow as tf
from common.utils import LRTensorBoard, collect_callback_args
from common.training import lr_schedulers
from object_detection.tf.src.utils import calculate_objdet_metrics
class _ObjectDetectionMetrics(tf.keras.callbacks.Callback):
def __init__(self, num_classes, iou_threshold, name=None):
super().__init__()
self.num_classes = num_classes
self.iou_threshold = iou_threshold
self.name = name
def on_epoch_end(self, epoch, logs=None):
# Calculate the average metrics for the data
# collected by the test_step function
groundtruths, detections = self.model.get_metrics_data()
mpre, mrec, mAP = calculate_objdet_metrics(groundtruths, detections, self.iou_threshold, averages_only=True)
# Add the metrics to the callbacks logs dictionary
logs['val_mpre'] = mpre
logs['val_mrec'] = mrec
logs['val_map'] = mAP
# Reset the metrics data collected by the test_step function
self.model.reset_metrics_data()
def on_test_batch_end(self, batch, logs=None):
self.model.set_tmp_metrics_data(batch)
class MultiResCallback(tf.keras.callbacks.Callback):
def __init__(self, image_sizes, period, name=None):
super().__init__()
self.resolutions = image_sizes
self.period = period
def on_train_batch_begin(self, batch, logs=None):
res = self.resolutions[((batch-1)//self.period)%len(self.resolutions)]
self.model.set_resolution(res)
def _get_callback_monitoring(args: DictConfig, callback_name: str = None, message: str = None) -> tuple:
if not args.monitor:
monitor = "val_loss"
else:
monitor = args.monitor
valid = ("val_loss", "val_mpre", "val_mrec", "val_map")
if monitor not in valid:
raise ValueError(f"\nSupported values for the `monitor` argument of the '{callback_name}' "
f"callback are {valid}. Received {monitor}{message}")
if not args.mode:
mode = "auto"
else:
mode = args.mode
valid = ("min", "max", "auto")
if mode not in valid:
raise ValueError("\nSupported values for the `mode` argument of the "
f"'{callback_name}' callback are {valid}. Received {mode}{message}")
if mode == "auto":
mode = "min" if monitor == "val_loss" else "max"
if monitor in ("val_mpre", "val_mrec", "val_map") and mode == "min":
print(f"WARNING: The `mode` argument of the '{callback_name}' callback is set to 'min'.",
f"The '{monitor}' metrics will be minimized.")
return monitor, mode
def _get_best_model_callback(cfg: DictConfig,
saved_models_dir: str = None,
message: str = None) -> tf.keras.callbacks.Callback:
monitor = "val_loss"
mode = "min"
if "ModelCheckpoint" in cfg:
args = cfg["ModelCheckpoint"]
for k in args.keys():
if k not in ("monitor", "mode"):
raise ValueError("\nThe 'ModelCheckpoint' callback is built-in. Only the "
f"`monitor` and `mode` arguments can be specified.{message}")
monitor, mode = _get_callback_monitoring(args, callback_name="ModelCheckpoint", message=message)
callback = tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(saved_models_dir, "best_weights.weights.h5"),
save_best_only=True,
save_weights_only=True,
monitor=monitor,
mode=mode)
return callback
def get_callbacks(cfg: DictConfig,
num_classes=None,
iou_eval_threshold=None,
image_sizes=None,
period=None,
saved_models_dir: str = None,
log_dir: str = None,
metrics_dir: str = None) -> list[tf.keras.callbacks.Callback]:
"""
This function creates the list of Keras callbacks to be passed to
the fit() function including:
- the Model Zoo built-in callbacks that can't be redefined by the
user (ModelCheckpoint, TensorBoard, CSVLogger).
- the callbacks specified in the config file that can be either Keras
callbacks or custom callbacks (learning rate schedulers).
For each callback, the attributes and their values used in the config
file are used to create a string that is the callback instantiation as
it would be written in a Python script. Then, the string is evaluated.
If the evaluation succeeds, the callback object is returned. If it fails,
an error is thrown with a message saying that the name and/or arguments
of the callback are incorrect.
The function also checks that there is only one learning rate scheduler
in the list of callbacks.
Args:
callbacks_dict (DictConfig): dictionary containing the 'training.callbacks'
section of the configuration.
output_dir (str): directory root of the tree where the output files are saved.
logs_dir (str): directory the TensorBoard logs and training metrics are saved.
saved_models_dir (str): directory where the trained models are saved.
Returns:
List[tf.keras.callbacks.Callback]: A list of Keras callbacks to pass
to the fit() function.
"""
# Make sure every callback has an arguments dictionary
for name in cfg.keys():
if not cfg[name]:
cfg[name] = DefaultMunch.fromDict({})
message = "\nPlease check the 'training.callbacks' section of your configuration file."
lr_scheduler_names = lr_schedulers.get_scheduler_names()
num_lr_schedulers = 0
callback_list = []
# The callback that calculates the object detection
# metrics must be first in the list of callbacks.
callback_list.append(_ObjectDetectionMetrics(num_classes, iou_eval_threshold))
if image_sizes is not None and period is not None:
callback_list.append(MultiResCallback(image_sizes,period))
for name, args in cfg.items():
# The ModelCheckpoint callback is handled separately.
if name == "ModelCheckpoint":
continue
if name in ("TensorBoard", "CSVLogger"):
raise ValueError(f"\nThe `{name}` callback is built-in and can't be redefined.{message}")
if name in ("ReduceLROnPlateau", "EarlyStopping"):
monitor, mode = _get_callback_monitoring(args, callback_name=name, message=message)
args.monitor = monitor
args.mode = mode
if name in lr_scheduler_names:
text = "lr_schedulers.{}".format(name)
else:
text = "tf.keras.callbacks.{}".format(name)
# Add the arguments to the callback string
# and evaluate it to get the callback object
text += collect_callback_args(name, args, message=message)
try:
callback = eval(text)
except:
raise ValueError("\nThe callback name `{}` is unknown, or its arguments are incomplete "
"or invalid\nReceived: {}{}".format(name, text, message))
callback_list.append(callback)
# Count LR schedulers
if name in ["ReduceLROnPlateau", "LearningRateScheduler"] + lr_scheduler_names:
num_lr_schedulers += 1
# Check that there is only one scheduler
if num_lr_schedulers > 1:
raise ValueError(f"\nFound more than one learning rate scheduler{message}")
# Add the Keras callback that saves the best model obtained so far
callback = _get_best_model_callback(cfg, saved_models_dir=saved_models_dir, message=message)
callback_list.append(callback)
# Add the Keras callback that saves the model at the end of the epoch
callback = tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(saved_models_dir, "last_weights.weights.h5"),
save_best_only=False,
save_weights_only=True,
save_freq='epoch')
callback_list.append(callback)
# Add the TensorBoard callback
callback = LRTensorBoard(log_dir=log_dir)
callback_list.append(callback)
# Add the CVSLogger callback (must be last in the list
# of callbacks to make sure it records the learning rate)
callback = tf.keras.callbacks.CSVLogger(os.path.join(metrics_dir, "train_metrics.csv"))
callback_list.append(callback)
return callback_list