| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import sys |
| import yaml |
| from yaml.loader import SafeLoader |
| from omegaconf import DictConfig |
| from munch import DefaultMunch |
| import tensorflow as tf |
| import matplotlib.pyplot as plt |
| from common.utils import postprocess_config_dict, collect_callback_args |
| from common.training import lr_schedulers |
|
|
|
|
| def _get_learning_rate_scheduler(cfg: DictConfig) -> tf.keras.callbacks.Callback: |
| """ |
| Extracts the learning rate callback and attributes. |
| |
| Args: |
| cfg (DictConfig): Dictionary containing the 'callback' section |
| of the configuration file. |
| Returns: |
| scheduler_name (str): The name of the scheduler |
| scheduler (tf.keras.callbacks.Callback) : The scheduler call back to be plotted |
| """ |
| |
| |
| lr_scheduler_names = lr_schedulers.get_scheduler_names() + ["LearningRateScheduler"] |
|
|
| message = "\nPlease check the 'training.callbacks' section of your configuration file." |
|
|
| scheduler_name = None |
| scheduler = None |
| num_lr_schedulers = 0 |
| for name, args in cfg.items(): |
| if name == "ReduceLROnPlateau": |
| raise ValueError("\nUnable to plot the learning rate before training when using " |
| "the `ReduceLROnPlateau` scheduler\n" |
| f"The learning rate schedule is only available after training.{message}") |
| if name in lr_scheduler_names: |
| scheduler_name = name |
| num_lr_schedulers += 1 |
| if name == "LearningRateScheduler": |
| text = "tf.keras.callbacks.LearningRateScheduler" |
| else: |
| text = "lr_schedulers.{}".format(name) |
|
|
| |
| text += collect_callback_args(name, args=cfg[name], message=message) |
|
|
| |
| try: |
| scheduler = eval(text) |
| except: |
| raise ValueError(f"\nThe arguments of the `{name}` callback are incomplete or invalid.\n" |
| f"Received: {text}{message}") |
| if scheduler is None: |
| raise ValueError(f"\nCould not find a learning rate scheduler{message}") |
| if num_lr_schedulers > 1: |
| raise ValueError(f"\nFound more than one learning rate scheduler{message}") |
|
|
| return scheduler_name, scheduler |
|
|
|
|
| def _get_initial_learning_rate(cfg: DictConfig) -> float: |
| """ |
| The learning rate scheduler may need the initial learning rate provided |
| in the optimizer. If the learning_rate attribute is present in the optimizer section, we get its value. |
| |
| Args: |
| cfg (DictConfig): Dictionary containing the 'optimizer' section |
| of the configuration file. |
| Returns: |
| lr (float): The value of the 'learning_rate' argument of the optimizer. |
| """ |
| |
| optimizer_name = list(cfg.keys())[0] |
| optimizer_args = cfg[optimizer_name] |
| optimizer_lr = None |
| if optimizer_args: |
| for k, v in optimizer_args.items(): |
| if k == "learning_rate" or k == "lr": |
| optimizer_lr = v |
| |
| if not optimizer_lr: |
| optimizer_lr = 0.01 if optimizer_name == "SGD" else 0.001 |
| return optimizer_lr |
|
|
|
|
| def _plot_lr_schedule(scheduler_name: str, |
| scheduler: tf.keras.callbacks.Callback, |
| epochs: int = None, |
| initial_lr: float = None, |
| fname: str = None) -> None: |
| """ |
| This function plots the learning rate schedule for a given number of epochs. |
| |
| Args: |
| scheduler_name (str): name of the scheduler callback. |
| scheduler (tf.keras.callbacks.Callback): learning rate scheduler callback. |
| epochs (int): number of epochs to plot. |
| initial_lr (float): initial learning given in argument to the optimizer. |
| fname (str): filename to use the save the plot. |
| |
| Returns: |
| None |
| """ |
|
|
| learning_rate = [] |
| lr = initial_lr |
| for e in range(epochs): |
| lr = scheduler.schedule(e, lr) |
| learning_rate.append(lr) |
|
|
| plt.plot(learning_rate) |
| plt.title(f"{scheduler_name} Learning Rate Schedule") |
| plt.xlabel("epochs") |
| plt.ylabel("learning rate") |
| if fname: |
| plt.savefig(fname) |
| plt.show() |
|
|
|
|
| def plot_learning_rate_schedule(config_file_path : str=None, |
| fname : str=None) -> None : |
| """ |
| This function is the top routine to get and plot the learning rate schedule for a given number of epochs. |
| |
| Args: |
| config_file_path (str): path of the .yaml file with training information. |
| fname (str): filename to use the save the plot. |
| |
| Returns: |
| None |
| """ |
| |
| with open(config_file_path) as f: |
| config = yaml.load(f, Loader=SafeLoader) |
| postprocess_config_dict(config) |
| cfg = DefaultMunch.fromDict(config) |
|
|
| |
| if not cfg.training: |
| raise ValueError("\nThe configuration file should include a 'training' section.") |
| if not cfg.training.optimizer: |
| raise ValueError("\nThe configuration file should include a 'training.optimizer' section.") |
| |
| |
| |
| if type(cfg.training.optimizer) == str: |
| cfg.training.optimizer = DefaultMunch.fromDict({cfg.training.optimizer: None}) |
|
|
| if "callbacks" not in cfg.training: |
| raise ValueError("\nThe configuration file should include a 'training.callbacks' section.") |
| if cfg.training.callbacks is None: |
| cfg.training.callbacks = {} |
|
|
| scheduler_name, scheduler = _get_learning_rate_scheduler(cfg.training.callbacks) |
| |
| initial_lr = _get_initial_learning_rate(cfg.training.optimizer) |
| _plot_lr_schedule(scheduler_name, scheduler, epochs=cfg.training.epochs, initial_lr=initial_lr, fname=fname) |
|
|
|
|