Source / train.py
Richard1231's picture
Upload folder using huggingface_hub
470ac18
'''
2021/2/3
Guowang Xie
args:
n_epoch:epoch values for training
optimizer:various optimization algorithms
l_rate:initial learning rate
resume:the path of trained model parameter after
data_path_train:datasets path for training
data_path_validate:datasets path for validating
data_path_test:datasets path for testing
output-path:output path
batch_size:
schema:test or train
parallel:number of gpus used, like 0, or, 0123
'''
import os, sys
import argparse
import torch
from torch.autograd import Variable
import warnings
import time
import re
from pathlib import Path
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]
from network import FiducialPoints, DilatedResnetForFlatByFiducialPointsS2
# import utilsV3 as utils
import utilsV4 as utils
from dataloader import PerturbedDatastsForFiducialPoints_pickle_color_v2_v2
from loss import Losses
def train(args):
global _re_date
if args.resume is not None:
re_date = re.compile(r'\d{4}-\d{1,2}-\d{1,2}')
_re_date = re_date.search(str(args.resume)).group(0)
reslut_file = open(path + '/' + date + date_time + ' @' + _re_date + '_' + args.arch + '.log', 'w')
else:
_re_date = None
reslut_file = open(path+'/'+date+date_time+'_'+args.arch+'.log', 'w')
# Setup Dataloader
data_path = str(args.data_path_train)+'/'
data_path_validate = str(args.data_path_validate)+'/'
data_path_test = str(args.data_path_test)+'/'
print(args)
print(args, file=reslut_file)
n_classes = 2
model = FiducialPoints(n_classes=n_classes, num_filter=32, architecture=DilatedResnetForFlatByFiducialPointsS2, BatchNorm='BN', in_channels=3) #
if args.parallel is not None:
device_ids = list(map(int, args.parallel))
args.device = torch.device('cuda:'+str(device_ids[0]))
# args.gpu = device_ids[0]
# if args.gpu < 8:
torch.cuda.set_device(args.device)
model = torch.nn.DataParallel(model, device_ids=device_ids)
model.cuda(args.device)
elif args.distributed:
model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model)
else:
warnings.warn('no found gpu')
exit()
if args.optimizer == 'SGD':
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.8, weight_decay=1e-12)
elif args.optimizer == 'adam':
optimizer = torch.optim.Adam(model.parameters(), lr=args.l_rate, weight_decay=1e-10)
else:
assert 'please choice optimizer'
exit('error')
if args.resume is not None:
if os.path.isfile(args.resume):
print("Loading model and optimizer from checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume, map_location=args.device)
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
print("Loaded checkpoint '{}' (epoch {})"
.format(args.resume.name, checkpoint['epoch']))
else:
print("No checkpoint found at '{}'".format(args.resume.name))
loss_fun_classes = Losses(classify_size_average=True, args_gpu=args.device)
loss_fun = loss_fun_classes.loss_fn4_v5_r_4 # *
# loss_fun = loss_fun_classes.loss_fn4_v5_r_3 # *
loss_fun2 = loss_fun_classes.loss_fn_l1_loss
FlatImg = utils.FlatImg(args=args, path=path, date=date, date_time=date_time, _re_date=_re_date, model=model, \
reslut_file=reslut_file, n_classes=n_classes, optimizer=optimizer, \
loss_fn=loss_fun, loss_fn2=loss_fun2, data_loader=PerturbedDatastsForFiducialPoints_pickle_color_v2_v2, \
data_path=data_path, data_path_validate=data_path_validate, data_path_test=data_path_test, data_preproccess=False) # , valloaderSet=valloaderSet, v_loaderSet=v_loaderSet
''' load data '''
FlatImg.loadTestData()
train_time = AverageMeter()
losses = AverageMeter()
FlatImg.lambda_loss = 1
FlatImg.lambda_loss_segment = 0.01
FlatImg.lambda_loss_a = 0.1
FlatImg.lambda_loss_b = 0.001
FlatImg.lambda_loss_c = 0.01
scheduler = torch.optim.lr_scheduler.MultiStepLR(FlatImg.optimizer, milestones=[40, 90, 150, 200], gamma=0.5)
epoch_start = checkpoint['epoch'] if args.resume is not None else 0
if args.schema == 'train':
trainloader = FlatImg.loadTrainData(data_split='train', is_shuffle=True)
FlatImg.loadValidateAndTestData(is_shuffle=True)
trainloader_len = len(trainloader)
for epoch in range(epoch_start, args.n_epoch):
print('* lambda_loss :'+str(FlatImg.lambda_loss)+'\t'+'learning_rate :'+str(optimizer.param_groups[0]['lr']))
print('* lambda_loss :'+str(FlatImg.lambda_loss)+'\t'+'learning_rate :'+str(optimizer.param_groups[0]['lr']), file=reslut_file)
begin_train = time.time()
loss_segment_list = 0
loss_l1_list = 0
loss_local_list = 0
loss_edge_list = 0
loss_rectangles_list = 0
loss_list = []
model.train()
for i, (images, labels, segment) in enumerate(trainloader):
images = Variable(images)
labels = Variable(labels.cuda(args.device))
segment = Variable(segment.cuda(args.device))
optimizer.zero_grad()
outputs, outputs_segment = FlatImg.model(images, is_softmax=False)
loss_l1, loss_local, loss_edge, loss_rectangles = loss_fun(outputs, labels, size_average=True)
loss_segment = loss_fun2(outputs_segment, segment)
loss = FlatImg.lambda_loss*(loss_l1 + loss_local*FlatImg.lambda_loss_a + loss_edge*FlatImg.lambda_loss_b + loss_rectangles*FlatImg.lambda_loss_c) + FlatImg.lambda_loss_segment*loss_segment
losses.update(loss.item())
loss.backward()
optimizer.step()
loss_list.append(loss.item())
loss_segment_list += loss_segment.item()
loss_l1_list += loss_l1.item()
loss_local_list += loss_local.item()
# loss_edge_list += loss_edge.item()
# loss_rectangles_list += loss_rectangles.item()
if (i + 1) % args.print_freq == 0 or (i + 1) == trainloader_len:
list_len = len(loss_list)
print('[{0}][{1}/{2}]\t\t'
'[{3:.2f} {4:.4f} {5:.2f}]\t'
'[l1:{6:.4f} l:{7:.4f} e:{8:.4f} r:{9:.4f} s:{10:.4f}]\t'
'{loss.avg:.4f}'.format(
epoch + 1, i + 1, trainloader_len,
min(loss_list), sum(loss_list) / list_len, max(loss_list),
loss_l1_list / list_len, loss_local_list / list_len, loss_edge_list / list_len, loss_rectangles_list / list_len, loss_segment_list / list_len,
loss=losses))
print('[{0}][{1}/{2}]\t\t'
'[{3:.2f} {4:.4f} {5:.2f}]\t'
'[l1:{6:.4f} l:{7:.4f} e:{8:.4f} r:{9:.4f} s:{10:.4f}]\t'
'{loss.avg:.4f}'.format(
epoch + 1, i + 1, trainloader_len,
min(loss_list), sum(loss_list) / list_len, max(loss_list),
loss_l1_list / list_len, loss_local_list / list_len, loss_edge_list / list_len, loss_rectangles_list / list_len, loss_segment_list / list_len,
loss=losses), file=reslut_file)
del loss_list[:]
loss_segment_list = 0
loss_l1_list = 0
loss_local_list = 0
loss_edge_list = 0
loss_rectangles_list = 0
FlatImg.saveModel_epoch(epoch) # FlatImg.saveModel(epoch, save_path=path)
model.eval()
trian_t = time.time()-begin_train
losses.reset()
train_time.update(trian_t)
try:
FlatImg.validateOrTestModelV3(epoch, trian_t, validate_test='v_l4')
FlatImg.validateOrTestModelV3(epoch, 0, validate_test='t')
except:
print(' Error: validate or test')
try:
scheduler.step()
except:
pass
print('\n')
elif args.schema == 'validate':
epoch = checkpoint['epoch'] if args.resume is not None else 0
model.eval()
FlatImg.validateOrTestModelV3(epoch, 0, validate_test='t_all')
exit()
elif args.schema == 'test':
epoch = checkpoint['epoch'] if args.resume is not None else 0
model.eval()
FlatImg.validateOrTestModelV3(epoch, 0, validate_test='t_all')
exit()
elif args.schema == 'eval':
FlatImg.evalData(is_shuffle=True)
epoch = checkpoint['epoch'] if args.resume is not None else 0
model.eval()
FlatImg.evalModelGreyC1(epoch, is_scaling=False)
exit()
m, s = divmod(train_time.sum, 60)
h, m = divmod(m, 60)
print("All Train Time : %02d:%02d:%02d\n" % (h, m, s))
print("All Train Time : %02d:%02d:%02d\n" % (h, m, s), file=reslut_file)
reslut_file.close()
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Hyperparams')
parser.add_argument('--arch', nargs='?', type=str, default='Document-Dewarping-with-Control-Points',
help='Architecture')
parser.add_argument('--img_shrink', nargs='?', type=int, default=None,
help='short edge of the input image')
parser.add_argument('--n_epoch', nargs='?', type=int, default=300,
help='# of the epochs')
parser.add_argument('--optimizer', type=str, default='adam',
help='optimization')
parser.add_argument('--l_rate', nargs='?', type=float, default=0.0002,
help='Learning Rate')
parser.add_argument('--print-freq', '-p', default=60, type=int,
metavar='N', help='print frequency (default: 10)') # print frequency
parser.add_argument('--data_path_train', default=ROOT / 'dataset/fiducial1024/fiducial1024_v1/', type=str,
help='the path of train images.') # train image path
parser.add_argument('--data_path_validate', default=ROOT / 'dataset/fiducial1024/fiducial1024_v1/validate/', type=str,
help='the path of validate images.') # validate image path
parser.add_argument('--data_path_test', default=ROOT / 'data/', type=str, help='the path of test images.')
parser.add_argument('--output-path', default=ROOT / 'flat/', type=str, help='the path is used to save output --img or result.')
parser.add_argument('--resume', default=ROOT / 'ICDAR2021/2021-02-03 16:15:55/143/2021-02-03 16_15_55flat_img_by_fiducial_points-fiducial1024_v1.pkl', type=str,
help='Path to previous saved model to restart from')
parser.add_argument('--batch_size', nargs='?', type=int, default=2,
help='Batch Size')#28
parser.add_argument('--schema', type=str, default='test',
help='train or test') # train validate
# parser.set_defaults(resume='./ICDAR2021/2021-02-03 16:15:55/143/2021-02-03 16_15_55flat_img_by_fiducial_points-fiducial1024_v1.pkl')
parser.add_argument('--parallel', default='1', type=list,
help='choice the gpu id for parallel ')
args = parser.parse_args()
if args.resume is not None:
if not os.path.isfile(args.resume):
raise Exception(args.resume+' -- not exist')
if args.data_path_test is None:
raise Exception('-- No test path')
else:
if not os.path.exists(args.data_path_test):
raise Exception(args.data_path_test+' -- no find')
global path, date, date_time
date = time.strftime('%Y-%m-%d', time.localtime(time.time()))
date_time = time.strftime(' %H:%M:%S', time.localtime(time.time()))
path = os.path.join(args.output_path, date)
if not os.path.exists(path):
os.makedirs(path)
train(args)