File size: 10,275 Bytes
b4c5d36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
"""This module is for retraining the base (et) models. It contains the Tuning class."""
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import to_categorical
from sklearn.utils import shuffle
from joblib import load as j_load
from joblib import dump as j_dump
import pickle
import numpy as np
import os
from codes.base import eyeing as ey
from openpyxl import Workbook
PATH2ROOT_ABS = os.path.dirname(__file__) + "/../"
class Tuning(object):
@staticmethod
def et_mdl(
subjects,
models_list=[1],
r_train_list=[0.99],
n_epochs_patience=[[3, 3]],
trainable_layers=[1],
shift_samples=None,
blinking_threshold='uo',
show_model=False,
delete_files=False
):
"""
You can retrain the base (et) models using this method. You have the possibility to retrain the models with various hyper parameters
to see which one has better performance. So you can enter lists as inputs.
Parameters:
subjects: list of subjects
models_list: a list of models' number
r_train_list: the ratio of the data that you want for training
n_epochs_patience: number of epochs and patience
trainable_layers: The number of trainable layer's (ending layers of the network)
shift_samples: shift the inputs to align with outputs. This is because of the delay.
blinking_threshold: Blinking threshold for removing the samples that are during blink
show_model: Show the structure of the model
delete_files: delete the dataset after retraining
Returns:
None
"""
print("\nStarting to retrain eye_tracking model...")
x1_scaler, x2_scaler, y_scaler = j_load(ey.scalers_dir + f"scalers_et_main.bin") # Loading the scaler
# Going through each subject's folder
kk = 0
for num in subjects:
print(f"Subject number {num} in process...")
sbj_dir = ey.create_dir([ey.subjects_dir, f"{num}"])
# ### Retraining 'eye_tracking' model with subject calibration data
clb_dir = ey.create_dir([sbj_dir, ey.CLB])
if ey.file_existing(clb_dir, ey.X1+".pickle"):
print(f"Loading subject data in {clb_dir}")
(
x1_load0,
x2_load0,
y_load0,
t_mat,
eyes_ratio
) = ey.load(clb_dir, [ey.X1, ey.X2, ey.Y, ey.T, ey.ER])
if shift_samples:
if shift_samples[kk]:
ii = 0
for (x11, x21, y1, t1, eyr1) in zip(x1_load0, x2_load0, y_load0, t_mat, eyes_ratio):
t_mat[ii] = t1[:-shift_samples[kk]]
x1_load0[ii] = x11[shift_samples[kk]:]
x2_load0[ii] = x21[shift_samples[kk]:]
y_load0[ii] = y1[:-shift_samples[kk]]
eyes_ratio[ii] = eyr1[shift_samples[kk]:]
ii += 1
kk += 1
er_dir = ey.create_dir([sbj_dir, ey.ER])
blinking_threshold = ey.get_threshold(er_dir, blinking_threshold)
blinking = ey.get_blinking(t_mat, eyes_ratio, blinking_threshold)[1]
# Removing the samples that are during blink
x1_load = []
x2_load = []
y_load = []
k1 = 0
k2 = 0
for (x11, x21, y1, b1) in zip(x1_load0, x2_load0, y_load0, blinking):
for (x10, x20, y0, b0) in zip(x11, x21, y1, b1):
k2 += 1
if not b0:
k1 += 1
x1_load.append(x10)
x2_load.append(x20)
y_load.append(y0)
print(f"All samples of subjects: {k2}, Not blinking: {k1}")
x1_load = np.array(x1_load)
x2_load = np.array(x2_load)
y_load = np.array(y_load)
n_smp, frame_h, frame_w = x1_load.shape[:-1]
print(f"Samples number: {n_smp}")
# Displaying data
# ### Preparing modified calibration data to feeding in eye_tracking model. Going through each model to predict the output
print("Normalizing modified calibration data to feeding in eye_tracking model...")
for mdl_num in models_list:
print("Loading public eye_tracking models...")
mdl_name = ey.MDL + f"{mdl_num}"
info = ey.load(ey.et_trained_dir, [mdl_name])[0]
x2_chosen_features = info["x2_chosen_features"]
x2_new = x2_load[:, x2_chosen_features]
x1 = x1_load / x1_scaler
x2 = x2_scaler.transform(x2_new)
# Shuffling and splitting data to train and val
x1_shf, x2_shf, y_hrz_shf, y_vrt_shf = shuffle(x1, x2, y_load[:, 0], y_load[:, 1])
# Going through each training ratio in the ratio list
for rt in r_train_list:
n_train = int(rt * n_smp)
x1_train, x2_train = x1_shf[:n_train], x2_shf[:n_train]
x1_val, x2_val = x1_shf[n_train:], x2_shf[n_train:]
y_hrz_train, y_vrt_train = y_hrz_shf[:n_train], y_vrt_shf[:n_train]
y_hrz_val, y_vrt_val = y_hrz_shf[n_train:], y_vrt_shf[n_train:]
x_train = [x1_train, x2_train]
x_val = [x1_val, x2_val]
print(x1_train.shape, x1_val.shape, y_hrz_train.shape, y_hrz_val.shape,
x2_train.shape, x2_val.shape, y_vrt_train.shape, y_vrt_val.shape)
# Callback for training. Going through each epoch and patience in epochs list
for nep in n_epochs_patience:
cb = EarlyStopping(patience=nep[1], verbose=1, restore_best_weights=True)
# Going through each nubmer for trainable_layers list
for tl in trainable_layers:
# Retraining the models and saving them
model_hrz = load_model(ey.et_trained_dir + mdl_name + "-hrz.h5")
model_vrt = load_model(ey.et_trained_dir + mdl_name + "-vrt.h5")
info["trained_mdl_num"] = mdl_num
info["r_retrain"] = rt
info["n_epochs_patience_retrain"] = nep
info["trainable_layers"] = tl
for (layer_hrz, layer_vrt) in zip(model_hrz.layers[:-tl], model_vrt.layers[:-tl]):
layer_hrz.trainable = False
layer_vrt.trainable = False
if show_model:
print(model_hrz.summary())
sbj_mdl_dir = ey.create_dir([sbj_dir, ey.MDL])
retrained_mdl_num = ey.find_max_mdl(sbj_mdl_dir, b=-7) + 1
print(f"\n<<<<<<< {retrained_mdl_num}-sbj:{num}-model-hrz:{mdl_num}-r_train:{rt}-epoch_patience:{nep}-trainable_layers:{tl} >>>>>>>>")
model_hrz.fit(x_train,
y_hrz_train * y_scaler,
validation_data=(x_val, y_hrz_val * y_scaler),
epochs=nep[0],
callbacks=cb)
hrz_train_loss = model_hrz.evaluate(x_train, y_hrz_train * y_scaler)
hrz_val_loss = model_hrz.evaluate(x_val, y_hrz_val * y_scaler)
info["hrz_retrain_train_loss"] = hrz_train_loss
info["hrz_retrain_val_loss"] = hrz_val_loss
retrained_mdl_name = ey.MDL + f"{retrained_mdl_num}"
mdl_hrz_dir = sbj_mdl_dir + retrained_mdl_name + "-hrz.h5"
model_hrz.save(mdl_hrz_dir)
print("Saving model-et-hrz in " + mdl_hrz_dir)
print(f"\n<<<<<<< {retrained_mdl_num}-sbj:{num}-model-vrt:{mdl_num}-r_train:{rt}-epoch_patience:{nep}-trainable_layers:{tl} >>>>>>>>")
model_vrt.fit(x_train,
y_vrt_train * y_scaler,
validation_data=(x_val, y_vrt_val * y_scaler),
epochs=nep[0],
callbacks=cb)
vrt_train_loss = model_vrt.evaluate(x_train, y_vrt_train * y_scaler)
vrt_val_loss = model_vrt.evaluate(x_val, y_vrt_val * y_scaler)
info["vrt_retrain_train_loss"] = vrt_train_loss
info["vrt_retrain_val_loss"] = vrt_val_loss
mdl_vrt_dir = sbj_mdl_dir + retrained_mdl_name + "-vrt.h5"
model_vrt.save(mdl_vrt_dir)
print("Saving model-et-vrt in " + mdl_vrt_dir)
ey.save([info], sbj_mdl_dir, [retrained_mdl_name])
if delete_files:
ey.remove(clb_dir)
else:
print(f"Data does not exist in {clb_dir}")
|