Owleye / codes /tune_models_params.py
owleyetracker's picture
Upload 45 files
b4c5d36 verified
"""This module is for retraining the base (et) models. It contains the Tuning class."""
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import to_categorical
from sklearn.utils import shuffle
from joblib import load as j_load
from joblib import dump as j_dump
import pickle
import numpy as np
import os
from codes.base import eyeing as ey
from openpyxl import Workbook
PATH2ROOT_ABS = os.path.dirname(__file__) + "/../"
class Tuning(object):
@staticmethod
def et_mdl(
subjects,
models_list=[1],
r_train_list=[0.99],
n_epochs_patience=[[3, 3]],
trainable_layers=[1],
shift_samples=None,
blinking_threshold='uo',
show_model=False,
delete_files=False
):
"""
You can retrain the base (et) models using this method. You have the possibility to retrain the models with various hyper parameters
to see which one has better performance. So you can enter lists as inputs.
Parameters:
subjects: list of subjects
models_list: a list of models' number
r_train_list: the ratio of the data that you want for training
n_epochs_patience: number of epochs and patience
trainable_layers: The number of trainable layer's (ending layers of the network)
shift_samples: shift the inputs to align with outputs. This is because of the delay.
blinking_threshold: Blinking threshold for removing the samples that are during blink
show_model: Show the structure of the model
delete_files: delete the dataset after retraining
Returns:
None
"""
print("\nStarting to retrain eye_tracking model...")
x1_scaler, x2_scaler, y_scaler = j_load(ey.scalers_dir + f"scalers_et_main.bin") # Loading the scaler
# Going through each subject's folder
kk = 0
for num in subjects:
print(f"Subject number {num} in process...")
sbj_dir = ey.create_dir([ey.subjects_dir, f"{num}"])
# ### Retraining 'eye_tracking' model with subject calibration data
clb_dir = ey.create_dir([sbj_dir, ey.CLB])
if ey.file_existing(clb_dir, ey.X1+".pickle"):
print(f"Loading subject data in {clb_dir}")
(
x1_load0,
x2_load0,
y_load0,
t_mat,
eyes_ratio
) = ey.load(clb_dir, [ey.X1, ey.X2, ey.Y, ey.T, ey.ER])
if shift_samples:
if shift_samples[kk]:
ii = 0
for (x11, x21, y1, t1, eyr1) in zip(x1_load0, x2_load0, y_load0, t_mat, eyes_ratio):
t_mat[ii] = t1[:-shift_samples[kk]]
x1_load0[ii] = x11[shift_samples[kk]:]
x2_load0[ii] = x21[shift_samples[kk]:]
y_load0[ii] = y1[:-shift_samples[kk]]
eyes_ratio[ii] = eyr1[shift_samples[kk]:]
ii += 1
kk += 1
er_dir = ey.create_dir([sbj_dir, ey.ER])
blinking_threshold = ey.get_threshold(er_dir, blinking_threshold)
blinking = ey.get_blinking(t_mat, eyes_ratio, blinking_threshold)[1]
# Removing the samples that are during blink
x1_load = []
x2_load = []
y_load = []
k1 = 0
k2 = 0
for (x11, x21, y1, b1) in zip(x1_load0, x2_load0, y_load0, blinking):
for (x10, x20, y0, b0) in zip(x11, x21, y1, b1):
k2 += 1
if not b0:
k1 += 1
x1_load.append(x10)
x2_load.append(x20)
y_load.append(y0)
print(f"All samples of subjects: {k2}, Not blinking: {k1}")
x1_load = np.array(x1_load)
x2_load = np.array(x2_load)
y_load = np.array(y_load)
n_smp, frame_h, frame_w = x1_load.shape[:-1]
print(f"Samples number: {n_smp}")
# Displaying data
# ### Preparing modified calibration data to feeding in eye_tracking model. Going through each model to predict the output
print("Normalizing modified calibration data to feeding in eye_tracking model...")
for mdl_num in models_list:
print("Loading public eye_tracking models...")
mdl_name = ey.MDL + f"{mdl_num}"
info = ey.load(ey.et_trained_dir, [mdl_name])[0]
x2_chosen_features = info["x2_chosen_features"]
x2_new = x2_load[:, x2_chosen_features]
x1 = x1_load / x1_scaler
x2 = x2_scaler.transform(x2_new)
# Shuffling and splitting data to train and val
x1_shf, x2_shf, y_hrz_shf, y_vrt_shf = shuffle(x1, x2, y_load[:, 0], y_load[:, 1])
# Going through each training ratio in the ratio list
for rt in r_train_list:
n_train = int(rt * n_smp)
x1_train, x2_train = x1_shf[:n_train], x2_shf[:n_train]
x1_val, x2_val = x1_shf[n_train:], x2_shf[n_train:]
y_hrz_train, y_vrt_train = y_hrz_shf[:n_train], y_vrt_shf[:n_train]
y_hrz_val, y_vrt_val = y_hrz_shf[n_train:], y_vrt_shf[n_train:]
x_train = [x1_train, x2_train]
x_val = [x1_val, x2_val]
print(x1_train.shape, x1_val.shape, y_hrz_train.shape, y_hrz_val.shape,
x2_train.shape, x2_val.shape, y_vrt_train.shape, y_vrt_val.shape)
# Callback for training. Going through each epoch and patience in epochs list
for nep in n_epochs_patience:
cb = EarlyStopping(patience=nep[1], verbose=1, restore_best_weights=True)
# Going through each nubmer for trainable_layers list
for tl in trainable_layers:
# Retraining the models and saving them
model_hrz = load_model(ey.et_trained_dir + mdl_name + "-hrz.h5")
model_vrt = load_model(ey.et_trained_dir + mdl_name + "-vrt.h5")
info["trained_mdl_num"] = mdl_num
info["r_retrain"] = rt
info["n_epochs_patience_retrain"] = nep
info["trainable_layers"] = tl
for (layer_hrz, layer_vrt) in zip(model_hrz.layers[:-tl], model_vrt.layers[:-tl]):
layer_hrz.trainable = False
layer_vrt.trainable = False
if show_model:
print(model_hrz.summary())
sbj_mdl_dir = ey.create_dir([sbj_dir, ey.MDL])
retrained_mdl_num = ey.find_max_mdl(sbj_mdl_dir, b=-7) + 1
print(f"\n<<<<<<< {retrained_mdl_num}-sbj:{num}-model-hrz:{mdl_num}-r_train:{rt}-epoch_patience:{nep}-trainable_layers:{tl} >>>>>>>>")
model_hrz.fit(x_train,
y_hrz_train * y_scaler,
validation_data=(x_val, y_hrz_val * y_scaler),
epochs=nep[0],
callbacks=cb)
hrz_train_loss = model_hrz.evaluate(x_train, y_hrz_train * y_scaler)
hrz_val_loss = model_hrz.evaluate(x_val, y_hrz_val * y_scaler)
info["hrz_retrain_train_loss"] = hrz_train_loss
info["hrz_retrain_val_loss"] = hrz_val_loss
retrained_mdl_name = ey.MDL + f"{retrained_mdl_num}"
mdl_hrz_dir = sbj_mdl_dir + retrained_mdl_name + "-hrz.h5"
model_hrz.save(mdl_hrz_dir)
print("Saving model-et-hrz in " + mdl_hrz_dir)
print(f"\n<<<<<<< {retrained_mdl_num}-sbj:{num}-model-vrt:{mdl_num}-r_train:{rt}-epoch_patience:{nep}-trainable_layers:{tl} >>>>>>>>")
model_vrt.fit(x_train,
y_vrt_train * y_scaler,
validation_data=(x_val, y_vrt_val * y_scaler),
epochs=nep[0],
callbacks=cb)
vrt_train_loss = model_vrt.evaluate(x_train, y_vrt_train * y_scaler)
vrt_val_loss = model_vrt.evaluate(x_val, y_vrt_val * y_scaler)
info["vrt_retrain_train_loss"] = vrt_train_loss
info["vrt_retrain_val_loss"] = vrt_val_loss
mdl_vrt_dir = sbj_mdl_dir + retrained_mdl_name + "-vrt.h5"
model_vrt.save(mdl_vrt_dir)
print("Saving model-et-vrt in " + mdl_vrt_dir)
ey.save([info], sbj_mdl_dir, [retrained_mdl_name])
if delete_files:
ey.remove(clb_dir)
else:
print(f"Data does not exist in {clb_dir}")