| | import os |
| | import random |
| | import copy |
| | import time |
| | import sys |
| | import shutil |
| | import argparse |
| | import errno |
| | import math |
| | import numpy as np |
| | from collections import defaultdict, OrderedDict |
| | import tensorboardX |
| | from tqdm import tqdm |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.optim as optim |
| | from torch.utils.data import DataLoader |
| | from torch.optim.lr_scheduler import StepLR |
| |
|
| | from lib.utils.tools import * |
| | from lib.model.loss import * |
| | from lib.model.loss_mesh import * |
| | from lib.utils.utils_mesh import * |
| | from lib.utils.utils_smpl import * |
| | from lib.utils.utils_data import * |
| | from lib.utils.learning import * |
| | from lib.data.dataset_mesh import MotionSMPL |
| | from lib.model.model_mesh import MeshRegressor |
| | from torch.utils.data import DataLoader |
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--config", type=str, default="configs/pretrain.yaml", help="Path to the config file.") |
| | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', help='checkpoint directory') |
| | parser.add_argument('-p', '--pretrained', default='checkpoint', type=str, metavar='PATH', help='pretrained checkpoint directory') |
| | parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME', help='checkpoint to resume (file name)') |
| | parser.add_argument('-e', '--evaluate', default='', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)') |
| | parser.add_argument('-freq', '--print_freq', default=100) |
| | parser.add_argument('-ms', '--selection', default='latest_epoch.bin', type=str, metavar='FILENAME', help='checkpoint to finetune (file name)') |
| | parser.add_argument('-sd', '--seed', default=0, type=int, help='random seed') |
| | opts = parser.parse_args() |
| | return opts |
| |
|
| | def set_random_seed(seed): |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| |
|
| | def validate(test_loader, model, criterion, dataset_name='h36m'): |
| | model.eval() |
| | print(f'===========> validating {dataset_name}') |
| | batch_time = AverageMeter() |
| | losses = AverageMeter() |
| | losses_dict = {'loss_3d_pos': AverageMeter(), |
| | 'loss_3d_scale': AverageMeter(), |
| | 'loss_3d_velocity': AverageMeter(), |
| | 'loss_lv': AverageMeter(), |
| | 'loss_lg': AverageMeter(), |
| | 'loss_a': AverageMeter(), |
| | 'loss_av': AverageMeter(), |
| | 'loss_pose': AverageMeter(), |
| | 'loss_shape': AverageMeter(), |
| | 'loss_norm': AverageMeter(), |
| | } |
| | mpjpes = AverageMeter() |
| | mpves = AverageMeter() |
| | results = defaultdict(list) |
| | smpl = SMPL(args.data_root, batch_size=1).cuda() |
| | J_regressor = smpl.J_regressor_h36m |
| | with torch.no_grad(): |
| | end = time.time() |
| | for idx, (batch_input, batch_gt) in tqdm(enumerate(test_loader)): |
| | batch_size, clip_len = batch_input.shape[:2] |
| | if torch.cuda.is_available(): |
| | batch_gt['theta'] = batch_gt['theta'].cuda().float() |
| | batch_gt['kp_3d'] = batch_gt['kp_3d'].cuda().float() |
| | batch_gt['verts'] = batch_gt['verts'].cuda().float() |
| | batch_input = batch_input.cuda().float() |
| | output = model(batch_input) |
| | output_final = output |
| | if args.flip: |
| | batch_input_flip = flip_data(batch_input) |
| | output_flip = model(batch_input_flip) |
| | output_flip_pose = output_flip[0]['theta'][:, :, :72] |
| | output_flip_shape = output_flip[0]['theta'][:, :, 72:] |
| | output_flip_pose = flip_thetas_batch(output_flip_pose) |
| | output_flip_pose = output_flip_pose.reshape(-1, 72) |
| | output_flip_shape = output_flip_shape.reshape(-1, 10) |
| | output_flip_smpl = smpl( |
| | betas=output_flip_shape, |
| | body_pose=output_flip_pose[:, 3:], |
| | global_orient=output_flip_pose[:, :3], |
| | pose2rot=True |
| | ) |
| | output_flip_verts = output_flip_smpl.vertices.detach()*1000.0 |
| | J_regressor_batch = J_regressor[None, :].expand(output_flip_verts.shape[0], -1, -1).to(output_flip_verts.device) |
| | output_flip_kp3d = torch.matmul(J_regressor_batch, output_flip_verts) |
| | output_flip_back = [{ |
| | 'theta': torch.cat((output_flip_pose.reshape(batch_size, clip_len, -1), output_flip_shape.reshape(batch_size, clip_len, -1)), dim=-1), |
| | 'verts': output_flip_verts.reshape(batch_size, clip_len, -1, 3), |
| | 'kp_3d': output_flip_kp3d.reshape(batch_size, clip_len, -1, 3), |
| | }] |
| | output_final = [{}] |
| | for k, v in output_flip[0].items(): |
| | output_final[0][k] = (output[0][k] + output_flip_back[0][k])*0.5 |
| | output = output_final |
| | loss_dict = criterion(output, batch_gt) |
| | loss = args.lambda_3d * loss_dict['loss_3d_pos'] + \ |
| | args.lambda_scale * loss_dict['loss_3d_scale'] + \ |
| | args.lambda_3dv * loss_dict['loss_3d_velocity'] + \ |
| | args.lambda_lv * loss_dict['loss_lv'] + \ |
| | args.lambda_lg * loss_dict['loss_lg'] + \ |
| | args.lambda_a * loss_dict['loss_a'] + \ |
| | args.lambda_av * loss_dict['loss_av'] + \ |
| | args.lambda_shape * loss_dict['loss_shape'] + \ |
| | args.lambda_pose * loss_dict['loss_pose'] + \ |
| | args.lambda_norm * loss_dict['loss_norm'] |
| | |
| | losses.update(loss.item(), batch_size) |
| | loss_str = '' |
| | for k, v in loss_dict.items(): |
| | losses_dict[k].update(v.item(), batch_size) |
| | loss_str += '{0} {loss.val:.3f} ({loss.avg:.3f})\t'.format(k, loss=losses_dict[k]) |
| | mpjpe, mpve = compute_error(output, batch_gt) |
| | mpjpes.update(mpjpe, batch_size) |
| | mpves.update(mpve, batch_size) |
| | |
| | for keys in output[0].keys(): |
| | output[0][keys] = output[0][keys].detach().cpu().numpy() |
| | batch_gt[keys] = batch_gt[keys].detach().cpu().numpy() |
| | results['kp_3d'].append(output[0]['kp_3d']) |
| | results['verts'].append(output[0]['verts']) |
| | results['kp_3d_gt'].append(batch_gt['kp_3d']) |
| | results['verts_gt'].append(batch_gt['verts']) |
| |
|
| | |
| | batch_time.update(time.time() - end) |
| | end = time.time() |
| |
|
| | if idx % int(opts.print_freq) == 0: |
| | print('Test: [{0}/{1}]\t' |
| | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' |
| | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' |
| | '{2}' |
| | 'PVE {mpves.val:.3f} ({mpves.avg:.3f})\t' |
| | 'JPE {mpjpes.val:.3f} ({mpjpes.avg:.3f})'.format( |
| | idx, len(test_loader), loss_str, batch_time=batch_time, |
| | loss=losses, mpves=mpves, mpjpes=mpjpes)) |
| |
|
| | print(f'==> start concating results of {dataset_name}') |
| | for term in results.keys(): |
| | results[term] = np.concatenate(results[term]) |
| | print(f'==> start evaluating {dataset_name}...') |
| | error_dict = evaluate_mesh(results) |
| | err_str = '' |
| | for err_key, err_val in error_dict.items(): |
| | err_str += '{}: {:.2f}mm \t'.format(err_key, err_val) |
| | print(f'=======================> {dataset_name} validation done: ', loss_str) |
| | print(f'=======================> {dataset_name} validation done: ', err_str) |
| | return losses.avg, error_dict['mpjpe'], error_dict['pa_mpjpe'], error_dict['mpve'], losses_dict |
| |
|
| |
|
| | def train_epoch(args, opts, model, train_loader, losses_train, losses_dict, mpjpes, mpves, criterion, optimizer, batch_time, data_time, epoch): |
| | model.train() |
| | end = time.time() |
| | for idx, (batch_input, batch_gt) in tqdm(enumerate(train_loader)): |
| | data_time.update(time.time() - end) |
| | batch_size = len(batch_input) |
| |
|
| | if torch.cuda.is_available(): |
| | batch_gt['theta'] = batch_gt['theta'].cuda().float() |
| | batch_gt['kp_3d'] = batch_gt['kp_3d'].cuda().float() |
| | batch_gt['verts'] = batch_gt['verts'].cuda().float() |
| | batch_input = batch_input.cuda().float() |
| | output = model(batch_input) |
| | optimizer.zero_grad() |
| | loss_dict = criterion(output, batch_gt) |
| | loss_train = args.lambda_3d * loss_dict['loss_3d_pos'] + \ |
| | args.lambda_scale * loss_dict['loss_3d_scale'] + \ |
| | args.lambda_3dv * loss_dict['loss_3d_velocity'] + \ |
| | args.lambda_lv * loss_dict['loss_lv'] + \ |
| | args.lambda_lg * loss_dict['loss_lg'] + \ |
| | args.lambda_a * loss_dict['loss_a'] + \ |
| | args.lambda_av * loss_dict['loss_av'] + \ |
| | args.lambda_shape * loss_dict['loss_shape'] + \ |
| | args.lambda_pose * loss_dict['loss_pose'] + \ |
| | args.lambda_norm * loss_dict['loss_norm'] |
| | losses_train.update(loss_train.item(), batch_size) |
| | loss_str = '' |
| | for k, v in loss_dict.items(): |
| | losses_dict[k].update(v.item(), batch_size) |
| | loss_str += '{0} {loss.val:.3f} ({loss.avg:.3f})\t'.format(k, loss=losses_dict[k]) |
| | |
| | mpjpe, mpve = compute_error(output, batch_gt) |
| | mpjpes.update(mpjpe, batch_size) |
| | mpves.update(mpve, batch_size) |
| | |
| | loss_train.backward() |
| | optimizer.step() |
| |
|
| | batch_time.update(time.time() - end) |
| | end = time.time() |
| | |
| | if idx % int(opts.print_freq) == 0: |
| | print('Train: [{0}][{1}/{2}]\t' |
| | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' |
| | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' |
| | 'loss {loss.val:.3f} ({loss.avg:.3f})\t' |
| | '{3}' |
| | 'PVE {mpves.val:.3f} ({mpves.avg:.3f})\t' |
| | 'JPE {mpjpes.val:.3f} ({mpjpes.avg:.3f})'.format( |
| | epoch, idx + 1, len(train_loader), loss_str, batch_time=batch_time, |
| | data_time=data_time, loss=losses_train, mpves=mpves, mpjpes=mpjpes)) |
| | sys.stdout.flush() |
| |
|
| | def train_with_config(args, opts): |
| | print(args) |
| | try: |
| | os.makedirs(opts.checkpoint) |
| | shutil.copy(opts.config, opts.checkpoint) |
| | except OSError as e: |
| | if e.errno != errno.EEXIST: |
| | raise RuntimeError('Unable to create checkpoint directory:', opts.checkpoint) |
| | train_writer = tensorboardX.SummaryWriter(os.path.join(opts.checkpoint, "logs")) |
| | model_backbone = load_backbone(args) |
| | if args.finetune: |
| | if opts.resume or opts.evaluate: |
| | pass |
| | else: |
| | chk_filename = os.path.join(opts.pretrained, opts.selection) |
| | print('Loading backbone', chk_filename) |
| | checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)['model_pos'] |
| | model_backbone = load_pretrained_weights(model_backbone, checkpoint) |
| | if args.partial_train: |
| | model_backbone = partial_train_layers(model_backbone, args.partial_train) |
| | model = MeshRegressor(args, backbone=model_backbone, dim_rep=args.dim_rep, hidden_dim=args.hidden_dim, dropout_ratio=args.dropout, num_joints=args.num_joints) |
| | criterion = MeshLoss(loss_type = args.loss_type) |
| | best_jpe = 9999.0 |
| | model_params = 0 |
| | for parameter in model.parameters(): |
| | if parameter.requires_grad == True: |
| | model_params = model_params + parameter.numel() |
| | print('INFO: Trainable parameter count:', model_params) |
| | print('Loading dataset...') |
| | trainloader_params = { |
| | 'batch_size': args.batch_size, |
| | 'shuffle': True, |
| | 'num_workers': 8, |
| | 'pin_memory': True, |
| | 'prefetch_factor': 4, |
| | 'persistent_workers': True |
| | } |
| | testloader_params = { |
| | 'batch_size': args.batch_size, |
| | 'shuffle': False, |
| | 'num_workers': 8, |
| | 'pin_memory': True, |
| | 'prefetch_factor': 4, |
| | 'persistent_workers': True |
| | } |
| | if hasattr(args, "dt_file_h36m"): |
| | mesh_train = MotionSMPL(args, data_split='train', dataset="h36m") |
| | mesh_val = MotionSMPL(args, data_split='test', dataset="h36m") |
| | train_loader = DataLoader(mesh_train, **trainloader_params) |
| | test_loader = DataLoader(mesh_val, **testloader_params) |
| | print('INFO: Training on {} batches (h36m)'.format(len(train_loader))) |
| |
|
| | if hasattr(args, "dt_file_pw3d"): |
| | if args.train_pw3d: |
| | mesh_train_pw3d = MotionSMPL(args, data_split='train', dataset="pw3d") |
| | train_loader_pw3d = DataLoader(mesh_train_pw3d, **trainloader_params) |
| | print('INFO: Training on {} batches (pw3d)'.format(len(train_loader_pw3d))) |
| | mesh_val_pw3d = MotionSMPL(args, data_split='test', dataset="pw3d") |
| | test_loader_pw3d = DataLoader(mesh_val_pw3d, **testloader_params) |
| | |
| | |
| | trainloader_img_params = { |
| | 'batch_size': args.batch_size_img, |
| | 'shuffle': True, |
| | 'num_workers': 8, |
| | 'pin_memory': True, |
| | 'prefetch_factor': 4, |
| | 'persistent_workers': True |
| | } |
| | testloader_img_params = { |
| | 'batch_size': args.batch_size_img, |
| | 'shuffle': False, |
| | 'num_workers': 8, |
| | 'pin_memory': True, |
| | 'prefetch_factor': 4, |
| | 'persistent_workers': True |
| | } |
| | |
| | if hasattr(args, "dt_file_coco"): |
| | mesh_train_coco = MotionSMPL(args, data_split='train', dataset="coco") |
| | mesh_val_coco = MotionSMPL(args, data_split='test', dataset="coco") |
| | train_loader_coco = DataLoader(mesh_train_coco, **trainloader_img_params) |
| | test_loader_coco = DataLoader(mesh_val_coco, **testloader_img_params) |
| | print('INFO: Training on {} batches (coco)'.format(len(train_loader_coco))) |
| |
|
| | if torch.cuda.is_available(): |
| | model = nn.DataParallel(model) |
| | model = model.cuda() |
| | |
| | chk_filename = os.path.join(opts.checkpoint, "latest_epoch.bin") |
| | if os.path.exists(chk_filename): |
| | opts.resume = chk_filename |
| | if opts.resume or opts.evaluate: |
| | chk_filename = opts.evaluate if opts.evaluate else opts.resume |
| | print('Loading checkpoint', chk_filename) |
| | checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage) |
| | model.load_state_dict(checkpoint['model'], strict=True) |
| | if not opts.evaluate: |
| | optimizer = optim.AdamW( |
| | [ {"params": filter(lambda p: p.requires_grad, model.module.backbone.parameters()), "lr": args.lr_backbone}, |
| | {"params": filter(lambda p: p.requires_grad, model.module.head.parameters()), "lr": args.lr_head}, |
| | ], lr=args.lr_backbone, |
| | weight_decay=args.weight_decay |
| | ) |
| | scheduler = StepLR(optimizer, step_size=1, gamma=args.lr_decay) |
| | st = 0 |
| | if opts.resume: |
| | st = checkpoint['epoch'] |
| | if 'optimizer' in checkpoint and checkpoint['optimizer'] is not None: |
| | optimizer.load_state_dict(checkpoint['optimizer']) |
| | else: |
| | print('WARNING: this checkpoint does not contain an optimizer state. The optimizer will be reinitialized.') |
| | lr = checkpoint['lr'] |
| | if 'best_jpe' in checkpoint and checkpoint['best_jpe'] is not None: |
| | best_jpe = checkpoint['best_jpe'] |
| | |
| | |
| | for epoch in range(st, args.epochs): |
| | print('Training epoch %d.' % epoch) |
| | losses_train = AverageMeter() |
| | losses_dict = { |
| | 'loss_3d_pos': AverageMeter(), |
| | 'loss_3d_scale': AverageMeter(), |
| | 'loss_3d_velocity': AverageMeter(), |
| | 'loss_lv': AverageMeter(), |
| | 'loss_lg': AverageMeter(), |
| | 'loss_a': AverageMeter(), |
| | 'loss_av': AverageMeter(), |
| | 'loss_pose': AverageMeter(), |
| | 'loss_shape': AverageMeter(), |
| | 'loss_norm': AverageMeter(), |
| | } |
| | mpjpes = AverageMeter() |
| | mpves = AverageMeter() |
| | batch_time = AverageMeter() |
| | data_time = AverageMeter() |
| | |
| | if hasattr(args, "dt_file_h36m") and epoch < args.warmup_h36m: |
| | train_epoch(args, opts, model, train_loader, losses_train, losses_dict, mpjpes, mpves, criterion, optimizer, batch_time, data_time, epoch) |
| | test_loss, test_mpjpe, test_pa_mpjpe, test_mpve, test_losses_dict = validate(test_loader, model, criterion, 'h36m') |
| | for k, v in test_losses_dict.items(): |
| | train_writer.add_scalar('test_loss/'+k, v.avg, epoch + 1) |
| | train_writer.add_scalar('test_loss', test_loss, epoch + 1) |
| | train_writer.add_scalar('test_mpjpe', test_mpjpe, epoch + 1) |
| | train_writer.add_scalar('test_pa_mpjpe', test_pa_mpjpe, epoch + 1) |
| | train_writer.add_scalar('test_mpve', test_mpve, epoch + 1) |
| |
|
| | if hasattr(args, "dt_file_coco") and epoch < args.warmup_coco: |
| | train_epoch(args, opts, model, train_loader_coco, losses_train, losses_dict, mpjpes, mpves, criterion, optimizer, batch_time, data_time, epoch) |
| | |
| | if hasattr(args, "dt_file_pw3d"): |
| | if args.train_pw3d: |
| | train_epoch(args, opts, model, train_loader_pw3d, losses_train, losses_dict, mpjpes, mpves, criterion, optimizer, batch_time, data_time, epoch) |
| | test_loss_pw3d, test_mpjpe_pw3d, test_pa_mpjpe_pw3d, test_mpve_pw3d, test_losses_dict_pw3d = validate(test_loader_pw3d, model, criterion, 'pw3d') |
| | for k, v in test_losses_dict_pw3d.items(): |
| | train_writer.add_scalar('test_loss_pw3d/'+k, v.avg, epoch + 1) |
| | train_writer.add_scalar('test_loss_pw3d', test_loss_pw3d, epoch + 1) |
| | train_writer.add_scalar('test_mpjpe_pw3d', test_mpjpe_pw3d, epoch + 1) |
| | train_writer.add_scalar('test_pa_mpjpe_pw3d', test_pa_mpjpe_pw3d, epoch + 1) |
| | train_writer.add_scalar('test_mpve_pw3d', test_mpve_pw3d, epoch + 1) |
| | |
| | for k, v in losses_dict.items(): |
| | train_writer.add_scalar('train_loss/'+k, v.avg, epoch + 1) |
| | train_writer.add_scalar('train_loss', losses_train.avg, epoch + 1) |
| | train_writer.add_scalar('train_mpjpe', mpjpes.avg, epoch + 1) |
| | train_writer.add_scalar('train_mpve', mpves.avg, epoch + 1) |
| | |
| | |
| | scheduler.step() |
| | |
| | chk_path = os.path.join(opts.checkpoint, 'latest_epoch.bin') |
| | print('Saving checkpoint to', chk_path) |
| | torch.save({ |
| | 'epoch': epoch+1, |
| | 'lr': scheduler.get_last_lr(), |
| | 'optimizer': optimizer.state_dict(), |
| | 'model': model.state_dict(), |
| | 'best_jpe' : best_jpe |
| | }, chk_path) |
| | |
| | |
| | if (epoch+1) % args.checkpoint_frequency == 0: |
| | chk_path = os.path.join(opts.checkpoint, 'epoch_{}.bin'.format(epoch)) |
| | print('Saving checkpoint to', chk_path) |
| | torch.save({ |
| | 'epoch': epoch+1, |
| | 'lr': scheduler.get_last_lr(), |
| | 'optimizer': optimizer.state_dict(), |
| | 'model': model.state_dict(), |
| | 'best_jpe' : best_jpe |
| | }, chk_path) |
| |
|
| | if hasattr(args, "dt_file_pw3d"): |
| | best_jpe_cur = test_mpjpe_pw3d |
| | else: |
| | best_jpe_cur = test_mpjpe |
| | |
| | best_chk_path = os.path.join(opts.checkpoint, 'best_epoch.bin'.format(epoch)) |
| | if best_jpe_cur < best_jpe: |
| | best_jpe = best_jpe_cur |
| | print("save best checkpoint") |
| | torch.save({ |
| | 'epoch': epoch+1, |
| | 'lr': scheduler.get_last_lr(), |
| | 'optimizer': optimizer.state_dict(), |
| | 'model': model.state_dict(), |
| | 'best_jpe' : best_jpe |
| | }, best_chk_path) |
| |
|
| | if opts.evaluate: |
| | if hasattr(args, "dt_file_h36m"): |
| | test_loss, test_mpjpe, test_pa_mpjpe, test_mpve, _ = validate(test_loader, model, criterion, 'h36m') |
| | if hasattr(args, "dt_file_pw3d"): |
| | test_loss_pw3d, test_mpjpe_pw3d, test_pa_mpjpe_pw3d, test_mpve_pw3d, _ = validate(test_loader_pw3d, model, criterion, 'pw3d') |
| |
|
| | if __name__ == "__main__": |
| | opts = parse_args() |
| | set_random_seed(opts.seed) |
| | args = get_config(opts.config) |
| | train_with_config(args, opts) |