import os from data_gen import DataGenerator from os import listdir from utils import ( iou, PlotLosses, dice_loss, focal_loss, categorical_loss, categorical_focal_loss, resolution2framesize3cha, resolution2framesize, bce_loss, ) import matplotlib.pyplot as plt import tensorflow as tf tf.config.run_functions_eagerly(True) # from keras.backend.tensorflow_backend import set_session import argparse import sys import numpy as np import thundernet_config as Thundernet_config from datetime import datetime from matplotlib import pyplot as plt from model.model import Thundernet as Thundernet_original from model.model_ppm_factors import Thundernet as Thundernet_ppm from pathlib import Path from collections import defaultdict import copy plt.switch_backend("agg") parser = argparse.ArgumentParser() parser.add_argument( "--train_dir", type=str, default=Thundernet_config.train_path, help="The directory containing the train image dataset.", ) parser.add_argument( "--val_dir", type=str, default=Thundernet_config.val_path, help="The directory containing the validation image dataset.", ) parser.add_argument( "--batch_size", type=int, default=Thundernet_config.batch_size, choices=[1, 2, 4, 8, 16], help="Batch size used for training Thundernet", ) parser.add_argument( "--augment", type=bool, default=Thundernet_config.augment, choices=[False, True], help="Whether to use color augmentation for training Thundernet.", ) parser.add_argument( "--rand_crop", type=float, default=Thundernet_config.rand_crop, choices=[0, 0.02, 0.05, 0.1, 0.2, 0.5], help="Frequency of random crop data augmentation technique.", ) parser.add_argument( "--loss", type=str, default=Thundernet_config.loss, choices=["BCE", "BFL", "CFL", "DCL", "FTL", "CAT"], help="Loss function to be used - Binary Cross Entropy (BCE), Focal Loss (FL) , Dice Coefficient Loss (DCL) and Focal Tversky Loss (FTL)", ) parser.add_argument( "--model_dir", type=str, default=Thundernet_config.model_dir, help="Base directory for the models_repo. " "Make sure 'model_checkpoint_path' given in 'checkpoint' file matches " "with checkpoint name.", ) parser.add_argument( "--weights", type=dict, default=Thundernet_config.weights, help="Class weights used for Weighted Binary Cross Entropy Loss.", ) parser.add_argument( "--lr", type=float, default=Thundernet_config.lr, help="Learning Rate." ) parser.add_argument( "--epochs", type=int, default=Thundernet_config.epochs, help="Epochs" ) parser.add_argument( "--classes", type=int, default=Thundernet_config.classes, help="Epochs" ) parser.add_argument( "--resolution", type=str, default=Thundernet_config.resolution, help="Input Resolution", ) parser.add_argument( "--kernel_regularizer", type=float, default=Thundernet_config.kernel_regularizer, help="kernel_regularizer", ) parser.add_argument( "--pretrained", type=bool, default=Thundernet_config.pretrained_bool, help="In case you want to train", ) parser.add_argument( "--pretrained_weigths", type=str, default=Thundernet_config.pretrained_weigths, help="In case you want to train", ) def main( args: list, transformations: tuple = tuple(), model: str = "original", class_mappings: dict = None, ): """ Train the model Args: - args (list): list of parsed arguments - model (str): type of model. Default: "original" - class_mappings (dict): class mapper. Default: None - transformations (tuple): list of transformations to execute in the data. Default: tuple() - show (bool): display the predictions. Default: False Returns: - None """ FLAGS: list = parser.parse_args(args) os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "0" # use id from $ nvidia-smi mypath_train = FLAGS.train_dir + "images/" label_path_train = FLAGS.train_dir + "labels/" list_IDs_train = [f[:-4] for f in listdir(mypath_train) if f[-4:] == ".jpg"] mypath_val = FLAGS.val_dir + "images/" label_path_val = FLAGS.val_dir + "labels/" list_IDs_val = [f[:-4] for f in listdir(mypath_val) if f[-4:] == ".jpg"] # First we assure that the dir for saving the experiments is created if not os.path.exists(FLAGS.model_dir): os.makedirs(FLAGS.model_dir) # For every trial of the same experiment we create a new subfolder k = 1 dir_created = False while not dir_created: model_dir = FLAGS.model_dir + str(k) + "/" if not os.path.exists(model_dir): os.makedirs(model_dir) dir_created = True else: k += 1 # Model if model == "original": Thundernet = Thundernet_original elif model == "ppm": Thundernet = Thundernet_ppm else: raise ValueError(f"Unknown model: {model}") # Class mappings if class_mappings is not None: FLAGS.classes = len(set(class_mappings.values())) + 1 # Write the file configuration in model_dir file = open(model_dir + "config.txt", "w") file.write("Experiment num " + str(k) + "\n") file.write("Fecha=" + str(datetime.now()) + "\n") file.write("Train with=" + FLAGS.train_dir + "\n") file.write("Val with=" + FLAGS.val_dir + "\n") file.write("Input Resoltuion with=" + FLAGS.resolution + "\n") file.write("Batch Size=" + str(FLAGS.batch_size) + "\n") file.write("Batch augment=" + str(FLAGS.augment) + "\n") file.write("Rand Crop=" + str(FLAGS.rand_crop) + "\n") file.write("Loss=" + FLAGS.loss + "\n") file.write("Model dir=" + FLAGS.model_dir + "\n") file.write("weights=" + str(FLAGS.weights) + "\n") file.write("lr=" + str(FLAGS.lr) + "\n") file.write("epochs=" + str(FLAGS.epochs) + "\n") file.write("classes=" + str(FLAGS.classes) + "\n") file.write("kernel_regularizer=" + str(FLAGS.kernel_regularizer) + "\n") file.write("pretrained=" + str(FLAGS.pretrained) + "\n") file.write("pretrained_weigths=" + str(FLAGS.pretrained_weigths) + "\n") file.write("Class mappings=" + str(class_mappings) + "\n") file.write("Model=" + model + "\n") file.write(f"Transformations: {transformations}\n") file.write("Comentarios=" + "" + "\n") file.close() print( "resolution2framesize3cha(FLAGS.resolution) ", resolution2framesize3cha(FLAGS.resolution), ) thundernet = Thundernet( input_shape=resolution2framesize3cha(FLAGS.resolution), n_classes=FLAGS.classes, resnet_trainable=True, kernel_regularizer=FLAGS.kernel_regularizer, ) if FLAGS.pretrained: print("loading weights from", FLAGS.pretrained_weigths) thundernet.model.load_weights( FLAGS.pretrained_weigths, by_name=True, skip_mismatch=True ) lr = FLAGS.lr opt = tf.keras.optimizers.Adam(learning_rate=lr) # for keras 2.6.0 if not model_dir.endswith(os.path.sep): model_dir += os.path.sep callbacks = [ PlotLosses(model_dir), tf.keras.callbacks.ModelCheckpoint( filepath=os.path.normpath( os.path.join( model_dir, f"BS{FLAGS.batch_size}_loss{FLAGS.loss}_weights_lr_{lr}_reg-{FLAGS.kernel_regularizer}-ep-{{epoch}}-val_loss{{val_loss}}-train_loss{{loss}}-val_iou{{val_iou}}-train_iou{{iou}}.hdf5", ) ), save_best_only=True, save_weights_only=True, ), ] if FLAGS.loss == "BCE": loss = bce_loss() elif FLAGS.loss == "BFL": loss = focal_loss() elif FLAGS.loss == "DCL": loss = dice_loss() elif FLAGS.loss == "CFL": loss = categorical_focal_loss() elif FLAGS.loss == "CAT": loss = categorical_loss() thundernet.model.compile(loss=loss, optimizer=opt, metrics=[iou]) dataset_dir = Path(Thundernet_config.train_path).parent training_generator, validation_generator = DataGenerator.create_generators( dataset_dir, FLAGS.classes, training_batch_size=Thundernet_config.batch_size, validation_batch_size=Thundernet_config.batch_size, to_stereo=False, transformations=transformations, class_mappings=class_mappings, ) if FLAGS.loss == "BCE": weights = FLAGS.weights else: weights = FLAGS.weights history = thundernet.model.fit_generator( generator=training_generator, validation_data=validation_generator, callbacks=callbacks, use_multiprocessing=False, workers=6, epochs=FLAGS.epochs, class_weight=None, ) if __name__ == "__main__": main(sys.argv[1:], model="original", class_mappings=defaultdict(int, {1: 1})) # main(sys.argv[1:], model="ppm", class_mappings=defaultdict(int, {1: 1})) # main(sys.argv[1:], model='original', class_mappings=defaultdict(int, {1: 1, 2: 2, 5: 3})) # In case you also want to segment two specific type of objects (original class_id=2 and class_id=5) # main(sys.argv[1:], model='ppm', class_mappings=defaultdict(int, {1: 1, 2: 2, 5: 2})) # In case you want to treat both objects as the same class