''' 2021/2/3 Guowang Xie args: n_epoch:epoch values for training optimizer:various optimization algorithms l_rate:initial learning rate resume:the path of trained model parameter after data_path_train:datasets path for training data_path_validate:datasets path for validating data_path_test:datasets path for testing output-path:output path batch_size: schema:test or train parallel:number of gpus used, like 0, or, 0123 ''' import os, sys import argparse import torch from torch.autograd import Variable import warnings import time import re from pathlib import Path FILE = Path(__file__).resolve() ROOT = FILE.parents[0] from network import FiducialPoints, DilatedResnetForFlatByFiducialPointsS2 # import utilsV3 as utils import utilsV4 as utils from dataloader import PerturbedDatastsForFiducialPoints_pickle_color_v2_v2 from loss import Losses def train(args): global _re_date if args.resume is not None: re_date = re.compile(r'\d{4}-\d{1,2}-\d{1,2}') _re_date = re_date.search(str(args.resume)).group(0) reslut_file = open(path + '/' + date + date_time + ' @' + _re_date + '_' + args.arch + '.log', 'w') else: _re_date = None reslut_file = open(path+'/'+date+date_time+'_'+args.arch+'.log', 'w') # Setup Dataloader data_path = str(args.data_path_train)+'/' data_path_validate = str(args.data_path_validate)+'/' data_path_test = str(args.data_path_test)+'/' print(args) print(args, file=reslut_file) n_classes = 2 model = FiducialPoints(n_classes=n_classes, num_filter=32, architecture=DilatedResnetForFlatByFiducialPointsS2, BatchNorm='BN', in_channels=3) # if args.parallel is not None: device_ids = list(map(int, args.parallel)) args.device = torch.device('cuda:'+str(device_ids[0])) # args.gpu = device_ids[0] # if args.gpu < 8: torch.cuda.set_device(args.device) model = torch.nn.DataParallel(model, device_ids=device_ids) model.cuda(args.device) elif args.distributed: model.cuda() model = torch.nn.parallel.DistributedDataParallel(model) else: warnings.warn('no found gpu') exit() if args.optimizer == 'SGD': optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.8, weight_decay=1e-12) elif args.optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=args.l_rate, weight_decay=1e-10) else: assert 'please choice optimizer' exit('error') if args.resume is not None: if os.path.isfile(args.resume): print("Loading model and optimizer from checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume, map_location=args.device) model.load_state_dict(checkpoint['model_state']) optimizer.load_state_dict(checkpoint['optimizer_state']) print("Loaded checkpoint '{}' (epoch {})" .format(args.resume.name, checkpoint['epoch'])) else: print("No checkpoint found at '{}'".format(args.resume.name)) loss_fun_classes = Losses(classify_size_average=True, args_gpu=args.device) loss_fun = loss_fun_classes.loss_fn4_v5_r_4 # * # loss_fun = loss_fun_classes.loss_fn4_v5_r_3 # * loss_fun2 = loss_fun_classes.loss_fn_l1_loss FlatImg = utils.FlatImg(args=args, path=path, date=date, date_time=date_time, _re_date=_re_date, model=model, \ reslut_file=reslut_file, n_classes=n_classes, optimizer=optimizer, \ loss_fn=loss_fun, loss_fn2=loss_fun2, data_loader=PerturbedDatastsForFiducialPoints_pickle_color_v2_v2, \ data_path=data_path, data_path_validate=data_path_validate, data_path_test=data_path_test, data_preproccess=False) # , valloaderSet=valloaderSet, v_loaderSet=v_loaderSet ''' load data ''' FlatImg.loadTestData() train_time = AverageMeter() losses = AverageMeter() FlatImg.lambda_loss = 1 FlatImg.lambda_loss_segment = 0.01 FlatImg.lambda_loss_a = 0.1 FlatImg.lambda_loss_b = 0.001 FlatImg.lambda_loss_c = 0.01 scheduler = torch.optim.lr_scheduler.MultiStepLR(FlatImg.optimizer, milestones=[40, 90, 150, 200], gamma=0.5) epoch_start = checkpoint['epoch'] if args.resume is not None else 0 if args.schema == 'train': trainloader = FlatImg.loadTrainData(data_split='train', is_shuffle=True) FlatImg.loadValidateAndTestData(is_shuffle=True) trainloader_len = len(trainloader) for epoch in range(epoch_start, args.n_epoch): print('* lambda_loss :'+str(FlatImg.lambda_loss)+'\t'+'learning_rate :'+str(optimizer.param_groups[0]['lr'])) print('* lambda_loss :'+str(FlatImg.lambda_loss)+'\t'+'learning_rate :'+str(optimizer.param_groups[0]['lr']), file=reslut_file) begin_train = time.time() loss_segment_list = 0 loss_l1_list = 0 loss_local_list = 0 loss_edge_list = 0 loss_rectangles_list = 0 loss_list = [] model.train() for i, (images, labels, segment) in enumerate(trainloader): images = Variable(images) labels = Variable(labels.cuda(args.device)) segment = Variable(segment.cuda(args.device)) optimizer.zero_grad() outputs, outputs_segment = FlatImg.model(images, is_softmax=False) loss_l1, loss_local, loss_edge, loss_rectangles = loss_fun(outputs, labels, size_average=True) loss_segment = loss_fun2(outputs_segment, segment) loss = FlatImg.lambda_loss*(loss_l1 + loss_local*FlatImg.lambda_loss_a + loss_edge*FlatImg.lambda_loss_b + loss_rectangles*FlatImg.lambda_loss_c) + FlatImg.lambda_loss_segment*loss_segment losses.update(loss.item()) loss.backward() optimizer.step() loss_list.append(loss.item()) loss_segment_list += loss_segment.item() loss_l1_list += loss_l1.item() loss_local_list += loss_local.item() # loss_edge_list += loss_edge.item() # loss_rectangles_list += loss_rectangles.item() if (i + 1) % args.print_freq == 0 or (i + 1) == trainloader_len: list_len = len(loss_list) print('[{0}][{1}/{2}]\t\t' '[{3:.2f} {4:.4f} {5:.2f}]\t' '[l1:{6:.4f} l:{7:.4f} e:{8:.4f} r:{9:.4f} s:{10:.4f}]\t' '{loss.avg:.4f}'.format( epoch + 1, i + 1, trainloader_len, min(loss_list), sum(loss_list) / list_len, max(loss_list), loss_l1_list / list_len, loss_local_list / list_len, loss_edge_list / list_len, loss_rectangles_list / list_len, loss_segment_list / list_len, loss=losses)) print('[{0}][{1}/{2}]\t\t' '[{3:.2f} {4:.4f} {5:.2f}]\t' '[l1:{6:.4f} l:{7:.4f} e:{8:.4f} r:{9:.4f} s:{10:.4f}]\t' '{loss.avg:.4f}'.format( epoch + 1, i + 1, trainloader_len, min(loss_list), sum(loss_list) / list_len, max(loss_list), loss_l1_list / list_len, loss_local_list / list_len, loss_edge_list / list_len, loss_rectangles_list / list_len, loss_segment_list / list_len, loss=losses), file=reslut_file) del loss_list[:] loss_segment_list = 0 loss_l1_list = 0 loss_local_list = 0 loss_edge_list = 0 loss_rectangles_list = 0 FlatImg.saveModel_epoch(epoch) # FlatImg.saveModel(epoch, save_path=path) model.eval() trian_t = time.time()-begin_train losses.reset() train_time.update(trian_t) try: FlatImg.validateOrTestModelV3(epoch, trian_t, validate_test='v_l4') FlatImg.validateOrTestModelV3(epoch, 0, validate_test='t') except: print(' Error: validate or test') try: scheduler.step() except: pass print('\n') elif args.schema == 'validate': epoch = checkpoint['epoch'] if args.resume is not None else 0 model.eval() FlatImg.validateOrTestModelV3(epoch, 0, validate_test='t_all') exit() elif args.schema == 'test': epoch = checkpoint['epoch'] if args.resume is not None else 0 model.eval() FlatImg.validateOrTestModelV3(epoch, 0, validate_test='t_all') exit() elif args.schema == 'eval': FlatImg.evalData(is_shuffle=True) epoch = checkpoint['epoch'] if args.resume is not None else 0 model.eval() FlatImg.evalModelGreyC1(epoch, is_scaling=False) exit() m, s = divmod(train_time.sum, 60) h, m = divmod(m, 60) print("All Train Time : %02d:%02d:%02d\n" % (h, m, s)) print("All Train Time : %02d:%02d:%02d\n" % (h, m, s), file=reslut_file) reslut_file.close() class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count if __name__ == '__main__': parser = argparse.ArgumentParser(description='Hyperparams') parser.add_argument('--arch', nargs='?', type=str, default='Document-Dewarping-with-Control-Points', help='Architecture') parser.add_argument('--img_shrink', nargs='?', type=int, default=None, help='short edge of the input image') parser.add_argument('--n_epoch', nargs='?', type=int, default=300, help='# of the epochs') parser.add_argument('--optimizer', type=str, default='adam', help='optimization') parser.add_argument('--l_rate', nargs='?', type=float, default=0.0002, help='Learning Rate') parser.add_argument('--print-freq', '-p', default=60, type=int, metavar='N', help='print frequency (default: 10)') # print frequency parser.add_argument('--data_path_train', default=ROOT / 'dataset/fiducial1024/fiducial1024_v1/', type=str, help='the path of train images.') # train image path parser.add_argument('--data_path_validate', default=ROOT / 'dataset/fiducial1024/fiducial1024_v1/validate/', type=str, help='the path of validate images.') # validate image path parser.add_argument('--data_path_test', default=ROOT / 'data/', type=str, help='the path of test images.') parser.add_argument('--output-path', default=ROOT / 'flat/', type=str, help='the path is used to save output --img or result.') parser.add_argument('--resume', default=ROOT / 'ICDAR2021/2021-02-03 16:15:55/143/2021-02-03 16_15_55flat_img_by_fiducial_points-fiducial1024_v1.pkl', type=str, help='Path to previous saved model to restart from') parser.add_argument('--batch_size', nargs='?', type=int, default=2, help='Batch Size')#28 parser.add_argument('--schema', type=str, default='test', help='train or test') # train validate # parser.set_defaults(resume='./ICDAR2021/2021-02-03 16:15:55/143/2021-02-03 16_15_55flat_img_by_fiducial_points-fiducial1024_v1.pkl') parser.add_argument('--parallel', default='1', type=list, help='choice the gpu id for parallel ') args = parser.parse_args() if args.resume is not None: if not os.path.isfile(args.resume): raise Exception(args.resume+' -- not exist') if args.data_path_test is None: raise Exception('-- No test path') else: if not os.path.exists(args.data_path_test): raise Exception(args.data_path_test+' -- no find') global path, date, date_time date = time.strftime('%Y-%m-%d', time.localtime(time.time())) date_time = time.strftime(' %H:%M:%S', time.localtime(time.time())) path = os.path.join(args.output_path, date) if not os.path.exists(path): os.makedirs(path) train(args)