| import numpy as np |
| from cleverhans.future.tf2.attacks import fast_gradient_method |
| import pandas as pd |
| from sklearn.model_selection import KFold |
| import sys |
| import tensorflow |
| import tensorflow as tf |
|
|
| from _utility import print_test, get_adversarial_examples |
|
|
| import pickle |
|
|
| folder_name = "./adversarial_examples_parseval_net/src/logs/saved_models/" |
|
|
|
|
| def train( |
| instance, |
| X_train, |
| Y_train, |
| X_test, |
| y_test, |
| epochs, |
| BS, |
| sgd, |
| generator, |
| callbacks_list, |
| model_name="ResNet", |
| ): |
|
|
| kfold = KFold(n_splits=10, random_state=42, shuffle=False) |
|
|
| for j, (train, val) in enumerate(kfold.split(X_train)): |
|
|
| model = instance.create_wide_residual_network() |
| model.compile(loss="categorical_crossentropy", optimizer=sgd, metrics=["acc"]) |
|
|
| print("Finished compiling") |
|
|
| x_train, y_train = X_train[train], Y_train[train] |
| x_val, y_val = X_train[val], Y_train[val] |
|
|
| hist = model.fit( |
| generator.flow(x_train, y_train, batch_size=BS), |
| steps_per_epoch=len(x_train) // BS, |
| epochs=epochs, |
| callbacks=callbacks_list, |
| validation_data=(x_val, y_val), |
| validation_steps=x_val.shape[0] // BS, |
| ) |
| |
|
|
| with open("history_" + model_name + str(j), "wb") as file_pi: |
| pickle.dump(hist.history, file_pi) |
|
|
| model_name = folder_name + model_name + "_" + str(j) + ".h5" |
| model.save(model_name) |
|
|