AMontiB
Your original commit message (now includes LFS pointer)
9c4b1c4
import os
import tqdm
from utils import TrainingModel, create_dataloader, EarlyStopping
from sklearn.metrics import balanced_accuracy_score, roc_auc_score
from utils.processing import add_processing_arguments
from parser import get_parser
if __name__ == "__main__":
parser = get_parser()
parser = add_processing_arguments(parser)
opt = parser.parse_args()
os.makedirs(os.path.join('checkpoint', opt.name,'weights'), exist_ok=True)
valid_data_loader = create_dataloader(opt, split="val")
train_data_loader = create_dataloader(opt, split="train")
print()
print("# validation batches = %d" % len(valid_data_loader))
print("# training batches = %d" % len(train_data_loader))
model = TrainingModel(opt)
early_stopping = None
start_epoch = model.total_steps // len(train_data_loader)
print()
for epoch in range(start_epoch, opt.num_epoches+1):
if epoch > start_epoch:
# Training
pbar = tqdm.tqdm(train_data_loader)
for data in pbar:
loss = model.train_on_batch(data).item()
total_steps = model.total_steps
pbar.set_description(f"Train loss: {loss:.4f}")
# Save model
model.save_networks(epoch)
# Validation
print("Validation ...", flush=True)
y_true, y_pred, y_path = model.predict(valid_data_loader)
acc = balanced_accuracy_score(y_true, y_pred > 0.0)
auc = roc_auc_score(y_true, y_pred)
lr = model.get_learning_rate()
print("After {} epoches: val acc = {}; val auc = {}".format(epoch, acc, auc), flush=True)
# Early Stopping
if early_stopping is None:
early_stopping = EarlyStopping(
init_score=acc, patience=opt.earlystop_epoch,
delta=0.001, verbose=True,
)
print('Save best model', flush=True)
model.save_networks('best')
else:
if early_stopping(acc):
print('Save best model', flush=True)
model.save_networks('best')
if early_stopping.early_stop:
cont_train = model.adjust_learning_rate()
if cont_train:
print("Learning rate dropped by 10, continue training ...", flush=True)
early_stopping.reset_counter()
else:
print("Early stopping.", flush=True)
break