| import os |
| from data_gen import DataGenerator |
| from os import listdir |
| from utils import ( |
| iou, |
| PlotLosses, |
| dice_loss, |
| focal_loss, |
| categorical_loss, |
| categorical_focal_loss, |
| resolution2framesize3cha, |
| resolution2framesize, |
| bce_loss, |
| ) |
| import matplotlib.pyplot as plt |
| import tensorflow as tf |
|
|
| tf.config.run_functions_eagerly(True) |
| |
| import argparse |
| import sys |
| import numpy as np |
| import thundernet_config as Thundernet_config |
| from datetime import datetime |
| from matplotlib import pyplot as plt |
|
|
| from model.model import Thundernet as Thundernet_original |
| from model.model_ppm_factors import Thundernet as Thundernet_ppm |
|
|
| from pathlib import Path |
| from collections import defaultdict |
| import copy |
|
|
| plt.switch_backend("agg") |
|
|
| parser = argparse.ArgumentParser() |
|
|
| parser.add_argument( |
| "--train_dir", |
| type=str, |
| default=Thundernet_config.train_path, |
| help="The directory containing the train image dataset.", |
| ) |
|
|
| parser.add_argument( |
| "--val_dir", |
| type=str, |
| default=Thundernet_config.val_path, |
| help="The directory containing the validation image dataset.", |
| ) |
|
|
| parser.add_argument( |
| "--batch_size", |
| type=int, |
| default=Thundernet_config.batch_size, |
| choices=[1, 2, 4, 8, 16], |
| help="Batch size used for training Thundernet", |
| ) |
|
|
| parser.add_argument( |
| "--augment", |
| type=bool, |
| default=Thundernet_config.augment, |
| choices=[False, True], |
| help="Whether to use color augmentation for training Thundernet.", |
| ) |
|
|
| parser.add_argument( |
| "--rand_crop", |
| type=float, |
| default=Thundernet_config.rand_crop, |
| choices=[0, 0.02, 0.05, 0.1, 0.2, 0.5], |
| help="Frequency of random crop data augmentation technique.", |
| ) |
|
|
| parser.add_argument( |
| "--loss", |
| type=str, |
| default=Thundernet_config.loss, |
| choices=["BCE", "BFL", "CFL", "DCL", "FTL", "CAT"], |
| help="Loss function to be used - Binary Cross Entropy (BCE), Focal Loss (FL) , Dice Coefficient Loss (DCL) and Focal Tversky Loss (FTL)", |
| ) |
|
|
| parser.add_argument( |
| "--model_dir", |
| type=str, |
| default=Thundernet_config.model_dir, |
| help="Base directory for the models_repo. " |
| "Make sure 'model_checkpoint_path' given in 'checkpoint' file matches " |
| "with checkpoint name.", |
| ) |
|
|
| parser.add_argument( |
| "--weights", |
| type=dict, |
| default=Thundernet_config.weights, |
| help="Class weights used for Weighted Binary Cross Entropy Loss.", |
| ) |
|
|
| parser.add_argument( |
| "--lr", type=float, default=Thundernet_config.lr, help="Learning Rate." |
| ) |
|
|
| parser.add_argument( |
| "--epochs", type=int, default=Thundernet_config.epochs, help="Epochs" |
| ) |
|
|
| parser.add_argument( |
| "--classes", type=int, default=Thundernet_config.classes, help="Epochs" |
| ) |
|
|
| parser.add_argument( |
| "--resolution", |
| type=str, |
| default=Thundernet_config.resolution, |
| help="Input Resolution", |
| ) |
|
|
| parser.add_argument( |
| "--kernel_regularizer", |
| type=float, |
| default=Thundernet_config.kernel_regularizer, |
| help="kernel_regularizer", |
| ) |
|
|
| parser.add_argument( |
| "--pretrained", |
| type=bool, |
| default=Thundernet_config.pretrained_bool, |
| help="In case you want to train", |
| ) |
|
|
| parser.add_argument( |
| "--pretrained_weigths", |
| type=str, |
| default=Thundernet_config.pretrained_weigths, |
| help="In case you want to train", |
| ) |
|
|
|
|
| def main( |
| args: list, |
| transformations: tuple = tuple(), |
| model: str = "original", |
| class_mappings: dict = None, |
| ): |
| """ |
| Train the model |
| Args: |
| - args (list): list of parsed arguments |
| - model (str): type of model. Default: "original" |
| - class_mappings (dict): class mapper. Default: None |
| - transformations (tuple): list of transformations to execute in the data. Default: tuple() |
| - show (bool): display the predictions. Default: False |
| Returns: |
| - None |
| """ |
|
|
| FLAGS: list = parser.parse_args(args) |
|
|
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|
|
| mypath_train = FLAGS.train_dir + "images/" |
| label_path_train = FLAGS.train_dir + "labels/" |
|
|
| list_IDs_train = [f[:-4] for f in listdir(mypath_train) if f[-4:] == ".jpg"] |
| mypath_val = FLAGS.val_dir + "images/" |
| label_path_val = FLAGS.val_dir + "labels/" |
| list_IDs_val = [f[:-4] for f in listdir(mypath_val) if f[-4:] == ".jpg"] |
|
|
| |
| if not os.path.exists(FLAGS.model_dir): |
| os.makedirs(FLAGS.model_dir) |
|
|
| |
| k = 1 |
| dir_created = False |
| while not dir_created: |
| model_dir = FLAGS.model_dir + str(k) + "/" |
| if not os.path.exists(model_dir): |
| os.makedirs(model_dir) |
| dir_created = True |
| else: |
| k += 1 |
|
|
| |
| if model == "original": |
| Thundernet = Thundernet_original |
| elif model == "ppm": |
| Thundernet = Thundernet_ppm |
| else: |
| raise ValueError(f"Unknown model: {model}") |
|
|
| |
|
|
| if class_mappings is not None: |
| FLAGS.classes = len(set(class_mappings.values())) + 1 |
|
|
| |
| file = open(model_dir + "config.txt", "w") |
| file.write("Experiment num " + str(k) + "\n") |
| file.write("Fecha=" + str(datetime.now()) + "\n") |
| file.write("Train with=" + FLAGS.train_dir + "\n") |
| file.write("Val with=" + FLAGS.val_dir + "\n") |
| file.write("Input Resoltuion with=" + FLAGS.resolution + "\n") |
| file.write("Batch Size=" + str(FLAGS.batch_size) + "\n") |
| file.write("Batch augment=" + str(FLAGS.augment) + "\n") |
| file.write("Rand Crop=" + str(FLAGS.rand_crop) + "\n") |
| file.write("Loss=" + FLAGS.loss + "\n") |
| file.write("Model dir=" + FLAGS.model_dir + "\n") |
| file.write("weights=" + str(FLAGS.weights) + "\n") |
| file.write("lr=" + str(FLAGS.lr) + "\n") |
| file.write("epochs=" + str(FLAGS.epochs) + "\n") |
| file.write("classes=" + str(FLAGS.classes) + "\n") |
| file.write("kernel_regularizer=" + str(FLAGS.kernel_regularizer) + "\n") |
| file.write("pretrained=" + str(FLAGS.pretrained) + "\n") |
| file.write("pretrained_weigths=" + str(FLAGS.pretrained_weigths) + "\n") |
| file.write("Class mappings=" + str(class_mappings) + "\n") |
| file.write("Model=" + model + "\n") |
| file.write(f"Transformations: {transformations}\n") |
| file.write("Comentarios=" + "" + "\n") |
| file.close() |
|
|
| print( |
| "resolution2framesize3cha(FLAGS.resolution) ", |
| resolution2framesize3cha(FLAGS.resolution), |
| ) |
| thundernet = Thundernet( |
| input_shape=resolution2framesize3cha(FLAGS.resolution), |
| n_classes=FLAGS.classes, |
| resnet_trainable=True, |
| kernel_regularizer=FLAGS.kernel_regularizer, |
| ) |
|
|
| if FLAGS.pretrained: |
| print("loading weights from", FLAGS.pretrained_weigths) |
| thundernet.model.load_weights( |
| FLAGS.pretrained_weigths, by_name=True, skip_mismatch=True |
| ) |
|
|
| lr = FLAGS.lr |
| opt = tf.keras.optimizers.Adam(learning_rate=lr) |
|
|
| if not model_dir.endswith(os.path.sep): |
| model_dir += os.path.sep |
|
|
| callbacks = [ |
| PlotLosses(model_dir), |
| tf.keras.callbacks.ModelCheckpoint( |
| filepath=os.path.normpath( |
| os.path.join( |
| model_dir, |
| f"BS{FLAGS.batch_size}_loss{FLAGS.loss}_weights_lr_{lr}_reg-{FLAGS.kernel_regularizer}-ep-{{epoch}}-val_loss{{val_loss}}-train_loss{{loss}}-val_iou{{val_iou}}-train_iou{{iou}}.hdf5", |
| ) |
| ), |
| save_best_only=True, |
| save_weights_only=True, |
| ), |
| ] |
|
|
| if FLAGS.loss == "BCE": |
| loss = bce_loss() |
| elif FLAGS.loss == "BFL": |
| loss = focal_loss() |
| elif FLAGS.loss == "DCL": |
| loss = dice_loss() |
| elif FLAGS.loss == "CFL": |
| loss = categorical_focal_loss() |
| elif FLAGS.loss == "CAT": |
| loss = categorical_loss() |
|
|
| thundernet.model.compile(loss=loss, optimizer=opt, metrics=[iou]) |
|
|
| dataset_dir = Path(Thundernet_config.train_path).parent |
|
|
| training_generator, validation_generator = DataGenerator.create_generators( |
| dataset_dir, |
| FLAGS.classes, |
| training_batch_size=Thundernet_config.batch_size, |
| validation_batch_size=Thundernet_config.batch_size, |
| to_stereo=False, |
| transformations=transformations, |
| class_mappings=class_mappings, |
| ) |
|
|
| if FLAGS.loss == "BCE": |
|
|
| weights = FLAGS.weights |
| else: |
|
|
| weights = FLAGS.weights |
|
|
| history = thundernet.model.fit_generator( |
| generator=training_generator, |
| validation_data=validation_generator, |
| callbacks=callbacks, |
| use_multiprocessing=False, |
| workers=6, |
| epochs=FLAGS.epochs, |
| class_weight=None, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
|
|
| main(sys.argv[1:], model="original", class_mappings=defaultdict(int, {1: 1})) |
| |
| |
| |
|
|
|
|