Upload 17 files
Browse files- common/README.MD +1 -0
- common/__pycache__/camera.cpython-311.pyc +0 -0
- common/__pycache__/model_poseformer.cpython-311.pyc +0 -0
- common/__pycache__/quaternion.cpython-311.pyc +0 -0
- common/__pycache__/utils.cpython-311.pyc +0 -0
- common/arguments.py +96 -0
- common/camera.py +90 -0
- common/custom_dataset.py +66 -0
- common/generators.py +261 -0
- common/h36m_dataset.py +263 -0
- common/loss.py +108 -0
- common/mocap_dataset.py +44 -0
- common/model_poseformer.py +242 -0
- common/quaternion.py +35 -0
- common/skeleton.py +88 -0
- common/utils.py +80 -0
- common/visualization.py +216 -0
common/README.MD
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
common/__pycache__/camera.cpython-311.pyc
ADDED
|
Binary file (5.11 kB). View file
|
|
|
common/__pycache__/model_poseformer.cpython-311.pyc
ADDED
|
Binary file (17.8 kB). View file
|
|
|
common/__pycache__/quaternion.cpython-311.pyc
ADDED
|
Binary file (1.69 kB). View file
|
|
|
common/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (3.78 kB). View file
|
|
|
common/arguments.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2018-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# Modified by Qitao Zhao (qitaozhao@mail.sdu.edu.cn)
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
|
| 10 |
+
def parse_args():
|
| 11 |
+
parser = argparse.ArgumentParser(description='Training script')
|
| 12 |
+
|
| 13 |
+
# General arguments
|
| 14 |
+
parser.add_argument('-d', '--dataset', default='h36m', type=str, metavar='NAME', help='target dataset') # h36m or humaneva
|
| 15 |
+
parser.add_argument('-k', '--keypoints', default='cpn_ft_h36m_dbb', type=str, metavar='NAME', help='2D detections to use')
|
| 16 |
+
parser.add_argument('-str', '--subjects-train', default='S1,S5,S6,S7,S8', type=str, metavar='LIST',
|
| 17 |
+
help='training subjects separated by comma')
|
| 18 |
+
parser.add_argument('-ste', '--subjects-test', default='S9,S11', type=str, metavar='LIST', help='test subjects separated by comma')
|
| 19 |
+
parser.add_argument('-sun', '--subjects-unlabeled', default='', type=str, metavar='LIST',
|
| 20 |
+
help='unlabeled subjects separated by comma for self-supervision')
|
| 21 |
+
parser.add_argument('-a', '--actions', default='*', type=str, metavar='LIST',
|
| 22 |
+
help='actions to train/test on, separated by comma, or * for all')
|
| 23 |
+
parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH',
|
| 24 |
+
help='checkpoint directory')
|
| 25 |
+
parser.add_argument('--checkpoint-frequency', default=40, type=int, metavar='N',
|
| 26 |
+
help='create a checkpoint every N epochs')
|
| 27 |
+
parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME',
|
| 28 |
+
help='checkpoint to resume (file name)')
|
| 29 |
+
parser.add_argument('--evaluate', default='', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)')
|
| 30 |
+
parser.add_argument('--render', action='store_true', help='visualize a particular video')
|
| 31 |
+
parser.add_argument('--by-subject', action='store_true', help='break down error by subject (on evaluation)')
|
| 32 |
+
parser.add_argument('--export-training-curves', action='store_true', help='save training curves as .png images')
|
| 33 |
+
parser.add_argument('-g', '--gpu', type=list, help='set gpu number')
|
| 34 |
+
parser.add_argument('--local_rank', type=int, default=0, help='node rank for distributed training')
|
| 35 |
+
parser.add_argument('--center-pose', type=int, default=0, help='choose fine-tuning task as 3d pose estimation')
|
| 36 |
+
|
| 37 |
+
# Model arguments
|
| 38 |
+
parser.add_argument('-s', '--stride', default=1, type=int, metavar='N', help='chunk size to use during training')
|
| 39 |
+
parser.add_argument('-e', '--epochs', default=200, type=int, metavar='N', help='number of training epochs')
|
| 40 |
+
parser.add_argument('-b', '--batch-size', default=1024, type=int, metavar='N', help='batch size in terms of predicted frames')
|
| 41 |
+
parser.add_argument('-drop', '--dropout', default=0., type=float, metavar='P', help='dropout probability')
|
| 42 |
+
parser.add_argument('-lr', '--learning-rate', default=0.0001, type=float, metavar='LR', help='initial learning rate')
|
| 43 |
+
parser.add_argument('-lrd', '--lr-decay', default=0.99, type=float, metavar='LR', help='learning rate decay per epoch')
|
| 44 |
+
parser.add_argument('-no-da', '--no-data-augmentation', dest='data_augmentation', action='store_false',
|
| 45 |
+
help='disable train-time flipping')
|
| 46 |
+
parser.add_argument('-frame', '--number-of-frames', default='81', type=int, metavar='N',
|
| 47 |
+
help='how many frames used as input')
|
| 48 |
+
parser.add_argument('-frame-kept', '--number-of-kept-frames', default='27', type=int, metavar='N',
|
| 49 |
+
help='how many frames are kept')
|
| 50 |
+
parser.add_argument('-coeff-kept', '--number-of-kept-coeffs', type=int, metavar='N', help='how many coefficients are kept')
|
| 51 |
+
parser.add_argument('--depth', default=4, type=int, metavar='N', help='number of transformer blocks')
|
| 52 |
+
parser.add_argument('--embed-dim-ratio', default=32, type=int, metavar='N', help='dimension of embedding ratio')
|
| 53 |
+
parser.add_argument('-std', type=float, default=0.0, help='the standard deviation for gaussian noise')
|
| 54 |
+
|
| 55 |
+
# Experimental
|
| 56 |
+
parser.add_argument('--subset', default=1, type=float, metavar='FRACTION', help='reduce dataset size by fraction')
|
| 57 |
+
parser.add_argument('--downsample', default=1, type=int, metavar='FACTOR', help='downsample frame rate by factor (semi-supervised)')
|
| 58 |
+
parser.add_argument('--warmup', default=1, type=int, metavar='N', help='warm-up epochs for semi-supervision')
|
| 59 |
+
parser.add_argument('--no-eval', action='store_true', help='disable epoch evaluation while training (small speed-up)')
|
| 60 |
+
parser.add_argument('--dense', action='store_true', help='use dense convolutions instead of dilated convolutions')
|
| 61 |
+
parser.add_argument('--disable-optimizations', action='store_true', help='disable optimized model for single-frame predictions')
|
| 62 |
+
parser.add_argument('--linear-projection', action='store_true', help='use only linear coefficients for semi-supervised projection')
|
| 63 |
+
parser.add_argument('--no-bone-length', action='store_false', dest='bone_length_term',
|
| 64 |
+
help='disable bone length term in semi-supervised settings')
|
| 65 |
+
parser.add_argument('--no-proj', action='store_true', help='disable projection for semi-supervised setting')
|
| 66 |
+
|
| 67 |
+
# Visualization
|
| 68 |
+
parser.add_argument('--viz-subject', type=str, metavar='STR', help='subject to render')
|
| 69 |
+
parser.add_argument('--viz-action', type=str, metavar='STR', help='action to render')
|
| 70 |
+
parser.add_argument('--viz-camera', type=int, default=0, metavar='N', help='camera to render')
|
| 71 |
+
parser.add_argument('--viz-video', type=str, metavar='PATH', help='path to input video')
|
| 72 |
+
parser.add_argument('--viz-skip', type=int, default=0, metavar='N', help='skip first N frames of input video')
|
| 73 |
+
parser.add_argument('--viz-output', type=str, metavar='PATH', help='output file name (.gif or .mp4)')
|
| 74 |
+
parser.add_argument('--viz-export', type=str, metavar='PATH', help='output file name for coordinates')
|
| 75 |
+
parser.add_argument('--viz-bitrate', type=int, default=3000, metavar='N', help='bitrate for mp4 videos')
|
| 76 |
+
parser.add_argument('--viz-no-ground-truth', action='store_true', help='do not show ground-truth poses')
|
| 77 |
+
parser.add_argument('--viz-limit', type=int, default=-1, metavar='N', help='only render first N frames')
|
| 78 |
+
parser.add_argument('--viz-downsample', type=int, default=1, metavar='N', help='downsample FPS by a factor N')
|
| 79 |
+
parser.add_argument('--viz-size', type=int, default=5, metavar='N', help='image size')
|
| 80 |
+
|
| 81 |
+
parser.set_defaults(bone_length_term=True)
|
| 82 |
+
parser.set_defaults(data_augmentation=True)
|
| 83 |
+
parser.set_defaults(test_time_augmentation=True)
|
| 84 |
+
# parser.set_defaults(test_time_augmentation=False)
|
| 85 |
+
|
| 86 |
+
args = parser.parse_args()
|
| 87 |
+
# Check invalid configuration
|
| 88 |
+
if args.resume and args.evaluate:
|
| 89 |
+
print('Invalid flags: --resume and --evaluate cannot be set at the same time')
|
| 90 |
+
exit()
|
| 91 |
+
|
| 92 |
+
if args.export_training_curves and args.no_eval:
|
| 93 |
+
print('Invalid flags: --export-training-curves and --no-eval cannot be set at the same time')
|
| 94 |
+
exit()
|
| 95 |
+
|
| 96 |
+
return args
|
common/camera.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2018-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from common.utils import wrap
|
| 12 |
+
from common.quaternion import qrot, qinverse
|
| 13 |
+
|
| 14 |
+
def normalize_screen_coordinates(X, w, h):
|
| 15 |
+
assert X.shape[-1] == 2
|
| 16 |
+
|
| 17 |
+
# Normalize so that [0, w] is mapped to [-1, 1], while preserving the aspect ratio
|
| 18 |
+
return X/w*2 - [1, h/w]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def image_coordinates(X, w, h):
|
| 22 |
+
assert X.shape[-1] == 2
|
| 23 |
+
|
| 24 |
+
# Reverse camera frame normalization
|
| 25 |
+
return (X + [1, h/w])*w/2
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def world_to_camera(X, R, t):
|
| 29 |
+
Rt = wrap(qinverse, R) # Invert rotation
|
| 30 |
+
return wrap(qrot, np.tile(Rt, (*X.shape[:-1], 1)), X - t) # Rotate and translate
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def camera_to_world(X, R, t):
|
| 34 |
+
return wrap(qrot, np.tile(R, (*X.shape[:-1], 1)), X) + t
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def project_to_2d(X, camera_params):
|
| 38 |
+
"""
|
| 39 |
+
Project 3D points to 2D using the Human3.6M camera projection function.
|
| 40 |
+
This is a differentiable and batched reimplementation of the original MATLAB script.
|
| 41 |
+
|
| 42 |
+
Arguments:
|
| 43 |
+
X -- 3D points in *camera space* to transform (N, *, 3)
|
| 44 |
+
camera_params -- intrinsic parameteres (N, 2+2+3+2=9)
|
| 45 |
+
"""
|
| 46 |
+
assert X.shape[-1] == 3
|
| 47 |
+
assert len(camera_params.shape) == 2
|
| 48 |
+
assert camera_params.shape[-1] == 9
|
| 49 |
+
assert X.shape[0] == camera_params.shape[0]
|
| 50 |
+
|
| 51 |
+
while len(camera_params.shape) < len(X.shape):
|
| 52 |
+
camera_params = camera_params.unsqueeze(1)
|
| 53 |
+
|
| 54 |
+
f = camera_params[..., :2]
|
| 55 |
+
c = camera_params[..., 2:4]
|
| 56 |
+
k = camera_params[..., 4:7]
|
| 57 |
+
p = camera_params[..., 7:]
|
| 58 |
+
|
| 59 |
+
XX = torch.clamp(X[..., :2] / X[..., 2:], min=-1, max=1)
|
| 60 |
+
r2 = torch.sum(XX[..., :2]**2, dim=len(XX.shape)-1, keepdim=True)
|
| 61 |
+
|
| 62 |
+
radial = 1 + torch.sum(k * torch.cat((r2, r2**2, r2**3), dim=len(r2.shape)-1), dim=len(r2.shape)-1, keepdim=True)
|
| 63 |
+
tan = torch.sum(p*XX, dim=len(XX.shape)-1, keepdim=True)
|
| 64 |
+
|
| 65 |
+
XXX = XX*(radial + tan) + p*r2
|
| 66 |
+
|
| 67 |
+
return f*XXX + c
|
| 68 |
+
|
| 69 |
+
def project_to_2d_linear(X, camera_params):
|
| 70 |
+
"""
|
| 71 |
+
Project 3D points to 2D using only linear parameters (focal length and principal point).
|
| 72 |
+
|
| 73 |
+
Arguments:
|
| 74 |
+
X -- 3D points in *camera space* to transform (N, *, 3)
|
| 75 |
+
camera_params -- intrinsic parameteres (N, 2+2+3+2=9)
|
| 76 |
+
"""
|
| 77 |
+
assert X.shape[-1] == 3
|
| 78 |
+
assert len(camera_params.shape) == 2
|
| 79 |
+
assert camera_params.shape[-1] == 9
|
| 80 |
+
assert X.shape[0] == camera_params.shape[0]
|
| 81 |
+
|
| 82 |
+
while len(camera_params.shape) < len(X.shape):
|
| 83 |
+
camera_params = camera_params.unsqueeze(1)
|
| 84 |
+
|
| 85 |
+
f = camera_params[..., :2]
|
| 86 |
+
c = camera_params[..., 2:4]
|
| 87 |
+
|
| 88 |
+
XX = torch.clamp(X[..., :2] / X[..., 2:], min=-1, max=1)
|
| 89 |
+
|
| 90 |
+
return f*XX + c
|
common/custom_dataset.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2018-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import copy
|
| 10 |
+
from common.skeleton import Skeleton
|
| 11 |
+
from common.mocap_dataset import MocapDataset
|
| 12 |
+
from common.camera import normalize_screen_coordinates, image_coordinates
|
| 13 |
+
from common.h36m_dataset import h36m_skeleton
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
custom_camera_params = {
|
| 17 |
+
'id': None,
|
| 18 |
+
'res_w': None, # Pulled from metadata
|
| 19 |
+
'res_h': None, # Pulled from metadata
|
| 20 |
+
|
| 21 |
+
# Dummy camera parameters (taken from Human3.6M), only for visualization purposes
|
| 22 |
+
'azimuth': 70, # Only used for visualization
|
| 23 |
+
'orientation': [0.1407056450843811, -0.1500701755285263, -0.755240797996521, 0.6223280429840088],
|
| 24 |
+
'translation': [1841.1070556640625, 4955.28466796875, 1563.4454345703125],
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
class CustomDataset(MocapDataset):
|
| 28 |
+
def __init__(self, detections_path, remove_static_joints=True):
|
| 29 |
+
super().__init__(fps=None, skeleton=h36m_skeleton)
|
| 30 |
+
|
| 31 |
+
# Load serialized dataset
|
| 32 |
+
data = np.load(detections_path, allow_pickle=True)
|
| 33 |
+
resolutions = data['metadata'].item()['video_metadata']
|
| 34 |
+
|
| 35 |
+
self._cameras = {}
|
| 36 |
+
self._data = {}
|
| 37 |
+
for video_name, res in resolutions.items():
|
| 38 |
+
cam = {}
|
| 39 |
+
cam.update(custom_camera_params)
|
| 40 |
+
cam['orientation'] = np.array(cam['orientation'], dtype='float32')
|
| 41 |
+
cam['translation'] = np.array(cam['translation'], dtype='float32')
|
| 42 |
+
cam['translation'] = cam['translation']/1000 # mm to meters
|
| 43 |
+
|
| 44 |
+
cam['id'] = video_name
|
| 45 |
+
cam['res_w'] = res['w']
|
| 46 |
+
cam['res_h'] = res['h']
|
| 47 |
+
|
| 48 |
+
self._cameras[video_name] = [cam]
|
| 49 |
+
|
| 50 |
+
self._data[video_name] = {
|
| 51 |
+
'custom': {
|
| 52 |
+
'cameras': cam
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
if remove_static_joints:
|
| 57 |
+
# Bring the skeleton to 17 joints instead of the original 32
|
| 58 |
+
self.remove_joints([4, 5, 9, 10, 11, 16, 20, 21, 22, 23, 24, 28, 29, 30, 31])
|
| 59 |
+
|
| 60 |
+
# Rewire shoulders to the correct parents
|
| 61 |
+
self._skeleton._parents[11] = 8
|
| 62 |
+
self._skeleton._parents[14] = 8
|
| 63 |
+
|
| 64 |
+
def supports_semi_supervised(self):
|
| 65 |
+
return False
|
| 66 |
+
|
common/generators.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2018-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
from itertools import zip_longest
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# def getbone(seq, boneindex):
|
| 13 |
+
# bs = np.shape(seq)[0]
|
| 14 |
+
# ss = np.shape(seq)[1]
|
| 15 |
+
# seq = np.reshape(seq,(bs*ss,-1,3))
|
| 16 |
+
# bone = []
|
| 17 |
+
# for index in boneindex:
|
| 18 |
+
# bone.append(seq[:,index[0]] - seq[:,index[1]])
|
| 19 |
+
# bone = np.stack(bone,1)
|
| 20 |
+
# bone = np.power(np.power(bone,2).sum(2),0.5)
|
| 21 |
+
# bone = np.reshape(bone, (bs,ss,np.shape(bone)[1]))
|
| 22 |
+
# return bone
|
| 23 |
+
|
| 24 |
+
class ChunkedGenerator:
|
| 25 |
+
"""
|
| 26 |
+
Batched data generator, used for training.
|
| 27 |
+
The sequences are split into equal-length chunks and padded as necessary.
|
| 28 |
+
|
| 29 |
+
Arguments:
|
| 30 |
+
batch_size -- the batch size to use for training
|
| 31 |
+
cameras -- list of cameras, one element for each video (optional, used for semi-supervised training)
|
| 32 |
+
poses_3d -- list of ground-truth 3D poses, one element for each video (optional, used for supervised training)
|
| 33 |
+
poses_2d -- list of input 2D keypoints, one element for each video
|
| 34 |
+
chunk_length -- number of output frames to predict for each training example (usually 1)
|
| 35 |
+
pad -- 2D input padding to compensate for valid convolutions, per side (depends on the receptive field)
|
| 36 |
+
causal_shift -- asymmetric padding offset when causal convolutions are used (usually 0 or "pad")
|
| 37 |
+
shuffle -- randomly shuffle the dataset before each epoch
|
| 38 |
+
random_seed -- initial seed to use for the random generator
|
| 39 |
+
augment -- augment the dataset by flipping poses horizontally
|
| 40 |
+
kps_left and kps_right -- list of left/right 2D keypoints if flipping is enabled
|
| 41 |
+
joints_left and joints_right -- list of left/right 3D joints if flipping is enabled
|
| 42 |
+
"""
|
| 43 |
+
def __init__(self, batch_size, cameras, poses_3d, poses_2d,
|
| 44 |
+
chunk_length, pad=0, causal_shift=0,
|
| 45 |
+
shuffle=True, random_seed=1234,
|
| 46 |
+
augment=False, kps_left=None, kps_right=None, joints_left=None, joints_right=None,
|
| 47 |
+
endless=False):
|
| 48 |
+
assert poses_3d is None or len(poses_3d) == len(poses_2d), (len(poses_3d), len(poses_2d))
|
| 49 |
+
assert cameras is None or len(cameras) == len(poses_2d)
|
| 50 |
+
|
| 51 |
+
# Build lineage info
|
| 52 |
+
pairs = [] # (seq_idx, start_frame, end_frame, flip) tuples
|
| 53 |
+
for i in range(len(poses_2d)):
|
| 54 |
+
assert poses_3d is None or poses_3d[i].shape[0] == poses_3d[i].shape[0]
|
| 55 |
+
n_chunks = (poses_2d[i].shape[0] + chunk_length - 1) // chunk_length
|
| 56 |
+
offset = (n_chunks * chunk_length - poses_2d[i].shape[0]) // 2
|
| 57 |
+
bounds = np.arange(n_chunks+1)*chunk_length - offset
|
| 58 |
+
augment_vector = np.full(len(bounds - 1), False, dtype=bool)
|
| 59 |
+
pairs += zip(np.repeat(i, len(bounds - 1)), bounds[:-1], bounds[1:], augment_vector)
|
| 60 |
+
if augment:
|
| 61 |
+
pairs += zip(np.repeat(i, len(bounds - 1)), bounds[:-1], bounds[1:], ~augment_vector)
|
| 62 |
+
|
| 63 |
+
# Initialize buffers
|
| 64 |
+
if cameras is not None:
|
| 65 |
+
self.batch_cam = np.empty((batch_size, cameras[0].shape[-1]))
|
| 66 |
+
if poses_3d is not None:
|
| 67 |
+
self.batch_3d = np.empty((batch_size, chunk_length, poses_3d[0].shape[-2], poses_3d[0].shape[-1]))
|
| 68 |
+
# self.batch_3d = np.empty((batch_size, chunk_length + 2*pad, poses_3d[0].shape[-2], poses_3d[0].shape[-1]))
|
| 69 |
+
self.batch_2d = np.empty((batch_size, chunk_length + 2*pad, poses_2d[0].shape[-2], poses_2d[0].shape[-1]))
|
| 70 |
+
|
| 71 |
+
self.num_batches = (len(pairs) + batch_size - 1) // batch_size
|
| 72 |
+
self.batch_size = batch_size
|
| 73 |
+
self.random = np.random.RandomState(random_seed)
|
| 74 |
+
self.pairs = pairs
|
| 75 |
+
self.shuffle = shuffle
|
| 76 |
+
self.pad = pad
|
| 77 |
+
self.causal_shift = causal_shift
|
| 78 |
+
self.endless = endless
|
| 79 |
+
self.state = None
|
| 80 |
+
|
| 81 |
+
self.cameras = cameras
|
| 82 |
+
self.poses_3d = poses_3d
|
| 83 |
+
self.poses_2d = poses_2d
|
| 84 |
+
|
| 85 |
+
self.augment = augment
|
| 86 |
+
self.kps_left = kps_left
|
| 87 |
+
self.kps_right = kps_right
|
| 88 |
+
self.joints_left = joints_left
|
| 89 |
+
self.joints_right = joints_right
|
| 90 |
+
|
| 91 |
+
def num_frames(self):
|
| 92 |
+
return self.num_batches * self.batch_size
|
| 93 |
+
|
| 94 |
+
def random_state(self):
|
| 95 |
+
return self.random
|
| 96 |
+
|
| 97 |
+
def set_random_state(self, random):
|
| 98 |
+
self.random = random
|
| 99 |
+
|
| 100 |
+
def augment_enabled(self):
|
| 101 |
+
return self.augment
|
| 102 |
+
|
| 103 |
+
def next_pairs(self):
|
| 104 |
+
if self.state is None:
|
| 105 |
+
if self.shuffle:
|
| 106 |
+
pairs = self.random.permutation(self.pairs)
|
| 107 |
+
else:
|
| 108 |
+
pairs = self.pairs
|
| 109 |
+
return 0, pairs
|
| 110 |
+
else:
|
| 111 |
+
return self.state
|
| 112 |
+
|
| 113 |
+
def next_epoch(self):
|
| 114 |
+
enabled = True
|
| 115 |
+
while enabled:
|
| 116 |
+
start_idx, pairs = self.next_pairs()
|
| 117 |
+
for b_i in range(start_idx, self.num_batches):
|
| 118 |
+
chunks = pairs[b_i*self.batch_size : (b_i+1)*self.batch_size]
|
| 119 |
+
for i, (seq_i, start_3d, end_3d, flip) in enumerate(chunks):
|
| 120 |
+
start_2d = start_3d - self.pad - self.causal_shift
|
| 121 |
+
end_2d = end_3d + self.pad - self.causal_shift
|
| 122 |
+
|
| 123 |
+
# 2D poses
|
| 124 |
+
seq_2d = self.poses_2d[seq_i]
|
| 125 |
+
low_2d = max(start_2d, 0)
|
| 126 |
+
high_2d = min(end_2d, seq_2d.shape[0])
|
| 127 |
+
pad_left_2d = low_2d - start_2d
|
| 128 |
+
pad_right_2d = end_2d - high_2d
|
| 129 |
+
if pad_left_2d != 0 or pad_right_2d != 0:
|
| 130 |
+
self.batch_2d[i] = np.pad(seq_2d[low_2d:high_2d], ((pad_left_2d, pad_right_2d), (0, 0), (0, 0)), 'edge')
|
| 131 |
+
else:
|
| 132 |
+
self.batch_2d[i] = seq_2d[low_2d:high_2d]
|
| 133 |
+
|
| 134 |
+
if flip:
|
| 135 |
+
# Flip 2D keypoints
|
| 136 |
+
self.batch_2d[i, :, :, 0] *= -1
|
| 137 |
+
self.batch_2d[i, :, self.kps_left + self.kps_right] = self.batch_2d[i, :, self.kps_right + self.kps_left]
|
| 138 |
+
|
| 139 |
+
# 3D poses
|
| 140 |
+
if self.poses_3d is not None:
|
| 141 |
+
seq_3d = self.poses_3d[seq_i]
|
| 142 |
+
low_3d = max(start_3d, 0)
|
| 143 |
+
high_3d = min(end_3d, seq_3d.shape[0])
|
| 144 |
+
pad_left_3d = low_3d - start_3d
|
| 145 |
+
pad_right_3d = end_3d - high_3d
|
| 146 |
+
if pad_left_3d != 0 or pad_right_3d != 0:
|
| 147 |
+
# if pad_left_2d != 0 or pad_right_2d != 0:
|
| 148 |
+
# self.batch_3d[i] = np.pad(seq_3d[low_2d:high_2d], ((pad_left_2d, pad_right_2d), (0, 0), (0, 0)), 'edge')
|
| 149 |
+
self.batch_3d[i] = np.pad(seq_3d[low_3d:high_3d], ((pad_left_3d, pad_right_3d), (0, 0), (0, 0)), 'edge')
|
| 150 |
+
else:
|
| 151 |
+
# self.batch_3d[i] = seq_3d[low_2d:high_2d]
|
| 152 |
+
self.batch_3d[i] = seq_3d[low_3d:high_3d]
|
| 153 |
+
|
| 154 |
+
if flip:
|
| 155 |
+
# Flip 3D joints
|
| 156 |
+
self.batch_3d[i, :, :, 0] *= -1
|
| 157 |
+
self.batch_3d[i, :, self.joints_left + self.joints_right] = \
|
| 158 |
+
self.batch_3d[i, :, self.joints_right + self.joints_left]
|
| 159 |
+
|
| 160 |
+
# Cameras
|
| 161 |
+
if self.cameras is not None:
|
| 162 |
+
self.batch_cam[i] = self.cameras[seq_i]
|
| 163 |
+
if flip:
|
| 164 |
+
# Flip horizontal distortion coefficients
|
| 165 |
+
self.batch_cam[i, 2] *= -1
|
| 166 |
+
self.batch_cam[i, 7] *= -1
|
| 167 |
+
|
| 168 |
+
if self.endless:
|
| 169 |
+
self.state = (b_i + 1, pairs)
|
| 170 |
+
if self.poses_3d is None and self.cameras is None:
|
| 171 |
+
yield None, None, self.batch_2d[:len(chunks)]
|
| 172 |
+
elif self.poses_3d is not None and self.cameras is None:
|
| 173 |
+
yield None, self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)]
|
| 174 |
+
# yield None, self.batch_bins_3d[:len(chunks)], self.batch_2d[:len(chunks)]
|
| 175 |
+
elif self.poses_3d is None:
|
| 176 |
+
yield self.batch_cam[:len(chunks)], None, self.batch_2d[:len(chunks)]
|
| 177 |
+
else:
|
| 178 |
+
yield self.batch_cam[:len(chunks)], self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)]
|
| 179 |
+
# yield self.batch_cam[:len(chunks)], self.batch_bins_3d[:len(chunks)], self.batch_2d[:len(chunks)]
|
| 180 |
+
|
| 181 |
+
if self.endless:
|
| 182 |
+
self.state = None
|
| 183 |
+
else:
|
| 184 |
+
enabled = False
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class UnchunkedGenerator:
|
| 188 |
+
"""
|
| 189 |
+
Non-batched data generator, used for testing.
|
| 190 |
+
Sequences are returned one at a time (i.e. batch size = 1), without chunking.
|
| 191 |
+
|
| 192 |
+
If data augmentation is enabled, the batches contain two sequences (i.e. batch size = 2),
|
| 193 |
+
the second of which is a mirrored version of the first.
|
| 194 |
+
|
| 195 |
+
Arguments:
|
| 196 |
+
cameras -- list of cameras, one element for each video (optional, used for semi-supervised training)
|
| 197 |
+
poses_3d -- list of ground-truth 3D poses, one element for each video (optional, used for supervised training)
|
| 198 |
+
poses_2d -- list of input 2D keypoints, one element for each video
|
| 199 |
+
pad -- 2D input padding to compensate for valid convolutions, per side (depends on the receptive field)
|
| 200 |
+
causal_shift -- asymmetric padding offset when causal convolutions are used (usually 0 or "pad")
|
| 201 |
+
augment -- augment the dataset by flipping poses horizontally
|
| 202 |
+
kps_left and kps_right -- list of left/right 2D keypoints if flipping is enabled
|
| 203 |
+
joints_left and joints_right -- list of left/right 3D joints if flipping is enabled
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
def __init__(self, cameras, poses_3d, poses_2d, pad=0, causal_shift=0,
|
| 207 |
+
augment=False, kps_left=None, kps_right=None, joints_left=None, joints_right=None):
|
| 208 |
+
assert poses_3d is None or len(poses_3d) == len(poses_2d)
|
| 209 |
+
assert cameras is None or len(cameras) == len(poses_2d)
|
| 210 |
+
|
| 211 |
+
self.augment = False
|
| 212 |
+
self.kps_left = kps_left
|
| 213 |
+
self.kps_right = kps_right
|
| 214 |
+
self.joints_left = joints_left
|
| 215 |
+
self.joints_right = joints_right
|
| 216 |
+
|
| 217 |
+
self.pad = pad
|
| 218 |
+
self.causal_shift = causal_shift
|
| 219 |
+
self.cameras = [] if cameras is None else cameras
|
| 220 |
+
self.poses_3d = [] if poses_3d is None else poses_3d
|
| 221 |
+
self.poses_2d = poses_2d
|
| 222 |
+
|
| 223 |
+
def num_frames(self):
|
| 224 |
+
count = 0
|
| 225 |
+
for p in self.poses_2d:
|
| 226 |
+
count += p.shape[0]
|
| 227 |
+
return count
|
| 228 |
+
|
| 229 |
+
def augment_enabled(self):
|
| 230 |
+
return self.augment
|
| 231 |
+
|
| 232 |
+
def set_augment(self, augment):
|
| 233 |
+
self.augment = augment
|
| 234 |
+
|
| 235 |
+
def next_epoch(self):
|
| 236 |
+
for seq_cam, seq_3d, seq_2d in zip_longest(self.cameras, self.poses_3d, self.poses_2d):
|
| 237 |
+
batch_cam = None if seq_cam is None else np.expand_dims(seq_cam, axis=0)
|
| 238 |
+
batch_3d = None if seq_3d is None else np.expand_dims(seq_3d, axis=0)
|
| 239 |
+
# batch_3d = np.expand_dims(np.pad(seq_3d,
|
| 240 |
+
# ((self.pad + self.causal_shift, self.pad - self.causal_shift), (0, 0), (0, 0)),
|
| 241 |
+
# 'edge'), axis=0)
|
| 242 |
+
batch_2d = np.expand_dims(np.pad(seq_2d,
|
| 243 |
+
((self.pad + self.causal_shift, self.pad - self.causal_shift), (0, 0), (0, 0)),
|
| 244 |
+
'edge'), axis=0)
|
| 245 |
+
if self.augment:
|
| 246 |
+
# Append flipped version
|
| 247 |
+
if batch_cam is not None:
|
| 248 |
+
batch_cam = np.concatenate((batch_cam, batch_cam), axis=0)
|
| 249 |
+
batch_cam[1, 2] *= -1
|
| 250 |
+
batch_cam[1, 7] *= -1
|
| 251 |
+
|
| 252 |
+
if batch_3d is not None:
|
| 253 |
+
batch_3d = np.concatenate((batch_3d, batch_3d), axis=0)
|
| 254 |
+
batch_3d[1, :, :, 0] *= -1
|
| 255 |
+
batch_3d[1, :, self.joints_left + self.joints_right] = batch_3d[1, :, self.joints_right + self.joints_left]
|
| 256 |
+
|
| 257 |
+
batch_2d = np.concatenate((batch_2d, batch_2d), axis=0)
|
| 258 |
+
batch_2d[1, :, :, 0] *= -1
|
| 259 |
+
batch_2d[1, :, self.kps_left + self.kps_right] = batch_2d[1, :, self.kps_right + self.kps_left]
|
| 260 |
+
|
| 261 |
+
yield batch_cam, batch_3d, batch_2d
|
common/h36m_dataset.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2018-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import copy
|
| 10 |
+
from common.skeleton import Skeleton
|
| 11 |
+
from common.mocap_dataset import MocapDataset
|
| 12 |
+
from common.camera import normalize_screen_coordinates, image_coordinates
|
| 13 |
+
|
| 14 |
+
h36m_skeleton = Skeleton(parents=[-1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 9, 0, 11, 12, 13, 14, 12,
|
| 15 |
+
16, 17, 18, 19, 20, 19, 22, 12, 24, 25, 26, 27, 28, 27, 30],
|
| 16 |
+
joints_left=[6, 7, 8, 9, 10, 16, 17, 18, 19, 20, 21, 22, 23],
|
| 17 |
+
joints_right=[1, 2, 3, 4, 5, 24, 25, 26, 27, 28, 29, 30, 31])
|
| 18 |
+
|
| 19 |
+
h36m_cameras_intrinsic_params = [
|
| 20 |
+
{
|
| 21 |
+
'id': '54138969',
|
| 22 |
+
'center': [512.54150390625, 515.4514770507812],
|
| 23 |
+
'focal_length': [1145.0494384765625, 1143.7811279296875],
|
| 24 |
+
'radial_distortion': [-0.20709891617298126, 0.24777518212795258, -0.0030751503072679043],
|
| 25 |
+
'tangential_distortion': [-0.0009756988729350269, -0.00142447161488235],
|
| 26 |
+
'res_w': 1000,
|
| 27 |
+
'res_h': 1002,
|
| 28 |
+
'azimuth': 70, # Only used for visualization
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
'id': '55011271',
|
| 32 |
+
'center': [508.8486328125, 508.0649108886719],
|
| 33 |
+
'focal_length': [1149.6756591796875, 1147.5916748046875],
|
| 34 |
+
'radial_distortion': [-0.1942136287689209, 0.2404085397720337, 0.006819975562393665],
|
| 35 |
+
'tangential_distortion': [-0.0016190266469493508, -0.0027408944442868233],
|
| 36 |
+
'res_w': 1000,
|
| 37 |
+
'res_h': 1000,
|
| 38 |
+
'azimuth': -70, # Only used for visualization
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
'id': '58860488',
|
| 42 |
+
'center': [519.8158569335938, 501.40264892578125],
|
| 43 |
+
'focal_length': [1149.1407470703125, 1148.7989501953125],
|
| 44 |
+
'radial_distortion': [-0.2083381861448288, 0.25548800826072693, -0.0024604974314570427],
|
| 45 |
+
'tangential_distortion': [0.0014843869721516967, -0.0007599993259645998],
|
| 46 |
+
'res_w': 1000,
|
| 47 |
+
'res_h': 1000,
|
| 48 |
+
'azimuth': 110, # Only used for visualization
|
| 49 |
+
},
|
| 50 |
+
{
|
| 51 |
+
'id': '60457274',
|
| 52 |
+
'center': [514.9682006835938, 501.88201904296875],
|
| 53 |
+
'focal_length': [1145.5113525390625, 1144.77392578125],
|
| 54 |
+
'radial_distortion': [-0.198384091258049, 0.21832367777824402, -0.008947807364165783],
|
| 55 |
+
'tangential_distortion': [-0.0005872055771760643, -0.0018133620033040643],
|
| 56 |
+
'res_w': 1000,
|
| 57 |
+
'res_h': 1002,
|
| 58 |
+
'azimuth': -110, # Only used for visualization
|
| 59 |
+
},
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
h36m_cameras_extrinsic_params = {
|
| 63 |
+
'S1': [
|
| 64 |
+
{
|
| 65 |
+
'orientation': [0.1407056450843811, -0.1500701755285263, -0.755240797996521, 0.6223280429840088],
|
| 66 |
+
'translation': [1841.1070556640625, 4955.28466796875, 1563.4454345703125],
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
'orientation': [0.6157187819480896, -0.764836311340332, -0.14833825826644897, 0.11794740706682205],
|
| 70 |
+
'translation': [1761.278564453125, -5078.0068359375, 1606.2650146484375],
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
'orientation': [0.14651472866535187, -0.14647851884365082, 0.7653023600578308, -0.6094175577163696],
|
| 74 |
+
'translation': [-1846.7777099609375, 5215.04638671875, 1491.972412109375],
|
| 75 |
+
},
|
| 76 |
+
{
|
| 77 |
+
'orientation': [0.5834008455276489, -0.7853162288665771, 0.14548823237419128, -0.14749594032764435],
|
| 78 |
+
'translation': [-1794.7896728515625, -3722.698974609375, 1574.8927001953125],
|
| 79 |
+
},
|
| 80 |
+
],
|
| 81 |
+
'S2': [
|
| 82 |
+
{},
|
| 83 |
+
{},
|
| 84 |
+
{},
|
| 85 |
+
{},
|
| 86 |
+
],
|
| 87 |
+
'S3': [
|
| 88 |
+
{},
|
| 89 |
+
{},
|
| 90 |
+
{},
|
| 91 |
+
{},
|
| 92 |
+
],
|
| 93 |
+
'S4': [
|
| 94 |
+
{},
|
| 95 |
+
{},
|
| 96 |
+
{},
|
| 97 |
+
{},
|
| 98 |
+
],
|
| 99 |
+
'S5': [
|
| 100 |
+
{
|
| 101 |
+
'orientation': [0.1467377245426178, -0.162370964884758, -0.7551892995834351, 0.6178938746452332],
|
| 102 |
+
'translation': [2097.3916015625, 4880.94482421875, 1605.732421875],
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
'orientation': [0.6159758567810059, -0.7626792192459106, -0.15728192031383514, 0.1189815029501915],
|
| 106 |
+
'translation': [2031.7008056640625, -5167.93310546875, 1612.923095703125],
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
'orientation': [0.14291371405124664, -0.12907841801643372, 0.7678384780883789, -0.6110143065452576],
|
| 110 |
+
'translation': [-1620.5948486328125, 5171.65869140625, 1496.43701171875],
|
| 111 |
+
},
|
| 112 |
+
{
|
| 113 |
+
'orientation': [0.5920479893684387, -0.7814217805862427, 0.1274748593568802, -0.15036417543888092],
|
| 114 |
+
'translation': [-1637.1737060546875, -3867.3173828125, 1547.033203125],
|
| 115 |
+
},
|
| 116 |
+
],
|
| 117 |
+
'S6': [
|
| 118 |
+
{
|
| 119 |
+
'orientation': [0.1337897777557373, -0.15692396461963654, -0.7571090459823608, 0.6198879480361938],
|
| 120 |
+
'translation': [1935.4517822265625, 4950.24560546875, 1618.0838623046875],
|
| 121 |
+
},
|
| 122 |
+
{
|
| 123 |
+
'orientation': [0.6147197484970093, -0.7628812789916992, -0.16174767911434174, 0.11819244921207428],
|
| 124 |
+
'translation': [1969.803955078125, -5128.73876953125, 1632.77880859375],
|
| 125 |
+
},
|
| 126 |
+
{
|
| 127 |
+
'orientation': [0.1529948115348816, -0.13529130816459656, 0.7646096348762512, -0.6112781167030334],
|
| 128 |
+
'translation': [-1769.596435546875, 5185.361328125, 1476.993408203125],
|
| 129 |
+
},
|
| 130 |
+
{
|
| 131 |
+
'orientation': [0.5916101336479187, -0.7804774045944214, 0.12832270562648773, -0.1561593860387802],
|
| 132 |
+
'translation': [-1721.668701171875, -3884.13134765625, 1540.4879150390625],
|
| 133 |
+
},
|
| 134 |
+
],
|
| 135 |
+
'S7': [
|
| 136 |
+
{
|
| 137 |
+
'orientation': [0.1435241848230362, -0.1631336808204651, -0.7548328638076782, 0.6188824772834778],
|
| 138 |
+
'translation': [1974.512939453125, 4926.3544921875, 1597.8326416015625],
|
| 139 |
+
},
|
| 140 |
+
{
|
| 141 |
+
'orientation': [0.6141672730445862, -0.7638262510299683, -0.1596645563840866, 0.1177929937839508],
|
| 142 |
+
'translation': [1937.0584716796875, -5119.7900390625, 1631.5665283203125],
|
| 143 |
+
},
|
| 144 |
+
{
|
| 145 |
+
'orientation': [0.14550060033798218, -0.12874816358089447, 0.7660516500473022, -0.6127139329910278],
|
| 146 |
+
'translation': [-1741.8111572265625, 5208.24951171875, 1464.8245849609375],
|
| 147 |
+
},
|
| 148 |
+
{
|
| 149 |
+
'orientation': [0.5912848114967346, -0.7821764349937439, 0.12445473670959473, -0.15196487307548523],
|
| 150 |
+
'translation': [-1734.7105712890625, -3832.42138671875, 1548.5830078125],
|
| 151 |
+
},
|
| 152 |
+
],
|
| 153 |
+
'S8': [
|
| 154 |
+
{
|
| 155 |
+
'orientation': [0.14110587537288666, -0.15589867532253265, -0.7561917304992676, 0.619644045829773],
|
| 156 |
+
'translation': [2150.65185546875, 4896.1611328125, 1611.9046630859375],
|
| 157 |
+
},
|
| 158 |
+
{
|
| 159 |
+
'orientation': [0.6169601678848267, -0.7647668123245239, -0.14846350252628326, 0.11158157885074615],
|
| 160 |
+
'translation': [2219.965576171875, -5148.453125, 1613.0440673828125],
|
| 161 |
+
},
|
| 162 |
+
{
|
| 163 |
+
'orientation': [0.1471444070339203, -0.13377119600772858, 0.7670128345489502, -0.6100369691848755],
|
| 164 |
+
'translation': [-1571.2215576171875, 5137.0185546875, 1498.1761474609375],
|
| 165 |
+
},
|
| 166 |
+
{
|
| 167 |
+
'orientation': [0.5927824378013611, -0.7825870513916016, 0.12147816270589828, -0.14631995558738708],
|
| 168 |
+
'translation': [-1476.913330078125, -3896.7412109375, 1547.97216796875],
|
| 169 |
+
},
|
| 170 |
+
],
|
| 171 |
+
'S9': [
|
| 172 |
+
{
|
| 173 |
+
'orientation': [0.15540587902069092, -0.15548215806484222, -0.7532095313072205, 0.6199594736099243],
|
| 174 |
+
'translation': [2044.45849609375, 4935.1171875, 1481.2275390625],
|
| 175 |
+
},
|
| 176 |
+
{
|
| 177 |
+
'orientation': [0.618784487247467, -0.7634735107421875, -0.14132238924503326, 0.11933968216180801],
|
| 178 |
+
'translation': [1990.959716796875, -5123.810546875, 1568.8048095703125],
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
'orientation': [0.13357827067375183, -0.1367100477218628, 0.7689454555511475, -0.6100738644599915],
|
| 182 |
+
'translation': [-1670.9921875, 5211.98583984375, 1528.387939453125],
|
| 183 |
+
},
|
| 184 |
+
{
|
| 185 |
+
'orientation': [0.5879399180412292, -0.7823407053947449, 0.1427614390850067, -0.14794869720935822],
|
| 186 |
+
'translation': [-1696.04345703125, -3827.099853515625, 1591.4127197265625],
|
| 187 |
+
},
|
| 188 |
+
],
|
| 189 |
+
'S11': [
|
| 190 |
+
{
|
| 191 |
+
'orientation': [0.15232472121715546, -0.15442320704460144, -0.7547563314437866, 0.6191070079803467],
|
| 192 |
+
'translation': [2098.440185546875, 4926.5546875, 1500.278564453125],
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
'orientation': [0.6189449429512024, -0.7600917220115662, -0.15300633013248444, 0.1255258321762085],
|
| 196 |
+
'translation': [2083.182373046875, -4912.1728515625, 1561.07861328125],
|
| 197 |
+
},
|
| 198 |
+
{
|
| 199 |
+
'orientation': [0.14943228662014008, -0.15650227665901184, 0.7681233882904053, -0.6026304364204407],
|
| 200 |
+
'translation': [-1609.8153076171875, 5177.3359375, 1537.896728515625],
|
| 201 |
+
},
|
| 202 |
+
{
|
| 203 |
+
'orientation': [0.5894251465797424, -0.7818877100944519, 0.13991211354732513, -0.14715361595153809],
|
| 204 |
+
'translation': [-1590.738037109375, -3854.1689453125, 1578.017578125],
|
| 205 |
+
},
|
| 206 |
+
],
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
class Human36mDataset(MocapDataset):
|
| 210 |
+
def __init__(self, path, remove_static_joints=True):
|
| 211 |
+
super().__init__(fps=50, skeleton=h36m_skeleton)
|
| 212 |
+
|
| 213 |
+
self._cameras = copy.deepcopy(h36m_cameras_extrinsic_params)
|
| 214 |
+
for cameras in self._cameras.values():
|
| 215 |
+
for i, cam in enumerate(cameras):
|
| 216 |
+
cam.update(h36m_cameras_intrinsic_params[i])
|
| 217 |
+
for k, v in cam.items():
|
| 218 |
+
if k not in ['id', 'res_w', 'res_h']:
|
| 219 |
+
cam[k] = np.array(v, dtype='float32')
|
| 220 |
+
|
| 221 |
+
# Normalize camera frame
|
| 222 |
+
cam['center'] = normalize_screen_coordinates(cam['center'], w=cam['res_w'], h=cam['res_h']).astype('float32')
|
| 223 |
+
cam['focal_length'] = cam['focal_length']/cam['res_w']*2
|
| 224 |
+
if 'translation' in cam:
|
| 225 |
+
cam['translation'] = cam['translation']/1000 # mm to meters
|
| 226 |
+
|
| 227 |
+
# Add intrinsic parameters vector
|
| 228 |
+
cam['intrinsic'] = np.concatenate((cam['focal_length'],
|
| 229 |
+
cam['center'],
|
| 230 |
+
cam['radial_distortion'],
|
| 231 |
+
cam['tangential_distortion'],
|
| 232 |
+
[1/cam['focal_length'][0], 0, -cam['center'][0]/cam['focal_length'][0],
|
| 233 |
+
0, 1/cam['focal_length'][1], -cam['center'][1]/cam['focal_length'][1],
|
| 234 |
+
0, 0, 1]))
|
| 235 |
+
|
| 236 |
+
# proj_matrix = np.array([1/cam['focal_length'][0], 0, -cam['center'][0]/cam['focal_length'][0],
|
| 237 |
+
# 0, 1/cam['focal_length'][1], -cam['center'][1]/cam['focal_length'][1],
|
| 238 |
+
# 0, 0, 1])
|
| 239 |
+
# cam['intrinsic'] = np.concatenate(camera_intrinsics, proj_matrix)
|
| 240 |
+
|
| 241 |
+
# Load serialized dataset
|
| 242 |
+
data = np.load(path, allow_pickle=True)['positions_3d'].item()
|
| 243 |
+
|
| 244 |
+
self._data = {}
|
| 245 |
+
for subject, actions in data.items():
|
| 246 |
+
self._data[subject] = {}
|
| 247 |
+
for action_name, positions in actions.items():
|
| 248 |
+
self._data[subject][action_name] = {
|
| 249 |
+
'positions': positions,
|
| 250 |
+
'cameras': self._cameras[subject],
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
if remove_static_joints:
|
| 254 |
+
# Bring the skeleton to 17 joints instead of the original 32
|
| 255 |
+
self.remove_joints([4, 5, 9, 10, 11, 16, 20, 21, 22, 23, 24, 28, 29, 30, 31])
|
| 256 |
+
|
| 257 |
+
# Rewire shoulders to the correct parents
|
| 258 |
+
self._skeleton._parents[11] = 8
|
| 259 |
+
self._skeleton._parents[14] = 8
|
| 260 |
+
|
| 261 |
+
def supports_semi_supervised(self):
|
| 262 |
+
return True
|
| 263 |
+
|
common/loss.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2018-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import numpy as np
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def mpjpe(predicted, target):
|
| 15 |
+
"""
|
| 16 |
+
Mean per-joint position error (i.e. mean Euclidean distance),
|
| 17 |
+
often referred to as "Protocol #1" in many papers.
|
| 18 |
+
"""
|
| 19 |
+
assert predicted.shape == target.shape
|
| 20 |
+
return torch.mean(torch.norm(predicted - target, dim=len(target.shape)-1))
|
| 21 |
+
|
| 22 |
+
def mse(predicted, target, weights=None, gamma=0):
|
| 23 |
+
loss = nn.MSELoss()
|
| 24 |
+
return loss(predicted, target)
|
| 25 |
+
|
| 26 |
+
def cross_entropy(predicted, target, weights=None, gamma=0, return_weights=False):
|
| 27 |
+
loss = nn.CrossEntropyLoss()
|
| 28 |
+
return loss(predicted.permute(0, 4, 1, 2, 3), target)
|
| 29 |
+
|
| 30 |
+
def weighted_mpjpe(predicted, target, w):
|
| 31 |
+
"""
|
| 32 |
+
Weighted mean per-joint position error (i.e. mean Euclidean distance)
|
| 33 |
+
"""
|
| 34 |
+
assert predicted.shape == target.shape
|
| 35 |
+
assert w.shape[0] == predicted.shape[0]
|
| 36 |
+
return torch.mean(w * torch.norm(predicted - target, dim=len(target.shape)-1))
|
| 37 |
+
|
| 38 |
+
def p_mpjpe(predicted, target):
|
| 39 |
+
"""
|
| 40 |
+
Pose error: MPJPE after rigid alignment (scale, rotation, and translation),
|
| 41 |
+
often referred to as "Protocol #2" in many papers.
|
| 42 |
+
"""
|
| 43 |
+
assert predicted.shape == target.shape
|
| 44 |
+
|
| 45 |
+
muX = np.mean(target, axis=1, keepdims=True)
|
| 46 |
+
muY = np.mean(predicted, axis=1, keepdims=True)
|
| 47 |
+
|
| 48 |
+
X0 = target - muX
|
| 49 |
+
Y0 = predicted - muY
|
| 50 |
+
|
| 51 |
+
normX = np.sqrt(np.sum(X0**2, axis=(1, 2), keepdims=True))
|
| 52 |
+
normY = np.sqrt(np.sum(Y0**2, axis=(1, 2), keepdims=True))
|
| 53 |
+
|
| 54 |
+
X0 /= normX
|
| 55 |
+
Y0 /= normY
|
| 56 |
+
|
| 57 |
+
H = np.matmul(X0.transpose(0, 2, 1), Y0)
|
| 58 |
+
U, s, Vt = np.linalg.svd(H)
|
| 59 |
+
V = Vt.transpose(0, 2, 1)
|
| 60 |
+
R = np.matmul(V, U.transpose(0, 2, 1))
|
| 61 |
+
|
| 62 |
+
# Avoid improper rotations (reflections), i.e. rotations with det(R) = -1
|
| 63 |
+
sign_detR = np.sign(np.expand_dims(np.linalg.det(R), axis=1))
|
| 64 |
+
V[:, :, -1] *= sign_detR
|
| 65 |
+
s[:, -1] *= sign_detR.flatten()
|
| 66 |
+
R = np.matmul(V, U.transpose(0, 2, 1)) # Rotation
|
| 67 |
+
|
| 68 |
+
tr = np.expand_dims(np.sum(s, axis=1, keepdims=True), axis=2)
|
| 69 |
+
|
| 70 |
+
a = tr * normX / normY # Scale
|
| 71 |
+
t = muX - a*np.matmul(muY, R) # Translation
|
| 72 |
+
|
| 73 |
+
# Perform rigid transformation on the input
|
| 74 |
+
predicted_aligned = a*np.matmul(predicted, R) + t
|
| 75 |
+
|
| 76 |
+
# Return MPJPE
|
| 77 |
+
return np.mean(np.linalg.norm(predicted_aligned - target, axis=len(target.shape)-1))
|
| 78 |
+
|
| 79 |
+
def n_mpjpe(predicted, target):
|
| 80 |
+
"""
|
| 81 |
+
Normalized MPJPE (scale only), adapted from:
|
| 82 |
+
https://github.com/hrhodin/UnsupervisedGeometryAwareRepresentationLearning/blob/master/losses/poses.py
|
| 83 |
+
"""
|
| 84 |
+
assert predicted.shape == target.shape
|
| 85 |
+
|
| 86 |
+
norm_predicted = torch.mean(torch.sum(predicted**2, dim=3, keepdim=True), dim=2, keepdim=True)
|
| 87 |
+
norm_target = torch.mean(torch.sum(target*predicted, dim=3, keepdim=True), dim=2, keepdim=True)
|
| 88 |
+
scale = norm_target / norm_predicted
|
| 89 |
+
return mpjpe(scale * predicted, target)#[0]
|
| 90 |
+
|
| 91 |
+
def weighted_bonelen_loss(predict_3d_length, gt_3d_length):
|
| 92 |
+
loss_length = 0.001 * torch.pow(predict_3d_length - gt_3d_length, 2).mean()
|
| 93 |
+
return loss_length
|
| 94 |
+
|
| 95 |
+
def weighted_boneratio_loss(predict_3d_length, gt_3d_length):
|
| 96 |
+
loss_length = 0.1 * torch.pow((predict_3d_length - gt_3d_length)/gt_3d_length, 2).mean()
|
| 97 |
+
return loss_length
|
| 98 |
+
|
| 99 |
+
def mean_velocity_error(predicted, target):
|
| 100 |
+
"""
|
| 101 |
+
Mean per-joint velocity error (i.e. mean Euclidean distance of the 1st derivative)
|
| 102 |
+
"""
|
| 103 |
+
assert predicted.shape == target.shape
|
| 104 |
+
|
| 105 |
+
velocity_predicted = np.diff(predicted, axis=0)
|
| 106 |
+
velocity_target = np.diff(target, axis=0)
|
| 107 |
+
|
| 108 |
+
return np.mean(np.linalg.norm(velocity_predicted - velocity_target, axis=len(target.shape)-1))
|
common/mocap_dataset.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2018-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from common.skeleton import Skeleton
|
| 10 |
+
|
| 11 |
+
class MocapDataset:
|
| 12 |
+
def __init__(self, fps, skeleton):
|
| 13 |
+
self._skeleton = skeleton
|
| 14 |
+
self._fps = fps
|
| 15 |
+
self._data = None # Must be filled by subclass
|
| 16 |
+
self._cameras = None # Must be filled by subclass
|
| 17 |
+
|
| 18 |
+
def remove_joints(self, joints_to_remove):
|
| 19 |
+
kept_joints = self._skeleton.remove_joints(joints_to_remove)
|
| 20 |
+
for subject in self._data.keys():
|
| 21 |
+
for action in self._data[subject].keys():
|
| 22 |
+
s = self._data[subject][action]
|
| 23 |
+
if 'positions' in s:
|
| 24 |
+
s['positions'] = s['positions'][:, kept_joints]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def __getitem__(self, key):
|
| 28 |
+
return self._data[key]
|
| 29 |
+
|
| 30 |
+
def subjects(self):
|
| 31 |
+
return self._data.keys()
|
| 32 |
+
|
| 33 |
+
def fps(self):
|
| 34 |
+
return self._fps
|
| 35 |
+
|
| 36 |
+
def skeleton(self):
|
| 37 |
+
return self._skeleton
|
| 38 |
+
|
| 39 |
+
def cameras(self):
|
| 40 |
+
return self._cameras
|
| 41 |
+
|
| 42 |
+
def supports_semi_supervised(self):
|
| 43 |
+
# This method can be overridden
|
| 44 |
+
return False
|
common/model_poseformer.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Our PoseFormer model was revised from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
| 2 |
+
# Written by Ce Zheng (cezheng@knights.ucf.edu)
|
| 3 |
+
# Modified by Qitao Zhao (qitaozhao@mail.sdu.edu.cn)
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
import logging
|
| 7 |
+
from functools import partial
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch_dct as dct
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
from timm.models.layers import DropPath
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Mlp(nn.Module):
|
| 20 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 21 |
+
super().__init__()
|
| 22 |
+
out_features = out_features or in_features
|
| 23 |
+
hidden_features = hidden_features or in_features
|
| 24 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 25 |
+
self.act = act_layer()
|
| 26 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 27 |
+
self.drop = nn.Dropout(drop)
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
x = self.fc1(x)
|
| 31 |
+
x = self.act(x)
|
| 32 |
+
x = self.drop(x)
|
| 33 |
+
x = self.fc2(x)
|
| 34 |
+
x = self.drop(x)
|
| 35 |
+
return x
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class FreqMlp(nn.Module):
|
| 39 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 40 |
+
super().__init__()
|
| 41 |
+
out_features = out_features or in_features
|
| 42 |
+
hidden_features = hidden_features or in_features
|
| 43 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 44 |
+
self.act = act_layer()
|
| 45 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 46 |
+
self.drop = nn.Dropout(drop)
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
b, f, _ = x.shape
|
| 50 |
+
x = dct.dct(x.permute(0, 2, 1)).permute(0, 2, 1).contiguous()
|
| 51 |
+
x = self.fc1(x)
|
| 52 |
+
x = self.act(x)
|
| 53 |
+
x = self.drop(x)
|
| 54 |
+
x = self.fc2(x)
|
| 55 |
+
x = self.drop(x)
|
| 56 |
+
x = dct.idct(x.permute(0, 2, 1)).permute(0, 2, 1).contiguous()
|
| 57 |
+
return x
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class Attention(nn.Module):
|
| 61 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.num_heads = num_heads
|
| 64 |
+
head_dim = dim // num_heads
|
| 65 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
| 66 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 67 |
+
|
| 68 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 69 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 70 |
+
self.proj = nn.Linear(dim, dim)
|
| 71 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
B, N, C = x.shape
|
| 75 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 76 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
| 77 |
+
|
| 78 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 79 |
+
attn = attn.softmax(dim=-1)
|
| 80 |
+
attn = self.attn_drop(attn)
|
| 81 |
+
|
| 82 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 83 |
+
x = self.proj(x)
|
| 84 |
+
x = self.proj_drop(x)
|
| 85 |
+
return x
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class Block(nn.Module):
|
| 89 |
+
|
| 90 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 91 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.norm1 = norm_layer(dim)
|
| 94 |
+
self.attn = Attention(
|
| 95 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 96 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 97 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 98 |
+
self.norm2 = norm_layer(dim)
|
| 99 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 100 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 101 |
+
|
| 102 |
+
def forward(self, x):
|
| 103 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
| 104 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 105 |
+
return x
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class MixedBlock(nn.Module):
|
| 109 |
+
|
| 110 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 111 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.norm1 = norm_layer(dim)
|
| 114 |
+
self.attn = Attention(
|
| 115 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 116 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 117 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 118 |
+
self.norm2 = norm_layer(dim)
|
| 119 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 120 |
+
self.mlp1 = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 121 |
+
self.norm3 = norm_layer(dim)
|
| 122 |
+
self.mlp2 = FreqMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 123 |
+
|
| 124 |
+
def forward(self, x):
|
| 125 |
+
b, f, c = x.shape
|
| 126 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
| 127 |
+
x1 = x[:, :f//2] + self.drop_path(self.mlp1(self.norm2(x[:, :f//2])))
|
| 128 |
+
x2 = x[:, f//2:] + self.drop_path(self.mlp2(self.norm3(x[:, f//2:])))
|
| 129 |
+
return torch.cat((x1, x2), dim=1)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class PoseTransformerV2(nn.Module):
|
| 133 |
+
def __init__(self, num_frame=9, num_joints=17, in_chans=2,
|
| 134 |
+
num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,
|
| 135 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2, norm_layer=None, args=None):
|
| 136 |
+
""" ##########hybrid_backbone=None, representation_size=None,
|
| 137 |
+
Args:
|
| 138 |
+
num_frame (int, tuple): input frame number
|
| 139 |
+
num_joints (int, tuple): joints number
|
| 140 |
+
in_chans (int): number of input channels, 2D joints have 2 channels: (x,y)
|
| 141 |
+
embed_dim_ratio (int): embedding dimension ratio
|
| 142 |
+
depth (int): depth of transformer
|
| 143 |
+
num_heads (int): number of attention heads
|
| 144 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 145 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 146 |
+
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
| 147 |
+
drop_rate (float): dropout rate
|
| 148 |
+
attn_drop_rate (float): attention dropout rate
|
| 149 |
+
drop_path_rate (float): stochastic depth rate
|
| 150 |
+
norm_layer: (nn.Module): normalization layer
|
| 151 |
+
"""
|
| 152 |
+
super().__init__()
|
| 153 |
+
|
| 154 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
| 155 |
+
embed_dim_ratio = args.embed_dim_ratio
|
| 156 |
+
depth = args.depth
|
| 157 |
+
embed_dim = embed_dim_ratio * num_joints #### temporal embed_dim is num_joints * spatial embedding dim ratio
|
| 158 |
+
out_dim = num_joints * 3 #### output dimension is num_joints * 3
|
| 159 |
+
self.num_frame_kept = args.number_of_kept_frames
|
| 160 |
+
self.num_coeff_kept = args.number_of_kept_coeffs if args.number_of_kept_coeffs else self.num_frame_kept
|
| 161 |
+
|
| 162 |
+
### spatial patch embedding
|
| 163 |
+
self.Joint_embedding = nn.Linear(in_chans, embed_dim_ratio)
|
| 164 |
+
self.Freq_embedding = nn.Linear(in_chans*num_joints, embed_dim)
|
| 165 |
+
|
| 166 |
+
self.Spatial_pos_embed = nn.Parameter(torch.zeros(1, num_joints, embed_dim_ratio))
|
| 167 |
+
self.Temporal_pos_embed = nn.Parameter(torch.zeros(1, self.num_frame_kept, embed_dim))
|
| 168 |
+
self.Temporal_pos_embed_ = nn.Parameter(torch.zeros(1, self.num_coeff_kept, embed_dim))
|
| 169 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 170 |
+
|
| 171 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 172 |
+
|
| 173 |
+
self.Spatial_blocks = nn.ModuleList([
|
| 174 |
+
Block(
|
| 175 |
+
dim=embed_dim_ratio, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 176 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
| 177 |
+
for i in range(depth)])
|
| 178 |
+
|
| 179 |
+
self.blocks = nn.ModuleList([
|
| 180 |
+
MixedBlock(
|
| 181 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 182 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
| 183 |
+
for i in range(depth)])
|
| 184 |
+
|
| 185 |
+
self.Spatial_norm = norm_layer(embed_dim_ratio)
|
| 186 |
+
self.Temporal_norm = norm_layer(embed_dim)
|
| 187 |
+
|
| 188 |
+
####### A easy way to implement weighted mean
|
| 189 |
+
self.weighted_mean = torch.nn.Conv1d(in_channels=self.num_coeff_kept, out_channels=1, kernel_size=1)
|
| 190 |
+
self.weighted_mean_ = torch.nn.Conv1d(in_channels=self.num_frame_kept, out_channels=1, kernel_size=1)
|
| 191 |
+
|
| 192 |
+
self.head = nn.Sequential(
|
| 193 |
+
nn.LayerNorm(embed_dim*2),
|
| 194 |
+
nn.Linear(embed_dim*2, out_dim),
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
def Spatial_forward_features(self, x):
|
| 198 |
+
b, f, p, _ = x.shape ##### b is batch size, f is number of frames, p is number of joints
|
| 199 |
+
num_frame_kept = self.num_frame_kept
|
| 200 |
+
|
| 201 |
+
index = torch.arange((f-1)//2-num_frame_kept//2, (f-1)//2+num_frame_kept//2+1)
|
| 202 |
+
|
| 203 |
+
x = self.Joint_embedding(x[:, index].view(b*num_frame_kept, p, -1))
|
| 204 |
+
x += self.Spatial_pos_embed
|
| 205 |
+
x = self.pos_drop(x)
|
| 206 |
+
|
| 207 |
+
for blk in self.Spatial_blocks:
|
| 208 |
+
x = blk(x)
|
| 209 |
+
|
| 210 |
+
x = self.Spatial_norm(x)
|
| 211 |
+
x = rearrange(x, '(b f) p c -> b f (p c)', f=num_frame_kept)
|
| 212 |
+
return x
|
| 213 |
+
|
| 214 |
+
def forward_features(self, x, Spatial_feature):
|
| 215 |
+
b, f, p, _ = x.shape
|
| 216 |
+
num_coeff_kept = self.num_coeff_kept
|
| 217 |
+
|
| 218 |
+
x = dct.dct(x.permute(0, 2, 3, 1))[:, :, :, :num_coeff_kept]
|
| 219 |
+
x = x.permute(0, 3, 1, 2).contiguous().view(b, num_coeff_kept, -1)
|
| 220 |
+
x = self.Freq_embedding(x)
|
| 221 |
+
|
| 222 |
+
Spatial_feature += self.Temporal_pos_embed
|
| 223 |
+
x += self.Temporal_pos_embed_
|
| 224 |
+
x = torch.cat((x, Spatial_feature), dim=1)
|
| 225 |
+
|
| 226 |
+
for blk in self.blocks:
|
| 227 |
+
x = blk(x)
|
| 228 |
+
|
| 229 |
+
x = self.Temporal_norm(x)
|
| 230 |
+
return x
|
| 231 |
+
|
| 232 |
+
def forward(self, x):
|
| 233 |
+
b, f, p, _ = x.shape
|
| 234 |
+
x_ = x.clone()
|
| 235 |
+
|
| 236 |
+
Spatial_feature = self.Spatial_forward_features(x)
|
| 237 |
+
x = self.forward_features(x_, Spatial_feature)
|
| 238 |
+
x = torch.cat((self.weighted_mean(x[:, :self.num_coeff_kept]), self.weighted_mean_(x[:, self.num_coeff_kept:])), dim=-1)
|
| 239 |
+
|
| 240 |
+
x = self.head(x).view(b, 1, p, -1)
|
| 241 |
+
return x
|
| 242 |
+
|
common/quaternion.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2018-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
def qrot(q, v):
|
| 11 |
+
"""
|
| 12 |
+
Rotate vector(s) v about the rotation described by quaternion(s) q.
|
| 13 |
+
Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
|
| 14 |
+
where * denotes any number of dimensions.
|
| 15 |
+
Returns a tensor of shape (*, 3).
|
| 16 |
+
"""
|
| 17 |
+
assert q.shape[-1] == 4
|
| 18 |
+
assert v.shape[-1] == 3
|
| 19 |
+
assert q.shape[:-1] == v.shape[:-1]
|
| 20 |
+
|
| 21 |
+
qvec = q[..., 1:]
|
| 22 |
+
uv = torch.cross(qvec, v, dim=len(q.shape)-1)
|
| 23 |
+
uuv = torch.cross(qvec, uv, dim=len(q.shape)-1)
|
| 24 |
+
return (v + 2 * (q[..., :1] * uv + uuv))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def qinverse(q, inplace=False):
|
| 28 |
+
# We assume the quaternion to be normalized
|
| 29 |
+
if inplace:
|
| 30 |
+
q[..., 1:] *= -1
|
| 31 |
+
return q
|
| 32 |
+
else:
|
| 33 |
+
w = q[..., :1]
|
| 34 |
+
xyz = q[..., 1:]
|
| 35 |
+
return torch.cat((w, -xyz), dim=len(q.shape)-1)
|
common/skeleton.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2018-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
class Skeleton:
|
| 11 |
+
def __init__(self, parents, joints_left, joints_right):
|
| 12 |
+
assert len(joints_left) == len(joints_right)
|
| 13 |
+
|
| 14 |
+
self._parents = np.array(parents)
|
| 15 |
+
self._joints_left = joints_left
|
| 16 |
+
self._joints_right = joints_right
|
| 17 |
+
self._compute_metadata()
|
| 18 |
+
|
| 19 |
+
def num_joints(self):
|
| 20 |
+
return len(self._parents)
|
| 21 |
+
|
| 22 |
+
def parents(self):
|
| 23 |
+
return self._parents
|
| 24 |
+
|
| 25 |
+
def has_children(self):
|
| 26 |
+
return self._has_children
|
| 27 |
+
|
| 28 |
+
def children(self):
|
| 29 |
+
return self._children
|
| 30 |
+
|
| 31 |
+
def remove_joints(self, joints_to_remove):
|
| 32 |
+
"""
|
| 33 |
+
Remove the joints specified in 'joints_to_remove'.
|
| 34 |
+
"""
|
| 35 |
+
valid_joints = []
|
| 36 |
+
for joint in range(len(self._parents)):
|
| 37 |
+
if joint not in joints_to_remove:
|
| 38 |
+
valid_joints.append(joint)
|
| 39 |
+
|
| 40 |
+
for i in range(len(self._parents)):
|
| 41 |
+
while self._parents[i] in joints_to_remove:
|
| 42 |
+
self._parents[i] = self._parents[self._parents[i]]
|
| 43 |
+
|
| 44 |
+
index_offsets = np.zeros(len(self._parents), dtype=int)
|
| 45 |
+
new_parents = []
|
| 46 |
+
for i, parent in enumerate(self._parents):
|
| 47 |
+
if i not in joints_to_remove:
|
| 48 |
+
new_parents.append(parent - index_offsets[parent])
|
| 49 |
+
else:
|
| 50 |
+
index_offsets[i:] += 1
|
| 51 |
+
self._parents = np.array(new_parents)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if self._joints_left is not None:
|
| 55 |
+
new_joints_left = []
|
| 56 |
+
for joint in self._joints_left:
|
| 57 |
+
if joint in valid_joints:
|
| 58 |
+
new_joints_left.append(joint - index_offsets[joint])
|
| 59 |
+
self._joints_left = new_joints_left
|
| 60 |
+
if self._joints_right is not None:
|
| 61 |
+
new_joints_right = []
|
| 62 |
+
for joint in self._joints_right:
|
| 63 |
+
if joint in valid_joints:
|
| 64 |
+
new_joints_right.append(joint - index_offsets[joint])
|
| 65 |
+
self._joints_right = new_joints_right
|
| 66 |
+
|
| 67 |
+
self._compute_metadata()
|
| 68 |
+
|
| 69 |
+
return valid_joints
|
| 70 |
+
|
| 71 |
+
def joints_left(self):
|
| 72 |
+
return self._joints_left
|
| 73 |
+
|
| 74 |
+
def joints_right(self):
|
| 75 |
+
return self._joints_right
|
| 76 |
+
|
| 77 |
+
def _compute_metadata(self):
|
| 78 |
+
self._has_children = np.zeros(len(self._parents)).astype(bool)
|
| 79 |
+
for i, parent in enumerate(self._parents):
|
| 80 |
+
if parent != -1:
|
| 81 |
+
self._has_children[parent] = True
|
| 82 |
+
|
| 83 |
+
self._children = []
|
| 84 |
+
for i, parent in enumerate(self._parents):
|
| 85 |
+
self._children.append([])
|
| 86 |
+
for i, parent in enumerate(self._parents):
|
| 87 |
+
if parent != -1:
|
| 88 |
+
self._children[parent].append(i)
|
common/utils.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2018-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
import hashlib
|
| 11 |
+
|
| 12 |
+
def wrap(func, *args, unsqueeze=False):
|
| 13 |
+
"""
|
| 14 |
+
Wrap a torch function so it can be called with NumPy arrays.
|
| 15 |
+
Input and return types are seamlessly converted.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
# Convert input types where applicable
|
| 19 |
+
args = list(args)
|
| 20 |
+
for i, arg in enumerate(args):
|
| 21 |
+
if type(arg) == np.ndarray:
|
| 22 |
+
args[i] = torch.from_numpy(arg)
|
| 23 |
+
if unsqueeze:
|
| 24 |
+
args[i] = args[i].unsqueeze(0)
|
| 25 |
+
|
| 26 |
+
result = func(*args)
|
| 27 |
+
|
| 28 |
+
# Convert output types where applicable
|
| 29 |
+
if isinstance(result, tuple):
|
| 30 |
+
result = list(result)
|
| 31 |
+
for i, res in enumerate(result):
|
| 32 |
+
if type(res) == torch.Tensor:
|
| 33 |
+
if unsqueeze:
|
| 34 |
+
res = res.squeeze(0)
|
| 35 |
+
result[i] = res.numpy()
|
| 36 |
+
return tuple(result)
|
| 37 |
+
elif type(result) == torch.Tensor:
|
| 38 |
+
if unsqueeze:
|
| 39 |
+
result = result.squeeze(0)
|
| 40 |
+
return result.numpy()
|
| 41 |
+
else:
|
| 42 |
+
return result
|
| 43 |
+
|
| 44 |
+
def deterministic_random(min_value, max_value, data):
|
| 45 |
+
digest = hashlib.sha256(data.encode()).digest()
|
| 46 |
+
raw_value = int.from_bytes(digest[:4], byteorder='little', signed=False)
|
| 47 |
+
return int(raw_value / (2**32 - 1) * (max_value - min_value)) + min_value
|
| 48 |
+
|
| 49 |
+
def load_pretrained_weights(model, checkpoint):
|
| 50 |
+
"""Load pretrianed weights to model
|
| 51 |
+
Incompatible layers (unmatched in name or size) will be ignored
|
| 52 |
+
Args:
|
| 53 |
+
- model (nn.Module): network model, which must not be nn.DataParallel
|
| 54 |
+
- weight_path (str): path to pretrained weights
|
| 55 |
+
"""
|
| 56 |
+
import collections
|
| 57 |
+
if 'state_dict' in checkpoint:
|
| 58 |
+
state_dict = checkpoint['state_dict']
|
| 59 |
+
else:
|
| 60 |
+
state_dict = checkpoint
|
| 61 |
+
model_dict = model.state_dict()
|
| 62 |
+
new_state_dict = collections.OrderedDict()
|
| 63 |
+
matched_layers, discarded_layers = [], []
|
| 64 |
+
for k, v in state_dict.items():
|
| 65 |
+
# If the pretrained state_dict was saved as nn.DataParallel,
|
| 66 |
+
# keys would contain "module.", which should be ignored.
|
| 67 |
+
if k.startswith('module.'):
|
| 68 |
+
k = k[7:]
|
| 69 |
+
if k in model_dict and model_dict[k].size() == v.size():
|
| 70 |
+
new_state_dict[k] = v
|
| 71 |
+
matched_layers.append(k)
|
| 72 |
+
else:
|
| 73 |
+
discarded_layers.append(k)
|
| 74 |
+
# new_state_dict.requires_grad = False
|
| 75 |
+
model_dict.update(new_state_dict)
|
| 76 |
+
|
| 77 |
+
model.load_state_dict(model_dict)
|
| 78 |
+
print('load_weight', len(matched_layers))
|
| 79 |
+
# model.state_dict(model_dict).requires_grad = False
|
| 80 |
+
return model
|
common/visualization.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2018-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
import matplotlib
|
| 9 |
+
|
| 10 |
+
matplotlib.use('Agg')
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
from matplotlib.animation import FuncAnimation, writers
|
| 15 |
+
from mpl_toolkits.mplot3d import Axes3D
|
| 16 |
+
import numpy as np
|
| 17 |
+
import subprocess as sp
|
| 18 |
+
import cv2
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_resolution(filename):
|
| 22 |
+
command = ['ffprobe', '-v', 'error', '-select_streams', 'v:0',
|
| 23 |
+
'-show_entries', 'stream=width,height', '-of', 'csv=p=0', filename]
|
| 24 |
+
with sp.Popen(command, stdout=sp.PIPE, bufsize=-1) as pipe:
|
| 25 |
+
for line in pipe.stdout:
|
| 26 |
+
w, h = line.decode().strip().split(',')
|
| 27 |
+
return int(w), int(h)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_fps(filename):
|
| 31 |
+
command = ['ffprobe', '-v', 'error', '-select_streams', 'v:0',
|
| 32 |
+
'-show_entries', 'stream=r_frame_rate', '-of', 'csv=p=0', filename]
|
| 33 |
+
with sp.Popen(command, stdout=sp.PIPE, bufsize=-1) as pipe:
|
| 34 |
+
for line in pipe.stdout:
|
| 35 |
+
a, b = line.decode().strip().split('/')
|
| 36 |
+
return int(a) / int(b)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def read_video(filename, skip=0, limit=-1):
|
| 40 |
+
# w, h = get_resolution(filename)
|
| 41 |
+
w = 1000
|
| 42 |
+
h = 1002
|
| 43 |
+
|
| 44 |
+
command = ['ffmpeg',
|
| 45 |
+
'-i', filename,
|
| 46 |
+
'-f', 'image2pipe',
|
| 47 |
+
'-pix_fmt', 'rgb24',
|
| 48 |
+
'-vsync', '0',
|
| 49 |
+
'-vcodec', 'rawvideo', '-']
|
| 50 |
+
|
| 51 |
+
i = 0
|
| 52 |
+
with sp.Popen(command, stdout=sp.PIPE, bufsize=-1) as pipe:
|
| 53 |
+
while True:
|
| 54 |
+
data = pipe.stdout.read(w * h * 3)
|
| 55 |
+
if not data:
|
| 56 |
+
break
|
| 57 |
+
i += 1
|
| 58 |
+
if i > limit and limit != -1:
|
| 59 |
+
continue
|
| 60 |
+
if i > skip:
|
| 61 |
+
yield np.frombuffer(data, dtype='uint8').reshape((h, w, 3))
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def downsample_tensor(X, factor):
|
| 65 |
+
length = X.shape[0] // factor * factor
|
| 66 |
+
return np.mean(X[:length].reshape(-1, factor, *X.shape[1:]), axis=1)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def render_animation(keypoints, keypoints_metadata, poses, skeleton, fps, bitrate, azim, output, viewport,
|
| 70 |
+
limit=-1, downsample=1, size=6, input_video_path=None, input_video_skip=0):
|
| 71 |
+
"""
|
| 72 |
+
TODO
|
| 73 |
+
Render an animation. The supported output modes are:
|
| 74 |
+
-- 'interactive': display an interactive figure
|
| 75 |
+
(also works on notebooks if associated with %matplotlib inline)
|
| 76 |
+
-- 'html': render the animation as HTML5 video. Can be displayed in a notebook using HTML(...).
|
| 77 |
+
-- 'filename.mp4': render and export the animation as an h264 video (requires ffmpeg).
|
| 78 |
+
-- 'filename.gif': render and export the animation a gif file (requires imagemagick).
|
| 79 |
+
"""
|
| 80 |
+
plt.ioff()
|
| 81 |
+
fig = plt.figure(figsize=(size * (1 + len(poses)), size))
|
| 82 |
+
ax_in = fig.add_subplot(1, 1 + len(poses), 1)
|
| 83 |
+
ax_in.get_xaxis().set_visible(False)
|
| 84 |
+
ax_in.get_yaxis().set_visible(False)
|
| 85 |
+
ax_in.set_axis_off()
|
| 86 |
+
ax_in.set_title('Input')
|
| 87 |
+
|
| 88 |
+
ax_3d = []
|
| 89 |
+
lines_3d = []
|
| 90 |
+
trajectories = []
|
| 91 |
+
radius = 1.7
|
| 92 |
+
for index, (title, data) in enumerate(poses.items()):
|
| 93 |
+
ax = fig.add_subplot(1, 1 + len(poses), index + 2, projection='3d')
|
| 94 |
+
ax.view_init(elev=15., azim=azim)
|
| 95 |
+
ax.set_xlim3d([-radius / 2, radius / 2])
|
| 96 |
+
ax.set_zlim3d([0, radius])
|
| 97 |
+
ax.set_ylim3d([-radius / 2, radius / 2])
|
| 98 |
+
try:
|
| 99 |
+
ax.set_aspect('equal')
|
| 100 |
+
except NotImplementedError:
|
| 101 |
+
ax.set_aspect('auto')
|
| 102 |
+
ax.set_xticklabels([])
|
| 103 |
+
ax.set_yticklabels([])
|
| 104 |
+
ax.set_zticklabels([])
|
| 105 |
+
ax.dist = 7.5
|
| 106 |
+
ax.set_title(title) # , pad=35
|
| 107 |
+
ax_3d.append(ax)
|
| 108 |
+
lines_3d.append([])
|
| 109 |
+
trajectories.append(data[:, 0, [0, 1]])
|
| 110 |
+
poses = list(poses.values())
|
| 111 |
+
|
| 112 |
+
# Decode video
|
| 113 |
+
if input_video_path is None:
|
| 114 |
+
# Black background
|
| 115 |
+
all_frames = np.zeros((keypoints.shape[0], viewport[1], viewport[0]), dtype='uint8')
|
| 116 |
+
else:
|
| 117 |
+
# Load video using ffmpeg
|
| 118 |
+
all_frames = []
|
| 119 |
+
for f in read_video(input_video_path, skip=input_video_skip, limit=limit):
|
| 120 |
+
all_frames.append(f)
|
| 121 |
+
effective_length = min(keypoints.shape[0], len(all_frames))
|
| 122 |
+
all_frames = all_frames[:effective_length]
|
| 123 |
+
|
| 124 |
+
keypoints = keypoints[input_video_skip:] # todo remove
|
| 125 |
+
for idx in range(len(poses)):
|
| 126 |
+
poses[idx] = poses[idx][input_video_skip:]
|
| 127 |
+
|
| 128 |
+
if fps is None:
|
| 129 |
+
fps = get_fps(input_video_path)
|
| 130 |
+
|
| 131 |
+
if downsample > 1:
|
| 132 |
+
keypoints = downsample_tensor(keypoints, downsample)
|
| 133 |
+
all_frames = downsample_tensor(np.array(all_frames), downsample).astype('uint8')
|
| 134 |
+
for idx in range(len(poses)):
|
| 135 |
+
poses[idx] = downsample_tensor(poses[idx], downsample)
|
| 136 |
+
trajectories[idx] = downsample_tensor(trajectories[idx], downsample)
|
| 137 |
+
fps /= downsample
|
| 138 |
+
|
| 139 |
+
initialized = False
|
| 140 |
+
image = None
|
| 141 |
+
lines = []
|
| 142 |
+
points = None
|
| 143 |
+
|
| 144 |
+
if limit < 1:
|
| 145 |
+
limit = len(all_frames)
|
| 146 |
+
else:
|
| 147 |
+
limit = min(limit, len(all_frames))
|
| 148 |
+
|
| 149 |
+
parents = skeleton.parents()
|
| 150 |
+
|
| 151 |
+
def update_video(i):
|
| 152 |
+
nonlocal initialized, image, lines, points
|
| 153 |
+
|
| 154 |
+
for n, ax in enumerate(ax_3d):
|
| 155 |
+
ax.set_xlim3d([-radius / 2 + trajectories[n][i, 0], radius / 2 + trajectories[n][i, 0]])
|
| 156 |
+
ax.set_ylim3d([-radius / 2 + trajectories[n][i, 1], radius / 2 + trajectories[n][i, 1]])
|
| 157 |
+
|
| 158 |
+
# Update 2D poses
|
| 159 |
+
joints_right_2d = keypoints_metadata['keypoints_symmetry'][1]
|
| 160 |
+
colors_2d = np.full(keypoints.shape[1], 'black')
|
| 161 |
+
colors_2d[joints_right_2d] = 'red'
|
| 162 |
+
if not initialized:
|
| 163 |
+
image = ax_in.imshow(all_frames[i], aspect='equal')
|
| 164 |
+
|
| 165 |
+
for j, j_parent in enumerate(parents):
|
| 166 |
+
if j_parent == -1:
|
| 167 |
+
continue
|
| 168 |
+
|
| 169 |
+
if len(parents) == keypoints.shape[1] and keypoints_metadata['layout_name'] != 'coco':
|
| 170 |
+
# Draw skeleton only if keypoints match (otherwise we don't have the parents definition)
|
| 171 |
+
lines.append(ax_in.plot([keypoints[i, j, 0], keypoints[i, j_parent, 0]],
|
| 172 |
+
[keypoints[i, j, 1], keypoints[i, j_parent, 1]], color='pink'))
|
| 173 |
+
|
| 174 |
+
col = 'red' if j in skeleton.joints_right() else 'black'
|
| 175 |
+
for n, ax in enumerate(ax_3d):
|
| 176 |
+
pos = poses[n][i]
|
| 177 |
+
lines_3d[n].append(ax.plot([pos[j, 0], pos[j_parent, 0]],
|
| 178 |
+
[pos[j, 1], pos[j_parent, 1]],
|
| 179 |
+
[pos[j, 2], pos[j_parent, 2]], zdir='z', c=col))
|
| 180 |
+
|
| 181 |
+
points = ax_in.scatter(*keypoints[i].T, 10, color=colors_2d, edgecolors='white', zorder=10)
|
| 182 |
+
|
| 183 |
+
initialized = True
|
| 184 |
+
else:
|
| 185 |
+
image.set_data(all_frames[i])
|
| 186 |
+
|
| 187 |
+
for j, j_parent in enumerate(parents):
|
| 188 |
+
if j_parent == -1:
|
| 189 |
+
continue
|
| 190 |
+
|
| 191 |
+
if len(parents) == keypoints.shape[1] and keypoints_metadata['layout_name'] != 'coco':
|
| 192 |
+
lines[j - 1][0].set_data([keypoints[i, j, 0], keypoints[i, j_parent, 0]],
|
| 193 |
+
[keypoints[i, j, 1], keypoints[i, j_parent, 1]])
|
| 194 |
+
|
| 195 |
+
for n, ax in enumerate(ax_3d):
|
| 196 |
+
pos = poses[n][i]
|
| 197 |
+
lines_3d[n][j - 1][0].set_xdata(np.array([pos[j, 0], pos[j_parent, 0]]))
|
| 198 |
+
lines_3d[n][j - 1][0].set_ydata(np.array([pos[j, 1], pos[j_parent, 1]]))
|
| 199 |
+
lines_3d[n][j - 1][0].set_3d_properties(np.array([pos[j, 2], pos[j_parent, 2]]), zdir='z')
|
| 200 |
+
|
| 201 |
+
points.set_offsets(keypoints[i])
|
| 202 |
+
|
| 203 |
+
print('{}/{} '.format(i, limit), end='\r')
|
| 204 |
+
|
| 205 |
+
fig.tight_layout()
|
| 206 |
+
|
| 207 |
+
anim = FuncAnimation(fig, update_video, frames=np.arange(0, limit), interval=1000 / fps, repeat=False)
|
| 208 |
+
if output.endswith('.mp4'):
|
| 209 |
+
Writer = writers['ffmpeg']
|
| 210 |
+
writer = Writer(fps=fps, metadata={}, bitrate=bitrate)
|
| 211 |
+
anim.save(output, writer=writer)
|
| 212 |
+
elif output.endswith('.gif'):
|
| 213 |
+
anim.save(output, dpi=80, writer='imagemagick')
|
| 214 |
+
else:
|
| 215 |
+
raise ValueError('Unsupported output format (only .mp4 and .gif are supported)')
|
| 216 |
+
plt.close()
|