|
|
import copy |
|
|
import os |
|
|
import joblib |
|
|
import numpy as np |
|
|
from scipy.spatial.transform import Slerp, Rotation |
|
|
import torch |
|
|
|
|
|
from hawor.utils.process import run_mano, run_mano_left |
|
|
from hawor.utils.rotation import angle_axis_to_quaternion, angle_axis_to_rotation_matrix, quaternion_to_rotation_matrix, rotation_matrix_to_angle_axis |
|
|
from lib.utils.geometry import rotmat_to_rot6d |
|
|
from lib.utils.geometry import rot6d_to_rotmat |
|
|
|
|
|
def slerp_interpolation_aa(pos, valid): |
|
|
|
|
|
B, T, N, _ = pos.shape |
|
|
pos_interp = pos.copy() |
|
|
|
|
|
for b in range(B): |
|
|
for n in range(N): |
|
|
quat_b_n = pos[b, :, n, :] |
|
|
valid_b_n = valid[b, :] |
|
|
|
|
|
invalid_idxs = np.where(~valid_b_n)[0] |
|
|
valid_idxs = np.where(valid_b_n)[0] |
|
|
|
|
|
if len(invalid_idxs) == 0: |
|
|
continue |
|
|
|
|
|
if len(valid_idxs) > 1: |
|
|
valid_times = valid_idxs |
|
|
valid_rots = Rotation.from_rotvec(quat_b_n[valid_idxs]) |
|
|
|
|
|
slerp = Slerp(valid_times, valid_rots) |
|
|
|
|
|
for idx in invalid_idxs: |
|
|
if idx < valid_idxs[0]: |
|
|
pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[0]] |
|
|
elif idx > valid_idxs[-1]: |
|
|
pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[-1]] |
|
|
else: |
|
|
interp_rot = slerp([idx]) |
|
|
pos_interp[b, idx, n, :] = interp_rot.as_rotvec()[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return pos_interp |
|
|
|
|
|
def slerp_interpolation_quat(pos, valid): |
|
|
|
|
|
|
|
|
pos = pos[:, :, :, [1, 2, 3, 0]] |
|
|
|
|
|
B, T, N, _ = pos.shape |
|
|
pos_interp = pos.copy() |
|
|
|
|
|
for b in range(B): |
|
|
for n in range(N): |
|
|
quat_b_n = pos[b, :, n, :] |
|
|
valid_b_n = valid[b, :] |
|
|
|
|
|
invalid_idxs = np.where(~valid_b_n)[0] |
|
|
valid_idxs = np.where(valid_b_n)[0] |
|
|
|
|
|
if len(invalid_idxs) == 0: |
|
|
continue |
|
|
|
|
|
if len(valid_idxs) > 1: |
|
|
valid_times = valid_idxs |
|
|
valid_rots = Rotation.from_quat(quat_b_n[valid_idxs]) |
|
|
|
|
|
slerp = Slerp(valid_times, valid_rots) |
|
|
|
|
|
for idx in invalid_idxs: |
|
|
if idx < valid_idxs[0]: |
|
|
pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[0]] |
|
|
elif idx > valid_idxs[-1]: |
|
|
pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[-1]] |
|
|
else: |
|
|
interp_rot = slerp([idx]) |
|
|
pos_interp[b, idx, n, :] = interp_rot.as_quat()[0] |
|
|
|
|
|
|
|
|
pos_interp = pos_interp[:, :, :, [3, 0, 1, 2]] |
|
|
return pos_interp |
|
|
|
|
|
|
|
|
def linear_interpolation_nd(pos, valid): |
|
|
B, T = pos.shape[:2] |
|
|
feature_dim = pos.shape[2] |
|
|
pos_interp = pos.copy() |
|
|
|
|
|
for b in range(B): |
|
|
for idx in range(feature_dim): |
|
|
pos_b_idx = pos[b, :, idx] |
|
|
valid_b = valid[b, :] |
|
|
|
|
|
|
|
|
invalid_idxs = np.where(~valid_b)[0] |
|
|
valid_idxs = np.where(valid_b)[0] |
|
|
|
|
|
if len(invalid_idxs) == 0: |
|
|
continue |
|
|
|
|
|
|
|
|
if len(valid_idxs) > 1: |
|
|
pos_b_idx[invalid_idxs] = np.interp(invalid_idxs, valid_idxs, pos_b_idx[valid_idxs]) |
|
|
pos_interp[b, :, idx] = pos_b_idx |
|
|
|
|
|
return pos_interp |
|
|
|
|
|
def world2canonical_convert(R_c2w_sla, t_c2w_sla, data_out, handedness): |
|
|
init_rot_mat = copy.deepcopy(data_out["init_root_orient"]) |
|
|
init_rot_mat = torch.einsum("tij,btjk->btik", R_c2w_sla, init_rot_mat) |
|
|
init_rot = rotation_matrix_to_angle_axis(init_rot_mat) |
|
|
init_rot_quat = angle_axis_to_quaternion(init_rot) |
|
|
|
|
|
|
|
|
data_out_init_root_orient = rotation_matrix_to_angle_axis(data_out["init_root_orient"]) |
|
|
data_out_init_hand_pose = rotation_matrix_to_angle_axis(data_out["init_hand_pose"]) |
|
|
|
|
|
init_trans = data_out["init_trans"] |
|
|
if handedness == "left": |
|
|
outputs = run_mano_left(data_out["init_trans"], data_out_init_root_orient, data_out_init_hand_pose, betas=data_out["init_betas"]) |
|
|
|
|
|
elif handedness == "right": |
|
|
outputs = run_mano(data_out["init_trans"], data_out_init_root_orient, data_out_init_hand_pose, betas=data_out["init_betas"]) |
|
|
root_loc = outputs["joints"][..., 0, :].cpu() |
|
|
offset = init_trans - root_loc |
|
|
init_trans = ( |
|
|
torch.einsum("tij,btj->bti", R_c2w_sla, root_loc) |
|
|
+ t_c2w_sla[None, :] |
|
|
+ offset |
|
|
) |
|
|
|
|
|
data_world = { |
|
|
"init_root_orient": init_rot, |
|
|
"init_hand_pose": data_out_init_hand_pose, |
|
|
"init_trans": init_trans, |
|
|
"init_betas": data_out["init_betas"] |
|
|
} |
|
|
|
|
|
return data_world |
|
|
|
|
|
def filling_preprocess(item): |
|
|
|
|
|
num_joints = 15 |
|
|
|
|
|
global_trans = item['trans'] |
|
|
global_rot = item['rot'] |
|
|
hand_pose = item['hand_pose'] |
|
|
betas = item['betas'] |
|
|
valid = item['valid'] |
|
|
|
|
|
N, T, _ = global_trans.shape |
|
|
R_canonical2world_left_aa = torch.from_numpy(global_rot[0, 0]) |
|
|
R_canonical2world_right_aa = torch.from_numpy(global_rot[1, 0]) |
|
|
R_world2canonical_left = angle_axis_to_rotation_matrix(R_canonical2world_left_aa).t() |
|
|
R_world2canonical_right = angle_axis_to_rotation_matrix(R_canonical2world_right_aa).t() |
|
|
|
|
|
|
|
|
|
|
|
hand_pose = hand_pose.reshape(N, T, num_joints, 3) |
|
|
data_world_left = { |
|
|
"init_trans": torch.from_numpy(global_trans[0:1]), |
|
|
"init_root_orient": angle_axis_to_rotation_matrix(torch.from_numpy(global_rot[0:1])), |
|
|
"init_hand_pose": angle_axis_to_rotation_matrix(torch.from_numpy(hand_pose[0:1])), |
|
|
"init_betas": torch.from_numpy(betas[0:1]), |
|
|
} |
|
|
|
|
|
data_left_init_root_orient = rotation_matrix_to_angle_axis(data_world_left["init_root_orient"]) |
|
|
data_left_init_hand_pose = rotation_matrix_to_angle_axis(data_world_left["init_hand_pose"]) |
|
|
outputs = run_mano_left(data_world_left["init_trans"], data_left_init_root_orient, data_left_init_hand_pose, betas=data_world_left["init_betas"]) |
|
|
init_trans = data_world_left["init_trans"][0, 0] |
|
|
root_loc = outputs["joints"][0, 0, 0, :].cpu() |
|
|
offset = init_trans - root_loc |
|
|
t_world2canonical_left = -torch.einsum("ij,j->i", R_world2canonical_left, root_loc) - offset |
|
|
|
|
|
R_world2canonical_left = R_world2canonical_left.repeat(T, 1, 1) |
|
|
t_world2canonical_left = t_world2canonical_left.repeat(T, 1) |
|
|
data_canonical_left = world2canonical_convert(R_world2canonical_left, t_world2canonical_left, data_world_left, "left") |
|
|
|
|
|
|
|
|
data_world_right = { |
|
|
"init_trans": torch.from_numpy(global_trans[1:2]), |
|
|
"init_root_orient": angle_axis_to_rotation_matrix(torch.from_numpy(global_rot[1:2])), |
|
|
"init_hand_pose": angle_axis_to_rotation_matrix(torch.from_numpy(hand_pose[1:2])), |
|
|
"init_betas": torch.from_numpy(betas[1:2]), |
|
|
} |
|
|
|
|
|
data_right_init_root_orient = rotation_matrix_to_angle_axis(data_world_right["init_root_orient"]) |
|
|
data_right_init_hand_pose = rotation_matrix_to_angle_axis(data_world_right["init_hand_pose"]) |
|
|
outputs = run_mano(data_world_right["init_trans"], data_right_init_root_orient, data_right_init_hand_pose, betas=data_world_right["init_betas"]) |
|
|
init_trans = data_world_right["init_trans"][0, 0] |
|
|
root_loc = outputs["joints"][0, 0, 0, :].cpu() |
|
|
offset = init_trans - root_loc |
|
|
t_world2canonical_right = -torch.einsum("ij,j->i", R_world2canonical_right, root_loc) - offset |
|
|
|
|
|
R_world2canonical_right = R_world2canonical_right.repeat(T, 1, 1) |
|
|
t_world2canonical_right = t_world2canonical_right.repeat(T, 1) |
|
|
data_canonical_right = world2canonical_convert(R_world2canonical_right, t_world2canonical_right, data_world_right, "right") |
|
|
|
|
|
|
|
|
global_rot = torch.cat((data_canonical_left['init_root_orient'], data_canonical_right['init_root_orient'])) |
|
|
global_trans = torch.cat((data_canonical_left['init_trans'], data_canonical_right['init_trans'])).numpy() |
|
|
|
|
|
|
|
|
global_rot = global_rot.reshape(N, T, 1, 3).numpy() |
|
|
|
|
|
hand_pose = hand_pose.reshape(N, T, 15, 3) |
|
|
|
|
|
|
|
|
|
|
|
global_trans_lerped = linear_interpolation_nd(global_trans, valid) |
|
|
betas_lerped = linear_interpolation_nd(betas, valid) |
|
|
global_rot_slerped = slerp_interpolation_aa(global_rot, valid) |
|
|
hand_pose_slerped = slerp_interpolation_aa(hand_pose, valid) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
global_rot_slerped_mat = angle_axis_to_rotation_matrix(torch.from_numpy(global_rot_slerped.reshape(N*T, -1))) |
|
|
|
|
|
global_rot_slerped_rot6d = rotmat_to_rot6d(global_rot_slerped_mat).reshape(N, T, -1).numpy() |
|
|
hand_pose_slerped_mat = angle_axis_to_rotation_matrix(torch.from_numpy(hand_pose_slerped.reshape(N*T*num_joints, -1))) |
|
|
|
|
|
hand_pose_slerped_rot6d = rotmat_to_rot6d(hand_pose_slerped_mat).reshape(N, T, -1).numpy() |
|
|
|
|
|
|
|
|
|
|
|
global_pose_vec_input = np.concatenate((global_trans_lerped, betas_lerped, global_rot_slerped_rot6d, hand_pose_slerped_rot6d), axis=-1).transpose(1, 0, 2).reshape(T, -1) |
|
|
|
|
|
R_canon2w_left = R_world2canonical_left.transpose(-1, -2) |
|
|
t_canon2w_left = -torch.einsum("tij,tj->ti", R_canon2w_left, t_world2canonical_left) |
|
|
R_canon2w_right = R_world2canonical_right.transpose(-1, -2) |
|
|
t_canon2w_right = -torch.einsum("tij,tj->ti", R_canon2w_right, t_world2canonical_right) |
|
|
|
|
|
transform_w_canon = { |
|
|
"R_w2canon_left": R_world2canonical_left, |
|
|
"t_w2canon_left": t_world2canonical_left, |
|
|
"R_canon2w_left": R_canon2w_left, |
|
|
"t_canon2w_left": t_canon2w_left, |
|
|
|
|
|
"R_w2canon_right": R_world2canonical_right, |
|
|
"t_w2canon_right": t_world2canonical_right, |
|
|
"R_canon2w_right": R_canon2w_right, |
|
|
"t_canon2w_right": t_canon2w_right, |
|
|
} |
|
|
|
|
|
return global_pose_vec_input, transform_w_canon |
|
|
|
|
|
def custom_rot6d_to_rotmat(rot6d): |
|
|
original_shape = rot6d.shape[:-1] |
|
|
rot6d = rot6d.reshape(-1, 6) |
|
|
mat = rot6d_to_rotmat(rot6d) |
|
|
mat = mat.reshape(*original_shape, 3, 3) |
|
|
return mat |
|
|
|
|
|
def filling_postprocess(output, transform_w_canon): |
|
|
|
|
|
output = output.permute(1, 0, 2) |
|
|
N, T, _ = output.shape |
|
|
canon_trans = output[:, :, :3] |
|
|
betas = output[:, :, 3:13] |
|
|
canon_rot_rot6d = output[:, :, 13:19] |
|
|
hand_pose_rot6d = output[:, :, 19:109].reshape(N, T, 15, 6) |
|
|
|
|
|
canon_rot_mat = custom_rot6d_to_rotmat(canon_rot_rot6d) |
|
|
hand_pose_mat = custom_rot6d_to_rotmat(hand_pose_rot6d) |
|
|
|
|
|
data_canonical_left = { |
|
|
"init_trans": canon_trans[[0], :, :], |
|
|
"init_root_orient": canon_rot_mat[[0], :, :, :], |
|
|
"init_hand_pose": hand_pose_mat[[0], :, :, :, :], |
|
|
"init_betas": betas[[0], :, :] |
|
|
} |
|
|
|
|
|
data_canonical_right = { |
|
|
"init_trans": canon_trans[[1], :, :], |
|
|
"init_root_orient": canon_rot_mat[[1], :, :, :], |
|
|
"init_hand_pose": hand_pose_mat[[1], :, :, :, :], |
|
|
"init_betas": betas[[1], :, :] |
|
|
} |
|
|
|
|
|
R_canon2w_left = transform_w_canon['R_canon2w_left'] |
|
|
t_canon2w_left = transform_w_canon['t_canon2w_left'] |
|
|
R_canon2w_right = transform_w_canon['R_canon2w_right'] |
|
|
t_canon2w_right = transform_w_canon['t_canon2w_right'] |
|
|
|
|
|
|
|
|
world_left = world2canonical_convert(R_canon2w_left, t_canon2w_left, data_canonical_left, "left") |
|
|
world_right = world2canonical_convert(R_canon2w_right, t_canon2w_right, data_canonical_right, "right") |
|
|
|
|
|
global_rot = torch.cat((world_left['init_root_orient'], world_right['init_root_orient'])).numpy() |
|
|
global_trans = torch.cat((world_left['init_trans'], world_right['init_trans'])).numpy() |
|
|
|
|
|
pred_data = { |
|
|
"trans": global_trans, |
|
|
"rot": global_rot, |
|
|
"hand_pose": rotation_matrix_to_angle_axis(hand_pose_mat).flatten(-2).numpy(), |
|
|
"betas": betas.numpy(), |
|
|
} |
|
|
|
|
|
return pred_data |
|
|
|
|
|
|