Spaces:
Runtime error
Runtime error
| # -------------------------------------------------------- | |
| # OpenVQA | |
| # Written by Yuhao Cui https://github.com/cuiyuhao1996 | |
| # -------------------------------------------------------- | |
| import os, torch, datetime, shutil, time | |
| import numpy as np | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.utils.data as Data | |
| from openvqa.models.model_loader import ModelLoader | |
| from openvqa.utils.optim import get_optim, adjust_lr | |
| from utils.test_engine import test_engine, ckpt_proc | |
| from utils.extract_engine import extract_engine | |
| def train_engine(__C, dataset, dataset_eval=None): | |
| data_size = dataset.data_size | |
| token_size = dataset.token_size | |
| ans_size = dataset.ans_size | |
| pretrained_emb = dataset.pretrained_emb | |
| net = ModelLoader(__C).Net( | |
| __C, | |
| pretrained_emb, | |
| token_size, | |
| ans_size | |
| ) | |
| net.cuda() | |
| net.train() | |
| if __C.N_GPU > 1: | |
| net = nn.DataParallel(net, device_ids=__C.DEVICES) | |
| # Define Loss Function | |
| loss_fn = eval('torch.nn.' + __C.LOSS_FUNC_NAME_DICT[__C.LOSS_FUNC] + "(reduction='" + __C.LOSS_REDUCTION + "').cuda()") | |
| # Load checkpoint if resume training | |
| if __C.RESUME: | |
| print(' ========== Resume training') | |
| if __C.CKPT_PATH is not None: | |
| print('Warning: Now using CKPT_PATH args, ' | |
| 'CKPT_VERSION and CKPT_EPOCH will not work') | |
| path = __C.CKPT_PATH | |
| else: | |
| path = __C.CKPTS_PATH + \ | |
| '/ckpt_' + __C.CKPT_VERSION + \ | |
| '/epoch' + str(__C.CKPT_EPOCH) + '.pkl' | |
| # Load the network parameters | |
| print('Loading ckpt from {}'.format(path)) | |
| ckpt = torch.load(path) | |
| print('Finish!') | |
| if __C.N_GPU > 1: | |
| net.load_state_dict(ckpt_proc(ckpt['state_dict'])) | |
| else: | |
| net.load_state_dict(ckpt['state_dict']) | |
| start_epoch = ckpt['epoch'] | |
| # Load the optimizer paramters | |
| optim = get_optim(__C, net, data_size, ckpt['lr_base']) | |
| optim._step = int(data_size / __C.BATCH_SIZE * start_epoch) | |
| optim.optimizer.load_state_dict(ckpt['optimizer']) | |
| if ('ckpt_' + __C.VERSION) not in os.listdir(__C.CKPTS_PATH): | |
| os.mkdir(__C.CKPTS_PATH + '/ckpt_' + __C.VERSION) | |
| else: | |
| if ('ckpt_' + __C.VERSION) not in os.listdir(__C.CKPTS_PATH): | |
| #shutil.rmtree(__C.CKPTS_PATH + '/ckpt_' + __C.VERSION) | |
| os.mkdir(__C.CKPTS_PATH + '/ckpt_' + __C.VERSION) | |
| optim = get_optim(__C, net, data_size) | |
| start_epoch = 0 | |
| loss_sum = 0 | |
| named_params = list(net.named_parameters()) | |
| grad_norm = np.zeros(len(named_params)) | |
| # Define multi-thread dataloader | |
| # if __C.SHUFFLE_MODE in ['external']: | |
| # dataloader = Data.DataLoader( | |
| # dataset, | |
| # batch_size=__C.BATCH_SIZE, | |
| # shuffle=False, | |
| # num_workers=__C.NUM_WORKERS, | |
| # pin_memory=__C.PIN_MEM, | |
| # drop_last=True | |
| # ) | |
| # else: | |
| dataloader = Data.DataLoader( | |
| dataset, | |
| batch_size=__C.BATCH_SIZE, | |
| shuffle=True, | |
| num_workers=__C.NUM_WORKERS, | |
| pin_memory=__C.PIN_MEM, | |
| drop_last=True | |
| ) | |
| logfile = open( | |
| __C.LOG_PATH + | |
| '/log_run_' + __C.VERSION + '.txt', | |
| 'a+' | |
| ) | |
| logfile.write(str(__C)) | |
| logfile.close() | |
| # Training script | |
| for epoch in range(start_epoch, __C.MAX_EPOCH): | |
| # Save log to file | |
| logfile = open( | |
| __C.LOG_PATH + | |
| '/log_run_' + __C.VERSION + '.txt', | |
| 'a+' | |
| ) | |
| logfile.write( | |
| '=====================================\nnowTime: ' + | |
| datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + | |
| '\n' | |
| ) | |
| logfile.close() | |
| # Learning Rate Decay | |
| if epoch in __C.LR_DECAY_LIST: | |
| adjust_lr(optim, __C.LR_DECAY_R) | |
| # Externally shuffle data list | |
| # if __C.SHUFFLE_MODE == 'external': | |
| # dataset.shuffle_list(dataset.ans_list) | |
| time_start = time.time() | |
| # Iteration | |
| for step, ( | |
| frcn_feat_iter, | |
| grid_feat_iter, | |
| bbox_feat_iter, | |
| ques_ix_iter, | |
| ans_iter | |
| ) in enumerate(dataloader): | |
| optim.zero_grad() | |
| frcn_feat_iter = frcn_feat_iter.cuda() | |
| grid_feat_iter = grid_feat_iter.cuda() | |
| bbox_feat_iter = bbox_feat_iter.cuda() | |
| ques_ix_iter = ques_ix_iter.cuda() | |
| ans_iter = ans_iter.cuda() | |
| loss_tmp = 0 | |
| for accu_step in range(__C.GRAD_ACCU_STEPS): | |
| loss_tmp = 0 | |
| sub_frcn_feat_iter = \ | |
| frcn_feat_iter[accu_step * __C.SUB_BATCH_SIZE: | |
| (accu_step + 1) * __C.SUB_BATCH_SIZE] | |
| sub_grid_feat_iter = \ | |
| grid_feat_iter[accu_step * __C.SUB_BATCH_SIZE: | |
| (accu_step + 1) * __C.SUB_BATCH_SIZE] | |
| sub_bbox_feat_iter = \ | |
| bbox_feat_iter[accu_step * __C.SUB_BATCH_SIZE: | |
| (accu_step + 1) * __C.SUB_BATCH_SIZE] | |
| sub_ques_ix_iter = \ | |
| ques_ix_iter[accu_step * __C.SUB_BATCH_SIZE: | |
| (accu_step + 1) * __C.SUB_BATCH_SIZE] | |
| sub_ans_iter = \ | |
| ans_iter[accu_step * __C.SUB_BATCH_SIZE: | |
| (accu_step + 1) * __C.SUB_BATCH_SIZE] | |
| pred = net( | |
| sub_frcn_feat_iter, | |
| sub_grid_feat_iter, | |
| sub_bbox_feat_iter, | |
| sub_ques_ix_iter | |
| ) | |
| loss_item = [pred, sub_ans_iter] | |
| loss_nonlinear_list = __C.LOSS_FUNC_NONLINEAR[__C.LOSS_FUNC] | |
| for item_ix, loss_nonlinear in enumerate(loss_nonlinear_list): | |
| if loss_nonlinear in ['flat']: | |
| loss_item[item_ix] = loss_item[item_ix].view(-1) | |
| elif loss_nonlinear: | |
| loss_item[item_ix] = eval('F.' + loss_nonlinear + '(loss_item[item_ix], dim=1)') | |
| loss = loss_fn(loss_item[0], loss_item[1]) | |
| if __C.LOSS_REDUCTION == 'mean': | |
| # only mean-reduction needs be divided by grad_accu_steps | |
| loss /= __C.GRAD_ACCU_STEPS | |
| loss.backward() | |
| loss_tmp += loss.cpu().data.numpy() * __C.GRAD_ACCU_STEPS | |
| loss_sum += loss.cpu().data.numpy() * __C.GRAD_ACCU_STEPS | |
| if __C.VERBOSE: | |
| if dataset_eval is not None: | |
| mode_str = __C.SPLIT['train'] + '->' + __C.SPLIT['val'] | |
| else: | |
| mode_str = __C.SPLIT['train'] + '->' + __C.SPLIT['test'] | |
| print("\r[Version %s][Model %s][Dataset %s][Epoch %2d][Step %4d/%4d][%s] Loss: %.4f, Lr: %.2e" % ( | |
| __C.VERSION, | |
| __C.MODEL_USE, | |
| __C.DATASET, | |
| epoch + 1, | |
| step, | |
| int(data_size / __C.BATCH_SIZE), | |
| mode_str, | |
| loss_tmp / __C.SUB_BATCH_SIZE, | |
| optim._rate | |
| ), end=' ') | |
| # Gradient norm clipping | |
| if __C.GRAD_NORM_CLIP > 0: | |
| nn.utils.clip_grad_norm_( | |
| net.parameters(), | |
| __C.GRAD_NORM_CLIP | |
| ) | |
| # Save the gradient information | |
| for name in range(len(named_params)): | |
| norm_v = torch.norm(named_params[name][1].grad).cpu().data.numpy() \ | |
| if named_params[name][1].grad is not None else 0 | |
| grad_norm[name] += norm_v * __C.GRAD_ACCU_STEPS | |
| # print('Param %-3s Name %-80s Grad_Norm %-20s'% | |
| # (str(grad_wt), | |
| # params[grad_wt][0], | |
| # str(norm_v))) | |
| optim.step() | |
| time_end = time.time() | |
| elapse_time = time_end-time_start | |
| print('Finished in {}s'.format(int(elapse_time))) | |
| epoch_finish = epoch + 1 | |
| # Save checkpoint | |
| if not __C.SAVE_LAST or epoch_finish == __C.MAX_EPOCH: | |
| if __C.N_GPU > 1: | |
| state = { | |
| 'state_dict': net.module.state_dict(), | |
| 'optimizer': optim.optimizer.state_dict(), | |
| 'lr_base': optim.lr_base, | |
| 'epoch': epoch_finish | |
| } | |
| else: | |
| state = { | |
| 'state_dict': net.state_dict(), | |
| 'optimizer': optim.optimizer.state_dict(), | |
| 'lr_base': optim.lr_base, | |
| 'epoch': epoch_finish | |
| } | |
| torch.save( | |
| state, | |
| __C.CKPTS_PATH + | |
| '/ckpt_' + __C.VERSION + | |
| '/epoch' + str(epoch_finish) + | |
| '.pkl' | |
| ) | |
| # Logging | |
| logfile = open( | |
| __C.LOG_PATH + | |
| '/log_run_' + __C.VERSION + '.txt', | |
| 'a+' | |
| ) | |
| logfile.write( | |
| 'Epoch: ' + str(epoch_finish) + | |
| ', Loss: ' + str(loss_sum / data_size) + | |
| ', Lr: ' + str(optim._rate) + '\n' + | |
| 'Elapsed time: ' + str(int(elapse_time)) + | |
| ', Speed(s/batch): ' + str(elapse_time / step) + | |
| '\n\n' | |
| ) | |
| logfile.close() | |
| # Eval after every epoch | |
| if dataset_eval is not None: | |
| test_engine( | |
| __C, | |
| dataset_eval, | |
| state_dict=net.state_dict(), | |
| validation=True | |
| ) | |
| # if self.__C.VERBOSE: | |
| # logfile = open( | |
| # self.__C.LOG_PATH + | |
| # '/log_run_' + self.__C.VERSION + '.txt', | |
| # 'a+' | |
| # ) | |
| # for name in range(len(named_params)): | |
| # logfile.write( | |
| # 'Param %-3s Name %-80s Grad_Norm %-25s\n' % ( | |
| # str(name), | |
| # named_params[name][0], | |
| # str(grad_norm[name] / data_size * self.__C.BATCH_SIZE) | |
| # ) | |
| # ) | |
| # logfile.write('\n') | |
| # logfile.close() | |
| loss_sum = 0 | |
| grad_norm = np.zeros(len(named_params)) | |
| # Modification - optionally run full result extract after training ends | |
| if __C.EXTRACT_AFTER: | |
| extract_engine(__C, state_dict=net.state_dict()) |