File size: 6,587 Bytes
747451d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | # /*---------------------------------------------------------------------------------------------
# * Copyright (c) 2022-2023 STMicroelectronics.
# * All rights reserved.
# *
# * This software is licensed under terms that can be found in the LICENSE file in
# * the root directory of this software component.
# * If no LICENSE file comes with this software, it is provided AS-IS.
# *--------------------------------------------------------------------------------------------*/
import os
import sys
import yaml
from yaml.loader import SafeLoader
from omegaconf import DictConfig
from munch import DefaultMunch
import tensorflow as tf
import matplotlib.pyplot as plt
from common.utils import postprocess_config_dict, collect_callback_args
from common.training import lr_schedulers
def _get_learning_rate_scheduler(cfg: DictConfig) -> tf.keras.callbacks.Callback:
"""
Extracts the learning rate callback and attributes.
Args:
cfg (DictConfig): Dictionary containing the 'callback' section
of the configuration file.
Returns:
scheduler_name (str): The name of the scheduler
scheduler (tf.keras.callbacks.Callback) : The scheduler call back to be plotted
"""
# Get the list of all learning rate scheduler callback names.
# This includes the Keras LearningRateScheduler callback.
lr_scheduler_names = lr_schedulers.get_scheduler_names() + ["LearningRateScheduler"]
message = "\nPlease check the 'training.callbacks' section of your configuration file."
scheduler_name = None
scheduler = None
num_lr_schedulers = 0
for name, args in cfg.items():
if name == "ReduceLROnPlateau":
raise ValueError("\nUnable to plot the learning rate before training when using "
"the `ReduceLROnPlateau` scheduler\n"
f"The learning rate schedule is only available after training.{message}")
if name in lr_scheduler_names:
scheduler_name = name
num_lr_schedulers += 1
if name == "LearningRateScheduler":
text = "tf.keras.callbacks.LearningRateScheduler"
else:
text = "lr_schedulers.{}".format(name)
# Collect the callback arguments
text += collect_callback_args(name, args=cfg[name], message=message)
# Evaluate the callback string
try:
scheduler = eval(text)
except:
raise ValueError(f"\nThe arguments of the `{name}` callback are incomplete or invalid.\n"
f"Received: {text}{message}")
if scheduler is None:
raise ValueError(f"\nCould not find a learning rate scheduler{message}")
if num_lr_schedulers > 1:
raise ValueError(f"\nFound more than one learning rate scheduler{message}")
return scheduler_name, scheduler
def _get_initial_learning_rate(cfg: DictConfig) -> float:
"""
The learning rate scheduler may need the initial learning rate provided
in the optimizer. If the learning_rate attribute is present in the optimizer section, we get its value.
Args:
cfg (DictConfig): Dictionary containing the 'optimizer' section
of the configuration file.
Returns:
lr (float): The value of the 'learning_rate' argument of the optimizer.
"""
# Look for the 'lr' or 'learning_rate' argument
optimizer_name = list(cfg.keys())[0]
optimizer_args = cfg[optimizer_name]
optimizer_lr = None
if optimizer_args:
for k, v in optimizer_args.items():
if k == "learning_rate" or k == "lr":
optimizer_lr = v
if not optimizer_lr:
optimizer_lr = 0.01 if optimizer_name == "SGD" else 0.001
return optimizer_lr
def _plot_lr_schedule(scheduler_name: str,
scheduler: tf.keras.callbacks.Callback,
epochs: int = None,
initial_lr: float = None,
fname: str = None) -> None:
"""
This function plots the learning rate schedule for a given number of epochs.
Args:
scheduler_name (str): name of the scheduler callback.
scheduler (tf.keras.callbacks.Callback): learning rate scheduler callback.
epochs (int): number of epochs to plot.
initial_lr (float): initial learning given in argument to the optimizer.
fname (str): filename to use the save the plot.
Returns:
None
"""
learning_rate = []
lr = initial_lr
for e in range(epochs):
lr = scheduler.schedule(e, lr)
learning_rate.append(lr)
plt.plot(learning_rate)
plt.title(f"{scheduler_name} Learning Rate Schedule")
plt.xlabel("epochs")
plt.ylabel("learning rate")
if fname:
plt.savefig(fname)
plt.show()
def plot_learning_rate_schedule(config_file_path : str=None,
fname : str=None) -> None :
"""
This function is the top routine to get and plot the learning rate schedule for a given number of epochs.
Args:
config_file_path (str): path of the .yaml file with training information.
fname (str): filename to use the save the plot.
Returns:
None
"""
# Load and postprocess the config file
with open(config_file_path) as f:
config = yaml.load(f, Loader=SafeLoader)
postprocess_config_dict(config)
cfg = DefaultMunch.fromDict(config)
# Check that the required sections of the config file are present
if not cfg.training:
raise ValueError("\nThe configuration file should include a 'training' section.")
if not cfg.training.optimizer:
raise ValueError("\nThe configuration file should include a 'training.optimizer' section.")
# If it has no argument, the optimizer may written as below:
# optimizer: Adam
if type(cfg.training.optimizer) == str:
cfg.training.optimizer = DefaultMunch.fromDict({cfg.training.optimizer: None})
if "callbacks" not in cfg.training:
raise ValueError("\nThe configuration file should include a 'training.callbacks' section.")
if cfg.training.callbacks is None:
cfg.training.callbacks = {}
scheduler_name, scheduler = _get_learning_rate_scheduler(cfg.training.callbacks)
initial_lr = _get_initial_learning_rate(cfg.training.optimizer)
_plot_lr_schedule(scheduler_name, scheduler, epochs=cfg.training.epochs, initial_lr=initial_lr, fname=fname)
|