File size: 9,402 Bytes
ae29340
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import os
from data_gen import DataGenerator
from os import listdir
from utils import (
    iou,
    PlotLosses,
    dice_loss,
    focal_loss,
    categorical_loss,
    categorical_focal_loss,
    resolution2framesize3cha,
    resolution2framesize,
    bce_loss,
)
import matplotlib.pyplot as plt
import tensorflow as tf

tf.config.run_functions_eagerly(True)
# from keras.backend.tensorflow_backend import set_session
import argparse
import sys
import numpy as np
import thundernet_config as Thundernet_config
from datetime import datetime
from matplotlib import pyplot as plt

from model.model import Thundernet as Thundernet_original
from model.model_ppm_factors import Thundernet as Thundernet_ppm

from pathlib import Path
from collections import defaultdict
import copy

plt.switch_backend("agg")

parser = argparse.ArgumentParser()

parser.add_argument(
    "--train_dir",
    type=str,
    default=Thundernet_config.train_path,
    help="The directory containing the train image dataset.",
)

parser.add_argument(
    "--val_dir",
    type=str,
    default=Thundernet_config.val_path,
    help="The directory containing the validation image dataset.",
)

parser.add_argument(
    "--batch_size",
    type=int,
    default=Thundernet_config.batch_size,
    choices=[1, 2, 4, 8, 16],
    help="Batch size used for training Thundernet",
)

parser.add_argument(
    "--augment",
    type=bool,
    default=Thundernet_config.augment,
    choices=[False, True],
    help="Whether to use color augmentation for training Thundernet.",
)

parser.add_argument(
    "--rand_crop",
    type=float,
    default=Thundernet_config.rand_crop,
    choices=[0, 0.02, 0.05, 0.1, 0.2, 0.5],
    help="Frequency of random crop data augmentation technique.",
)

parser.add_argument(
    "--loss",
    type=str,
    default=Thundernet_config.loss,
    choices=["BCE", "BFL", "CFL", "DCL", "FTL", "CAT"],
    help="Loss function to be used - Binary Cross Entropy (BCE), Focal Loss (FL) , Dice Coefficient Loss (DCL) and Focal Tversky Loss (FTL)",
)

parser.add_argument(
    "--model_dir",
    type=str,
    default=Thundernet_config.model_dir,
    help="Base directory for the models_repo. "
    "Make sure 'model_checkpoint_path' given in 'checkpoint' file matches "
    "with checkpoint name.",
)

parser.add_argument(
    "--weights",
    type=dict,
    default=Thundernet_config.weights,
    help="Class weights used for Weighted Binary Cross Entropy Loss.",
)

parser.add_argument(
    "--lr", type=float, default=Thundernet_config.lr, help="Learning Rate."
)

parser.add_argument(
    "--epochs", type=int, default=Thundernet_config.epochs, help="Epochs"
)

parser.add_argument(
    "--classes", type=int, default=Thundernet_config.classes, help="Epochs"
)

parser.add_argument(
    "--resolution",
    type=str,
    default=Thundernet_config.resolution,
    help="Input Resolution",
)

parser.add_argument(
    "--kernel_regularizer",
    type=float,
    default=Thundernet_config.kernel_regularizer,
    help="kernel_regularizer",
)

parser.add_argument(
    "--pretrained",
    type=bool,
    default=Thundernet_config.pretrained_bool,
    help="In case you want to train",
)

parser.add_argument(
    "--pretrained_weigths",
    type=str,
    default=Thundernet_config.pretrained_weigths,
    help="In case you want to train",
)


def main(
    args: list,
    transformations: tuple = tuple(),
    model: str = "original",
    class_mappings: dict = None,
):
    """
    Train the model
    Args:
     - args (list): list of parsed arguments
     - model (str): type of model. Default: "original"
     - class_mappings (dict): class mapper. Default: None
     - transformations (tuple): list of transformations to execute in the data. Default: tuple()
     - show (bool): display the predictions. Default: False
    Returns:
    - None
    """

    FLAGS: list = parser.parse_args(args)

    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # use id from $ nvidia-smi

    mypath_train = FLAGS.train_dir + "images/"
    label_path_train = FLAGS.train_dir + "labels/"

    list_IDs_train = [f[:-4] for f in listdir(mypath_train) if f[-4:] == ".jpg"]
    mypath_val = FLAGS.val_dir + "images/"
    label_path_val = FLAGS.val_dir + "labels/"
    list_IDs_val = [f[:-4] for f in listdir(mypath_val) if f[-4:] == ".jpg"]

    # First we assure that the dir for saving the experiments is created
    if not os.path.exists(FLAGS.model_dir):
        os.makedirs(FLAGS.model_dir)

    # For every trial of the same experiment we create a new subfolder
    k = 1
    dir_created = False
    while not dir_created:
        model_dir = FLAGS.model_dir + str(k) + "/"
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
            dir_created = True
        else:
            k += 1

    # Model
    if model == "original":
        Thundernet = Thundernet_original
    elif model == "ppm":
        Thundernet = Thundernet_ppm
    else:
        raise ValueError(f"Unknown model: {model}")

    # Class mappings

    if class_mappings is not None:
        FLAGS.classes = len(set(class_mappings.values())) + 1

    # Write the file configuration in model_dir
    file = open(model_dir + "config.txt", "w")
    file.write("Experiment num " + str(k) + "\n")
    file.write("Fecha=" + str(datetime.now()) + "\n")
    file.write("Train with=" + FLAGS.train_dir + "\n")
    file.write("Val with=" + FLAGS.val_dir + "\n")
    file.write("Input Resoltuion with=" + FLAGS.resolution + "\n")
    file.write("Batch Size=" + str(FLAGS.batch_size) + "\n")
    file.write("Batch augment=" + str(FLAGS.augment) + "\n")
    file.write("Rand Crop=" + str(FLAGS.rand_crop) + "\n")
    file.write("Loss=" + FLAGS.loss + "\n")
    file.write("Model dir=" + FLAGS.model_dir + "\n")
    file.write("weights=" + str(FLAGS.weights) + "\n")
    file.write("lr=" + str(FLAGS.lr) + "\n")
    file.write("epochs=" + str(FLAGS.epochs) + "\n")
    file.write("classes=" + str(FLAGS.classes) + "\n")
    file.write("kernel_regularizer=" + str(FLAGS.kernel_regularizer) + "\n")
    file.write("pretrained=" + str(FLAGS.pretrained) + "\n")
    file.write("pretrained_weigths=" + str(FLAGS.pretrained_weigths) + "\n")
    file.write("Class mappings=" + str(class_mappings) + "\n")
    file.write("Model=" + model + "\n")
    file.write(f"Transformations: {transformations}\n")
    file.write("Comentarios=" + "" + "\n")
    file.close()

    print(
        "resolution2framesize3cha(FLAGS.resolution) ",
        resolution2framesize3cha(FLAGS.resolution),
    )
    thundernet = Thundernet(
        input_shape=resolution2framesize3cha(FLAGS.resolution),
        n_classes=FLAGS.classes,
        resnet_trainable=True,
        kernel_regularizer=FLAGS.kernel_regularizer,
    )

    if FLAGS.pretrained:
        print("loading weights from", FLAGS.pretrained_weigths)
        thundernet.model.load_weights(
            FLAGS.pretrained_weigths, by_name=True, skip_mismatch=True
        )

    lr = FLAGS.lr
    opt = tf.keras.optimizers.Adam(learning_rate=lr)  # for keras 2.6.0

    if not model_dir.endswith(os.path.sep):
        model_dir += os.path.sep

    callbacks = [
        PlotLosses(model_dir),
        tf.keras.callbacks.ModelCheckpoint(
            filepath=os.path.normpath(
                os.path.join(
                    model_dir,
                    f"BS{FLAGS.batch_size}_loss{FLAGS.loss}_weights_lr_{lr}_reg-{FLAGS.kernel_regularizer}-ep-{{epoch}}-val_loss{{val_loss}}-train_loss{{loss}}-val_iou{{val_iou}}-train_iou{{iou}}.hdf5",
                )
            ),
            save_best_only=True,
            save_weights_only=True,
        ),
    ]

    if FLAGS.loss == "BCE":
        loss = bce_loss()
    elif FLAGS.loss == "BFL":
        loss = focal_loss()
    elif FLAGS.loss == "DCL":
        loss = dice_loss()
    elif FLAGS.loss == "CFL":
        loss = categorical_focal_loss()
    elif FLAGS.loss == "CAT":
        loss = categorical_loss()

    thundernet.model.compile(loss=loss, optimizer=opt, metrics=[iou])

    dataset_dir = Path(Thundernet_config.train_path).parent

    training_generator, validation_generator = DataGenerator.create_generators(
        dataset_dir,
        FLAGS.classes,
        training_batch_size=Thundernet_config.batch_size,
        validation_batch_size=Thundernet_config.batch_size,
        to_stereo=False,
        transformations=transformations,
        class_mappings=class_mappings,
    )

    if FLAGS.loss == "BCE":

        weights = FLAGS.weights
    else:

        weights = FLAGS.weights

    history = thundernet.model.fit_generator(
        generator=training_generator,
        validation_data=validation_generator,
        callbacks=callbacks,
        use_multiprocessing=False,
        workers=6,
        epochs=FLAGS.epochs,
        class_weight=None,
    )


if __name__ == "__main__":

    main(sys.argv[1:], model="original", class_mappings=defaultdict(int, {1: 1}))
    # main(sys.argv[1:], model="ppm", class_mappings=defaultdict(int, {1: 1}))
    # main(sys.argv[1:], model='original', class_mappings=defaultdict(int, {1: 1, 2: 2, 5: 3})) # In case you also want to segment two specific type of objects (original class_id=2 and class_id=5)
    # main(sys.argv[1:], model='ppm', class_mappings=defaultdict(int, {1: 1, 2: 2, 5: 2})) # In case you want to treat both objects as the same class