NMR / tools /data_process /process /process.py
Xxx999's picture
upload
45950ff
import sys
sys.path.append('/mnt/shenzhen2cephfs/capybarali/codes/humanoid')
import torch, yaml, os
from tqdm import tqdm
from smplx import SMPLX
from src.utils.rotation_conversions import quaternion_to_matrix, matrix_to_rotation_6d, matrix_to_axis_angle, rotation_6d_to_matrix, matrix_to_quaternion
import numpy as np
from argparse import ArgumentParser
import joblib
from data.vis import vis_3d_motion
from data.vis_g1 import vis_3d_g1
from copy import deepcopy
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_debug_save_count = 0
_debug_save_dir = 'data/debug_transform'
parser = ArgumentParser(description="Launch MoCap processing")
parser.add_argument('--save_root', type=str, default="data/final_data")
parser.add_argument('--start_idx', type=int, default=0)
parser.add_argument('--interval', type=int, default=1)
args = parser.parse_args()
os.makedirs(args.save_root, exist_ok=True)
os.makedirs(os.path.join(args.save_root, 'motions'), exist_ok=True)
smplx_model = SMPLX(
model_path='checkpoints/human_model/SMPLX_NEUTRAL.npz',
use_pca=False, num_expression_coeffs=100, num_betas=10, ext='npz'
).to(device)
def extract_g1_component(x):
vel_xy = x[:, :2]
dof = x[:, -29:]
root_rot_mat = rotation_6d_to_matrix(x[:, 2:8])
trans_xyz = torch.cat([torch.cumsum(vel_xy, dim=0), x[:, 10]])
rot_mat = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]]).float()
global_orient_mat = torch.from_numpy(root_rot_mat).squeeze(1)
global_orient_mat = torch.einsum('ij,tjk->tik', rot_mat, global_orient_mat)
rot_quat = matrix_to_quaternion(global_orient_mat) # (T, 4) wxyz order
transl = torch.from_numpy(trans_xyz).float()
transl = torch.einsum('ij,tj->ti', rot_mat, transl)
return dof, rot_quat, trans_xyz
def get_smplx_motion(data_path):
smplx_data = joblib.load(data_path)
transl = torch.from_numpy(smplx_data['smplx_transl'])
betas = torch.from_numpy(np.load('checkpoints/humanoid_model/g1/betas.npy'))
global_orient = torch.from_numpy(smplx_data['smplx_global_orient'])
body_pose = torch.from_numpy(smplx_data['smplx_body_pose'])
N = transl.shape[0]
motion_params = dict(
transl=transl,
global_orient=global_orient,
body_pose=body_pose,
betas=betas.unsqueeze(0).repeat(N, 1).float()
)
# 1. process positions
frame_params = {k: v.to(device) for k, v in motion_params.items()}
frame_params['leye_pose'] = torch.zeros((N, 3)).to(device)
frame_params['reye_pose'] = torch.zeros((N, 3)).to(device)
frame_params['left_hand_pose'] = torch.zeros((N, 45)).to(device)
frame_params['right_hand_pose'] = torch.zeros((N, 45)).to(device)
frame_params['jaw_pose'] = torch.zeros((N, 3)).to(device)
frame_params['expression'] = torch.zeros((N, 100)).to(device)
output = smplx_model(**frame_params)
position_data = output.joints.detach().cpu()[:, :22] # T, 22 ,3
position_val_data = position_data[1:] - position_data[:-1]
root_idx = 0
# put on floor and put root on origin for the first frame
ori = deepcopy(position_data[0, root_idx]) # first frame root position
y_min = torch.min(position_data[:, :, 1])
ori[1] = y_min
position_data = position_data - ori
velocities_root = position_data[1:, root_idx, :] - position_data[:-1, root_idx, :]
position_data[:,:,0] -= position_data[:,0:1,0]
position_data[:,:,2] -= position_data[:,0:1,2]
T, njoint, _ = position_data.shape
final_x = torch.zeros((T, 2 + 6 + njoint * 3 + njoint * 3))
final_x[1:, 0] = velocities_root[:, 0]
final_x[1:, 1] = velocities_root[:, 1]
final_x[:, 2:2+6] = matrix_to_rotation_6d(global_orient)
final_x[:, 8:8+njoint*3] = position_data.flatten(1, 2)
final_x[1:, 8+njoint*3:8+njoint*6] = position_val_data.flatten(1, 2) # T, 140
return final_x
def get_g1_motion(data_path):
global _debug_save_count
rotation_matrix = torch.tensor([[1.0, 0, 0], [0, 0, -1], [0, 1, 0]]).inverse()
g1_data = joblib.load(data_path)
dof = g1_data['g1_dof'] # T, 29
global_orient = g1_data['g1_root_ori'] # T, 4, wxyz
joints = g1_data['g1_joints'] # T, 30, 3
# save before-transform data for first 3 motions
if _debug_save_count < 3:
os.makedirs(_debug_save_dir, exist_ok=True)
np.savez(
os.path.join(_debug_save_dir, f'motion_{_debug_save_count:02d}_before.npz'),
g1_trans=joints,
g1_root_rot=global_orient,
g1_dof=dof,
)
global_orient_mat = quaternion_to_matrix(torch.from_numpy(global_orient)).float()
global_orient_mat = torch.einsum('ij,tjk->tik', rotation_matrix, global_orient_mat)
global_orient = matrix_to_axis_angle(global_orient_mat)
position_data = torch.einsum('ij,tkj->tki', rotation_matrix, torch.from_numpy(joints).float())
position_val_data = position_data[1:] - position_data[:-1]
# save after-transform data for first 3 motions
if _debug_save_count < 3:
np.savez(
os.path.join(_debug_save_dir, f'motion_{_debug_save_count:02d}_after.npz'),
g1_trans=position_data,
g1_root_rot=global_orient.numpy(),
g1_dof=dof,
)
root_idx = 0
# put on floor and put root on origin for the first frame
ori = deepcopy(position_data[0, root_idx]) # first frame root position
y_min = torch.min(position_data[:, :, 1])
ori[1] = y_min
position_data = position_data - ori
velocities_root = position_data[1:, root_idx, :] - position_data[:-1, root_idx, :]
position_data[:,:,0] -= position_data[:,0:1,0]
position_data[:,:,2] -= position_data[:,0:1,2]
T, njoint, _ = position_data.shape
final_x = torch.zeros((T, 2 + 6 + njoint * 3 + njoint * 3))
final_x[1:, 0] = velocities_root[:, 0]
final_x[1:, 1] = velocities_root[:, 1]
final_x[:, 2:2+6] = matrix_to_rotation_6d(global_orient)
final_x[:, 8:8+njoint*3] = position_data.flatten(1, 2)
final_x[1:, 8+njoint*3:8+njoint*6] = position_val_data.flatten(1, 2) # T, 140
final_x = torch.concat([final_x, torch.from_numpy(dof)], dim=-1)
return final_x # 217
# python -m data.motionmillion.tools.process --save_root "data/motionmillion/final_data"
if __name__ == '__main__':
with open('data/merge_gmr_retarget_smplx_path.txt', 'r') as f:
paths = f.readlines()
for line in tqdm(paths[args.start_idx::args.interval]):
data_path = line.strip()
smplx_motion = get_smplx_motion(data_path)
g1_motion = get_g1_motion(data_path)
if smplx_motion.shape[0] != g1_motion.shape[0]:
min_len = min(smplx_motion.shape[0], g1_motion.shape[0])
smplx_motion = smplx_motion[:min_len]
g1_motion = g1_motion[:min_len]
data = dict(
g1_motion=g1_motion,
smplx_motion=smplx_motion,
)
data_path = data_path.replace('.npz', '.pkl')
save_path = args.save_root + '/motions/' + '/'.join(data_path.split('/')[2:])
# os.makedirs(os.path.dirname(save_path), exist_ok=True)
# joblib.dump(data, save_path)