| |
| |
|
|
| """ |
| @Author : Peike Li |
| @Contact : peike.li@yahoo.com |
| @File : train.py |
| @Time : 8/4/19 3:36 PM |
| @Desc : |
| @License : This source code is licensed under the license found in the |
| LICENSE file in the root directory of this source tree. |
| """ |
|
|
| import os |
| import json |
| import timeit |
| import argparse |
|
|
| import torch |
| import torch.optim as optim |
| import torchvision.transforms as transforms |
| import torch.backends.cudnn as cudnn |
| from torch.utils import data |
|
|
| import networks |
| import utils.schp as schp |
| from datasets.datasets import LIPDataSet |
| from datasets.target_generation import generate_edge_tensor |
| from utils.transforms import BGR2RGB_transform |
| from utils.criterion import CriterionAll |
| from utils.encoding import DataParallelModel, DataParallelCriterion |
| from utils.warmup_scheduler import SGDRScheduler |
|
|
|
|
| def get_arguments(): |
| """Parse all the arguments provided from the CLI. |
| Returns: |
| A list of parsed arguments. |
| """ |
| parser = argparse.ArgumentParser(description="Self Correction for Human Parsing") |
|
|
| |
| parser.add_argument("--arch", type=str, default='resnet101') |
| |
| parser.add_argument("--data-dir", type=str, default='./data/LIP') |
| parser.add_argument("--batch-size", type=int, default=16) |
| parser.add_argument("--input-size", type=str, default='473,473') |
| parser.add_argument("--split-name", type=str, default='crop_pic') |
| parser.add_argument("--num-classes", type=int, default=20) |
| parser.add_argument("--ignore-label", type=int, default=255) |
| parser.add_argument("--random-mirror", action="store_true") |
| parser.add_argument("--random-scale", action="store_true") |
| |
| parser.add_argument("--learning-rate", type=float, default=7e-3) |
| parser.add_argument("--momentum", type=float, default=0.9) |
| parser.add_argument("--weight-decay", type=float, default=5e-4) |
| parser.add_argument("--gpu", type=str, default='0,1,2') |
| parser.add_argument("--start-epoch", type=int, default=0) |
| parser.add_argument("--epochs", type=int, default=150) |
| parser.add_argument("--eval-epochs", type=int, default=10) |
| parser.add_argument("--imagenet-pretrain", type=str, default='./pretrain_model/resnet101-imagenet.pth') |
| parser.add_argument("--log-dir", type=str, default='./log') |
| parser.add_argument("--model-restore", type=str, default='./log/checkpoint.pth.tar') |
| parser.add_argument("--schp-start", type=int, default=100, help='schp start epoch') |
| parser.add_argument("--cycle-epochs", type=int, default=10, help='schp cyclical epoch') |
| parser.add_argument("--schp-restore", type=str, default='./log/schp_checkpoint.pth.tar') |
| parser.add_argument("--lambda-s", type=float, default=1, help='segmentation loss weight') |
| parser.add_argument("--lambda-e", type=float, default=1, help='edge loss weight') |
| parser.add_argument("--lambda-c", type=float, default=0.1, help='segmentation-edge consistency loss weight') |
| return parser.parse_args() |
|
|
|
|
| def main(): |
| args = get_arguments() |
| print(args) |
|
|
| start_epoch = 0 |
| cycle_n = 0 |
|
|
| if not os.path.exists(args.log_dir): |
| os.makedirs(args.log_dir) |
| with open(os.path.join(args.log_dir, 'args.json'), 'w') as opt_file: |
| json.dump(vars(args), opt_file) |
|
|
| gpus = [int(i) for i in args.gpu.split(',')] |
| if not args.gpu == 'None': |
| os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu |
|
|
| input_size = list(map(int, args.input_size.split(','))) |
|
|
| cudnn.enabled = True |
| cudnn.benchmark = True |
|
|
| |
| AugmentCE2P = networks.init_model(args.arch, num_classes=args.num_classes, pretrained=args.imagenet_pretrain) |
| model = DataParallelModel(AugmentCE2P) |
| model.cuda() |
|
|
| IMAGE_MEAN = AugmentCE2P.mean |
| IMAGE_STD = AugmentCE2P.std |
| INPUT_SPACE = AugmentCE2P.input_space |
| print('image mean: {}'.format(IMAGE_MEAN)) |
| print('image std: {}'.format(IMAGE_STD)) |
| print('input space:{}'.format(INPUT_SPACE)) |
|
|
| restore_from = args.model_restore |
| if os.path.exists(restore_from): |
| print('Resume training from {}'.format(restore_from)) |
| checkpoint = torch.load(restore_from) |
| model.load_state_dict(checkpoint['state_dict']) |
| start_epoch = checkpoint['epoch'] |
|
|
| SCHP_AugmentCE2P = networks.init_model(args.arch, num_classes=args.num_classes, pretrained=args.imagenet_pretrain) |
| schp_model = DataParallelModel(SCHP_AugmentCE2P) |
| schp_model.cuda() |
|
|
| if os.path.exists(args.schp_restore): |
| print('Resuming schp checkpoint from {}'.format(args.schp_restore)) |
| schp_checkpoint = torch.load(args.schp_restore) |
| schp_model_state_dict = schp_checkpoint['state_dict'] |
| cycle_n = schp_checkpoint['cycle_n'] |
| schp_model.load_state_dict(schp_model_state_dict) |
|
|
| |
| criterion = CriterionAll(lambda_1=args.lambda_s, lambda_2=args.lambda_e, lambda_3=args.lambda_c, |
| num_classes=args.num_classes) |
| criterion = DataParallelCriterion(criterion) |
| criterion.cuda() |
|
|
| |
| if INPUT_SPACE == 'BGR': |
| print('BGR Transformation') |
| transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize(mean=IMAGE_MEAN, |
| std=IMAGE_STD), |
| ]) |
|
|
| elif INPUT_SPACE == 'RGB': |
| print('RGB Transformation') |
| transform = transforms.Compose([ |
| transforms.ToTensor(), |
| BGR2RGB_transform(), |
| transforms.Normalize(mean=IMAGE_MEAN, |
| std=IMAGE_STD), |
| ]) |
|
|
| train_dataset = LIPDataSet(args.data_dir, args.split_name, crop_size=input_size, transform=transform) |
| train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size * len(gpus), |
| num_workers=16, shuffle=True, pin_memory=True, drop_last=True) |
| print('Total training samples: {}'.format(len(train_dataset))) |
|
|
| |
| optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, |
| weight_decay=args.weight_decay) |
|
|
| lr_scheduler = SGDRScheduler(optimizer, total_epoch=args.epochs, |
| eta_min=args.learning_rate / 100, warmup_epoch=10, |
| start_cyclical=args.schp_start, cyclical_base_lr=args.learning_rate / 2, |
| cyclical_epoch=args.cycle_epochs) |
|
|
| total_iters = args.epochs * len(train_loader) |
| start = timeit.default_timer() |
| for epoch in range(start_epoch, args.epochs): |
| lr_scheduler.step(epoch=epoch) |
| lr = lr_scheduler.get_lr()[0] |
|
|
| model.train() |
| for i_iter, batch in enumerate(train_loader): |
| i_iter += len(train_loader) * epoch |
|
|
| images, labels, _ = batch |
| labels = labels.cuda(non_blocking=True) |
|
|
| edges = generate_edge_tensor(labels) |
| labels = labels.type(torch.cuda.LongTensor) |
| edges = edges.type(torch.cuda.LongTensor) |
|
|
| preds = model(images) |
|
|
| |
| if cycle_n >= 1: |
| with torch.no_grad(): |
| soft_preds = schp_model(images) |
| soft_parsing = [] |
| soft_edge = [] |
| for soft_pred in soft_preds: |
| soft_parsing.append(soft_pred[0][-1]) |
| soft_edge.append(soft_pred[1][-1]) |
| soft_preds = torch.cat(soft_parsing, dim=0) |
| soft_edges = torch.cat(soft_edge, dim=0) |
| else: |
| soft_preds = None |
| soft_edges = None |
|
|
| loss = criterion(preds, [labels, edges, soft_preds, soft_edges], cycle_n) |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
|
|
| if i_iter % 100 == 0: |
| print('iter = {} of {} completed, lr = {}, loss = {}'.format(i_iter, total_iters, lr, |
| loss.data.cpu().numpy())) |
| if (epoch + 1) % (args.eval_epochs) == 0: |
| schp.save_checkpoint({ |
| 'epoch': epoch + 1, |
| 'state_dict': model.state_dict(), |
| }, False, args.log_dir, filename='checkpoint_{}.pth.tar'.format(epoch + 1)) |
|
|
| |
| if (epoch + 1) >= args.schp_start and (epoch + 1 - args.schp_start) % args.cycle_epochs == 0: |
| print('Self-correction cycle number {}'.format(cycle_n)) |
| schp.moving_average(schp_model, model, 1.0 / (cycle_n + 1)) |
| cycle_n += 1 |
| schp.bn_re_estimate(train_loader, schp_model) |
| schp.save_schp_checkpoint({ |
| 'state_dict': schp_model.state_dict(), |
| 'cycle_n': cycle_n, |
| }, False, args.log_dir, filename='schp_{}_checkpoint.pth.tar'.format(cycle_n)) |
|
|
| torch.cuda.empty_cache() |
| end = timeit.default_timer() |
| print('epoch = {} of {} completed using {} s'.format(epoch, args.epochs, |
| (end - start) / (epoch - start_epoch + 1))) |
|
|
| end = timeit.default_timer() |
| print('Training Finished in {} seconds'.format(end - start)) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|