TunaDance / data /code /pre_motion.py
NikhilMarisetty's picture
Upload folder using huggingface_hub
eb71a72 verified
import argparse
import os
from pathlib import Path
import smplx, pickle
import torch
import sys
from tqdm import tqdm
import glob
import numpy as np
sys.path.append(os.getcwd())
from dataset.quaternion import ax_to_6v, ax_from_6v
from dataset.preprocess import Normalizer, vectorize_many
def motion_feats_extract(inputs_dir, outputs_dir):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print("extracting")
raw_fps = 30
data_fps = 30
data_fps <= raw_fps
if not os.path.exists(outputs_dir):
os.makedirs(outputs_dir)
# All motion is retargeted to this standard model.
smplx_model = smplx.SMPLX(model_path='assets/smpl_model/smplx', ext='npz', gender='neutral',
num_betas=10, flat_hand_mean=True, num_expression_coeffs=10, use_pca=False).eval().to(device)
motions = sorted(glob.glob(os.path.join(inputs_dir, "*.npy")))
for motion in tqdm(motions):
name = os.path.splitext(os.path.basename(motion))[0].split(".")[0]
print("name is", name)
data = np.load(motion, allow_pickle=True)
print(data.shape)
pos = data[:,:3] # length, c
q = data[:,3:]
root_pos = torch.Tensor(pos).to(device) # T, 3
length = root_pos.shape[0]
local_q_rot6d = torch.Tensor(q).to(device) # T, 312
print("local_q_rot6d", local_q_rot6d.shape)
local_q = local_q_rot6d.reshape(length, 52, 6).clone()
local_q = ax_from_6v(local_q).view(length, 156) # T, 156
smplx_output = smplx_model(
betas = torch.zeros([root_pos.shape[0], 10], device=device, dtype=torch.float32),
transl = root_pos, # global translation
global_orient = local_q[:, :3],
body_pose = local_q[:, 3:66], # 21
jaw_pose = torch.zeros([root_pos.shape[0], 3], device=device, dtype=torch.float32), # 1
leye_pose = torch.zeros([root_pos.shape[0], 3], device=device, dtype=torch.float32), # 1
reye_pose= torch.zeros([root_pos.shape[0], 3], device=device, dtype=torch.float32), # 1
left_hand_pose = local_q[:, 66:66+45], # 15
right_hand_pose = local_q[:, 66+45:], # 15
expression = torch.zeros([root_pos.shape[0], 10], device=device, dtype=torch.float32),
return_verts = False
)
positions = smplx_output.joints.view(length, -1, 3) # bxt, j, 3
feet = positions[:, (7, 8, 10, 11)] # # 150, 4, 3
feetv = torch.zeros(feet.shape[:2], device=device) # 150, 4
feetv[:-1] = (feet[1:] - feet[:-1]).norm(dim=-1)
contacts = (feetv < 0.01).to(local_q) # cast to right dtype # b, 150, 4
mofea319 = torch.cat([contacts, root_pos, local_q_rot6d], dim=1)
assert mofea319.shape[1] == 319
mofea319 = mofea319.detach().cpu().numpy()
np.save(os.path.join(outputs_dir, name+'.npy'), mofea319)
return
if __name__ == "__main__":
motion_feats_extract("data/finedance/motion", "data/finedance/motion_fea319")