Thundernet / train_optuna.py
ExtendedRealityLab's picture
Add files using upload-large-folder tool
ae29340 verified
import os
import optuna
from data_gen import DataGenerator
from os import listdir
from utils import (
iou,
PlotLosses,
dice_loss,
focal_loss,
categorical_loss,
categorical_focal_loss,
resolution2framesize3cha,
resolution2framesize,
bce_loss,
)
import matplotlib.pyplot as plt
import tensorflow as tf
from model.model import Thundernet as Thundernet_original
from models_repo.model_attention import Thundernet as Thundernet_attention
from models_repo.model_attention_2 import Thundernet as Thundernet_attention2
from models_repo.model_ppm_factors import Thundernet as Thundernet_ppm
from datetime import datetime
from matplotlib import pyplot as plt
from pathlib import Path
import os
# from data_gen_tfkeras import DataGenerator
from data_gen import DataGenerator
from os import listdir
from utils import (
iou,
PlotLosses,
dice_loss,
focal_loss,
categorical_loss,
categorical_focal_loss,
resolution2framesize3cha,
resolution2framesize,
)
import matplotlib.pyplot as plt
import tensorflow as tf
tf.config.run_functions_eagerly(True)
# from keras.backend.tensorflow_backend import set_session
import argparse
import sys
import numpy as np
import thundernet_config as Thundernet_config
from datetime import datetime
from matplotlib import pyplot as plt
from model.model import Thundernet as Thundernet_original
from models_repo.model_attention import Thundernet as Thundernet_attention
from models_repo.model_attention_2 import Thundernet as Thundernet_attention2
from models_repo.model_ppm_factors import Thundernet as Thundernet_ppm
from pathlib import Path
from collections import defaultdict
import copy
from collections import defaultdict
# Optuna-related imports
import optuna
import copy
plt.switch_backend("agg")
def objective(trial):
# Define the hyperparameters you want to tune
batch_size = trial.suggest_categorical("batch_size", [1, 2, 4])
lr = trial.suggest_loguniform("lr", 1e-5, 1e-1) # Learning rate
kernel_regularizer = trial.suggest_loguniform("kernel_regularizer", 1e-5, 1e-2)
# Call the main function with trial parameters
return main(
model="ppm", # Use the 'ppm' model as per your request
class_mappings=defaultdict(int, {1: 1}),
batch_size=batch_size,
lr=lr,
kernel_regularizer=kernel_regularizer,
epochs=1, # Run only for 1 epoch
loss="BCE",
transformations=(), # Add transformations as needed
)
def main(
model="original",
class_mappings=None,
batch_size=8,
lr=1e-4,
kernel_regularizer=0.001,
epochs=1,
loss="BCE",
transformations=tuple(),
):
# Parsing arguments for the main function
FLAGS = argparse.Namespace(
train_dir=Thundernet_config.train_path,
val_dir=Thundernet_config.val_path,
batch_size=batch_size,
augment=Thundernet_config.augment,
rand_crop=Thundernet_config.rand_crop,
loss=loss,
model_dir=Thundernet_config.model_dir,
weights=Thundernet_config.weights,
lr=lr,
epochs=epochs,
classes=Thundernet_config.classes,
resolution=Thundernet_config.resolution,
kernel_regularizer=kernel_regularizer,
pretrained=Thundernet_config.pretrained_bool,
pretrained_weigths=Thundernet_config.pretrained_weigths,
)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
mypath_train = FLAGS.train_dir + "images/"
label_path_train = FLAGS.train_dir + "labels/"
list_IDs_train = [f[:-4] for f in listdir(mypath_train) if f[-4:] == ".jpg"]
mypath_val = FLAGS.val_dir + "images/"
label_path_val = FLAGS.val_dir + "labels/"
list_IDs_val = [f[:-4] for f in listdir(mypath_val) if f[-4:] == ".jpg"]
# Model Setup
if model == "original":
Thundernet = Thundernet_original
elif model == "attention":
Thundernet = Thundernet_attention
elif model == "attention2":
Thundernet = Thundernet_attention2
elif model == "ppm":
Thundernet = Thundernet_ppm
else:
raise ValueError(f"Unknown model: {model}")
# Model directory setup
model_dir = FLAGS.model_dir
if not os.path.exists(model_dir):
os.makedirs(model_dir)
thundernet = Thundernet(
input_shape=resolution2framesize3cha(FLAGS.resolution),
n_classes=FLAGS.classes,
resnet_trainable=True,
kernel_regularizer=FLAGS.kernel_regularizer,
)
if FLAGS.pretrained:
thundernet.model.load_weights(
FLAGS.pretrained_weigths, by_name=True, skip_mismatch=True
)
# Optimizer setup
opt = tf.keras.optimizers.Adam(learning_rate=FLAGS.lr)
# Set the loss function
if FLAGS.loss == "BCE":
loss = bce_loss()
elif FLAGS.loss == "BFL":
loss = focal_loss()
elif FLAGS.loss == "DCL":
loss = dice_loss()
elif FLAGS.loss == "CFL":
loss = categorical_focal_loss()
elif FLAGS.loss == "CAT":
loss = categorical_loss()
thundernet.model.compile(loss=loss, optimizer=opt, metrics=[iou])
# Data generators setup
dataset_dir = Path(Thundernet_config.train_path).parent
training_generator, validation_generator = DataGenerator.create_generators(
dataset_dir,
FLAGS.classes,
training_batch_size=FLAGS.batch_size,
to_stereo=False,
transformations=transformations,
class_mappings=class_mappings,
)
# Train the model
history = thundernet.model.fit(
training_generator,
validation_data=validation_generator,
epochs=FLAGS.epochs,
class_weight=None,
callbacks=[PlotLosses(model_dir)],
use_multiprocessing=False,
workers=6,
)
# Return validation loss or metric for Optuna optimization
print(history)
return np.mean(history.history["iou"])
# Optuna study setup
if __name__ == "__main__":
study = optuna.create_study(
direction="maximize", storage="sqlite:///db.sqlite3"
) # Minimize the validation loss
study.optimize(objective, n_trials=100) # Optimize for 10 trials
print("Best hyperparameters found: ", study.best_params)
import optuna.visualization as vis
# Guardar el gráfico de importancia de parámetros
fig = vis.plot_param_importances(study)
fig.write_image("param_importance_IoU.png")
# Guardar el gráfico del historial de optimización
fig = vis.plot_optimization_history(study)
fig.write_image("optimization_history_IoU.png")
import pandas as pd
# Assuming `study` is the Optuna study object
df = study.trials_dataframe()
df.to_csv("results_optuna_IoU.csv")
# Plot Learning Rate vs Loss
plt.figure(figsize=(8, 6))
plt.scatter(df["params_lr"], df["value"], color="blue", alpha=0.7)
plt.title("Learning Rate vs Loss")
plt.xlabel("Learning Rate")
plt.ylabel("Loss")
plt.grid(True)
plt.savefig("lr_vs_loss_IoU.png")
plt.close()
# Plot Weight Decay vs Loss
plt.figure(figsize=(8, 6))
plt.scatter(df["params_batch_size"], df["value"], color="green", alpha=0.7)
plt.title("Batch size vs Loss")
plt.xlabel("Batch size")
plt.ylabel("Loss")
plt.grid(True)
plt.savefig("batch_size_vs_loss_IoU.png")
plt.close()
# Plot Loss Weight vs Loss
plt.figure(figsize=(8, 6))
plt.scatter(df["params_kernel_regularizer"], df["value"], color="red", alpha=0.7)
plt.title("Kernel regularizer vs Loss")
plt.xlabel("Kernel regularizer")
plt.ylabel("Loss")
plt.grid(True)
plt.savefig("kernel_regularizer_vs_loss_IoU.png")
plt.close()