import os import time from data import create_dataloader from earlystop import EarlyStopping from networks.trainer import Trainer from options.train_options import TrainOptions from tensorboardX import SummaryWriter from validate import validate """Currently assumes jpg_prob, blur_prob 0 or 1""" def get_val_opt(): val_opt = TrainOptions().parse(print_options=False) val_opt.dataroot = '{}/{}/'.format(val_opt.dataroot, val_opt.val_split) val_opt.isTrain = False val_opt.no_resize = False val_opt.no_crop = False val_opt.serial_batches = True val_opt.jpg_method = ['pil'] if len(val_opt.blur_sig) == 2: b_sig = val_opt.blur_sig val_opt.blur_sig = [(b_sig[0] + b_sig[1]) / 2] if len(val_opt.jpg_qual) != 1: j_qual = val_opt.jpg_qual val_opt.jpg_qual = [int((j_qual[0] + j_qual[-1]) / 2)] return val_opt if __name__ == '__main__': opt = TrainOptions().parse() opt.dataroot = '{}/{}/'.format(opt.dataroot, opt.train_split) val_opt = get_val_opt() data_loader = create_dataloader(opt) dataset_size = len(data_loader) print('#training images = %d' % dataset_size) train_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, 'train')) val_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, 'val')) model = Trainer(opt) early_stopping = EarlyStopping( patience=opt.earlystop_epoch, delta=-0.001, verbose=True ) for epoch in range(opt.niter): epoch_start_time = time.time() iter_data_time = time.time() epoch_iter = 0 for i, data in enumerate(data_loader): model.total_steps += 1 epoch_iter += opt.batch_size model.set_input(data) model.optimize_parameters() if model.total_steps % opt.loss_freq == 0: print( 'Train loss: {} at step: {}'.format(model.loss, model.total_steps) ) train_writer.add_scalar('loss', model.loss, model.total_steps) if model.total_steps % opt.save_latest_freq == 0: print( 'saving the latest model %s (epoch %d, model.total_steps %d)' % (opt.name, epoch, model.total_steps) ) model.save_networks('latest') # print("Iter time: %d sec" % (time.time()-iter_data_time)) # iter_data_time = time.time() if epoch % opt.save_epoch_freq == 0: print( 'saving the model at the end of epoch %d, iters %d' % (epoch, model.total_steps) ) model.save_networks('latest') model.save_networks(epoch) # Validation model.eval() acc, ap = validate(model.model, val_opt)[:2] val_writer.add_scalar('accuracy', acc, model.total_steps) val_writer.add_scalar('ap', ap, model.total_steps) print('(Val @ epoch {}) acc: {}; ap: {}'.format(epoch, acc, ap)) early_stopping(acc, model) if early_stopping.early_stop: cont_train = model.adjust_learning_rate() if cont_train: print('Learning rate dropped by 10, continue training...') early_stopping = EarlyStopping( patience=opt.earlystop_epoch, delta=-0.002, verbose=True ) else: print('Early stopping.') break model.train()