File size: 792 Bytes
f320de7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import keras
import json
import os

def exp_decay(epoch, lr):
    return lr * 0.1 ** (epoch / 40)

def train_model(bach_model, train, val, n_epochs, ARTIFACTS_PATH, MODEL_PATH):
    callbacks = [
        keras.callbacks.LearningRateScheduler(exp_decay),
        keras.callbacks.EarlyStopping(patience= 3, restore_best_weights= False, verbose= True, min_delta= 5e-5),
        keras.callbacks.ModelCheckpoint(os.path.join(ARTIFACTS_PATH , "checkpoint.keras"), verbose= 1),
                ]

    train_logs = bach_model.fit(train, validation_data= val, epochs= n_epochs, callbacks= callbacks)
    
    bach_model.save(os.path.join(MODEL_PATH, "bach_model.keras"))
    
    with open(os.path.join(ARTIFACTS_PATH, "train_logs.json"), "w") as f:
        json.dump(train_logs.history, f, indent=4)