| import keras |
| import json |
| import os |
|
|
| def exp_decay(epoch, lr): |
| return lr * 0.1 ** (epoch / 40) |
|
|
| def train_model(bach_model, train, val, n_epochs, ARTIFACTS_PATH, MODEL_PATH): |
| callbacks = [ |
| keras.callbacks.LearningRateScheduler(exp_decay), |
| keras.callbacks.EarlyStopping(patience= 3, restore_best_weights= False, verbose= True, min_delta= 5e-5), |
| keras.callbacks.ModelCheckpoint(os.path.join(ARTIFACTS_PATH , "checkpoint.keras"), verbose= 1), |
| ] |
|
|
| train_logs = bach_model.fit(train, validation_data= val, epochs= n_epochs, callbacks= callbacks) |
| |
| bach_model.save(os.path.join(MODEL_PATH, "bach_model.keras")) |
| |
| with open(os.path.join(ARTIFACTS_PATH, "train_logs.json"), "w") as f: |
| json.dump(train_logs.history, f, indent=4) |
|
|