File size: 16,479 Bytes
747451d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373

# /*---------------------------------------------------------------------------------------------
#  * Copyright (c) 2022-2023 STMicroelectronics.
#  * All rights reserved.
#  *
#  * This software is licensed under terms that can be found in the LICENSE file in
#  * the root directory of this software component.
#  * If no LICENSE file comes with this software, it is provided AS-IS.
#  *--------------------------------------------------------------------------------------------*/

# Import necessary libraries
import os
import sys
from timeit import default_timer as timer
from datetime import timedelta
from typing import Tuple, List, Dict, Optional

import mlflow
from hydra.core.hydra_config import HydraConfig
from munch import DefaultMunch
from omegaconf import DictConfig
import numpy as np 
import tensorflow as tf

# Suppress TensorFlow warnings to reduce log clutter
import logging
logging.getLogger('mlflow.tensorflow').setLevel(logging.ERROR)
logging.getLogger('tensorflow').setLevel(logging.ERROR)

# Import utility functions and modules
from common.utils import (
    log_to_file, log_last_epoch_history, LRTensorBoard, check_training_determinism,
    model_summary, collect_callback_args, vis_training_curves
)
from common.training import (
    set_frozen_layers, set_dropout_rate, get_optimizer, lr_schedulers,
    set_all_layers_trainable_parameter
)
from image_classification.tf.src.utils import get_loss, change_model_number_of_classes, change_model_input_shape
from image_classification.tf.src.data_augmentation import DataAugmentationLayer


# Define a custom callback for multi-resolution training
class MultiResCallback(tf.keras.callbacks.Callback):
    """

    A custom Keras callback to dynamically change the input resolution

    of the model during training.



    Args:

        image_sizes (List[int]): List of resolutions to cycle through.

        period (int): Number of batches before changing resolution.

        name (str, optional): Name of the callback.

    """
    def __init__(self, image_sizes, period, name=None):
        super().__init__()
        self.resolutions = image_sizes
        self.period = period

    def on_train_batch_begin(self, batch, logs=None):
        # Change the resolution of the input layer based on the batch number
        res = self.resolutions[((batch - 1) // self.period) % len(self.resolutions)]
        self.model.layers[0].change_res(res)


# Function to add preprocessing layers to the model
def _add_preprocessing_layers(

        model: tf.keras.Model,

        input_shape: Tuple = None,

        scale: float = None,

        offset: float = None,

        mean: float = None,

        std: float = None,

        data_augmentation: Dict = None,

        batches_per_epoch: float = None):
    """

    Adds preprocessing layers (rescaling and data augmentation) to the model.



    Args:

        model (tf.keras.Model): The base model.

        input_shape (Tuple): Input shape of the model.

        scale (float): Scaling factor for rescaling.

        offset (float): Offset for rescaling.

        mean (float): Mean for normalization.

        std (float): Standard deviation for normalization.

        data_augmentation (Dict): Data augmentation configuration.

        batches_per_epoch (float): Number of training batches per epoch.



    Returns:

        tf.keras.Model: The augmented model with preprocessing layers.

    """
    data_aug_args = DefaultMunch.fromDict(data_augmentation.config)
    if data_aug_args.random_periodic_resizing is not None:
        model, _ = change_model_input_shape(model, (None, None, None, 3))

    model_layers = []
    model_layers.append(tf.keras.Input(shape=input_shape))

    # Add data augmentation layer if specified
    if data_augmentation:
        # defining rescaling and normalization in case the three values are provided for std and mean
        if isinstance(std, float) and isinstance(mean, float):
            pixels_range = ((offset - mean) / std, (scale * 255 + offset - mean) / std)
        elif isinstance(std, list) and isinstance(mean, list):
            if len(std) != 3 or len(mean) != 3:
                raise ValueError("If std and mean are lists, they must have three elements each.")
            pixel_range_min = [(offset - m) / s for m, s in zip(mean, std)]
            pixel_range_max = [(scale * 255 + offset - m) / s for m, s in zip(mean, std)]
            pixels_range = (min(pixel_range_min), max(pixel_range_max))
        else:
            raise TypeError("std and mean must be either floats or lists of length 3.")

        model_layers.append(
            DataAugmentationLayer(
                data_augmentation_fn=data_augmentation.function_name,
                config=data_augmentation.config,
                pixels_range=pixels_range,
                batches_per_epoch=batches_per_epoch
            )
        )
    model_layers.append(model)
    augmented_model = tf.keras.Sequential(model_layers, name="augmented_model")

    return augmented_model


# Function to create Keras callbacks
def _get_callbacks(callbacks_dict: DictConfig, output_dir: str = None, logs_dir: str = None,

                   saved_models_dir: str = None) -> List[tf.keras.callbacks.Callback]:
    """

    Creates a list of Keras callbacks for training.



    Args:

        callbacks_dict (DictConfig): Configuration for callbacks.

        output_dir (str): Directory for saving outputs.

        logs_dir (str): Directory for saving logs.

        saved_models_dir (str): Directory for saving models.



    For each callback, the attributes and their values used in the config

    file are used to create a string that is the callback instantiation as

    it would be written in a Python script. Then, the string is evaluated.

    If the evaluation succeeds, the callback object is returned. If it fails,

    an error is thrown with a message saying that the name and/or arguments

    of the callback are incorrect.



    Returns:

        List[tf.keras.callbacks.Callback]: List of callbacks.

    """
    message = "\nPlease check the 'training.callbacks' section of your configuration file."
    lr_scheduler_names = lr_schedulers.get_scheduler_names()
    num_lr_schedulers = 0

    # Generate the callbacks used in the config file (there may be none)
    callback_list = []
    if callbacks_dict is not None:
        if type(callbacks_dict) != DefaultMunch:
            raise ValueError(f"\nInvalid callbacks syntax{message}")
        for name in callbacks_dict.keys():
            if name in ("ModelCheckpoint", "TensorBoard", "CSVLogger"):
                raise ValueError(f"\nThe `{name}` callback is built-in and can't be redefined.{message}")
            elif name in lr_scheduler_names:
                text = f"lr_schedulers.{name}"
            elif name == 'MultiResCallback':
                text = f"{name}"
            else:
                text = f"tf.keras.callbacks.{name}"

            # Add the arguments to the callback string
            # and evaluate it to get the callback object
            text += collect_callback_args(name, args=callbacks_dict[name], message=message)
            try:
                callback = eval(text)
            except ValueError as error:
                raise ValueError(f"\nThe callback name `{name}` is unknown, or its arguments are incomplete "
                                 f"or invalid\nReceived: {text}{message}") from error
            callback_list.append(callback)

            if name in lr_scheduler_names + ["ReduceLROnPlateau", "LearningRateScheduler"]:
                num_lr_schedulers += 1
            
    # Check that there is only one scheduler
    if num_lr_schedulers > 1:
        raise ValueError(f"\nFound more than one learning rate scheduler{message}")

    # Add built-in callbacks that saves the best model obtained so far
    callback_list.append(tf.keras.callbacks.ModelCheckpoint(
        filepath=os.path.join(output_dir, saved_models_dir, "best_augmented_model.keras"),
        save_best_only=True,
        monitor="val_accuracy",
        mode="max"
    ))
    # Add the Keras callback that saves the model at the end of the epoch
    callback_list.append(tf.keras.callbacks.ModelCheckpoint(
        filepath=os.path.join(output_dir, saved_models_dir, "last_augmented_model.keras"),
        save_best_only=False,
        monitor="val_accuracy",
        mode="max"
    ))
    # Add the TensorBoard callback
    callback_list.append(LRTensorBoard(log_dir=os.path.join(output_dir, logs_dir)))
    # Add the CVSLogger callback (must be last in the list 
    # of callbacks to make sure it records the learning rate)
    callback_list.append(tf.keras.callbacks.CSVLogger(os.path.join(output_dir, logs_dir, "metrics", "train_metrics.csv")))

    return callback_list


# Main class for training image classification models
class ICTrainer:
    def __init__(self, cfg, model=None, dataloaders=None):
        """

        Initializes the trainer with configuration, model, and datasets.



        Args:

            cfg: Configuration object.

            model: TensorFlow model.

            dataloaders: Dictionary containing training, validation, and test datasets.

        """
        self.cfg = cfg
        self.model = model
        self.train_ds = dataloaders['train']
        self.valid_ds = dataloaders['valid']
        self.test_ds = dataloaders['test']

        self.output_dir = HydraConfig.get().runtime.output_dir
        self.saved_models_dir = cfg.general.saved_models_dir
        self.class_names = cfg.dataset.class_names
        self.num_classes = len(self.class_names)
        self.augmented_model = None
        self.callbacks = None
        self.history = None

    def prepare(self):
        """

        Prepares the model, datasets, and callbacks for training.

        """
        # Print dataset statistics
        print("Dataset stats:")
        train_size = sum([x.shape[0] for x, _ in self.train_ds])
        valid_size = sum([x.shape[0] for x, _ in self.valid_ds])
        if self.test_ds:
            test_size = sum([x.shape[0] for x, _ in self.test_ds])

        print("  classes:", self.num_classes)
        print("  training set size:", train_size)
        print("  validation set size:", valid_size)
        if self.test_ds:
            print("  test set size:", test_size)
        else:
            print("  no test set")

        # Log dataset information
        if self.cfg.dataset.dataset_name:
            log_to_file(self.output_dir, f"Dataset : {self.cfg.dataset.dataset_name}")

        # Prepare the model
        if self.cfg.model:
            cfm = self.cfg.model
            print(f"[INFO] : Using `{cfm.model_name}` model")
            log_to_file(self.cfg.output_dir, (f"Model name : {cfm.model_name}"))
        elif self.cfg.model.model_path:
            self.model = change_model_number_of_classes(self.model, self.num_classes)
            print(f"[INFO] : Initialized model with weights from model file {self.cfg.model.model_path}")
            log_to_file(self.cfg.output_dir, (f"Weights from model file : {self.cfg.model.model_path}"))

        # Add preprocessing layers if not resuming training
        if self.cfg.training.resume_training_from:
            model_summary(self.model)
            self.augmented_model = self.model
        else:
            model_summary(self.model)
            input_shape = tuple(self.model.inputs[0].shape[1:])
            self.augmented_model = _add_preprocessing_layers(
                self.model,
                input_shape=input_shape,
                scale=self.cfg.preprocessing.rescaling.scale,
                offset=self.cfg.preprocessing.rescaling.offset,
                mean=getattr(self.cfg.preprocessing.normalization, 'mean', 0.0),
                std=getattr(self.cfg.preprocessing.normalization, 'std', 1.0),
                data_augmentation=self.cfg.data_augmentation,
                batches_per_epoch=len(self.train_ds)
            )
            self.augmented_model.compile(
                loss=get_loss(num_classes=self.num_classes),
                metrics=['accuracy'],
                optimizer=get_optimizer(cfg=self.cfg.training.optimizer)
            )

        # Configure MultiResCallback if applicable
        data_aug_args = DefaultMunch.fromDict(self.cfg.data_augmentation.config)
        if data_aug_args.random_periodic_resizing is not None:
            rpr = DefaultMunch.fromDict(data_aug_args.random_periodic_resizing)
            if rpr.image_sizes is not None:
                self.cfg.training.callbacks['MultiResCallback'] = DefaultMunch.fromDict({
                    'image_sizes': rpr.image_sizes,
                    'period': rpr.period if rpr.period is not None else 10
                })
            else:
                print("[WARNING]: 'random_periodic_resizing' can't be used because [image_sizes] argument is missing.")

        # Generate callbacks
        self.callbacks = _get_callbacks(
            callbacks_dict=self.cfg.training.callbacks,
            output_dir=self.output_dir,
            saved_models_dir=self.saved_models_dir,
            logs_dir=self.cfg.general.logs_dir
        )

    def enable_determinism(self):
        """

        Enables deterministic operations for reproducibility.

        """
        if self.cfg.general.deterministic_ops:
            sample_ds = self.train_ds.take(1)
            tf.config.experimental.enable_op_determinism()
            if not check_training_determinism(self.augmented_model, sample_ds):
                print("[WARNING]: Some operations cannot be run deterministically. Setting deterministic_ops to False.")
                tf.config.experimental.enable_op_determinism.__globals__["_pywrap_determinism"].enable(False)

    def fit(self):
        """

        Trains the model using the training dataset.

        """
        print("Starting training...")
        start_time = timer()
        steps_per_epoch = self.cfg.training.dryrun if self.cfg.training.dryrun else None
        self.history = self.augmented_model.fit(
            self.train_ds,
            validation_data=self.valid_ds,
            epochs=self.cfg.training.epochs,
            steps_per_epoch=steps_per_epoch,
            callbacks=self.callbacks
        )
        last_epoch = log_last_epoch_history(self.cfg, self.output_dir)
        end_time = timer()
        fit_run_time = int(end_time - start_time)
        average_time_per_epoch = round(fit_run_time / (int(last_epoch) + 1), 2)
        print("Training runtime: " + str(timedelta(seconds=fit_run_time)))
        log_to_file(self.cfg.output_dir, (
            f"Training runtime : {fit_run_time} s\n" +
            f"Average time per epoch : {average_time_per_epoch} s"
        ))
        vis_training_curves(history=self.history, output_dir=self.output_dir)

    def save_and_evaluate(self):
        """

        Saves the best model and evaluates it on validation and test datasets.

        """
        # Load the best model checkpoint
        models_dir = os.path.join(self.output_dir, self.saved_models_dir)
        checkpoint_filepath = os.path.join(models_dir, "best_augmented_model.keras")
        checkpoint_model = tf.keras.models.load_model(
            checkpoint_filepath,
            custom_objects={'DataAugmentationLayer': DataAugmentationLayer}
        )
        output_model_input_shape = tuple(self.model.inputs[0].shape)
        best_model = checkpoint_model.layers[-1]
        best_model, _ = change_model_input_shape(best_model, output_model_input_shape)
        best_model.compile(loss=get_loss(self.num_classes), metrics=['accuracy'])
        best_model_path = os.path.join(self.output_dir, f"{self.saved_models_dir}/best_model.keras")
        best_model.save(best_model_path)
        setattr(best_model, 'model_path', best_model_path)
        print('[INFO] : Training complete.')
        return best_model

    def train(self):
        """

        Executes the full training pipeline: prepare, train, save, and evaluate.

        """
        self.prepare()
        self.enable_determinism()
        self.fit()
        return self.save_and_evaluate()