File size: 8,380 Bytes
0e267a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
# Reference: https://github.com/Li-xingXiao/272-dim-Motion-Representation
# representation: 272 dim
# :2 local xz velocities of root, no heading, can recover translation
# 2:8 heading angular velocities, 6d rotation, can recover heading
# 8:8+3*njoint local position, no heading, all at xz origin
# 8+3*njoint:8+6*njoint local velocities, no heading, all at xz origin, can recover local postion
# 8+6*njoint:8+12*njoint local rotations, 6d rotation, no heading, all frames z+
import numpy as np
from utils.face_z_align_util import rotation_6d_to_matrix, matrix_to_axis_angle
import copy
import torch
import os
import visualization.plot_3d_global as plot_3d
import argparse
import tqdm
# from visualization.smplx2joints import process_smplx_data
def findAllFile(base, endswith='.npy'):
file_path = []
for root, ds, fs in os.walk(base, followlinks=True):
for f in fs:
fullname = os.path.join(root, f)
if fullname.endswith(endswith):
file_path.append(fullname)
return file_path
def rot_yaw(yaw):
cs = np.cos(yaw)
sn = np.sin(yaw)
return np.array([[cs,0,sn],[0,1,0],[-sn,0,cs]])
def my_quat_rotate(q, v):
shape = q.shape
q_w = q[:, -1]
q_vec = q[:, :3]
a = v * (2.0 * q_w ** 2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = q_vec * \
torch.bmm(q_vec.view(shape[0], 1, 3), v.view(
shape[0], 3, 1)).squeeze(-1) * 2.0
return a + b + c
def calc_heading(q):
ref_dir = torch.zeros_like(q[..., 0:3])
ref_dir[..., 2] = 1
rot_dir = my_quat_rotate(q, ref_dir)
heading = torch.atan2(rot_dir[..., 0], rot_dir[..., 2])
return heading
def calc_heading_quat_inv(q):
heading = calc_heading(q)
axis = torch.zeros_like(q[..., 0:3])
axis[..., 1] = 1
return -heading, axis
def accumulate_rotations(relative_rotations):
"""Accumulate relative rotations to get the overall rotation"""
# Initial rotation is the rotation matrix
R_total = [relative_rotations[0]]
# Iterate through all relative rotations, accumulating them
for R_rel in relative_rotations[1:]:
R_total.append(np.matmul(R_rel, R_total[-1]))
return np.array(R_total)
def recover_from_local_position(final_x, njoint):
# take positions_no_heading: local position on xz ori, no heading
# velocities_root_xy_no_heading: to recover translation
# global_heading_diff_rot: to recover root rotation
nfrm, _ = final_x.shape
positions_no_heading = final_x[:,8:8+3*njoint].reshape(nfrm, -1, 3) # frames, njoints * 3
velocities_root_xy_no_heading = final_x[:,:2] # frames, 2
global_heading_diff_rot = final_x[:,2:8] # frames, 6
# recover global heading
global_heading_rot = accumulate_rotations(rotation_6d_to_matrix(torch.from_numpy(global_heading_diff_rot)).numpy())
inv_global_heading_rot = np.transpose(global_heading_rot, (0, 2, 1))
# add global heading to position
positions_with_heading = np.matmul(np.repeat(inv_global_heading_rot[:, None,:, :], njoint, axis=1), positions_no_heading[...,None]).squeeze(-1)
# recover root translation
# add heading to velocities_root_xy_no_heading
velocities_root_xyz_no_heading = np.zeros((velocities_root_xy_no_heading.shape[0], 3))
velocities_root_xyz_no_heading[:, 0] = velocities_root_xy_no_heading[:, 0]
velocities_root_xyz_no_heading[:, 2] = velocities_root_xy_no_heading[:, 1]
velocities_root_xyz_no_heading[1:, :] = np.matmul(inv_global_heading_rot[:-1], velocities_root_xyz_no_heading[1:, :,None]).squeeze(-1)
root_translation = np.cumsum(velocities_root_xyz_no_heading, axis=0)
# add root translation
positions_with_heading[:, :, 0] += root_translation[:, 0:1]
positions_with_heading[:, :, 2] += root_translation[:, 2:]
return positions_with_heading
# add hip height to translation when recoverring from rotation
def recover_from_local_rotation(final_x, njoint):
nfrm, _ = final_x.shape
rotations_matrix = rotation_6d_to_matrix(torch.from_numpy(final_x[:,8+6*njoint:8+12*njoint]).reshape(nfrm, -1, 6)).numpy()
global_heading_diff_rot = final_x[:,2:8]
velocities_root_xy_no_heading = final_x[:,:2]
positions_no_heading = final_x[:, 8:8+3*njoint].reshape(nfrm, -1, 3)
height = positions_no_heading[:, 0, 1]
global_heading_rot = accumulate_rotations(rotation_6d_to_matrix(torch.from_numpy(global_heading_diff_rot)).numpy())
inv_global_heading_rot = np.transpose(global_heading_rot, (0, 2, 1))
# recover root rotation
rotations_matrix[:,0,...] = np.matmul(inv_global_heading_rot, rotations_matrix[:,0,...])
velocities_root_xyz_no_heading = np.zeros((velocities_root_xy_no_heading.shape[0], 3))
velocities_root_xyz_no_heading[:, 0] = velocities_root_xy_no_heading[:, 0]
velocities_root_xyz_no_heading[:, 2] = velocities_root_xy_no_heading[:, 1]
velocities_root_xyz_no_heading[1:, :] = np.matmul(inv_global_heading_rot[:-1], velocities_root_xyz_no_heading[1:, :,None]).squeeze(-1)
root_translation = np.cumsum(velocities_root_xyz_no_heading, axis=0)
root_translation[:, 1] = height
smpl_85 = rotations_matrix_to_smpl85(rotations_matrix, root_translation)
return smpl_85
def rotations_matrix_to_smpl85(rotations_matrix, translation):
nfrm, njoint, _, _ = rotations_matrix.shape
axis_angle = matrix_to_axis_angle(torch.from_numpy(rotations_matrix)).numpy().reshape(nfrm, -1)
smpl_85 = np.concatenate([axis_angle, np.zeros((nfrm, 6)), translation, np.zeros((nfrm, 10))], axis=-1)
return smpl_85
def smpl85_2_smpl322(smpl_85_data):
result = np.concatenate((smpl_85_data[:,:66], np.zeros((smpl_85_data.shape[0], 90)), np.zeros((smpl_85_data.shape[0], 3)), np.zeros((smpl_85_data.shape[0], 50)), np.zeros((smpl_85_data.shape[0], 100)), smpl_85_data[:,72:72+3], smpl_85_data[:,75:]), axis=-1)
return result
def visualize_smpl_85(data, title=None, output_path='visualize_result', name='', fps=30):
# data: torch.Size([nframe, 85])
smpl_85_data = data
if len(smpl_85_data.shape) == 3:
smpl_85_data = np.squeeze(smpl_85_data, axis=0)
smpl_85_data = smpl85_2_smpl322(smpl_85_data)
vert, joints, motion, faces = process_smplx_data(smpl_85_data, norm_global_orient=False, transform=False)
xyz = joints[:, :22, :].reshape(1, -1, 22, 3).detach().cpu().numpy()
os.makedirs(os.path.dirname(output_path), exist_ok=True)
pose_vis = plot_3d.draw_to_batch(xyz, title_batch=title, outname=[f'{output_path}/rot_{name}.mp4'], fps=fps)
return output_path
def visualize_pos_xyz(xyz, title_batch=None, output_path='./', name='', fps=30):
# xyz: torch.Size([nframe, 22, 3])
xyz = xyz[:1]
bs, seq = xyz.shape[:2]
xyz = xyz.reshape(bs, seq, -1, 3)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
plot_xyz = plot_3d.draw_to_batch(xyz, title_batch, [f'{output_path}/pos_{name}.mp4'], fps=fps)
return output_path
if __name__ == '__main__':
njoint = 22
parser = argparse.ArgumentParser(description='Visualize new representation.')
parser.add_argument('--input_dir', type=str, required=True, help='Input path')
parser.add_argument('--mode', type=str, required=True, default='rot', choices=['rot', 'pos'], help='Recover from rotation or position')
parser.add_argument('--output_dir', type=str, required=True, help='Output path')
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
for data_path in tqdm.tqdm(findAllFile(args.input_dir, endswith='.npy')):
data_272 = np.load(data_path)
if args.mode == 'rot':
# recover from rotation
global_rotation = recover_from_local_rotation(data_272, njoint) # get the 85-dim smpl data
visualize_smpl_85(global_rotation, output_path=args.output_dir, name=data_path.split('/')[-1].split('.')[0])
print(f"Visualized results are saved in {args.output_dir}")
else:
# recover from position
global_position = recover_from_local_position(data_272, njoint)
global_position = np.expand_dims(global_position, axis=0)
visualize_pos_xyz(global_position, output_path=args.output_dir, name=data_path.split('/')[-1].split('.')[0])
print(f"Visualized results are saved in {args.output_dir}") |