''' 2021/2/3 Guowang Xie args: n_epoch:epoch values for training optimizer:various optimization algorithms l_rate:initial learning rate resume:the path of trained model parameter after data_path_train:datasets path for training data_path_validate:datasets path for validating data_path_test:datasets path for testing output-path:output path batch_size: schema:test or train parallel:number of gpus used, like 0, or, 0123 ''' import os, sys import argparse import torch import time import re from pathlib import Path FILE = Path(__file__).resolve() ROOT = FILE.parents[0] from network import FiducialPoints, DilatedResnetForFlatByFiducialPointsS2 # import utilsV3 as utils import utilsV4 as utils from dataloader import PerturbedDatastsForFiducialPoints_pickle_color_v2_v2 def train(args): global _re_date if args.resume is not None: re_date = re.compile(r'\d{4}-\d{1,2}-\d{1,2}') _re_date = re_date.search(str(args.resume)).group(0) reslut_file = open(path + '/' + date + date_time + ' @' + _re_date + '_' + args.arch + '.log', 'w') else: _re_date = None reslut_file = open(path+'/'+date+date_time+'_'+args.arch+'.log', 'w') # Setup Dataloader data_path = str(args.data_path_train)+'/' data_path_validate = str(args.data_path_validate)+'/' data_path_test = str(args.data_path_test)+'/' print(args) print(args, file=reslut_file) n_classes = 2 model = FiducialPoints(n_classes=n_classes, num_filter=32, architecture=DilatedResnetForFlatByFiducialPointsS2, BatchNorm='BN', in_channels=3) # if args.parallel is not None: device_ids = list(map(int, args.parallel)) args.device = torch.device('cuda:'+str(device_ids[0])) # args.gpu = device_ids[0] # if args.gpu < 8: torch.cuda.set_device(args.device) model = torch.nn.DataParallel(model, device_ids=device_ids) model.cuda(args.device) else: # exit() args.device = torch.device('cpu') print('using CPU!') if args.optimizer == 'SGD': optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.8, weight_decay=1e-12) elif args.optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=args.l_rate, weight_decay=1e-10) else: assert 'please choice optimizer' exit('error') if args.resume is not None: if os.path.isfile(args.resume): print("Loading model and optimizer from checkpoint '{}'".format(args.resume)) if args.parallel is not None: checkpoint = torch.load(args.resume, map_location=args.device) model.load_state_dict(checkpoint['model_state']) else: checkpoint = torch.load(args.resume, map_location=args.device) '''cpu''' model_parameter_dick = {} for k in checkpoint['model_state']: model_parameter_dick[k.replace('module.', '')] = checkpoint['model_state'][k] model.load_state_dict(model_parameter_dick) print("Loaded checkpoint '{}' (epoch {})" .format(args.resume.name, checkpoint['epoch'])) else: print("No checkpoint found at '{}'".format(args.resume.name)) FlatImg = utils.FlatImg(args=args, path=path, date=date, date_time=date_time, _re_date=_re_date, model=model, \ reslut_file=reslut_file, n_classes=n_classes, optimizer=optimizer, \ data_loader=PerturbedDatastsForFiducialPoints_pickle_color_v2_v2, \ data_path=data_path, data_path_validate=data_path_validate, data_path_test=data_path_test, data_preproccess=False) # , valloaderSet=valloaderSet, v_loaderSet=v_loaderSet ''' load data ''' FlatImg.loadTestData() epoch = checkpoint['epoch'] if args.resume is not None else 0 model.eval() FlatImg.validateOrTestModelV3(epoch, 0, validate_test='t_all') exit() reslut_file.close() if __name__ == '__main__': print(FILE) print(ROOT) parser = argparse.ArgumentParser(description='Hyperparams') parser.add_argument('--arch', nargs='?', type=str, default='Document-Dewarping-with-Control-Points', help='Architecture') parser.add_argument('--img_shrink', nargs='?', type=int, default=None, help='short edge of the input image') parser.add_argument('--n_epoch', nargs='?', type=int, default=300, help='# of the epochs') parser.add_argument('--optimizer', type=str, default='adam', help='optimization') parser.add_argument('--l_rate', nargs='?', type=float, default=0.0002, help='Learning Rate') parser.add_argument('--print-freq', '-p', default=60, type=int, metavar='N', help='print frequency (default: 10)') # print frequency parser.add_argument('--data_path_train', default=ROOT / 'dataset/fiducial1024/fiducial1024_v1/color/', type=str, help='the path of train images.') # train image path parser.add_argument('--data_path_validate', default=ROOT / 'dataset/fiducial1024/fiducial1024_v1/validate/', type=str, help='the path of validate images.') # validate image path parser.add_argument('--data_path_test', default=ROOT / 'data/', type=str, help='the path of test images.') parser.add_argument('--output-path', default=ROOT / 'flat/', type=str, help='the path is used to save output --img or result.') parser.add_argument('--resume', default=ROOT / 'ICDAR2021/2021-02-03 16:15:55/143/2021-02-03 16_15_55flat_img_by_fiducial_points-fiducial1024_v1.pkl', type=str, help='Path to previous saved model to restart from') parser.add_argument('--batch_size', nargs='?', type=int, default=1, help='Batch Size')#28 parser.add_argument('--schema', type=str, default='test', help='train or test') # train validate # parser.set_defaults(resume='./ICDAR2021/2021-02-03 16:15:55/143/2021-02-03 16_15_55flat_img_by_fiducial_points-fiducial1024_v1.pkl') parser.add_argument('--parallel', default=None, type=list, help='choice the gpu id for parallel ') args = parser.parse_args() if args.resume is not None: if not os.path.isfile(args.resume): raise Exception(args.resume+' -- not exist') if args.data_path_test is None: raise Exception('-- No test path') else: if not os.path.exists(args.data_path_test): raise Exception(args.data_path_test+' -- no find') global path, date, date_time date = time.strftime('%Y-%m-%d', time.localtime(time.time())) date_time = time.strftime(' %H:%M:%S', time.localtime(time.time())) path = os.path.join(args.output_path, date) if not os.path.exists(path): os.makedirs(path) train(args)