import numpy as np import random from common.arguments import parse_args import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import os import sys import errno import math from einops import rearrange, repeat from copy import deepcopy from common.camera import * import collections from common.ddhpose_3dhp import * from common.loss import * from common.generators_3dhp import ChunkedGenerator_Seq, UnchunkedGenerator_Seq from time import time from common.utils import * from common.logging import Logger from torch.utils.tensorboard import SummaryWriter from datetime import datetime import scipy.io as scio #cudnn.benchmark = True torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # import ptvsd # ptvsd.enable_attach(address = ('192.168.210.130', 5678)) # print("ptvsd start") # ptvsd.wait_for_attach() # print("start debuging") # joints_errs = [] args = parse_args() os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu if args.evaluate != '': description = "Evaluate!" elif args.evaluate == '': description = "Train!" # initial setting TIMESTAMP = "{0:%Y%m%dT%H-%M-%S/}".format(datetime.now()) # tensorboard if not args.nolog: writer = SummaryWriter(args.log+'_'+TIMESTAMP) writer.add_text('description', description) writer.add_text('command', 'python ' + ' '.join(sys.argv)) # logging setting logfile = os.path.join(args.log+'_'+TIMESTAMP, 'logging.log') sys.stdout = Logger(logfile) print(description) print('python ' + ' '.join(sys.argv)) print("CUDA Device Count: ", torch.cuda.device_count()) print(args) manualSeed = 1 random.seed(manualSeed) torch.manual_seed(manualSeed) np.random.seed(manualSeed) torch.cuda.manual_seed_all(manualSeed) # if not assign checkpoint path, Save checkpoint file into log folder if args.checkpoint=='': args.checkpoint = args.log+'_'+TIMESTAMP try: # Create checkpoint directory if it does not exist os.makedirs(args.checkpoint) except OSError as e: if e.errno != errno.EEXIST: raise RuntimeError('Unable to create checkpoint directory:', args.checkpoint) # dataset loading print('Loading dataset...') dataset_path = 'data/data_3d_' + args.dataset + '.npz' if args.dataset == 'h36m': from common.h36m_dataset import Human36mDataset dataset = Human36mDataset(dataset_path) elif args.dataset.startswith('humaneva'): from common.humaneva_dataset import HumanEvaDataset dataset = HumanEvaDataset(dataset_path) elif args.dataset.startswith('custom'): from common.custom_dataset import CustomDataset dataset = CustomDataset('data/data_2d_' + args.dataset + '_' + args.keypoints + '.npz') else: raise KeyError('Invalid dataset') print('Preparing data...') out_poses_3d_train = {} out_poses_2d_train = {} out_poses_3d_test = {} out_poses_2d_test = {} valid_frame = {} kps_left, kps_right = [5, 6, 7, 11, 12, 13], [2, 3, 4, 8, 9, 10] joints_left, joints_right = [5, 6, 7, 11, 12, 13], [2, 3, 4, 8, 9, 10] def getbonelength(seq, boneindex): bs = seq.size(0) ss = seq.size(1) seq = seq.view(-1,seq.size(2),seq.size(3)) bone = [] for index in boneindex: bone.append(seq[:,index[1]] - seq[:,index[0]]) bone = torch.stack(bone,1) bone = torch.pow(torch.pow(bone,2).sum(2),0.5) bone = bone.view(bs,ss, bone.size(1),1) return bone def getbonedirect(seq, boneindex): bs = seq.size(0) ss = seq.size(1) seq = seq.view(-1,seq.size(2),seq.size(3)) bone = [] for index in boneindex: bone.append(seq[:,index[1]] - seq[:,index[0]]) bonedirect = torch.stack(bone,1) bonesum = torch.pow(torch.pow(bonedirect,2).sum(2), 0.5).unsqueeze(2) bonedirect = bonedirect/bonesum bonedirect = bonedirect.view(bs,ss,-1,3) return bonedirect data_train = np.load("./data/data_train_3dhp_ori.npz", allow_pickle=True)['data'].item() for seq in data_train.keys(): for cam in data_train[seq][0].keys(): anim = data_train[seq][0][cam] subject_name, seq_name = seq.split(" ") data_3d = anim['data_3d'] data_3d[:, :14] -= data_3d[:, 14:15] data_3d[:, 15:] -= data_3d[:, 14:15] out_poses_3d_train[(subject_name, seq_name, cam)] = data_3d data_2d = anim['data_2d'] data_2d[..., :2] = normalize_screen_coordinates(data_2d[..., :2], w=2048, h=2048) out_poses_2d_train[(subject_name, seq_name, cam)] = data_2d data_test = np.load("./data/data_test_3dhp_ori.npz", allow_pickle=True)['data'].item() for seq in data_test.keys(): anim = data_test[seq] valid_frame[seq] = anim["valid"] data_3d = anim['data_3d'] data_3d[:, :14] -= data_3d[:, 14:15] data_3d[:, 15:] -= data_3d[:, 14:15] out_poses_3d_test[seq] = data_3d data_2d = anim['data_2d'] if seq == "TS5" or seq == "TS6": width = 1920 height = 1080 else: width = 2048 height = 2048 data_2d[..., :2] = normalize_screen_coordinates(data_2d[..., :2], w=width, h=height) out_poses_2d_test[seq] = data_2d subjects_train = args.subjects_train.split(',') subjects_semi = [] if not args.subjects_unlabeled else args.subjects_unlabeled.split(',') if not args.render: subjects_test = args.subjects_test.split(',') else: subjects_test = [args.viz_subject] def fetch(subjects, action_filter=None, subset=1, parse_3d_poses=True): out_poses_3d = [] out_poses_2d = [] out_camera_params = [] for subject in subjects: for action in keypoints[subject].keys(): if action_filter is not None: found = False for a in action_filter: if action.startswith(a): found = True break if not found: continue poses_2d = keypoints[subject][action] for i in range(len(poses_2d)): # Iterate across cameras out_poses_2d.append(poses_2d[i]) if subject in dataset.cameras(): cams = dataset.cameras()[subject] assert len(cams) == len(poses_2d), 'Camera count mismatch' for cam in cams: if 'intrinsic' in cam: out_camera_params.append(cam['intrinsic']) if parse_3d_poses and 'positions_3d' in dataset[subject][action]: poses_3d = dataset[subject][action]['positions_3d'] assert len(poses_3d) == len(poses_2d), 'Camera count mismatch' for i in range(len(poses_3d)): # Iterate across cameras out_poses_3d.append(poses_3d[i]) if len(out_camera_params) == 0: out_camera_params = None if len(out_poses_3d) == 0: out_poses_3d = None stride = args.downsample if subset < 1: for i in range(len(out_poses_2d)): n_frames = int(round(len(out_poses_2d[i])//stride * subset)*stride) start = deterministic_random(0, len(out_poses_2d[i]) - n_frames + 1, str(len(out_poses_2d[i]))) out_poses_2d[i] = out_poses_2d[i][start:start+n_frames:stride] if out_poses_3d is not None: out_poses_3d[i] = out_poses_3d[i][start:start+n_frames:stride] elif stride > 1: # Downsample as requested for i in range(len(out_poses_2d)): out_poses_2d[i] = out_poses_2d[i][::stride] if out_poses_3d is not None: out_poses_3d[i] = out_poses_3d[i][::stride] return out_camera_params, out_poses_3d, out_poses_2d action_filter = None if args.actions == '*' else args.actions.split(',') if action_filter is not None: print('Selected actions:', action_filter) #cameras_valid, poses_valid, poses_valid_2d = fetch(subjects_test, action_filter) # set receptive_field as number assigned receptive_field = args.number_of_frames print('INFO: Receptive field: {} frames'.format(receptive_field)) if not args.nolog: writer.add_text(args.log+'_'+TIMESTAMP + '/Receptive field', str(receptive_field)) pad = (receptive_field -1) // 2 # Padding on each side min_loss = args.min_loss # width = cam['res_w'] # height = cam['res_h'] # num_joints = keypoints_metadata['num_joints'] args = parse_args() boneindextemp = args.boneindex_3dhp.split(',') boneindex = [] for i in range(0,len(boneindextemp),2): boneindex.append([int(boneindextemp[i]), int(boneindextemp[i+1])]) model_pos_train = DDHPose(args, joints_left, joints_right, is_train=True) model_pos_test_temp = DDHPose(args,joints_left, joints_right, is_train=False) model_pos = DDHPose(args,joints_left, joints_right, is_train=False, num_proposals=args.num_proposals, sampling_timesteps=args.sampling_timesteps) ################# causal_shift = 0 model_params = 0 for parameter in model_pos.parameters(): model_params += parameter.numel() print('INFO: Trainable parameter count:', model_params/1000000, 'Million') if not args.nolog: writer.add_text(args.log+'_'+TIMESTAMP + '/Trainable parameter count', str(model_params/1000000) + ' Million') # make model parallel if torch.cuda.is_available(): model_pos = nn.DataParallel(model_pos) model_pos = model_pos.cuda() model_pos_train = nn.DataParallel(model_pos_train) model_pos_train = model_pos_train.cuda() model_pos_test_temp = nn.DataParallel(model_pos_test_temp) model_pos_test_temp = model_pos_test_temp.cuda() if args.resume or args.evaluate: chk_filename = os.path.join(args.checkpoint, args.resume if args.resume else args.evaluate) # chk_filename = args.resume or args.evaluate print('Loading checkpoint', chk_filename) checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage) print('This model was trained for {} epochs'.format(checkpoint['epoch'])) model_pos_train.load_state_dict(checkpoint['model_pos'], strict=False) model_pos.load_state_dict(checkpoint['model_pos'], strict=False) test_generator = UnchunkedGenerator_Seq(None, out_poses_3d_test, out_poses_2d_test, pad=pad, causal_shift=causal_shift, augment=False, kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, joints_right=joints_right, valid_frame=valid_frame) print('INFO: Testing on {} frames'.format(test_generator.num_frames())) if not args.nolog: writer.add_text(args.log+'_'+TIMESTAMP + '/Testing Frames', str(test_generator.num_frames())) def eval_data_prepare(receptive_field, inputs_2d, inputs_3d, valid_frame): assert inputs_2d.shape[:-1] == inputs_3d.shape[:-1], "2d and 3d inputs shape must be same! "+str(inputs_2d.shape)+str(inputs_3d.shape) inputs_2d_p = torch.squeeze(inputs_2d) inputs_3d_p = torch.squeeze(inputs_3d) valid_frame = valid_frame.unsqueeze(1) if inputs_2d_p.shape[0] / receptive_field > inputs_2d_p.shape[0] // receptive_field: out_num = inputs_2d_p.shape[0] // receptive_field+1 elif inputs_2d_p.shape[0] / receptive_field == inputs_2d_p.shape[0] // receptive_field: out_num = inputs_2d_p.shape[0] // receptive_field eval_input_2d = torch.empty(out_num, receptive_field, inputs_2d_p.shape[1], inputs_2d_p.shape[2]) eval_input_3d = torch.empty(out_num, receptive_field, inputs_3d_p.shape[1], inputs_3d_p.shape[2]) eval_valid_frame = torch.empty(out_num, receptive_field, 1) for i in range(out_num-1): eval_input_2d[i,:,:,:] = inputs_2d_p[i*receptive_field:i*receptive_field+receptive_field,:,:] eval_input_3d[i,:,:,:] = inputs_3d_p[i*receptive_field:i*receptive_field+receptive_field,:,:] eval_valid_frame[i, :, :] = valid_frame[i * receptive_field:i * receptive_field + receptive_field, :] if inputs_2d_p.shape[0] < receptive_field: from torch.nn import functional as F pad_right = receptive_field-inputs_2d_p.shape[0] inputs_2d_p = rearrange(inputs_2d_p, 'b f c -> f c b') inputs_2d_p = F.pad(inputs_2d_p, (0,pad_right), mode='replicate') # inputs_2d_p = np.pad(inputs_2d_p, ((0, receptive_field-inputs_2d_p.shape[0]), (0, 0), (0, 0)), 'edge') inputs_2d_p = rearrange(inputs_2d_p, 'f c b -> b f c') if inputs_3d_p.shape[0] < receptive_field: pad_right = receptive_field-inputs_3d_p.shape[0] inputs_3d_p = rearrange(inputs_3d_p, 'b f c -> f c b') inputs_3d_p = F.pad(inputs_3d_p, (0,pad_right), mode='replicate') inputs_3d_p = rearrange(inputs_3d_p, 'f c b -> b f c') if valid_frame.shape[0] < receptive_field: pad_right = receptive_field-valid_frame.shape[0] valid_frame = rearrange(valid_frame, 'f c -> c f') valid_frame = F.pad(valid_frame, (0,pad_right), mode='replicate') valid_frame = rearrange(valid_frame, 'c f -> f c') eval_input_2d[-1,:,:,:] = inputs_2d_p[-receptive_field:,:,:] eval_input_3d[-1,:,:,:] = inputs_3d_p[-receptive_field:,:,:] eval_valid_frame[-1, :, :] = valid_frame[-receptive_field:, :] return eval_input_2d, eval_input_3d, eval_valid_frame def pose_post_process(pose_pred, data_list, keys, receptive_field): for ii in range(pose_pred.shape[0] - 1): data_list[keys][:, ii * receptive_field:(ii + 1) * receptive_field] = pose_pred[ii] data_list[keys][:, -receptive_field:] = pose_pred[-1] data_list[keys] = data_list[keys].transpose(3, 2, 1, 0) return data_list def cam_mm_to_pix(cam, cam_data): # w, h, ss_x, ss_y mx = cam_data[0] / cam_data[2] my = cam_data[1] / cam_data[3] cam[0] = cam[0] * mx cam[1] = cam[1] * my cam[2] = cam[2] * mx + cam_data[0]/2 cam[3] = cam[3] * my + cam_data[1]/2 return cam ################### # Training start if not args.evaluate: #cameras_train, poses_train, poses_train_2d = fetch(subjects_train, action_filter, subset=args.subset) lr = args.learning_rate optimizer = optim.AdamW(model_pos_train.parameters(), lr=lr, weight_decay=0.1) lr_decay = args.lr_decay losses_3d_train = [] losses_3d_pos_train = [] losses_3d_diff_train = [] losses_3d_train_eval = [] losses_3d_valid = [] losses_3d_depth_valid = [] epoch = 0 best_epoch = 0 initial_momentum = 0.1 final_momentum = 0.001 # get training data train_generator = ChunkedGenerator_Seq(args.batch_size//args.stride, None, out_poses_3d_train, out_poses_2d_train, args.number_of_frames, pad=pad, causal_shift=causal_shift, shuffle=True, augment=args.data_augmentation, kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, joints_right=joints_right) train_generator_eval = UnchunkedGenerator_Seq(None, out_poses_3d_train, out_poses_2d_train, pad=pad, causal_shift=causal_shift, augment=False) print('INFO: Training on {} frames'.format(train_generator_eval.num_frames())) if not args.nolog: writer.add_text(args.log+'_'+TIMESTAMP + '/Training Frames', str(train_generator_eval.num_frames())) if args.resume: epoch = checkpoint['epoch'] if 'optimizer' in checkpoint and checkpoint['optimizer'] is not None: optimizer.load_state_dict(checkpoint['optimizer']) train_generator.set_random_state(checkpoint['random_state']) else: print('WARNING: this checkpoint does not contain an optimizer state. The optimizer will be reinitialized.') if not args.coverlr: lr = checkpoint['lr'] print('** Note: reported losses are averaged over all frames.') print('** The final evaluation will be carried out after the last training epoch.') # Pos model only while epoch < args.epochs: start_time = time() epoch_loss_3d_train = 0 epoch_loss_3d_pos_train = 0 epoch_loss_3d_diff_train = 0 epoch_loss_traj_train = 0 epoch_loss_2d_train_unlabeled = 0 N = 0 N_semi = 0 model_pos_train.train() iteration = 0 num_batches = train_generator.batch_num() # Just train 1 time, for quick debug quickdebug=args.debug for cameras_train, batch_3d, batch_2d in train_generator.next_epoch(): # if notrain:break # notrain=True if iteration % 1000 == 0: print("%d/%d"% (iteration, num_batches)) if cameras_train is not None: cameras_train = torch.from_numpy(cameras_train.astype('float32')) inputs_3d = torch.from_numpy(batch_3d.astype('float32')) inputs_2d = torch.from_numpy(batch_2d.astype('float32')) if torch.cuda.is_available(): inputs_3d = inputs_3d.cuda() inputs_2d = inputs_2d.cuda() if cameras_train is not None: cameras_train = cameras_train.cuda() inputs_traj = inputs_3d[:, :, 14:15].clone() inputs_3d[:, :, 14] = 0 optimizer.zero_grad() # Predict 3D poses predicted_3d_pos = model_pos_train(inputs_2d, inputs_3d) loss_3d_pos = mpjpe(predicted_3d_pos, inputs_3d) # get bone length inputs_3d_length = getbonelength(inputs_3d, boneindex).mean(1) predicted_3d_length = getbonelength(predicted_3d_pos, boneindex).mean(1) loss_length = args.wl*torch.pow(inputs_3d_length - predicted_3d_length,2).mean() # get bone dir inputs_3d_bonedir = getbonedirect(inputs_3d, boneindex) predicted_bonedir = getbonedirect(predicted_3d_pos, boneindex) loss_dir = args.wd*torch.pow(inputs_3d_bonedir - predicted_bonedir,2).sum(3).mean() loss_total = loss_3d_pos + loss_length + loss_dir loss_total.backward(loss_total.clone().detach()) loss_total = torch.mean(loss_total) epoch_loss_3d_train += inputs_3d.shape[0] * inputs_3d.shape[1] * loss_total.item() epoch_loss_3d_pos_train += inputs_3d.shape[0] * inputs_3d.shape[1] * loss_3d_pos.item() N += inputs_3d.shape[0] * inputs_3d.shape[1] optimizer.step() iteration += 1 if quickdebug: if N==inputs_3d.shape[0] * inputs_3d.shape[1]: break losses_3d_train.append(epoch_loss_3d_train / N) losses_3d_pos_train.append(epoch_loss_3d_pos_train / N) # End-of-epoch evaluation with torch.no_grad(): model_pos_test_temp.load_state_dict(model_pos_train.state_dict(), strict=False) model_pos_test_temp.eval() epoch_loss_3d_valid = None epoch_loss_3d_depth_valid = 0 epoch_loss_traj_valid = 0 epoch_loss_2d_valid = 0 epoch_loss_3d_vel = 0 N = 0 iteration = 0 if not args.no_eval: # Evaluate on test set for cam, batch, batch_2d, batch_valid, _ in test_generator.next_epoch(): inputs_3d = torch.from_numpy(batch.astype('float32')) inputs_2d = torch.from_numpy(batch_2d.astype('float32')) inputs_valid = torch.from_numpy(batch_valid.astype('float32')) ##### apply test-time-augmentation (following Videopose3d) inputs_2d_flip = inputs_2d.clone() inputs_2d_flip[:, :, :, 0] *= -1 inputs_2d_flip[:, :, kps_left + kps_right, :] = inputs_2d_flip[:, :, kps_right + kps_left, :] ##### convert size inputs_3d_p = inputs_3d inputs_2d, inputs_3d, valid_frame = eval_data_prepare(receptive_field, inputs_2d, inputs_3d_p, inputs_valid) inputs_2d_flip, _, _ = eval_data_prepare(receptive_field, inputs_2d_flip, inputs_3d_p, inputs_valid) if torch.cuda.is_available(): inputs_3d = inputs_3d.cuda() inputs_2d = inputs_2d.cuda() inputs_2d_flip = inputs_2d_flip.cuda() inputs_3d[:, :, 14] = 0 bs = 4 total_batch = (inputs_3d.shape[0] + bs - 1) // bs for batch_cnt in range(total_batch): if (batch_cnt + 1) * bs > inputs_3d.shape[0]: inputs_2d_single = inputs_2d[batch_cnt * bs:] inputs_2d_flip_single = inputs_2d_flip[batch_cnt * bs:] inputs_3d_single = inputs_3d[batch_cnt * bs:] valid_frame_single = valid_frame[batch_cnt * bs:] else: inputs_2d_single = inputs_2d[batch_cnt * bs:(batch_cnt + 1) * bs] inputs_2d_flip_single = inputs_2d_flip[batch_cnt * bs:(batch_cnt + 1) * bs] inputs_3d_single = inputs_3d[batch_cnt * bs:(batch_cnt + 1) * bs] valid_frame_single = valid_frame[batch_cnt * bs:(batch_cnt + 1) * bs] predicted_3d_pos_single = model_pos_test_temp(inputs_2d_single, inputs_3d_single, input_2d_flip=inputs_2d_flip_single) # b, t, h, f, j, c predicted_3d_pos_single[:, :, :, :, 14] = 0 error = mpjpe_diffusion_3dhp(predicted_3d_pos_single, inputs_3d_single, valid_frame_single.type(torch.bool)) if iteration == 0: epoch_loss_3d_valid = inputs_3d_single.shape[0] * inputs_3d_single.shape[1] * error.clone() else: epoch_loss_3d_valid += inputs_3d_single.shape[0] * inputs_3d_single.shape[1] * error.clone() N += inputs_3d_single.shape[0] * inputs_3d_single.shape[1] iteration += 1 if quickdebug: if N == inputs_3d_single.shape[0] * inputs_3d_single.shape[1]: break if quickdebug: if N == inputs_3d_single.shape[0] * inputs_3d_single.shape[1]: break losses_3d_valid.append(epoch_loss_3d_valid / N) elapsed = (time() - start_time) / 60 if args.no_eval: print('[%d] time %.2f lr %f 3d_train %f 3d_pos_train %f 3d_diff_train %f' % ( epoch + 1, elapsed, lr, losses_3d_train[-1] * 1000, losses_3d_pos_train[-1] * 1000, losses_3d_diff_train[-1] * 1000 )) log_path = os.path.join(args.checkpoint, 'training_log.txt') f = open(log_path, mode='a') f.write('[%d] time %.2f lr %f 3d_train %f 3d_pos_train %f 3d_diff_train %f\n' % ( epoch + 1, elapsed, lr, losses_3d_train[-1] * 1000, losses_3d_pos_train[-1] * 1000, losses_3d_diff_train[-1] * 1000 )) f.close() else: print('[%d] time %.2f lr %f 3d_train %f 3d_pos_train %f 3d_pos_valid %f' % ( epoch + 1, elapsed, lr, losses_3d_train[-1], losses_3d_pos_train[-1], losses_3d_valid[-1][0] )) log_path = os.path.join(args.checkpoint, 'training_log.txt') f = open(log_path, mode='a') f.write('[%d] time %.2f lr %f 3d_train %f 3d_pos_train %f 3d_pos_valid %f\n' % ( epoch + 1, elapsed, lr, losses_3d_train[-1], losses_3d_pos_train[-1], losses_3d_valid[-1][0] )) f.close() if not args.nolog: #writer.add_scalar("Loss/3d training eval loss", losses_3d_train_eval[-1] * 1000, epoch+1) writer.add_scalar("Loss/3d validation loss", losses_3d_valid[-1] * 1000, epoch+1) if not args.nolog: writer.add_scalar("Loss/3d training loss", losses_3d_train[-1] * 1000, epoch+1) writer.add_scalar("Parameters/learing rate", lr, epoch+1) writer.add_scalar('Parameters/training time per epoch', elapsed, epoch+1) # Decay learning rate exponentially lr *= lr_decay for param_group in optimizer.param_groups: param_group['lr'] *= lr_decay epoch += 1 # Decay BatchNorm momentum # momentum = initial_momentum * np.exp(-epoch/args.epochs * np.log(initial_momentum/final_momentum)) # model_pos_train.set_bn_momentum(momentum) # Save checkpoint if necessary if epoch % args.checkpoint_frequency == 0: chk_path = os.path.join(args.checkpoint, 'epoch_{}.bin'.format(epoch)) print('Saving checkpoint to', chk_path) torch.save({ 'epoch': epoch, 'lr': lr, 'random_state': train_generator.random_state(), 'optimizer': optimizer.state_dict(), 'model_pos': model_pos_train.state_dict(), # 'min_loss': min_loss # 'model_traj': model_traj_train.state_dict() if semi_supervised else None, # 'random_state_semi': semi_generator.random_state() if semi_supervised else None, }, chk_path) #### save best checkpoint best_chk_path = os.path.join(args.checkpoint, 'best_epoch.bin') # min_loss = 41.65 if losses_3d_valid[-1][0] < min_loss: min_loss = losses_3d_valid[-1] best_epoch = epoch print("save best checkpoint") torch.save({ 'epoch': epoch, 'lr': lr, 'random_state': train_generator.random_state(), 'optimizer': optimizer.state_dict(), 'model_pos': model_pos_train.state_dict(), # 'model_traj': model_traj_train.state_dict() if semi_supervised else None, # 'random_state_semi': semi_generator.random_state() if semi_supervised else None, }, best_chk_path) f = open(log_path, mode='a') f.write('best epoch\n') f.close() # Save training curves after every epoch, as .png images (if requested) if args.export_training_curves and epoch > 3: if 'matplotlib' not in sys.modules: import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt plt.figure() epoch_x = np.arange(3, len(losses_3d_train)) + 1 plt.plot(epoch_x, losses_3d_train[3:], '--', color='C0') plt.plot(epoch_x, losses_3d_train_eval[3:], color='C0') plt.plot(epoch_x, losses_3d_valid[3:], color='C1') plt.legend(['3d train', '3d train (eval)', '3d valid (eval)']) plt.ylabel('MPJPE (m)') plt.xlabel('Epoch') plt.xlim((3, epoch)) plt.savefig(os.path.join(args.checkpoint, 'loss_3d.png')) plt.close('all') # Training end # Evaluate def evaluate(test_generator, action=None, return_predictions=False, use_trajectory_model=False, newmodel=None): epoch_loss_3d_pos = torch.zeros(args.sampling_timesteps).cuda() epoch_loss_3d_pos_mean = torch.zeros(args.sampling_timesteps).cuda() with torch.no_grad(): if newmodel is not None: print('Loading comparison model') model_eval = newmodel chk_file_path = '/mnt/data3/home/zjl/workspace/3dpose/PoseFormer/checkpoint/train_pf_00/epoch_60.bin' print('Loading evaluate checkpoint of comparison model', chk_file_path) checkpoint = torch.load(chk_file_path, map_location=lambda storage, loc: storage) model_eval.load_state_dict(checkpoint['model_pos'], strict=False) model_eval.eval() else: model_eval = model_pos if not use_trajectory_model: # load best checkpoint if args.evaluate == '': chk_file_path = os.path.join(args.checkpoint, 'best_epoch_%d_%.2f.bin' % (best_epoch, min_loss)) print('Loading best checkpoint', chk_file_path) elif args.evaluate != '': chk_file_path = os.path.join(args.checkpoint, args.evaluate) print('Loading evaluate checkpoint', chk_file_path) checkpoint = torch.load(chk_file_path, map_location=lambda storage, loc: storage) print('This model was trained for {} epochs'.format(checkpoint['epoch'])) # model_pos_train.load_state_dict(checkpoint['model_pos'], strict=False) model_eval.load_state_dict(checkpoint['model_pos']) model_eval.eval() # else: # model_traj.eval() N = 0 iteration = 0 data_inference_all = {} data_inference_mean = {} data_inference_h_min = {} data_inference_joint_min = {} data_inference_reproj_min = {} cam_1 = torch.tensor([7.32506, 7.32506, -0.0322884, 0.0929296, 0, 0, 0, 0, 0]) cam_data_1 = [2048, 2048, 10, 10] #width, height, sensorSize_x, sensorSize_y cam_2 = torch.tensor([8.770747185, 8.770747185, -0.104908645, 0.104899704, 0, 0, 0, 0, 0]) cam_data_2 = [1920, 1080, 10, 5.625] # width, height, sensorSize_x, sensorSize_y cam_1 = cam_mm_to_pix(cam_1, cam_data_1) cam_2 = cam_mm_to_pix(cam_2, cam_data_2) #cam_2 = torch.tensor([8.770747185, 8.770747185, -0.104908645, 0.104899704, -0.276859611, 0.131125256, -0.000360494, -0.001149441, -0.049318332]) #cam_2 = torch.tensor([8.770747185, 8.770747185, -0.104908645, 0.104899704, -0.276859611, 0.131125256, -0.049318332, -0.000360494, -0.001149441]) #num_batches = test_generator.batch_num() quickdebug=args.debug for _, batch, batch_2d, batch_valid, keys in test_generator.next_epoch(): # if keys != "TS5": # continue inputs_2d = torch.from_numpy(batch_2d.astype('float32')) inputs_3d = torch.from_numpy(batch.astype('float32')) inputs_valid = torch.from_numpy(batch_valid.astype('float32')) _, f_sz, j_sz, c_sz = inputs_3d.shape data_inference_all[keys] = np.zeros((args.sampling_timesteps, args.num_proposals, f_sz, j_sz, c_sz)) data_inference_mean[keys] = np.zeros((args.sampling_timesteps, f_sz, j_sz, c_sz)) data_inference_h_min[keys] = np.zeros((args.sampling_timesteps, f_sz, j_sz, c_sz)) data_inference_joint_min[keys] = np.zeros((args.sampling_timesteps, f_sz, j_sz, c_sz)) data_inference_reproj_min[keys] = np.zeros((args.sampling_timesteps, f_sz, j_sz, c_sz)) print(keys) ##### apply test-time-augmentation (following Videopose3d) inputs_2d_flip = inputs_2d.clone() inputs_2d_flip [:, :, :, 0] *= -1 inputs_2d_flip[:, :, kps_left + kps_right,:] = inputs_2d_flip[:, :, kps_right + kps_left,:] ##### convert size inputs_3d_p = inputs_3d if newmodel is not None: def eval_data_prepare_pf(receptive_field, inputs_2d, inputs_3d): inputs_2d_p = torch.squeeze(inputs_2d) inputs_3d_p = inputs_3d.permute(1,0,2,3) padding = int(receptive_field//2) inputs_2d_p = rearrange(inputs_2d_p, 'b f c -> f c b') inputs_2d_p = F.pad(inputs_2d_p, (padding,padding), mode='replicate') inputs_2d_p = rearrange(inputs_2d_p, 'f c b -> b f c') out_num = inputs_2d_p.shape[0] - receptive_field + 1 eval_input_2d = torch.empty(out_num, receptive_field, inputs_2d_p.shape[1], inputs_2d_p.shape[2]) for i in range(out_num): eval_input_2d[i,:,:,:] = inputs_2d_p[i:i+receptive_field, :, :] return eval_input_2d, inputs_3d_p inputs_2d, inputs_3d = eval_data_prepare_pf(81, inputs_2d, inputs_3d_p) inputs_2d_flip, _ = eval_data_prepare_pf(81, inputs_2d_flip, inputs_3d_p) else: inputs_2d, inputs_3d, valid_frame = eval_data_prepare(receptive_field, inputs_2d, inputs_3d_p, inputs_valid) inputs_2d_flip, _, _ = eval_data_prepare(receptive_field, inputs_2d_flip, inputs_3d_p, inputs_valid) if torch.cuda.is_available(): inputs_2d = inputs_2d.cuda() inputs_2d_flip = inputs_2d_flip.cuda() inputs_3d = inputs_3d.cuda() bs = 2 total_batch = (inputs_3d.shape[0] + bs - 1) // bs for batch_cnt in range(total_batch): if (batch_cnt + 1) * bs > inputs_3d.shape[0]: inputs_2d_single = inputs_2d[batch_cnt * bs:] inputs_2d_flip_single = inputs_2d_flip[batch_cnt * bs:] inputs_3d_single = inputs_3d[batch_cnt * bs:] valid_frame_single = valid_frame[batch_cnt * bs:] else: inputs_2d_single = inputs_2d[batch_cnt * bs:(batch_cnt+1) * bs] inputs_2d_flip_single = inputs_2d_flip[batch_cnt * bs:(batch_cnt+1) * bs] inputs_3d_single = inputs_3d[batch_cnt * bs:(batch_cnt+1) * bs] valid_frame_single = valid_frame[batch_cnt * bs:(batch_cnt + 1) * bs] traj = inputs_3d_single[:, :, 14:15].clone() inputs_3d_single[:, :, 14] = 0 predicted_3d_pos_single = model_eval(inputs_2d_single, inputs_3d_single, input_2d_flip=inputs_2d_flip_single) #b, t, h, f, j, c predicted_3d_pos_single[:, :, :, :, 14] = 0 # P-Agg mean_pose = torch.mean(predicted_3d_pos_single, dim=2, keepdim=False) # P-Best b,t,h,f,_,_ = predicted_3d_pos_single.shape #b, t, h, f, n, c target = inputs_3d_single.unsqueeze(1).unsqueeze(1).repeat(1, t, h, 1, 1, 1) errors = torch.norm(predicted_3d_pos_single - target, dim=len(target.shape) - 1) from einops import rearrange # errors = rearrange(errors, 't b h f n -> t h b f n', ).reshape(t, h, -1) errors_h = rearrange(errors, 'b t h f n -> t h b f n', ).reshape(t, h, -1) errors_h = torch.mean(errors_h, dim=-1, keepdim=True) h_min_indices = torch.min(errors_h, dim=1, keepdim=True).indices #t,1,1 h_min_indices = h_min_indices.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).repeat(b, 1, 1, f, j_sz, c_sz) h_min_pose = torch.gather(predicted_3d_pos_single, 2, h_min_indices).squeeze(2) # J-Best joint_min_indices = torch.min(errors, dim=2, keepdim=True).indices # b, t, 1, f, n joint_min_indices = joint_min_indices.unsqueeze(-1).repeat(1, 1, 1, 1, 1, c_sz) joint_min_pose = torch.gather(predicted_3d_pos_single, 2, joint_min_indices).squeeze(2) # J-Agg inputs_traj_single_all = traj.unsqueeze(1).unsqueeze(1).repeat(1, t, h, 1, 1, 1) predicted_3d_pos_abs_single = predicted_3d_pos_single + inputs_traj_single_all #predicted_3d_pos_abs_single = predicted_3d_pos_abs_single/1000 predicted_3d_pos_abs_single = predicted_3d_pos_abs_single.reshape(b * t * h * f, j_sz, c_sz) if keys == "TS5" or keys == "TS6": cam = cam_2.clone() cam_data = cam_data_2.copy() reproject_func = project_to_2d else: cam = cam_1.clone() cam_data = cam_data_1.copy() reproject_func = project_to_2d_linear # J-Agg: all hypotheses reprojection cam_single_all = cam.unsqueeze(0).repeat(b * t * h * f, 1).cuda() reproj_2d = reproject_func(predicted_3d_pos_abs_single, cam_single_all) reproj_2d = reproj_2d.reshape(b, t, h, f, j_sz, 2) #reproj_2d[..., :2] = torch.from_numpy(normalize_screen_coordinates(reproj_2d[..., :2].cpu().numpy(), w=cam_data[0],h=cam_data[1])).cuda() # J-Agg: gt reprojection cam_single_gt = cam.unsqueeze(0).repeat(b * f, 1).cuda() input_3d_reproj = inputs_3d_single + traj # input_3d_reproj = input_3d_reproj / 1000 input_3d_reproj = input_3d_reproj.reshape(b * f, 17, 3) reproj_2d_gt = reproject_func(input_3d_reproj, cam_single_gt) reproj_2d_gt = reproj_2d_gt.reshape(b, f, 17, 2) #reproj_2d_gt[..., :2] = torch.from_numpy(normalize_screen_coordinates(reproj_2d_gt[..., :2].cpu().numpy(), w=cam_data[0], h=cam_data[1])).cuda() target_2d = torch.from_numpy(image_coordinates(inputs_2d_single[..., :2].cpu().numpy(), w=cam_data[0], h=cam_data[1])).cuda() target_2d = target_2d.unsqueeze(1).unsqueeze(1).repeat(1, t, h, 1, 1, 1) errors_2d = torch.norm(reproj_2d - target_2d, dim=len(target_2d.shape) - 1) # b, t, h, f, n reproj_min_indices = torch.min(errors_2d, dim=2, keepdim=True).indices # b,t,1,f,n reproj_min_indices = reproj_min_indices.unsqueeze(-1).repeat(1, 1, 1, 1, 1, c_sz) reproj_min_pose = torch.gather(predicted_3d_pos_single, 2, reproj_min_indices).squeeze(2) if batch_cnt == 0: all_3d_pose_pred = predicted_3d_pos_single.cpu().numpy() mean_3d_pose_pred = mean_pose.cpu().numpy() h_min_3d_pose_pred = h_min_pose.cpu().numpy() joint_min_3d_pose_pred = joint_min_pose.cpu().numpy() reproj_min_3d_pose_pred = reproj_min_pose.cpu().numpy() else: all_3d_pose_pred = np.concatenate((all_3d_pose_pred, predicted_3d_pos_single.cpu().numpy()), axis=0) mean_3d_pose_pred = np.concatenate((mean_3d_pose_pred, mean_pose.cpu().numpy()), axis=0) h_min_3d_pose_pred = np.concatenate((h_min_3d_pose_pred, h_min_pose.cpu().numpy()), axis=0) joint_min_3d_pose_pred = np.concatenate((joint_min_3d_pose_pred, joint_min_pose.cpu().numpy()), axis=0) reproj_min_3d_pose_pred = np.concatenate((reproj_min_3d_pose_pred, reproj_min_pose.cpu().numpy()), axis=0) # predicted_3d_pos = torch.mean(torch.cat((predicted_3d_pos, predicted_3d_pos_flip), dim=1), dim=1, keepdim=True) if return_predictions: return predicted_3d_pos_single.squeeze().cpu().numpy() #error = mpjpe(predicted_3d_pos, inputs_3d) error = mpjpe_diffusion_3dhp(predicted_3d_pos_single, inputs_3d_single, valid_frame_single.type(torch.bool)) error_mean = mpjpe_diffusion_3dhp(predicted_3d_pos_single, inputs_3d_single, valid_frame_single.type(torch.bool), mean_pos=True) epoch_loss_3d_pos += inputs_3d_single.shape[0] * inputs_3d_single.shape[1] * error.clone() epoch_loss_3d_pos_mean += inputs_3d_single.shape[0] * inputs_3d_single.shape[1] * error_mean.clone() N += inputs_3d_single.shape[0] * inputs_3d_single.shape[1] for ii in range(all_3d_pose_pred.shape[0]-1): data_inference_all[keys][:, :, ii*receptive_field:(ii+1)*receptive_field] = all_3d_pose_pred[ii] data_inference_all[keys][:, :, -receptive_field:] = all_3d_pose_pred[-1] data_inference_all[keys] = data_inference_all[keys].transpose(4,3,2,1,0) data_inference_mean = pose_post_process(mean_3d_pose_pred, data_inference_mean, keys, receptive_field) data_inference_h_min = pose_post_process(h_min_3d_pose_pred, data_inference_h_min, keys, receptive_field) data_inference_joint_min = pose_post_process(joint_min_3d_pose_pred, data_inference_joint_min, keys, receptive_field) data_inference_reproj_min = pose_post_process(reproj_min_3d_pose_pred, data_inference_reproj_min, keys, receptive_field) log_path = os.path.join(args.checkpoint, '3dhp_test_log_H%d_K%d.txt' %(args.num_proposals, args.sampling_timesteps)) f = open(log_path, mode='a') if keys is None: print('----------') else: print('----'+keys+'----') f.write('----'+keys+'----\n') e1 = (epoch_loss_3d_pos / N) e1_mean = (epoch_loss_3d_pos_mean / N) print('Test time augmentation:', test_generator.augment_enabled()) for ii in range(e1.shape[0]): print('step %d : Protocol #1 Error (MPJPE) P_Best:' % ii, e1[ii].item(), 'mm') f.write('step %d : Protocol #1 Error (MPJPE) P_Best: %f mm\n' % (ii, e1[ii].item())) print('step %d : Protocol #1 Error (MPJPE) P_Agg:' % ii, e1_mean[ii].item(), 'mm') f.write('step %d : Protocol #1 Error (MPJPE) P_Agg: %f mm\n' % (ii, e1_mean[ii].item())) print('----------') f.write('----------\n') f.close() if quickdebug: break # mat_path_all = os.path.join(args.checkpoint, 'inference_data_all.mat') # scio.savemat(mat_path_all, data_inference_all) mat_path_mean = os.path.join(args.checkpoint, 'inference_data_P_Agg.mat') scio.savemat(mat_path_mean, data_inference_mean) mat_path_h_min = os.path.join(args.checkpoint, 'inference_data_P_Best.mat') scio.savemat(mat_path_h_min, data_inference_h_min) mat_path_joint_min = os.path.join(args.checkpoint, 'inference_data_J_Best.mat') scio.savemat(mat_path_joint_min, data_inference_joint_min) mat_path_reproj_min = os.path.join(args.checkpoint, 'inference_data_J_Agg.mat') scio.savemat(mat_path_reproj_min, data_inference_reproj_min) return e1, e1_mean if args.render: print('Rendering...') gen = UnchunkedGenerator_Seq(None, out_poses_3d_test, out_poses_2d_test, pad=pad, causal_shift=causal_shift, augment=args.test_time_augmentation, kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, joints_right=joints_right, valid_frame=valid_frame) ground_truth = out_poses_3d_test prediction, temp = evaluate(gen, return_predictions=True) b_sz, t_sz, h_sz, f_sz, j_sz, c_sz = prediction.shape prediction2 = np.empty((t_sz, h_sz, ground_truth.shape[0], 17, 3)).astype(np.float32) ### reshape prediction as ground truth if ground_truth.shape[0] / receptive_field > ground_truth.shape[0] // receptive_field: batch_num = (ground_truth.shape[0] // receptive_field) +1 for i in range(batch_num-1): prediction2[:, :, i*receptive_field:(i+1)*receptive_field,:,:] = prediction[i,:,:,:,:,:] left_frames = ground_truth.shape[0] - (batch_num-1)*receptive_field prediction2[:, :, -left_frames:,:,:] = prediction[-1,:, :, -left_frames:,:,:] #prediction = prediction2 elif ground_truth.shape[0] / receptive_field == ground_truth.shape[0] // receptive_field: batch_num = (ground_truth.shape[0] // receptive_field) for i in range(batch_num): prediction2[:, :, i * receptive_field:(i + 1) * receptive_field, :, :] = prediction[i, :, :, :, :, :] if args.viz_export is not None: print('Exporting joint positions to', args.viz_export) # Predictions are in camera space np.save(args.viz_export, prediction) f_all_sz = ground_truth.shape[0] if ground_truth is not None: # Reapply trajectory trajectory = ground_truth[:, :1] ground_truth[:, 1:] += trajectory trajectory = trajectory.reshape(1, 1, f_all_sz, 1, 3) prediction2 += trajectory if args.compare: prediction_pf += trajectory # Invert camera transformation cam = dataset.cameras()[args.viz_subject][args.viz_camera] if ground_truth is not None: if args.compare: prediction_pf = camera_to_world(prediction_pf, R=cam['orientation'], t=cam['translation']) aa = prediction2[0,0] bb = camera_to_world(aa, R=cam['orientation'], t=cam['translation']) prediction2_world = camera_to_world(prediction2, R=cam['orientation'], t=cam['translation']) ground_truth = camera_to_world(ground_truth, R=cam['orientation'], t=cam['translation']) # dir_loc_world = camera_to_world(dir_loc, R=cam['orientation'], t=cam['translation']) else: # If the ground truth is not available, take the camera extrinsic params from a random subject. # They are almost the same, and anyway, we only need this for visualization purposes. for subject in dataset.cameras(): if 'orientation' in dataset.cameras()[subject][args.viz_camera]: rot = dataset.cameras()[subject][args.viz_camera]['orientation'] break if args.compare: prediction_pf = camera_to_world(prediction_pf, R=rot, t=0) prediction_pf[:, :, 2] -= np.min(prediction_pf[:, :, 2]) prediction = camera_to_world(prediction, R=rot, t=0) # We don't have the trajectory, but at least we can rebase the height prediction[:, :, 2] -= np.min(prediction[:, :, 2]) prediction2_reshape = prediction2.reshape(t_sz*h_sz*prediction2.shape[2], j_sz, 3) #cam_intri = cam['intrinsic'][None].repeat(prediction2_reshape.shape[0],1) prediction2_reshape = torch.from_numpy(prediction2_reshape) cam_intri = torch.from_numpy(np.repeat(cam['intrinsic'][None], prediction2_reshape.shape[0], axis=0)) poses_2d_reproj = project_to_2d(prediction2_reshape, cam_intri) poses_2d_reproj = poses_2d_reproj.reshape(t_sz, h_sz, prediction2.shape[2], j_sz, 2) input_keypoints_pix = image_coordinates(input_keypoints[..., :2], w=cam['res_w'], h=cam['res_h']) poses_2d_reproj_pix = image_coordinates(poses_2d_reproj[..., :2].numpy(), w=cam['res_w'], h=cam['res_h']) parents = dataset.skeleton().parents() # data_x = data_all[:,2,0] # data_y = data_all[:,2,1] # data_x_re = poses_2d_reproj[:,2,0] # data_y_re = poses_2d_reproj[:,2,1] plt_path = './plot/h36m/' if not os.path.isdir(plt_path): os.makedirs(plt_path) import matplotlib.pyplot as plt import matplotlib.image as mpimg from common.visualization import draw_3d_image, draw_3d_image_select, draw_2d_image draw_3d_image_select(prediction2_world, gt_noise, gt_noise_flip, dataset.skeleton(), cam['azimuth'], args.viz_subject, args.viz_action, args.viz_camera, input_keypoints, poses_2d_reproj,) else: print('Evaluating...') all_actions = {} all_actions_flatten = [] all_actions_by_subject = {} def run_evaluation_all_actions(actions, action_filter=None): errors_p1 = [] errors_p1_mean = [] #poses_act, poses_2d_act = fetch_actions(actions) gen = UnchunkedGenerator_Seq(None, out_poses_3d_test, out_poses_2d_test, pad=pad, causal_shift=causal_shift, augment=args.test_time_augmentation, kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, joints_right=joints_right, valid_frame=valid_frame) #e1, e2, e3, ev = evaluate(gen) e1, e1_mean = evaluate(gen) errors_p1.append(e1) errors_p1_mean.append(e1_mean) if not args.by_subject: #run_evaluation(all_actions, action_filter) run_evaluation_all_actions(all_actions_flatten, action_filter) if not args.nolog: writer.close()