Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| import pickle | |
| from scipy.interpolate import interp1d | |
| #############Import fast smplx(modified from original ver) | |
| local_smplx_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..', 'deps/smplx')) | |
| sys.path.insert(0, local_smplx_path) | |
| import smplx_fast | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
| from utils.transforms import matrix_to_axis_angle, rotation_6d_to_matrix | |
| from utils.constants import pelvis_shift, relaxed_hand_pose, SELECTED_JOINTS24 | |
| ###########This model is used to predict the initial pose for the optimization########### | |
| class JointsToSMPLX(nn.Module): | |
| def __init__(self, input_dim, output_dim, hidden_dim, **kwargs): | |
| super().__init__() | |
| self.layers = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.BatchNorm1d(hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.BatchNorm1d(hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, output_dim), | |
| ) | |
| def forward(self, x): | |
| return self.layers(x) | |
| def get_j2s_model(ckpt_path, | |
| input_dim=72, | |
| output_dim=132, | |
| hidden_dim=64, | |
| device='cpu'): | |
| model_joints_to_smplx = JointsToSMPLX(input_dim=input_dim, | |
| output_dim=output_dim, | |
| hidden_dim=hidden_dim | |
| ) | |
| if device == 'cpu': | |
| map_location = torch.device('cpu') | |
| else: | |
| map_location = device | |
| model_joints_to_smplx.load_state_dict(torch.load(ckpt_path, map_location=map_location)) | |
| model_joints_to_smplx.eval() | |
| return model_joints_to_smplx | |
| ###########This model is used to predict the initial pose for the optimization########### | |
| def optimize_smpl(pose_pred, joints, joints_ind, smplx_path, print_loss=True): | |
| device = joints.device | |
| len = joints.shape[0] | |
| smpl_model = smplx_fast.create(smplx_path, | |
| model_type='smplx_joint_only', | |
| gender='male', ext='npz', | |
| num_betas=10, | |
| use_pca=False, | |
| create_global_orient=True, | |
| create_body_pose=True, | |
| create_betas=True, | |
| create_left_hand_pose=True, | |
| create_right_hand_pose=True, | |
| create_expression=True, | |
| create_jaw_pose=True, | |
| create_leye_pose=True, | |
| create_reye_pose=True, | |
| create_transl=True, | |
| batch_size=len, | |
| ).to(device) | |
| smpl_model.eval() | |
| joints = joints.reshape(len, -1, 3) + torch.tensor(pelvis_shift).to(device) | |
| pose_input = torch.nn.Parameter(pose_pred.detach(), requires_grad=True) | |
| transl = torch.nn.Parameter(torch.zeros(pose_pred.shape[0], 3).to(device), requires_grad=True) | |
| left_hand = torch.from_numpy(relaxed_hand_pose[:45].reshape(1, -1).repeat(pose_pred.shape[0], axis=0)).to(device) | |
| right_hand = torch.from_numpy(relaxed_hand_pose[45:].reshape(1, -1).repeat(pose_pred.shape[0], axis=0)).to(device) | |
| optimizer = torch.optim.Adam(params=[pose_input, transl], lr=0.05) | |
| loss_fn = nn.MSELoss() | |
| vertices_output = None | |
| for step in range(120): | |
| smpl_output = smpl_model(transl=transl, | |
| body_pose=pose_input[:, 3:], | |
| global_orient=pose_input[:, :3], | |
| return_verts=True, | |
| left_hand_pose=left_hand,# @ left_hand_components[:hand_pca], | |
| right_hand_pose=right_hand,# @ right_hand_components[:hand_pca], | |
| ) | |
| joints_output = smpl_output[:, joints_ind].reshape(len, -1, 3) | |
| loss = loss_fn(joints[:, :], joints_output[:, :]) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| if print_loss: | |
| print(loss.item(), flush=True) | |
| return pose_input.detach().cpu().numpy(), \ | |
| transl.detach().cpu().numpy(), \ | |
| left_hand.detach().cpu().numpy(), \ | |
| right_hand.detach().cpu().numpy(), \ | |
| vertices_output | |
| def joints_to_smpl(model, joints, joints_ind, interp_s, smplx_path, print_loss=True): | |
| joints = interpolate_joints(joints, scale=interp_s) | |
| input_len = joints.shape[0] | |
| joints = joints.reshape(input_len, -1, 3) | |
| joints = joints.permute(1, 0, 2) | |
| trans_np = joints[0].detach().cpu().numpy() | |
| joints = joints - joints[0] | |
| joints = joints.permute(1, 0, 2) | |
| joints = joints.reshape(input_len, -1) | |
| pose_pred = model(joints) | |
| pose_pred = pose_pred.reshape(-1, 6) | |
| pose_pred = matrix_to_axis_angle(rotation_6d_to_matrix(pose_pred)).reshape(input_len, -1) | |
| pose_output, transl, left_hand, right_hand, vertices = optimize_smpl(pose_pred, | |
| joints, | |
| joints_ind, | |
| smplx_path, | |
| print_loss=print_loss) | |
| transl = trans_np - np.array(pelvis_shift) + transl | |
| return pose_output, transl, left_hand, right_hand, vertices | |
| def interpolate_joints(joints, scale): | |
| if scale == 1: | |
| return joints | |
| device = joints.device | |
| joints = joints.detach().cpu().numpy() | |
| in_len = joints.shape[0] | |
| out_len = int(in_len * scale) | |
| joints = joints.reshape(in_len, -1) | |
| x = np.array(range(in_len)) | |
| xnew = np.linspace(0, in_len - 1, out_len) | |
| f = interp1d(x, joints, axis=0) | |
| joints_new = f(xnew) | |
| joints_new = torch.from_numpy(joints_new).to(device).float() | |
| return joints_new | |
| def process_file(file_path, # input dir | |
| file_name, # input file | |
| save_path, # output dir | |
| JointsToSMPLX_model_path, # JointsToSMPLX weight | |
| smplx_path, # smplx weight | |
| key_list = ['generated_samples', 'original_samples'], | |
| joints_ind = SELECTED_JOINTS24, | |
| interp_s=2, # 2*10=20 fps | |
| ): | |
| data = np.load(os.path.join(file_path, file_name), allow_pickle=True) | |
| model = get_j2s_model(ckpt_path=JointsToSMPLX_model_path, device='cpu') | |
| for key in key_list: # original_samples, generated_samples, GT | |
| if key in data: | |
| joints = torch.tensor(data[key], dtype=torch.float32).reshape(-1, 72) | |
| print_loss=False | |
| if key == 'generated_samples': | |
| print_loss=True | |
| pose, transl, left_hand, right_hand, vertices = joints_to_smpl(model, | |
| joints, | |
| joints_ind, | |
| interp_s, | |
| smplx_path, | |
| print_loss=print_loss) | |
| try: | |
| data_text = data['text'] | |
| except: | |
| data_text = None | |
| output_data = { | |
| 'body_pose': pose[:, 3:], | |
| 'global_orient': pose[:, :3], | |
| 'transl': transl, | |
| 'left_hand': left_hand, | |
| 'right_hand': right_hand, | |
| 'vertices': vertices, | |
| 'text': data_text, | |
| } | |
| if key == 'generated_samples': | |
| try: | |
| output_data['mask'] = data['mask'] | |
| except: | |
| output_data['mask'] = None | |
| if not os.path.exists(os.path.join(save_path, key)): | |
| os.makedirs(os.path.join(save_path, key)) | |
| output_file = os.path.join(os.path.join(save_path, key), file_name) | |
| with open(output_file, 'wb') as file: | |
| pickle.dump(output_data, file) |