rlogh commited on
Commit
73dfb75
·
verified ·
1 Parent(s): c45cbe3

Upload 17 files

Browse files
common/README.MD ADDED
@@ -0,0 +1 @@
 
 
1
+
common/__pycache__/camera.cpython-311.pyc ADDED
Binary file (5.11 kB). View file
 
common/__pycache__/model_poseformer.cpython-311.pyc ADDED
Binary file (17.8 kB). View file
 
common/__pycache__/quaternion.cpython-311.pyc ADDED
Binary file (1.69 kB). View file
 
common/__pycache__/utils.cpython-311.pyc ADDED
Binary file (3.78 kB). View file
 
common/arguments.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Modified by Qitao Zhao (qitaozhao@mail.sdu.edu.cn)
7
+
8
+ import argparse
9
+
10
+ def parse_args():
11
+ parser = argparse.ArgumentParser(description='Training script')
12
+
13
+ # General arguments
14
+ parser.add_argument('-d', '--dataset', default='h36m', type=str, metavar='NAME', help='target dataset') # h36m or humaneva
15
+ parser.add_argument('-k', '--keypoints', default='cpn_ft_h36m_dbb', type=str, metavar='NAME', help='2D detections to use')
16
+ parser.add_argument('-str', '--subjects-train', default='S1,S5,S6,S7,S8', type=str, metavar='LIST',
17
+ help='training subjects separated by comma')
18
+ parser.add_argument('-ste', '--subjects-test', default='S9,S11', type=str, metavar='LIST', help='test subjects separated by comma')
19
+ parser.add_argument('-sun', '--subjects-unlabeled', default='', type=str, metavar='LIST',
20
+ help='unlabeled subjects separated by comma for self-supervision')
21
+ parser.add_argument('-a', '--actions', default='*', type=str, metavar='LIST',
22
+ help='actions to train/test on, separated by comma, or * for all')
23
+ parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH',
24
+ help='checkpoint directory')
25
+ parser.add_argument('--checkpoint-frequency', default=40, type=int, metavar='N',
26
+ help='create a checkpoint every N epochs')
27
+ parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME',
28
+ help='checkpoint to resume (file name)')
29
+ parser.add_argument('--evaluate', default='', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)')
30
+ parser.add_argument('--render', action='store_true', help='visualize a particular video')
31
+ parser.add_argument('--by-subject', action='store_true', help='break down error by subject (on evaluation)')
32
+ parser.add_argument('--export-training-curves', action='store_true', help='save training curves as .png images')
33
+ parser.add_argument('-g', '--gpu', type=list, help='set gpu number')
34
+ parser.add_argument('--local_rank', type=int, default=0, help='node rank for distributed training')
35
+ parser.add_argument('--center-pose', type=int, default=0, help='choose fine-tuning task as 3d pose estimation')
36
+
37
+ # Model arguments
38
+ parser.add_argument('-s', '--stride', default=1, type=int, metavar='N', help='chunk size to use during training')
39
+ parser.add_argument('-e', '--epochs', default=200, type=int, metavar='N', help='number of training epochs')
40
+ parser.add_argument('-b', '--batch-size', default=1024, type=int, metavar='N', help='batch size in terms of predicted frames')
41
+ parser.add_argument('-drop', '--dropout', default=0., type=float, metavar='P', help='dropout probability')
42
+ parser.add_argument('-lr', '--learning-rate', default=0.0001, type=float, metavar='LR', help='initial learning rate')
43
+ parser.add_argument('-lrd', '--lr-decay', default=0.99, type=float, metavar='LR', help='learning rate decay per epoch')
44
+ parser.add_argument('-no-da', '--no-data-augmentation', dest='data_augmentation', action='store_false',
45
+ help='disable train-time flipping')
46
+ parser.add_argument('-frame', '--number-of-frames', default='81', type=int, metavar='N',
47
+ help='how many frames used as input')
48
+ parser.add_argument('-frame-kept', '--number-of-kept-frames', default='27', type=int, metavar='N',
49
+ help='how many frames are kept')
50
+ parser.add_argument('-coeff-kept', '--number-of-kept-coeffs', type=int, metavar='N', help='how many coefficients are kept')
51
+ parser.add_argument('--depth', default=4, type=int, metavar='N', help='number of transformer blocks')
52
+ parser.add_argument('--embed-dim-ratio', default=32, type=int, metavar='N', help='dimension of embedding ratio')
53
+ parser.add_argument('-std', type=float, default=0.0, help='the standard deviation for gaussian noise')
54
+
55
+ # Experimental
56
+ parser.add_argument('--subset', default=1, type=float, metavar='FRACTION', help='reduce dataset size by fraction')
57
+ parser.add_argument('--downsample', default=1, type=int, metavar='FACTOR', help='downsample frame rate by factor (semi-supervised)')
58
+ parser.add_argument('--warmup', default=1, type=int, metavar='N', help='warm-up epochs for semi-supervision')
59
+ parser.add_argument('--no-eval', action='store_true', help='disable epoch evaluation while training (small speed-up)')
60
+ parser.add_argument('--dense', action='store_true', help='use dense convolutions instead of dilated convolutions')
61
+ parser.add_argument('--disable-optimizations', action='store_true', help='disable optimized model for single-frame predictions')
62
+ parser.add_argument('--linear-projection', action='store_true', help='use only linear coefficients for semi-supervised projection')
63
+ parser.add_argument('--no-bone-length', action='store_false', dest='bone_length_term',
64
+ help='disable bone length term in semi-supervised settings')
65
+ parser.add_argument('--no-proj', action='store_true', help='disable projection for semi-supervised setting')
66
+
67
+ # Visualization
68
+ parser.add_argument('--viz-subject', type=str, metavar='STR', help='subject to render')
69
+ parser.add_argument('--viz-action', type=str, metavar='STR', help='action to render')
70
+ parser.add_argument('--viz-camera', type=int, default=0, metavar='N', help='camera to render')
71
+ parser.add_argument('--viz-video', type=str, metavar='PATH', help='path to input video')
72
+ parser.add_argument('--viz-skip', type=int, default=0, metavar='N', help='skip first N frames of input video')
73
+ parser.add_argument('--viz-output', type=str, metavar='PATH', help='output file name (.gif or .mp4)')
74
+ parser.add_argument('--viz-export', type=str, metavar='PATH', help='output file name for coordinates')
75
+ parser.add_argument('--viz-bitrate', type=int, default=3000, metavar='N', help='bitrate for mp4 videos')
76
+ parser.add_argument('--viz-no-ground-truth', action='store_true', help='do not show ground-truth poses')
77
+ parser.add_argument('--viz-limit', type=int, default=-1, metavar='N', help='only render first N frames')
78
+ parser.add_argument('--viz-downsample', type=int, default=1, metavar='N', help='downsample FPS by a factor N')
79
+ parser.add_argument('--viz-size', type=int, default=5, metavar='N', help='image size')
80
+
81
+ parser.set_defaults(bone_length_term=True)
82
+ parser.set_defaults(data_augmentation=True)
83
+ parser.set_defaults(test_time_augmentation=True)
84
+ # parser.set_defaults(test_time_augmentation=False)
85
+
86
+ args = parser.parse_args()
87
+ # Check invalid configuration
88
+ if args.resume and args.evaluate:
89
+ print('Invalid flags: --resume and --evaluate cannot be set at the same time')
90
+ exit()
91
+
92
+ if args.export_training_curves and args.no_eval:
93
+ print('Invalid flags: --export-training-curves and --no-eval cannot be set at the same time')
94
+ exit()
95
+
96
+ return args
common/camera.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from common.utils import wrap
12
+ from common.quaternion import qrot, qinverse
13
+
14
+ def normalize_screen_coordinates(X, w, h):
15
+ assert X.shape[-1] == 2
16
+
17
+ # Normalize so that [0, w] is mapped to [-1, 1], while preserving the aspect ratio
18
+ return X/w*2 - [1, h/w]
19
+
20
+
21
+ def image_coordinates(X, w, h):
22
+ assert X.shape[-1] == 2
23
+
24
+ # Reverse camera frame normalization
25
+ return (X + [1, h/w])*w/2
26
+
27
+
28
+ def world_to_camera(X, R, t):
29
+ Rt = wrap(qinverse, R) # Invert rotation
30
+ return wrap(qrot, np.tile(Rt, (*X.shape[:-1], 1)), X - t) # Rotate and translate
31
+
32
+
33
+ def camera_to_world(X, R, t):
34
+ return wrap(qrot, np.tile(R, (*X.shape[:-1], 1)), X) + t
35
+
36
+
37
+ def project_to_2d(X, camera_params):
38
+ """
39
+ Project 3D points to 2D using the Human3.6M camera projection function.
40
+ This is a differentiable and batched reimplementation of the original MATLAB script.
41
+
42
+ Arguments:
43
+ X -- 3D points in *camera space* to transform (N, *, 3)
44
+ camera_params -- intrinsic parameteres (N, 2+2+3+2=9)
45
+ """
46
+ assert X.shape[-1] == 3
47
+ assert len(camera_params.shape) == 2
48
+ assert camera_params.shape[-1] == 9
49
+ assert X.shape[0] == camera_params.shape[0]
50
+
51
+ while len(camera_params.shape) < len(X.shape):
52
+ camera_params = camera_params.unsqueeze(1)
53
+
54
+ f = camera_params[..., :2]
55
+ c = camera_params[..., 2:4]
56
+ k = camera_params[..., 4:7]
57
+ p = camera_params[..., 7:]
58
+
59
+ XX = torch.clamp(X[..., :2] / X[..., 2:], min=-1, max=1)
60
+ r2 = torch.sum(XX[..., :2]**2, dim=len(XX.shape)-1, keepdim=True)
61
+
62
+ radial = 1 + torch.sum(k * torch.cat((r2, r2**2, r2**3), dim=len(r2.shape)-1), dim=len(r2.shape)-1, keepdim=True)
63
+ tan = torch.sum(p*XX, dim=len(XX.shape)-1, keepdim=True)
64
+
65
+ XXX = XX*(radial + tan) + p*r2
66
+
67
+ return f*XXX + c
68
+
69
+ def project_to_2d_linear(X, camera_params):
70
+ """
71
+ Project 3D points to 2D using only linear parameters (focal length and principal point).
72
+
73
+ Arguments:
74
+ X -- 3D points in *camera space* to transform (N, *, 3)
75
+ camera_params -- intrinsic parameteres (N, 2+2+3+2=9)
76
+ """
77
+ assert X.shape[-1] == 3
78
+ assert len(camera_params.shape) == 2
79
+ assert camera_params.shape[-1] == 9
80
+ assert X.shape[0] == camera_params.shape[0]
81
+
82
+ while len(camera_params.shape) < len(X.shape):
83
+ camera_params = camera_params.unsqueeze(1)
84
+
85
+ f = camera_params[..., :2]
86
+ c = camera_params[..., 2:4]
87
+
88
+ XX = torch.clamp(X[..., :2] / X[..., 2:], min=-1, max=1)
89
+
90
+ return f*XX + c
common/custom_dataset.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import numpy as np
9
+ import copy
10
+ from common.skeleton import Skeleton
11
+ from common.mocap_dataset import MocapDataset
12
+ from common.camera import normalize_screen_coordinates, image_coordinates
13
+ from common.h36m_dataset import h36m_skeleton
14
+
15
+
16
+ custom_camera_params = {
17
+ 'id': None,
18
+ 'res_w': None, # Pulled from metadata
19
+ 'res_h': None, # Pulled from metadata
20
+
21
+ # Dummy camera parameters (taken from Human3.6M), only for visualization purposes
22
+ 'azimuth': 70, # Only used for visualization
23
+ 'orientation': [0.1407056450843811, -0.1500701755285263, -0.755240797996521, 0.6223280429840088],
24
+ 'translation': [1841.1070556640625, 4955.28466796875, 1563.4454345703125],
25
+ }
26
+
27
+ class CustomDataset(MocapDataset):
28
+ def __init__(self, detections_path, remove_static_joints=True):
29
+ super().__init__(fps=None, skeleton=h36m_skeleton)
30
+
31
+ # Load serialized dataset
32
+ data = np.load(detections_path, allow_pickle=True)
33
+ resolutions = data['metadata'].item()['video_metadata']
34
+
35
+ self._cameras = {}
36
+ self._data = {}
37
+ for video_name, res in resolutions.items():
38
+ cam = {}
39
+ cam.update(custom_camera_params)
40
+ cam['orientation'] = np.array(cam['orientation'], dtype='float32')
41
+ cam['translation'] = np.array(cam['translation'], dtype='float32')
42
+ cam['translation'] = cam['translation']/1000 # mm to meters
43
+
44
+ cam['id'] = video_name
45
+ cam['res_w'] = res['w']
46
+ cam['res_h'] = res['h']
47
+
48
+ self._cameras[video_name] = [cam]
49
+
50
+ self._data[video_name] = {
51
+ 'custom': {
52
+ 'cameras': cam
53
+ }
54
+ }
55
+
56
+ if remove_static_joints:
57
+ # Bring the skeleton to 17 joints instead of the original 32
58
+ self.remove_joints([4, 5, 9, 10, 11, 16, 20, 21, 22, 23, 24, 28, 29, 30, 31])
59
+
60
+ # Rewire shoulders to the correct parents
61
+ self._skeleton._parents[11] = 8
62
+ self._skeleton._parents[14] = 8
63
+
64
+ def supports_semi_supervised(self):
65
+ return False
66
+
common/generators.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ from itertools import zip_longest
9
+ import numpy as np
10
+
11
+
12
+ # def getbone(seq, boneindex):
13
+ # bs = np.shape(seq)[0]
14
+ # ss = np.shape(seq)[1]
15
+ # seq = np.reshape(seq,(bs*ss,-1,3))
16
+ # bone = []
17
+ # for index in boneindex:
18
+ # bone.append(seq[:,index[0]] - seq[:,index[1]])
19
+ # bone = np.stack(bone,1)
20
+ # bone = np.power(np.power(bone,2).sum(2),0.5)
21
+ # bone = np.reshape(bone, (bs,ss,np.shape(bone)[1]))
22
+ # return bone
23
+
24
+ class ChunkedGenerator:
25
+ """
26
+ Batched data generator, used for training.
27
+ The sequences are split into equal-length chunks and padded as necessary.
28
+
29
+ Arguments:
30
+ batch_size -- the batch size to use for training
31
+ cameras -- list of cameras, one element for each video (optional, used for semi-supervised training)
32
+ poses_3d -- list of ground-truth 3D poses, one element for each video (optional, used for supervised training)
33
+ poses_2d -- list of input 2D keypoints, one element for each video
34
+ chunk_length -- number of output frames to predict for each training example (usually 1)
35
+ pad -- 2D input padding to compensate for valid convolutions, per side (depends on the receptive field)
36
+ causal_shift -- asymmetric padding offset when causal convolutions are used (usually 0 or "pad")
37
+ shuffle -- randomly shuffle the dataset before each epoch
38
+ random_seed -- initial seed to use for the random generator
39
+ augment -- augment the dataset by flipping poses horizontally
40
+ kps_left and kps_right -- list of left/right 2D keypoints if flipping is enabled
41
+ joints_left and joints_right -- list of left/right 3D joints if flipping is enabled
42
+ """
43
+ def __init__(self, batch_size, cameras, poses_3d, poses_2d,
44
+ chunk_length, pad=0, causal_shift=0,
45
+ shuffle=True, random_seed=1234,
46
+ augment=False, kps_left=None, kps_right=None, joints_left=None, joints_right=None,
47
+ endless=False):
48
+ assert poses_3d is None or len(poses_3d) == len(poses_2d), (len(poses_3d), len(poses_2d))
49
+ assert cameras is None or len(cameras) == len(poses_2d)
50
+
51
+ # Build lineage info
52
+ pairs = [] # (seq_idx, start_frame, end_frame, flip) tuples
53
+ for i in range(len(poses_2d)):
54
+ assert poses_3d is None or poses_3d[i].shape[0] == poses_3d[i].shape[0]
55
+ n_chunks = (poses_2d[i].shape[0] + chunk_length - 1) // chunk_length
56
+ offset = (n_chunks * chunk_length - poses_2d[i].shape[0]) // 2
57
+ bounds = np.arange(n_chunks+1)*chunk_length - offset
58
+ augment_vector = np.full(len(bounds - 1), False, dtype=bool)
59
+ pairs += zip(np.repeat(i, len(bounds - 1)), bounds[:-1], bounds[1:], augment_vector)
60
+ if augment:
61
+ pairs += zip(np.repeat(i, len(bounds - 1)), bounds[:-1], bounds[1:], ~augment_vector)
62
+
63
+ # Initialize buffers
64
+ if cameras is not None:
65
+ self.batch_cam = np.empty((batch_size, cameras[0].shape[-1]))
66
+ if poses_3d is not None:
67
+ self.batch_3d = np.empty((batch_size, chunk_length, poses_3d[0].shape[-2], poses_3d[0].shape[-1]))
68
+ # self.batch_3d = np.empty((batch_size, chunk_length + 2*pad, poses_3d[0].shape[-2], poses_3d[0].shape[-1]))
69
+ self.batch_2d = np.empty((batch_size, chunk_length + 2*pad, poses_2d[0].shape[-2], poses_2d[0].shape[-1]))
70
+
71
+ self.num_batches = (len(pairs) + batch_size - 1) // batch_size
72
+ self.batch_size = batch_size
73
+ self.random = np.random.RandomState(random_seed)
74
+ self.pairs = pairs
75
+ self.shuffle = shuffle
76
+ self.pad = pad
77
+ self.causal_shift = causal_shift
78
+ self.endless = endless
79
+ self.state = None
80
+
81
+ self.cameras = cameras
82
+ self.poses_3d = poses_3d
83
+ self.poses_2d = poses_2d
84
+
85
+ self.augment = augment
86
+ self.kps_left = kps_left
87
+ self.kps_right = kps_right
88
+ self.joints_left = joints_left
89
+ self.joints_right = joints_right
90
+
91
+ def num_frames(self):
92
+ return self.num_batches * self.batch_size
93
+
94
+ def random_state(self):
95
+ return self.random
96
+
97
+ def set_random_state(self, random):
98
+ self.random = random
99
+
100
+ def augment_enabled(self):
101
+ return self.augment
102
+
103
+ def next_pairs(self):
104
+ if self.state is None:
105
+ if self.shuffle:
106
+ pairs = self.random.permutation(self.pairs)
107
+ else:
108
+ pairs = self.pairs
109
+ return 0, pairs
110
+ else:
111
+ return self.state
112
+
113
+ def next_epoch(self):
114
+ enabled = True
115
+ while enabled:
116
+ start_idx, pairs = self.next_pairs()
117
+ for b_i in range(start_idx, self.num_batches):
118
+ chunks = pairs[b_i*self.batch_size : (b_i+1)*self.batch_size]
119
+ for i, (seq_i, start_3d, end_3d, flip) in enumerate(chunks):
120
+ start_2d = start_3d - self.pad - self.causal_shift
121
+ end_2d = end_3d + self.pad - self.causal_shift
122
+
123
+ # 2D poses
124
+ seq_2d = self.poses_2d[seq_i]
125
+ low_2d = max(start_2d, 0)
126
+ high_2d = min(end_2d, seq_2d.shape[0])
127
+ pad_left_2d = low_2d - start_2d
128
+ pad_right_2d = end_2d - high_2d
129
+ if pad_left_2d != 0 or pad_right_2d != 0:
130
+ self.batch_2d[i] = np.pad(seq_2d[low_2d:high_2d], ((pad_left_2d, pad_right_2d), (0, 0), (0, 0)), 'edge')
131
+ else:
132
+ self.batch_2d[i] = seq_2d[low_2d:high_2d]
133
+
134
+ if flip:
135
+ # Flip 2D keypoints
136
+ self.batch_2d[i, :, :, 0] *= -1
137
+ self.batch_2d[i, :, self.kps_left + self.kps_right] = self.batch_2d[i, :, self.kps_right + self.kps_left]
138
+
139
+ # 3D poses
140
+ if self.poses_3d is not None:
141
+ seq_3d = self.poses_3d[seq_i]
142
+ low_3d = max(start_3d, 0)
143
+ high_3d = min(end_3d, seq_3d.shape[0])
144
+ pad_left_3d = low_3d - start_3d
145
+ pad_right_3d = end_3d - high_3d
146
+ if pad_left_3d != 0 or pad_right_3d != 0:
147
+ # if pad_left_2d != 0 or pad_right_2d != 0:
148
+ # self.batch_3d[i] = np.pad(seq_3d[low_2d:high_2d], ((pad_left_2d, pad_right_2d), (0, 0), (0, 0)), 'edge')
149
+ self.batch_3d[i] = np.pad(seq_3d[low_3d:high_3d], ((pad_left_3d, pad_right_3d), (0, 0), (0, 0)), 'edge')
150
+ else:
151
+ # self.batch_3d[i] = seq_3d[low_2d:high_2d]
152
+ self.batch_3d[i] = seq_3d[low_3d:high_3d]
153
+
154
+ if flip:
155
+ # Flip 3D joints
156
+ self.batch_3d[i, :, :, 0] *= -1
157
+ self.batch_3d[i, :, self.joints_left + self.joints_right] = \
158
+ self.batch_3d[i, :, self.joints_right + self.joints_left]
159
+
160
+ # Cameras
161
+ if self.cameras is not None:
162
+ self.batch_cam[i] = self.cameras[seq_i]
163
+ if flip:
164
+ # Flip horizontal distortion coefficients
165
+ self.batch_cam[i, 2] *= -1
166
+ self.batch_cam[i, 7] *= -1
167
+
168
+ if self.endless:
169
+ self.state = (b_i + 1, pairs)
170
+ if self.poses_3d is None and self.cameras is None:
171
+ yield None, None, self.batch_2d[:len(chunks)]
172
+ elif self.poses_3d is not None and self.cameras is None:
173
+ yield None, self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)]
174
+ # yield None, self.batch_bins_3d[:len(chunks)], self.batch_2d[:len(chunks)]
175
+ elif self.poses_3d is None:
176
+ yield self.batch_cam[:len(chunks)], None, self.batch_2d[:len(chunks)]
177
+ else:
178
+ yield self.batch_cam[:len(chunks)], self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)]
179
+ # yield self.batch_cam[:len(chunks)], self.batch_bins_3d[:len(chunks)], self.batch_2d[:len(chunks)]
180
+
181
+ if self.endless:
182
+ self.state = None
183
+ else:
184
+ enabled = False
185
+
186
+
187
+ class UnchunkedGenerator:
188
+ """
189
+ Non-batched data generator, used for testing.
190
+ Sequences are returned one at a time (i.e. batch size = 1), without chunking.
191
+
192
+ If data augmentation is enabled, the batches contain two sequences (i.e. batch size = 2),
193
+ the second of which is a mirrored version of the first.
194
+
195
+ Arguments:
196
+ cameras -- list of cameras, one element for each video (optional, used for semi-supervised training)
197
+ poses_3d -- list of ground-truth 3D poses, one element for each video (optional, used for supervised training)
198
+ poses_2d -- list of input 2D keypoints, one element for each video
199
+ pad -- 2D input padding to compensate for valid convolutions, per side (depends on the receptive field)
200
+ causal_shift -- asymmetric padding offset when causal convolutions are used (usually 0 or "pad")
201
+ augment -- augment the dataset by flipping poses horizontally
202
+ kps_left and kps_right -- list of left/right 2D keypoints if flipping is enabled
203
+ joints_left and joints_right -- list of left/right 3D joints if flipping is enabled
204
+ """
205
+
206
+ def __init__(self, cameras, poses_3d, poses_2d, pad=0, causal_shift=0,
207
+ augment=False, kps_left=None, kps_right=None, joints_left=None, joints_right=None):
208
+ assert poses_3d is None or len(poses_3d) == len(poses_2d)
209
+ assert cameras is None or len(cameras) == len(poses_2d)
210
+
211
+ self.augment = False
212
+ self.kps_left = kps_left
213
+ self.kps_right = kps_right
214
+ self.joints_left = joints_left
215
+ self.joints_right = joints_right
216
+
217
+ self.pad = pad
218
+ self.causal_shift = causal_shift
219
+ self.cameras = [] if cameras is None else cameras
220
+ self.poses_3d = [] if poses_3d is None else poses_3d
221
+ self.poses_2d = poses_2d
222
+
223
+ def num_frames(self):
224
+ count = 0
225
+ for p in self.poses_2d:
226
+ count += p.shape[0]
227
+ return count
228
+
229
+ def augment_enabled(self):
230
+ return self.augment
231
+
232
+ def set_augment(self, augment):
233
+ self.augment = augment
234
+
235
+ def next_epoch(self):
236
+ for seq_cam, seq_3d, seq_2d in zip_longest(self.cameras, self.poses_3d, self.poses_2d):
237
+ batch_cam = None if seq_cam is None else np.expand_dims(seq_cam, axis=0)
238
+ batch_3d = None if seq_3d is None else np.expand_dims(seq_3d, axis=0)
239
+ # batch_3d = np.expand_dims(np.pad(seq_3d,
240
+ # ((self.pad + self.causal_shift, self.pad - self.causal_shift), (0, 0), (0, 0)),
241
+ # 'edge'), axis=0)
242
+ batch_2d = np.expand_dims(np.pad(seq_2d,
243
+ ((self.pad + self.causal_shift, self.pad - self.causal_shift), (0, 0), (0, 0)),
244
+ 'edge'), axis=0)
245
+ if self.augment:
246
+ # Append flipped version
247
+ if batch_cam is not None:
248
+ batch_cam = np.concatenate((batch_cam, batch_cam), axis=0)
249
+ batch_cam[1, 2] *= -1
250
+ batch_cam[1, 7] *= -1
251
+
252
+ if batch_3d is not None:
253
+ batch_3d = np.concatenate((batch_3d, batch_3d), axis=0)
254
+ batch_3d[1, :, :, 0] *= -1
255
+ batch_3d[1, :, self.joints_left + self.joints_right] = batch_3d[1, :, self.joints_right + self.joints_left]
256
+
257
+ batch_2d = np.concatenate((batch_2d, batch_2d), axis=0)
258
+ batch_2d[1, :, :, 0] *= -1
259
+ batch_2d[1, :, self.kps_left + self.kps_right] = batch_2d[1, :, self.kps_right + self.kps_left]
260
+
261
+ yield batch_cam, batch_3d, batch_2d
common/h36m_dataset.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import numpy as np
9
+ import copy
10
+ from common.skeleton import Skeleton
11
+ from common.mocap_dataset import MocapDataset
12
+ from common.camera import normalize_screen_coordinates, image_coordinates
13
+
14
+ h36m_skeleton = Skeleton(parents=[-1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 9, 0, 11, 12, 13, 14, 12,
15
+ 16, 17, 18, 19, 20, 19, 22, 12, 24, 25, 26, 27, 28, 27, 30],
16
+ joints_left=[6, 7, 8, 9, 10, 16, 17, 18, 19, 20, 21, 22, 23],
17
+ joints_right=[1, 2, 3, 4, 5, 24, 25, 26, 27, 28, 29, 30, 31])
18
+
19
+ h36m_cameras_intrinsic_params = [
20
+ {
21
+ 'id': '54138969',
22
+ 'center': [512.54150390625, 515.4514770507812],
23
+ 'focal_length': [1145.0494384765625, 1143.7811279296875],
24
+ 'radial_distortion': [-0.20709891617298126, 0.24777518212795258, -0.0030751503072679043],
25
+ 'tangential_distortion': [-0.0009756988729350269, -0.00142447161488235],
26
+ 'res_w': 1000,
27
+ 'res_h': 1002,
28
+ 'azimuth': 70, # Only used for visualization
29
+ },
30
+ {
31
+ 'id': '55011271',
32
+ 'center': [508.8486328125, 508.0649108886719],
33
+ 'focal_length': [1149.6756591796875, 1147.5916748046875],
34
+ 'radial_distortion': [-0.1942136287689209, 0.2404085397720337, 0.006819975562393665],
35
+ 'tangential_distortion': [-0.0016190266469493508, -0.0027408944442868233],
36
+ 'res_w': 1000,
37
+ 'res_h': 1000,
38
+ 'azimuth': -70, # Only used for visualization
39
+ },
40
+ {
41
+ 'id': '58860488',
42
+ 'center': [519.8158569335938, 501.40264892578125],
43
+ 'focal_length': [1149.1407470703125, 1148.7989501953125],
44
+ 'radial_distortion': [-0.2083381861448288, 0.25548800826072693, -0.0024604974314570427],
45
+ 'tangential_distortion': [0.0014843869721516967, -0.0007599993259645998],
46
+ 'res_w': 1000,
47
+ 'res_h': 1000,
48
+ 'azimuth': 110, # Only used for visualization
49
+ },
50
+ {
51
+ 'id': '60457274',
52
+ 'center': [514.9682006835938, 501.88201904296875],
53
+ 'focal_length': [1145.5113525390625, 1144.77392578125],
54
+ 'radial_distortion': [-0.198384091258049, 0.21832367777824402, -0.008947807364165783],
55
+ 'tangential_distortion': [-0.0005872055771760643, -0.0018133620033040643],
56
+ 'res_w': 1000,
57
+ 'res_h': 1002,
58
+ 'azimuth': -110, # Only used for visualization
59
+ },
60
+ ]
61
+
62
+ h36m_cameras_extrinsic_params = {
63
+ 'S1': [
64
+ {
65
+ 'orientation': [0.1407056450843811, -0.1500701755285263, -0.755240797996521, 0.6223280429840088],
66
+ 'translation': [1841.1070556640625, 4955.28466796875, 1563.4454345703125],
67
+ },
68
+ {
69
+ 'orientation': [0.6157187819480896, -0.764836311340332, -0.14833825826644897, 0.11794740706682205],
70
+ 'translation': [1761.278564453125, -5078.0068359375, 1606.2650146484375],
71
+ },
72
+ {
73
+ 'orientation': [0.14651472866535187, -0.14647851884365082, 0.7653023600578308, -0.6094175577163696],
74
+ 'translation': [-1846.7777099609375, 5215.04638671875, 1491.972412109375],
75
+ },
76
+ {
77
+ 'orientation': [0.5834008455276489, -0.7853162288665771, 0.14548823237419128, -0.14749594032764435],
78
+ 'translation': [-1794.7896728515625, -3722.698974609375, 1574.8927001953125],
79
+ },
80
+ ],
81
+ 'S2': [
82
+ {},
83
+ {},
84
+ {},
85
+ {},
86
+ ],
87
+ 'S3': [
88
+ {},
89
+ {},
90
+ {},
91
+ {},
92
+ ],
93
+ 'S4': [
94
+ {},
95
+ {},
96
+ {},
97
+ {},
98
+ ],
99
+ 'S5': [
100
+ {
101
+ 'orientation': [0.1467377245426178, -0.162370964884758, -0.7551892995834351, 0.6178938746452332],
102
+ 'translation': [2097.3916015625, 4880.94482421875, 1605.732421875],
103
+ },
104
+ {
105
+ 'orientation': [0.6159758567810059, -0.7626792192459106, -0.15728192031383514, 0.1189815029501915],
106
+ 'translation': [2031.7008056640625, -5167.93310546875, 1612.923095703125],
107
+ },
108
+ {
109
+ 'orientation': [0.14291371405124664, -0.12907841801643372, 0.7678384780883789, -0.6110143065452576],
110
+ 'translation': [-1620.5948486328125, 5171.65869140625, 1496.43701171875],
111
+ },
112
+ {
113
+ 'orientation': [0.5920479893684387, -0.7814217805862427, 0.1274748593568802, -0.15036417543888092],
114
+ 'translation': [-1637.1737060546875, -3867.3173828125, 1547.033203125],
115
+ },
116
+ ],
117
+ 'S6': [
118
+ {
119
+ 'orientation': [0.1337897777557373, -0.15692396461963654, -0.7571090459823608, 0.6198879480361938],
120
+ 'translation': [1935.4517822265625, 4950.24560546875, 1618.0838623046875],
121
+ },
122
+ {
123
+ 'orientation': [0.6147197484970093, -0.7628812789916992, -0.16174767911434174, 0.11819244921207428],
124
+ 'translation': [1969.803955078125, -5128.73876953125, 1632.77880859375],
125
+ },
126
+ {
127
+ 'orientation': [0.1529948115348816, -0.13529130816459656, 0.7646096348762512, -0.6112781167030334],
128
+ 'translation': [-1769.596435546875, 5185.361328125, 1476.993408203125],
129
+ },
130
+ {
131
+ 'orientation': [0.5916101336479187, -0.7804774045944214, 0.12832270562648773, -0.1561593860387802],
132
+ 'translation': [-1721.668701171875, -3884.13134765625, 1540.4879150390625],
133
+ },
134
+ ],
135
+ 'S7': [
136
+ {
137
+ 'orientation': [0.1435241848230362, -0.1631336808204651, -0.7548328638076782, 0.6188824772834778],
138
+ 'translation': [1974.512939453125, 4926.3544921875, 1597.8326416015625],
139
+ },
140
+ {
141
+ 'orientation': [0.6141672730445862, -0.7638262510299683, -0.1596645563840866, 0.1177929937839508],
142
+ 'translation': [1937.0584716796875, -5119.7900390625, 1631.5665283203125],
143
+ },
144
+ {
145
+ 'orientation': [0.14550060033798218, -0.12874816358089447, 0.7660516500473022, -0.6127139329910278],
146
+ 'translation': [-1741.8111572265625, 5208.24951171875, 1464.8245849609375],
147
+ },
148
+ {
149
+ 'orientation': [0.5912848114967346, -0.7821764349937439, 0.12445473670959473, -0.15196487307548523],
150
+ 'translation': [-1734.7105712890625, -3832.42138671875, 1548.5830078125],
151
+ },
152
+ ],
153
+ 'S8': [
154
+ {
155
+ 'orientation': [0.14110587537288666, -0.15589867532253265, -0.7561917304992676, 0.619644045829773],
156
+ 'translation': [2150.65185546875, 4896.1611328125, 1611.9046630859375],
157
+ },
158
+ {
159
+ 'orientation': [0.6169601678848267, -0.7647668123245239, -0.14846350252628326, 0.11158157885074615],
160
+ 'translation': [2219.965576171875, -5148.453125, 1613.0440673828125],
161
+ },
162
+ {
163
+ 'orientation': [0.1471444070339203, -0.13377119600772858, 0.7670128345489502, -0.6100369691848755],
164
+ 'translation': [-1571.2215576171875, 5137.0185546875, 1498.1761474609375],
165
+ },
166
+ {
167
+ 'orientation': [0.5927824378013611, -0.7825870513916016, 0.12147816270589828, -0.14631995558738708],
168
+ 'translation': [-1476.913330078125, -3896.7412109375, 1547.97216796875],
169
+ },
170
+ ],
171
+ 'S9': [
172
+ {
173
+ 'orientation': [0.15540587902069092, -0.15548215806484222, -0.7532095313072205, 0.6199594736099243],
174
+ 'translation': [2044.45849609375, 4935.1171875, 1481.2275390625],
175
+ },
176
+ {
177
+ 'orientation': [0.618784487247467, -0.7634735107421875, -0.14132238924503326, 0.11933968216180801],
178
+ 'translation': [1990.959716796875, -5123.810546875, 1568.8048095703125],
179
+ },
180
+ {
181
+ 'orientation': [0.13357827067375183, -0.1367100477218628, 0.7689454555511475, -0.6100738644599915],
182
+ 'translation': [-1670.9921875, 5211.98583984375, 1528.387939453125],
183
+ },
184
+ {
185
+ 'orientation': [0.5879399180412292, -0.7823407053947449, 0.1427614390850067, -0.14794869720935822],
186
+ 'translation': [-1696.04345703125, -3827.099853515625, 1591.4127197265625],
187
+ },
188
+ ],
189
+ 'S11': [
190
+ {
191
+ 'orientation': [0.15232472121715546, -0.15442320704460144, -0.7547563314437866, 0.6191070079803467],
192
+ 'translation': [2098.440185546875, 4926.5546875, 1500.278564453125],
193
+ },
194
+ {
195
+ 'orientation': [0.6189449429512024, -0.7600917220115662, -0.15300633013248444, 0.1255258321762085],
196
+ 'translation': [2083.182373046875, -4912.1728515625, 1561.07861328125],
197
+ },
198
+ {
199
+ 'orientation': [0.14943228662014008, -0.15650227665901184, 0.7681233882904053, -0.6026304364204407],
200
+ 'translation': [-1609.8153076171875, 5177.3359375, 1537.896728515625],
201
+ },
202
+ {
203
+ 'orientation': [0.5894251465797424, -0.7818877100944519, 0.13991211354732513, -0.14715361595153809],
204
+ 'translation': [-1590.738037109375, -3854.1689453125, 1578.017578125],
205
+ },
206
+ ],
207
+ }
208
+
209
+ class Human36mDataset(MocapDataset):
210
+ def __init__(self, path, remove_static_joints=True):
211
+ super().__init__(fps=50, skeleton=h36m_skeleton)
212
+
213
+ self._cameras = copy.deepcopy(h36m_cameras_extrinsic_params)
214
+ for cameras in self._cameras.values():
215
+ for i, cam in enumerate(cameras):
216
+ cam.update(h36m_cameras_intrinsic_params[i])
217
+ for k, v in cam.items():
218
+ if k not in ['id', 'res_w', 'res_h']:
219
+ cam[k] = np.array(v, dtype='float32')
220
+
221
+ # Normalize camera frame
222
+ cam['center'] = normalize_screen_coordinates(cam['center'], w=cam['res_w'], h=cam['res_h']).astype('float32')
223
+ cam['focal_length'] = cam['focal_length']/cam['res_w']*2
224
+ if 'translation' in cam:
225
+ cam['translation'] = cam['translation']/1000 # mm to meters
226
+
227
+ # Add intrinsic parameters vector
228
+ cam['intrinsic'] = np.concatenate((cam['focal_length'],
229
+ cam['center'],
230
+ cam['radial_distortion'],
231
+ cam['tangential_distortion'],
232
+ [1/cam['focal_length'][0], 0, -cam['center'][0]/cam['focal_length'][0],
233
+ 0, 1/cam['focal_length'][1], -cam['center'][1]/cam['focal_length'][1],
234
+ 0, 0, 1]))
235
+
236
+ # proj_matrix = np.array([1/cam['focal_length'][0], 0, -cam['center'][0]/cam['focal_length'][0],
237
+ # 0, 1/cam['focal_length'][1], -cam['center'][1]/cam['focal_length'][1],
238
+ # 0, 0, 1])
239
+ # cam['intrinsic'] = np.concatenate(camera_intrinsics, proj_matrix)
240
+
241
+ # Load serialized dataset
242
+ data = np.load(path, allow_pickle=True)['positions_3d'].item()
243
+
244
+ self._data = {}
245
+ for subject, actions in data.items():
246
+ self._data[subject] = {}
247
+ for action_name, positions in actions.items():
248
+ self._data[subject][action_name] = {
249
+ 'positions': positions,
250
+ 'cameras': self._cameras[subject],
251
+ }
252
+
253
+ if remove_static_joints:
254
+ # Bring the skeleton to 17 joints instead of the original 32
255
+ self.remove_joints([4, 5, 9, 10, 11, 16, 20, 21, 22, 23, 24, 28, 29, 30, 31])
256
+
257
+ # Rewire shoulders to the correct parents
258
+ self._skeleton._parents[11] = 8
259
+ self._skeleton._parents[14] = 8
260
+
261
+ def supports_semi_supervised(self):
262
+ return True
263
+
common/loss.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import numpy as np
11
+ import matplotlib.pyplot as plt
12
+
13
+
14
+ def mpjpe(predicted, target):
15
+ """
16
+ Mean per-joint position error (i.e. mean Euclidean distance),
17
+ often referred to as "Protocol #1" in many papers.
18
+ """
19
+ assert predicted.shape == target.shape
20
+ return torch.mean(torch.norm(predicted - target, dim=len(target.shape)-1))
21
+
22
+ def mse(predicted, target, weights=None, gamma=0):
23
+ loss = nn.MSELoss()
24
+ return loss(predicted, target)
25
+
26
+ def cross_entropy(predicted, target, weights=None, gamma=0, return_weights=False):
27
+ loss = nn.CrossEntropyLoss()
28
+ return loss(predicted.permute(0, 4, 1, 2, 3), target)
29
+
30
+ def weighted_mpjpe(predicted, target, w):
31
+ """
32
+ Weighted mean per-joint position error (i.e. mean Euclidean distance)
33
+ """
34
+ assert predicted.shape == target.shape
35
+ assert w.shape[0] == predicted.shape[0]
36
+ return torch.mean(w * torch.norm(predicted - target, dim=len(target.shape)-1))
37
+
38
+ def p_mpjpe(predicted, target):
39
+ """
40
+ Pose error: MPJPE after rigid alignment (scale, rotation, and translation),
41
+ often referred to as "Protocol #2" in many papers.
42
+ """
43
+ assert predicted.shape == target.shape
44
+
45
+ muX = np.mean(target, axis=1, keepdims=True)
46
+ muY = np.mean(predicted, axis=1, keepdims=True)
47
+
48
+ X0 = target - muX
49
+ Y0 = predicted - muY
50
+
51
+ normX = np.sqrt(np.sum(X0**2, axis=(1, 2), keepdims=True))
52
+ normY = np.sqrt(np.sum(Y0**2, axis=(1, 2), keepdims=True))
53
+
54
+ X0 /= normX
55
+ Y0 /= normY
56
+
57
+ H = np.matmul(X0.transpose(0, 2, 1), Y0)
58
+ U, s, Vt = np.linalg.svd(H)
59
+ V = Vt.transpose(0, 2, 1)
60
+ R = np.matmul(V, U.transpose(0, 2, 1))
61
+
62
+ # Avoid improper rotations (reflections), i.e. rotations with det(R) = -1
63
+ sign_detR = np.sign(np.expand_dims(np.linalg.det(R), axis=1))
64
+ V[:, :, -1] *= sign_detR
65
+ s[:, -1] *= sign_detR.flatten()
66
+ R = np.matmul(V, U.transpose(0, 2, 1)) # Rotation
67
+
68
+ tr = np.expand_dims(np.sum(s, axis=1, keepdims=True), axis=2)
69
+
70
+ a = tr * normX / normY # Scale
71
+ t = muX - a*np.matmul(muY, R) # Translation
72
+
73
+ # Perform rigid transformation on the input
74
+ predicted_aligned = a*np.matmul(predicted, R) + t
75
+
76
+ # Return MPJPE
77
+ return np.mean(np.linalg.norm(predicted_aligned - target, axis=len(target.shape)-1))
78
+
79
+ def n_mpjpe(predicted, target):
80
+ """
81
+ Normalized MPJPE (scale only), adapted from:
82
+ https://github.com/hrhodin/UnsupervisedGeometryAwareRepresentationLearning/blob/master/losses/poses.py
83
+ """
84
+ assert predicted.shape == target.shape
85
+
86
+ norm_predicted = torch.mean(torch.sum(predicted**2, dim=3, keepdim=True), dim=2, keepdim=True)
87
+ norm_target = torch.mean(torch.sum(target*predicted, dim=3, keepdim=True), dim=2, keepdim=True)
88
+ scale = norm_target / norm_predicted
89
+ return mpjpe(scale * predicted, target)#[0]
90
+
91
+ def weighted_bonelen_loss(predict_3d_length, gt_3d_length):
92
+ loss_length = 0.001 * torch.pow(predict_3d_length - gt_3d_length, 2).mean()
93
+ return loss_length
94
+
95
+ def weighted_boneratio_loss(predict_3d_length, gt_3d_length):
96
+ loss_length = 0.1 * torch.pow((predict_3d_length - gt_3d_length)/gt_3d_length, 2).mean()
97
+ return loss_length
98
+
99
+ def mean_velocity_error(predicted, target):
100
+ """
101
+ Mean per-joint velocity error (i.e. mean Euclidean distance of the 1st derivative)
102
+ """
103
+ assert predicted.shape == target.shape
104
+
105
+ velocity_predicted = np.diff(predicted, axis=0)
106
+ velocity_target = np.diff(target, axis=0)
107
+
108
+ return np.mean(np.linalg.norm(velocity_predicted - velocity_target, axis=len(target.shape)-1))
common/mocap_dataset.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import numpy as np
9
+ from common.skeleton import Skeleton
10
+
11
+ class MocapDataset:
12
+ def __init__(self, fps, skeleton):
13
+ self._skeleton = skeleton
14
+ self._fps = fps
15
+ self._data = None # Must be filled by subclass
16
+ self._cameras = None # Must be filled by subclass
17
+
18
+ def remove_joints(self, joints_to_remove):
19
+ kept_joints = self._skeleton.remove_joints(joints_to_remove)
20
+ for subject in self._data.keys():
21
+ for action in self._data[subject].keys():
22
+ s = self._data[subject][action]
23
+ if 'positions' in s:
24
+ s['positions'] = s['positions'][:, kept_joints]
25
+
26
+
27
+ def __getitem__(self, key):
28
+ return self._data[key]
29
+
30
+ def subjects(self):
31
+ return self._data.keys()
32
+
33
+ def fps(self):
34
+ return self._fps
35
+
36
+ def skeleton(self):
37
+ return self._skeleton
38
+
39
+ def cameras(self):
40
+ return self._cameras
41
+
42
+ def supports_semi_supervised(self):
43
+ # This method can be overridden
44
+ return False
common/model_poseformer.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Our PoseFormer model was revised from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
2
+ # Written by Ce Zheng (cezheng@knights.ucf.edu)
3
+ # Modified by Qitao Zhao (qitaozhao@mail.sdu.edu.cn)
4
+
5
+ import math
6
+ import logging
7
+ from functools import partial
8
+ from einops import rearrange
9
+
10
+ import torch
11
+ import torch_dct as dct
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ import numpy as np
16
+ from timm.models.layers import DropPath
17
+
18
+
19
+ class Mlp(nn.Module):
20
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
21
+ super().__init__()
22
+ out_features = out_features or in_features
23
+ hidden_features = hidden_features or in_features
24
+ self.fc1 = nn.Linear(in_features, hidden_features)
25
+ self.act = act_layer()
26
+ self.fc2 = nn.Linear(hidden_features, out_features)
27
+ self.drop = nn.Dropout(drop)
28
+
29
+ def forward(self, x):
30
+ x = self.fc1(x)
31
+ x = self.act(x)
32
+ x = self.drop(x)
33
+ x = self.fc2(x)
34
+ x = self.drop(x)
35
+ return x
36
+
37
+
38
+ class FreqMlp(nn.Module):
39
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
40
+ super().__init__()
41
+ out_features = out_features or in_features
42
+ hidden_features = hidden_features or in_features
43
+ self.fc1 = nn.Linear(in_features, hidden_features)
44
+ self.act = act_layer()
45
+ self.fc2 = nn.Linear(hidden_features, out_features)
46
+ self.drop = nn.Dropout(drop)
47
+
48
+ def forward(self, x):
49
+ b, f, _ = x.shape
50
+ x = dct.dct(x.permute(0, 2, 1)).permute(0, 2, 1).contiguous()
51
+ x = self.fc1(x)
52
+ x = self.act(x)
53
+ x = self.drop(x)
54
+ x = self.fc2(x)
55
+ x = self.drop(x)
56
+ x = dct.idct(x.permute(0, 2, 1)).permute(0, 2, 1).contiguous()
57
+ return x
58
+
59
+
60
+ class Attention(nn.Module):
61
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
62
+ super().__init__()
63
+ self.num_heads = num_heads
64
+ head_dim = dim // num_heads
65
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
66
+ self.scale = qk_scale or head_dim ** -0.5
67
+
68
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
69
+ self.attn_drop = nn.Dropout(attn_drop)
70
+ self.proj = nn.Linear(dim, dim)
71
+ self.proj_drop = nn.Dropout(proj_drop)
72
+
73
+ def forward(self, x):
74
+ B, N, C = x.shape
75
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
76
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
77
+
78
+ attn = (q @ k.transpose(-2, -1)) * self.scale
79
+ attn = attn.softmax(dim=-1)
80
+ attn = self.attn_drop(attn)
81
+
82
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
83
+ x = self.proj(x)
84
+ x = self.proj_drop(x)
85
+ return x
86
+
87
+
88
+ class Block(nn.Module):
89
+
90
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
91
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
92
+ super().__init__()
93
+ self.norm1 = norm_layer(dim)
94
+ self.attn = Attention(
95
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
96
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
97
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
98
+ self.norm2 = norm_layer(dim)
99
+ mlp_hidden_dim = int(dim * mlp_ratio)
100
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
101
+
102
+ def forward(self, x):
103
+ x = x + self.drop_path(self.attn(self.norm1(x)))
104
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
105
+ return x
106
+
107
+
108
+ class MixedBlock(nn.Module):
109
+
110
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
111
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
112
+ super().__init__()
113
+ self.norm1 = norm_layer(dim)
114
+ self.attn = Attention(
115
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
116
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
117
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
118
+ self.norm2 = norm_layer(dim)
119
+ mlp_hidden_dim = int(dim * mlp_ratio)
120
+ self.mlp1 = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
121
+ self.norm3 = norm_layer(dim)
122
+ self.mlp2 = FreqMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
123
+
124
+ def forward(self, x):
125
+ b, f, c = x.shape
126
+ x = x + self.drop_path(self.attn(self.norm1(x)))
127
+ x1 = x[:, :f//2] + self.drop_path(self.mlp1(self.norm2(x[:, :f//2])))
128
+ x2 = x[:, f//2:] + self.drop_path(self.mlp2(self.norm3(x[:, f//2:])))
129
+ return torch.cat((x1, x2), dim=1)
130
+
131
+
132
+ class PoseTransformerV2(nn.Module):
133
+ def __init__(self, num_frame=9, num_joints=17, in_chans=2,
134
+ num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,
135
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2, norm_layer=None, args=None):
136
+ """ ##########hybrid_backbone=None, representation_size=None,
137
+ Args:
138
+ num_frame (int, tuple): input frame number
139
+ num_joints (int, tuple): joints number
140
+ in_chans (int): number of input channels, 2D joints have 2 channels: (x,y)
141
+ embed_dim_ratio (int): embedding dimension ratio
142
+ depth (int): depth of transformer
143
+ num_heads (int): number of attention heads
144
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
145
+ qkv_bias (bool): enable bias for qkv if True
146
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
147
+ drop_rate (float): dropout rate
148
+ attn_drop_rate (float): attention dropout rate
149
+ drop_path_rate (float): stochastic depth rate
150
+ norm_layer: (nn.Module): normalization layer
151
+ """
152
+ super().__init__()
153
+
154
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
155
+ embed_dim_ratio = args.embed_dim_ratio
156
+ depth = args.depth
157
+ embed_dim = embed_dim_ratio * num_joints #### temporal embed_dim is num_joints * spatial embedding dim ratio
158
+ out_dim = num_joints * 3 #### output dimension is num_joints * 3
159
+ self.num_frame_kept = args.number_of_kept_frames
160
+ self.num_coeff_kept = args.number_of_kept_coeffs if args.number_of_kept_coeffs else self.num_frame_kept
161
+
162
+ ### spatial patch embedding
163
+ self.Joint_embedding = nn.Linear(in_chans, embed_dim_ratio)
164
+ self.Freq_embedding = nn.Linear(in_chans*num_joints, embed_dim)
165
+
166
+ self.Spatial_pos_embed = nn.Parameter(torch.zeros(1, num_joints, embed_dim_ratio))
167
+ self.Temporal_pos_embed = nn.Parameter(torch.zeros(1, self.num_frame_kept, embed_dim))
168
+ self.Temporal_pos_embed_ = nn.Parameter(torch.zeros(1, self.num_coeff_kept, embed_dim))
169
+ self.pos_drop = nn.Dropout(p=drop_rate)
170
+
171
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
172
+
173
+ self.Spatial_blocks = nn.ModuleList([
174
+ Block(
175
+ dim=embed_dim_ratio, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
176
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
177
+ for i in range(depth)])
178
+
179
+ self.blocks = nn.ModuleList([
180
+ MixedBlock(
181
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
182
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
183
+ for i in range(depth)])
184
+
185
+ self.Spatial_norm = norm_layer(embed_dim_ratio)
186
+ self.Temporal_norm = norm_layer(embed_dim)
187
+
188
+ ####### A easy way to implement weighted mean
189
+ self.weighted_mean = torch.nn.Conv1d(in_channels=self.num_coeff_kept, out_channels=1, kernel_size=1)
190
+ self.weighted_mean_ = torch.nn.Conv1d(in_channels=self.num_frame_kept, out_channels=1, kernel_size=1)
191
+
192
+ self.head = nn.Sequential(
193
+ nn.LayerNorm(embed_dim*2),
194
+ nn.Linear(embed_dim*2, out_dim),
195
+ )
196
+
197
+ def Spatial_forward_features(self, x):
198
+ b, f, p, _ = x.shape ##### b is batch size, f is number of frames, p is number of joints
199
+ num_frame_kept = self.num_frame_kept
200
+
201
+ index = torch.arange((f-1)//2-num_frame_kept//2, (f-1)//2+num_frame_kept//2+1)
202
+
203
+ x = self.Joint_embedding(x[:, index].view(b*num_frame_kept, p, -1))
204
+ x += self.Spatial_pos_embed
205
+ x = self.pos_drop(x)
206
+
207
+ for blk in self.Spatial_blocks:
208
+ x = blk(x)
209
+
210
+ x = self.Spatial_norm(x)
211
+ x = rearrange(x, '(b f) p c -> b f (p c)', f=num_frame_kept)
212
+ return x
213
+
214
+ def forward_features(self, x, Spatial_feature):
215
+ b, f, p, _ = x.shape
216
+ num_coeff_kept = self.num_coeff_kept
217
+
218
+ x = dct.dct(x.permute(0, 2, 3, 1))[:, :, :, :num_coeff_kept]
219
+ x = x.permute(0, 3, 1, 2).contiguous().view(b, num_coeff_kept, -1)
220
+ x = self.Freq_embedding(x)
221
+
222
+ Spatial_feature += self.Temporal_pos_embed
223
+ x += self.Temporal_pos_embed_
224
+ x = torch.cat((x, Spatial_feature), dim=1)
225
+
226
+ for blk in self.blocks:
227
+ x = blk(x)
228
+
229
+ x = self.Temporal_norm(x)
230
+ return x
231
+
232
+ def forward(self, x):
233
+ b, f, p, _ = x.shape
234
+ x_ = x.clone()
235
+
236
+ Spatial_feature = self.Spatial_forward_features(x)
237
+ x = self.forward_features(x_, Spatial_feature)
238
+ x = torch.cat((self.weighted_mean(x[:, :self.num_coeff_kept]), self.weighted_mean_(x[:, self.num_coeff_kept:])), dim=-1)
239
+
240
+ x = self.head(x).view(b, 1, p, -1)
241
+ return x
242
+
common/quaternion.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import torch
9
+
10
+ def qrot(q, v):
11
+ """
12
+ Rotate vector(s) v about the rotation described by quaternion(s) q.
13
+ Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
14
+ where * denotes any number of dimensions.
15
+ Returns a tensor of shape (*, 3).
16
+ """
17
+ assert q.shape[-1] == 4
18
+ assert v.shape[-1] == 3
19
+ assert q.shape[:-1] == v.shape[:-1]
20
+
21
+ qvec = q[..., 1:]
22
+ uv = torch.cross(qvec, v, dim=len(q.shape)-1)
23
+ uuv = torch.cross(qvec, uv, dim=len(q.shape)-1)
24
+ return (v + 2 * (q[..., :1] * uv + uuv))
25
+
26
+
27
+ def qinverse(q, inplace=False):
28
+ # We assume the quaternion to be normalized
29
+ if inplace:
30
+ q[..., 1:] *= -1
31
+ return q
32
+ else:
33
+ w = q[..., :1]
34
+ xyz = q[..., 1:]
35
+ return torch.cat((w, -xyz), dim=len(q.shape)-1)
common/skeleton.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import numpy as np
9
+
10
+ class Skeleton:
11
+ def __init__(self, parents, joints_left, joints_right):
12
+ assert len(joints_left) == len(joints_right)
13
+
14
+ self._parents = np.array(parents)
15
+ self._joints_left = joints_left
16
+ self._joints_right = joints_right
17
+ self._compute_metadata()
18
+
19
+ def num_joints(self):
20
+ return len(self._parents)
21
+
22
+ def parents(self):
23
+ return self._parents
24
+
25
+ def has_children(self):
26
+ return self._has_children
27
+
28
+ def children(self):
29
+ return self._children
30
+
31
+ def remove_joints(self, joints_to_remove):
32
+ """
33
+ Remove the joints specified in 'joints_to_remove'.
34
+ """
35
+ valid_joints = []
36
+ for joint in range(len(self._parents)):
37
+ if joint not in joints_to_remove:
38
+ valid_joints.append(joint)
39
+
40
+ for i in range(len(self._parents)):
41
+ while self._parents[i] in joints_to_remove:
42
+ self._parents[i] = self._parents[self._parents[i]]
43
+
44
+ index_offsets = np.zeros(len(self._parents), dtype=int)
45
+ new_parents = []
46
+ for i, parent in enumerate(self._parents):
47
+ if i not in joints_to_remove:
48
+ new_parents.append(parent - index_offsets[parent])
49
+ else:
50
+ index_offsets[i:] += 1
51
+ self._parents = np.array(new_parents)
52
+
53
+
54
+ if self._joints_left is not None:
55
+ new_joints_left = []
56
+ for joint in self._joints_left:
57
+ if joint in valid_joints:
58
+ new_joints_left.append(joint - index_offsets[joint])
59
+ self._joints_left = new_joints_left
60
+ if self._joints_right is not None:
61
+ new_joints_right = []
62
+ for joint in self._joints_right:
63
+ if joint in valid_joints:
64
+ new_joints_right.append(joint - index_offsets[joint])
65
+ self._joints_right = new_joints_right
66
+
67
+ self._compute_metadata()
68
+
69
+ return valid_joints
70
+
71
+ def joints_left(self):
72
+ return self._joints_left
73
+
74
+ def joints_right(self):
75
+ return self._joints_right
76
+
77
+ def _compute_metadata(self):
78
+ self._has_children = np.zeros(len(self._parents)).astype(bool)
79
+ for i, parent in enumerate(self._parents):
80
+ if parent != -1:
81
+ self._has_children[parent] = True
82
+
83
+ self._children = []
84
+ for i, parent in enumerate(self._parents):
85
+ self._children.append([])
86
+ for i, parent in enumerate(self._parents):
87
+ if parent != -1:
88
+ self._children[parent].append(i)
common/utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import torch
9
+ import numpy as np
10
+ import hashlib
11
+
12
+ def wrap(func, *args, unsqueeze=False):
13
+ """
14
+ Wrap a torch function so it can be called with NumPy arrays.
15
+ Input and return types are seamlessly converted.
16
+ """
17
+
18
+ # Convert input types where applicable
19
+ args = list(args)
20
+ for i, arg in enumerate(args):
21
+ if type(arg) == np.ndarray:
22
+ args[i] = torch.from_numpy(arg)
23
+ if unsqueeze:
24
+ args[i] = args[i].unsqueeze(0)
25
+
26
+ result = func(*args)
27
+
28
+ # Convert output types where applicable
29
+ if isinstance(result, tuple):
30
+ result = list(result)
31
+ for i, res in enumerate(result):
32
+ if type(res) == torch.Tensor:
33
+ if unsqueeze:
34
+ res = res.squeeze(0)
35
+ result[i] = res.numpy()
36
+ return tuple(result)
37
+ elif type(result) == torch.Tensor:
38
+ if unsqueeze:
39
+ result = result.squeeze(0)
40
+ return result.numpy()
41
+ else:
42
+ return result
43
+
44
+ def deterministic_random(min_value, max_value, data):
45
+ digest = hashlib.sha256(data.encode()).digest()
46
+ raw_value = int.from_bytes(digest[:4], byteorder='little', signed=False)
47
+ return int(raw_value / (2**32 - 1) * (max_value - min_value)) + min_value
48
+
49
+ def load_pretrained_weights(model, checkpoint):
50
+ """Load pretrianed weights to model
51
+ Incompatible layers (unmatched in name or size) will be ignored
52
+ Args:
53
+ - model (nn.Module): network model, which must not be nn.DataParallel
54
+ - weight_path (str): path to pretrained weights
55
+ """
56
+ import collections
57
+ if 'state_dict' in checkpoint:
58
+ state_dict = checkpoint['state_dict']
59
+ else:
60
+ state_dict = checkpoint
61
+ model_dict = model.state_dict()
62
+ new_state_dict = collections.OrderedDict()
63
+ matched_layers, discarded_layers = [], []
64
+ for k, v in state_dict.items():
65
+ # If the pretrained state_dict was saved as nn.DataParallel,
66
+ # keys would contain "module.", which should be ignored.
67
+ if k.startswith('module.'):
68
+ k = k[7:]
69
+ if k in model_dict and model_dict[k].size() == v.size():
70
+ new_state_dict[k] = v
71
+ matched_layers.append(k)
72
+ else:
73
+ discarded_layers.append(k)
74
+ # new_state_dict.requires_grad = False
75
+ model_dict.update(new_state_dict)
76
+
77
+ model.load_state_dict(model_dict)
78
+ print('load_weight', len(matched_layers))
79
+ # model.state_dict(model_dict).requires_grad = False
80
+ return model
common/visualization.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import matplotlib
9
+
10
+ matplotlib.use('Agg')
11
+
12
+
13
+ import matplotlib.pyplot as plt
14
+ from matplotlib.animation import FuncAnimation, writers
15
+ from mpl_toolkits.mplot3d import Axes3D
16
+ import numpy as np
17
+ import subprocess as sp
18
+ import cv2
19
+
20
+
21
+ def get_resolution(filename):
22
+ command = ['ffprobe', '-v', 'error', '-select_streams', 'v:0',
23
+ '-show_entries', 'stream=width,height', '-of', 'csv=p=0', filename]
24
+ with sp.Popen(command, stdout=sp.PIPE, bufsize=-1) as pipe:
25
+ for line in pipe.stdout:
26
+ w, h = line.decode().strip().split(',')
27
+ return int(w), int(h)
28
+
29
+
30
+ def get_fps(filename):
31
+ command = ['ffprobe', '-v', 'error', '-select_streams', 'v:0',
32
+ '-show_entries', 'stream=r_frame_rate', '-of', 'csv=p=0', filename]
33
+ with sp.Popen(command, stdout=sp.PIPE, bufsize=-1) as pipe:
34
+ for line in pipe.stdout:
35
+ a, b = line.decode().strip().split('/')
36
+ return int(a) / int(b)
37
+
38
+
39
+ def read_video(filename, skip=0, limit=-1):
40
+ # w, h = get_resolution(filename)
41
+ w = 1000
42
+ h = 1002
43
+
44
+ command = ['ffmpeg',
45
+ '-i', filename,
46
+ '-f', 'image2pipe',
47
+ '-pix_fmt', 'rgb24',
48
+ '-vsync', '0',
49
+ '-vcodec', 'rawvideo', '-']
50
+
51
+ i = 0
52
+ with sp.Popen(command, stdout=sp.PIPE, bufsize=-1) as pipe:
53
+ while True:
54
+ data = pipe.stdout.read(w * h * 3)
55
+ if not data:
56
+ break
57
+ i += 1
58
+ if i > limit and limit != -1:
59
+ continue
60
+ if i > skip:
61
+ yield np.frombuffer(data, dtype='uint8').reshape((h, w, 3))
62
+
63
+
64
+ def downsample_tensor(X, factor):
65
+ length = X.shape[0] // factor * factor
66
+ return np.mean(X[:length].reshape(-1, factor, *X.shape[1:]), axis=1)
67
+
68
+
69
+ def render_animation(keypoints, keypoints_metadata, poses, skeleton, fps, bitrate, azim, output, viewport,
70
+ limit=-1, downsample=1, size=6, input_video_path=None, input_video_skip=0):
71
+ """
72
+ TODO
73
+ Render an animation. The supported output modes are:
74
+ -- 'interactive': display an interactive figure
75
+ (also works on notebooks if associated with %matplotlib inline)
76
+ -- 'html': render the animation as HTML5 video. Can be displayed in a notebook using HTML(...).
77
+ -- 'filename.mp4': render and export the animation as an h264 video (requires ffmpeg).
78
+ -- 'filename.gif': render and export the animation a gif file (requires imagemagick).
79
+ """
80
+ plt.ioff()
81
+ fig = plt.figure(figsize=(size * (1 + len(poses)), size))
82
+ ax_in = fig.add_subplot(1, 1 + len(poses), 1)
83
+ ax_in.get_xaxis().set_visible(False)
84
+ ax_in.get_yaxis().set_visible(False)
85
+ ax_in.set_axis_off()
86
+ ax_in.set_title('Input')
87
+
88
+ ax_3d = []
89
+ lines_3d = []
90
+ trajectories = []
91
+ radius = 1.7
92
+ for index, (title, data) in enumerate(poses.items()):
93
+ ax = fig.add_subplot(1, 1 + len(poses), index + 2, projection='3d')
94
+ ax.view_init(elev=15., azim=azim)
95
+ ax.set_xlim3d([-radius / 2, radius / 2])
96
+ ax.set_zlim3d([0, radius])
97
+ ax.set_ylim3d([-radius / 2, radius / 2])
98
+ try:
99
+ ax.set_aspect('equal')
100
+ except NotImplementedError:
101
+ ax.set_aspect('auto')
102
+ ax.set_xticklabels([])
103
+ ax.set_yticklabels([])
104
+ ax.set_zticklabels([])
105
+ ax.dist = 7.5
106
+ ax.set_title(title) # , pad=35
107
+ ax_3d.append(ax)
108
+ lines_3d.append([])
109
+ trajectories.append(data[:, 0, [0, 1]])
110
+ poses = list(poses.values())
111
+
112
+ # Decode video
113
+ if input_video_path is None:
114
+ # Black background
115
+ all_frames = np.zeros((keypoints.shape[0], viewport[1], viewport[0]), dtype='uint8')
116
+ else:
117
+ # Load video using ffmpeg
118
+ all_frames = []
119
+ for f in read_video(input_video_path, skip=input_video_skip, limit=limit):
120
+ all_frames.append(f)
121
+ effective_length = min(keypoints.shape[0], len(all_frames))
122
+ all_frames = all_frames[:effective_length]
123
+
124
+ keypoints = keypoints[input_video_skip:] # todo remove
125
+ for idx in range(len(poses)):
126
+ poses[idx] = poses[idx][input_video_skip:]
127
+
128
+ if fps is None:
129
+ fps = get_fps(input_video_path)
130
+
131
+ if downsample > 1:
132
+ keypoints = downsample_tensor(keypoints, downsample)
133
+ all_frames = downsample_tensor(np.array(all_frames), downsample).astype('uint8')
134
+ for idx in range(len(poses)):
135
+ poses[idx] = downsample_tensor(poses[idx], downsample)
136
+ trajectories[idx] = downsample_tensor(trajectories[idx], downsample)
137
+ fps /= downsample
138
+
139
+ initialized = False
140
+ image = None
141
+ lines = []
142
+ points = None
143
+
144
+ if limit < 1:
145
+ limit = len(all_frames)
146
+ else:
147
+ limit = min(limit, len(all_frames))
148
+
149
+ parents = skeleton.parents()
150
+
151
+ def update_video(i):
152
+ nonlocal initialized, image, lines, points
153
+
154
+ for n, ax in enumerate(ax_3d):
155
+ ax.set_xlim3d([-radius / 2 + trajectories[n][i, 0], radius / 2 + trajectories[n][i, 0]])
156
+ ax.set_ylim3d([-radius / 2 + trajectories[n][i, 1], radius / 2 + trajectories[n][i, 1]])
157
+
158
+ # Update 2D poses
159
+ joints_right_2d = keypoints_metadata['keypoints_symmetry'][1]
160
+ colors_2d = np.full(keypoints.shape[1], 'black')
161
+ colors_2d[joints_right_2d] = 'red'
162
+ if not initialized:
163
+ image = ax_in.imshow(all_frames[i], aspect='equal')
164
+
165
+ for j, j_parent in enumerate(parents):
166
+ if j_parent == -1:
167
+ continue
168
+
169
+ if len(parents) == keypoints.shape[1] and keypoints_metadata['layout_name'] != 'coco':
170
+ # Draw skeleton only if keypoints match (otherwise we don't have the parents definition)
171
+ lines.append(ax_in.plot([keypoints[i, j, 0], keypoints[i, j_parent, 0]],
172
+ [keypoints[i, j, 1], keypoints[i, j_parent, 1]], color='pink'))
173
+
174
+ col = 'red' if j in skeleton.joints_right() else 'black'
175
+ for n, ax in enumerate(ax_3d):
176
+ pos = poses[n][i]
177
+ lines_3d[n].append(ax.plot([pos[j, 0], pos[j_parent, 0]],
178
+ [pos[j, 1], pos[j_parent, 1]],
179
+ [pos[j, 2], pos[j_parent, 2]], zdir='z', c=col))
180
+
181
+ points = ax_in.scatter(*keypoints[i].T, 10, color=colors_2d, edgecolors='white', zorder=10)
182
+
183
+ initialized = True
184
+ else:
185
+ image.set_data(all_frames[i])
186
+
187
+ for j, j_parent in enumerate(parents):
188
+ if j_parent == -1:
189
+ continue
190
+
191
+ if len(parents) == keypoints.shape[1] and keypoints_metadata['layout_name'] != 'coco':
192
+ lines[j - 1][0].set_data([keypoints[i, j, 0], keypoints[i, j_parent, 0]],
193
+ [keypoints[i, j, 1], keypoints[i, j_parent, 1]])
194
+
195
+ for n, ax in enumerate(ax_3d):
196
+ pos = poses[n][i]
197
+ lines_3d[n][j - 1][0].set_xdata(np.array([pos[j, 0], pos[j_parent, 0]]))
198
+ lines_3d[n][j - 1][0].set_ydata(np.array([pos[j, 1], pos[j_parent, 1]]))
199
+ lines_3d[n][j - 1][0].set_3d_properties(np.array([pos[j, 2], pos[j_parent, 2]]), zdir='z')
200
+
201
+ points.set_offsets(keypoints[i])
202
+
203
+ print('{}/{} '.format(i, limit), end='\r')
204
+
205
+ fig.tight_layout()
206
+
207
+ anim = FuncAnimation(fig, update_video, frames=np.arange(0, limit), interval=1000 / fps, repeat=False)
208
+ if output.endswith('.mp4'):
209
+ Writer = writers['ffmpeg']
210
+ writer = Writer(fps=fps, metadata={}, bitrate=bitrate)
211
+ anim.save(output, writer=writer)
212
+ elif output.endswith('.gif'):
213
+ anim.save(output, dpi=80, writer='imagemagick')
214
+ else:
215
+ raise ValueError('Unsupported output format (only .mp4 and .gif are supported)')
216
+ plt.close()