Thundernet / train_config.py
ExtendedRealityLab's picture
Add files using upload-large-folder tool
ae29340 verified
import os
from data_gen import DataGenerator
from os import listdir
from utils import (
iou,
PlotLosses,
dice_loss,
focal_loss,
categorical_loss,
categorical_focal_loss,
resolution2framesize3cha,
resolution2framesize,
bce_loss,
)
import matplotlib.pyplot as plt
import tensorflow as tf
tf.config.run_functions_eagerly(True)
# from keras.backend.tensorflow_backend import set_session
import argparse
import sys
import numpy as np
import thundernet_config as Thundernet_config
from datetime import datetime
from matplotlib import pyplot as plt
from model.model import Thundernet as Thundernet_original
from model.model_ppm_factors import Thundernet as Thundernet_ppm
from pathlib import Path
from collections import defaultdict
import copy
plt.switch_backend("agg")
parser = argparse.ArgumentParser()
parser.add_argument(
"--train_dir",
type=str,
default=Thundernet_config.train_path,
help="The directory containing the train image dataset.",
)
parser.add_argument(
"--val_dir",
type=str,
default=Thundernet_config.val_path,
help="The directory containing the validation image dataset.",
)
parser.add_argument(
"--batch_size",
type=int,
default=Thundernet_config.batch_size,
choices=[1, 2, 4, 8, 16],
help="Batch size used for training Thundernet",
)
parser.add_argument(
"--augment",
type=bool,
default=Thundernet_config.augment,
choices=[False, True],
help="Whether to use color augmentation for training Thundernet.",
)
parser.add_argument(
"--rand_crop",
type=float,
default=Thundernet_config.rand_crop,
choices=[0, 0.02, 0.05, 0.1, 0.2, 0.5],
help="Frequency of random crop data augmentation technique.",
)
parser.add_argument(
"--loss",
type=str,
default=Thundernet_config.loss,
choices=["BCE", "BFL", "CFL", "DCL", "FTL", "CAT"],
help="Loss function to be used - Binary Cross Entropy (BCE), Focal Loss (FL) , Dice Coefficient Loss (DCL) and Focal Tversky Loss (FTL)",
)
parser.add_argument(
"--model_dir",
type=str,
default=Thundernet_config.model_dir,
help="Base directory for the models_repo. "
"Make sure 'model_checkpoint_path' given in 'checkpoint' file matches "
"with checkpoint name.",
)
parser.add_argument(
"--weights",
type=dict,
default=Thundernet_config.weights,
help="Class weights used for Weighted Binary Cross Entropy Loss.",
)
parser.add_argument(
"--lr", type=float, default=Thundernet_config.lr, help="Learning Rate."
)
parser.add_argument(
"--epochs", type=int, default=Thundernet_config.epochs, help="Epochs"
)
parser.add_argument(
"--classes", type=int, default=Thundernet_config.classes, help="Epochs"
)
parser.add_argument(
"--resolution",
type=str,
default=Thundernet_config.resolution,
help="Input Resolution",
)
parser.add_argument(
"--kernel_regularizer",
type=float,
default=Thundernet_config.kernel_regularizer,
help="kernel_regularizer",
)
parser.add_argument(
"--pretrained",
type=bool,
default=Thundernet_config.pretrained_bool,
help="In case you want to train",
)
parser.add_argument(
"--pretrained_weigths",
type=str,
default=Thundernet_config.pretrained_weigths,
help="In case you want to train",
)
def main(
args: list,
transformations: tuple = tuple(),
model: str = "original",
class_mappings: dict = None,
):
"""
Train the model
Args:
- args (list): list of parsed arguments
- model (str): type of model. Default: "original"
- class_mappings (dict): class mapper. Default: None
- transformations (tuple): list of transformations to execute in the data. Default: tuple()
- show (bool): display the predictions. Default: False
Returns:
- None
"""
FLAGS: list = parser.parse_args(args)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # use id from $ nvidia-smi
mypath_train = FLAGS.train_dir + "images/"
label_path_train = FLAGS.train_dir + "labels/"
list_IDs_train = [f[:-4] for f in listdir(mypath_train) if f[-4:] == ".jpg"]
mypath_val = FLAGS.val_dir + "images/"
label_path_val = FLAGS.val_dir + "labels/"
list_IDs_val = [f[:-4] for f in listdir(mypath_val) if f[-4:] == ".jpg"]
# First we assure that the dir for saving the experiments is created
if not os.path.exists(FLAGS.model_dir):
os.makedirs(FLAGS.model_dir)
# For every trial of the same experiment we create a new subfolder
k = 1
dir_created = False
while not dir_created:
model_dir = FLAGS.model_dir + str(k) + "/"
if not os.path.exists(model_dir):
os.makedirs(model_dir)
dir_created = True
else:
k += 1
# Model
if model == "original":
Thundernet = Thundernet_original
elif model == "ppm":
Thundernet = Thundernet_ppm
else:
raise ValueError(f"Unknown model: {model}")
# Class mappings
if class_mappings is not None:
FLAGS.classes = len(set(class_mappings.values())) + 1
# Write the file configuration in model_dir
file = open(model_dir + "config.txt", "w")
file.write("Experiment num " + str(k) + "\n")
file.write("Fecha=" + str(datetime.now()) + "\n")
file.write("Train with=" + FLAGS.train_dir + "\n")
file.write("Val with=" + FLAGS.val_dir + "\n")
file.write("Input Resoltuion with=" + FLAGS.resolution + "\n")
file.write("Batch Size=" + str(FLAGS.batch_size) + "\n")
file.write("Batch augment=" + str(FLAGS.augment) + "\n")
file.write("Rand Crop=" + str(FLAGS.rand_crop) + "\n")
file.write("Loss=" + FLAGS.loss + "\n")
file.write("Model dir=" + FLAGS.model_dir + "\n")
file.write("weights=" + str(FLAGS.weights) + "\n")
file.write("lr=" + str(FLAGS.lr) + "\n")
file.write("epochs=" + str(FLAGS.epochs) + "\n")
file.write("classes=" + str(FLAGS.classes) + "\n")
file.write("kernel_regularizer=" + str(FLAGS.kernel_regularizer) + "\n")
file.write("pretrained=" + str(FLAGS.pretrained) + "\n")
file.write("pretrained_weigths=" + str(FLAGS.pretrained_weigths) + "\n")
file.write("Class mappings=" + str(class_mappings) + "\n")
file.write("Model=" + model + "\n")
file.write(f"Transformations: {transformations}\n")
file.write("Comentarios=" + "" + "\n")
file.close()
print(
"resolution2framesize3cha(FLAGS.resolution) ",
resolution2framesize3cha(FLAGS.resolution),
)
thundernet = Thundernet(
input_shape=resolution2framesize3cha(FLAGS.resolution),
n_classes=FLAGS.classes,
resnet_trainable=True,
kernel_regularizer=FLAGS.kernel_regularizer,
)
if FLAGS.pretrained:
print("loading weights from", FLAGS.pretrained_weigths)
thundernet.model.load_weights(
FLAGS.pretrained_weigths, by_name=True, skip_mismatch=True
)
lr = FLAGS.lr
opt = tf.keras.optimizers.Adam(learning_rate=lr) # for keras 2.6.0
if not model_dir.endswith(os.path.sep):
model_dir += os.path.sep
callbacks = [
PlotLosses(model_dir),
tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.normpath(
os.path.join(
model_dir,
f"BS{FLAGS.batch_size}_loss{FLAGS.loss}_weights_lr_{lr}_reg-{FLAGS.kernel_regularizer}-ep-{{epoch}}-val_loss{{val_loss}}-train_loss{{loss}}-val_iou{{val_iou}}-train_iou{{iou}}.hdf5",
)
),
save_best_only=True,
save_weights_only=True,
),
]
if FLAGS.loss == "BCE":
loss = bce_loss()
elif FLAGS.loss == "BFL":
loss = focal_loss()
elif FLAGS.loss == "DCL":
loss = dice_loss()
elif FLAGS.loss == "CFL":
loss = categorical_focal_loss()
elif FLAGS.loss == "CAT":
loss = categorical_loss()
thundernet.model.compile(loss=loss, optimizer=opt, metrics=[iou])
dataset_dir = Path(Thundernet_config.train_path).parent
training_generator, validation_generator = DataGenerator.create_generators(
dataset_dir,
FLAGS.classes,
training_batch_size=Thundernet_config.batch_size,
validation_batch_size=Thundernet_config.batch_size,
to_stereo=False,
transformations=transformations,
class_mappings=class_mappings,
)
if FLAGS.loss == "BCE":
weights = FLAGS.weights
else:
weights = FLAGS.weights
history = thundernet.model.fit_generator(
generator=training_generator,
validation_data=validation_generator,
callbacks=callbacks,
use_multiprocessing=False,
workers=6,
epochs=FLAGS.epochs,
class_weight=None,
)
if __name__ == "__main__":
main(sys.argv[1:], model="original", class_mappings=defaultdict(int, {1: 1}))
# main(sys.argv[1:], model="ppm", class_mappings=defaultdict(int, {1: 1}))
# main(sys.argv[1:], model='original', class_mappings=defaultdict(int, {1: 1, 2: 2, 5: 3})) # In case you also want to segment two specific type of objects (original class_id=2 and class_id=5)
# main(sys.argv[1:], model='ppm', class_mappings=defaultdict(int, {1: 1, 2: 2, 5: 2})) # In case you want to treat both objects as the same class