Spaces:
Running on Zero
Running on Zero
| """ | |
| 推理逻辑:SMPL-X AMASS NPZ → G1 DOF/root_trans/root_rot | |
| 从 tools/inference.py 提取,路径适配 HF demo 目录结构。 | |
| """ | |
| import os | |
| import time | |
| import torch | |
| import numpy as np | |
| from copy import deepcopy | |
| from scipy import signal as sp_signal | |
| from smplx import SMPLX | |
| from mmengine.registry import MODELS | |
| from src.utils.rotation_conversions import ( | |
| axis_angle_to_matrix, matrix_to_axis_angle, matrix_to_rotation_6d, | |
| rotation_6d_to_matrix, matrix_to_quaternion, | |
| ) | |
| # 确保 mmengine 注册表加载 | |
| import src # noqa | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| # ---------- 模型配置(内联,不依赖外部 .py 文件) ---------- | |
| MODEL_CFG = dict( | |
| init_cfg=None, | |
| n_embd=512, | |
| smplx_vqvae_cfg=dict( | |
| decoder_cfg=dict( | |
| activation='relu', depth=3, dilation_growth_rate=3, down_t=2, | |
| input_emb_width=140, norm=None, output_emb_width=512, | |
| type='DecoderAttn', width=512), | |
| encoder_cfg=dict( | |
| activation='relu', depth=3, dilation_growth_rate=3, down_t=2, | |
| input_emb_width=140, norm=None, output_emb_width=512, | |
| stride_t=2, type='EncoderAttn', width=512), | |
| quantizer_cfg=dict(dim=512, levels=[8, 8, 6, 5], type='FSQ'), | |
| type='VQVAE'), | |
| transformer_cfg=dict( | |
| block_size=1024, n_embd=512, n_head=8, n_layer=8, | |
| type='LLaMAHF_Fwd', vocab_size=512), | |
| type='RetargetTransformerPredMotion_no_smplvq') | |
| # ---------- 推理常量 ---------- | |
| FPS = 30 | |
| CHUNK_SECONDS = 4 | |
| CHUNK_FRAMES = (FPS * CHUNK_SECONDS // 4) * 4 # 120 | |
| OVERLAP_FRAMES = 32 | |
| STRIDE_FRAMES = CHUNK_FRAMES - OVERLAP_FRAMES | |
| # ---------- 数据加载 ---------- | |
| def load_smpl_data(file_path): | |
| """从 AMASS NPZ 加载 SMPL-X 运动参数。 | |
| 支持字段: | |
| - transl/global_orient/body_pose (标准 SMPL-X) | |
| - trans/root_orient/pose_body (AMASS 格式,Z-up) | |
| """ | |
| data = np.load(file_path, allow_pickle=True) | |
| if 'transl' in data: | |
| transl = torch.from_numpy(data['transl']).float() | |
| global_orient = torch.from_numpy(data['global_orient']).float() | |
| body_pose = torch.from_numpy(data['body_pose']).float() | |
| if 'mocap_frame_rate' in data: | |
| src_fps = float(data['mocap_frame_rate']) | |
| if src_fps > 30: | |
| step = int(src_fps / 30) | |
| transl = transl[::step] | |
| global_orient = global_orient[::step] | |
| body_pose = body_pose[::step] | |
| return transl, global_orient, body_pose, None | |
| else: | |
| # AMASS 格式:Z-up → Y-up | |
| transl = torch.from_numpy(data['trans']).float() | |
| global_orient = torch.from_numpy(data['root_orient']).float() | |
| body_pose = torch.from_numpy(data['pose_body']).float() | |
| rot_zup_to_yup = torch.tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]]).float() | |
| transl = torch.einsum('ij,tj->ti', rot_zup_to_yup, transl) | |
| go_mat = axis_angle_to_matrix(global_orient) | |
| go_mat = torch.einsum('ij,tjk->tik', rot_zup_to_yup, go_mat) | |
| global_orient = matrix_to_axis_angle(go_mat) | |
| if 'mocap_frame_rate' in data: | |
| src_fps = float(data['mocap_frame_rate']) | |
| if src_fps > 30: | |
| step = int(src_fps / 30) | |
| transl = transl[::step] | |
| global_orient = global_orient[::step] | |
| body_pose = body_pose[::step] | |
| return transl, global_orient, body_pose, None | |
| def preprocess_smpl(file_path, smplx_model, betas, device): | |
| """将 SMPL-X NPZ 转换为 (T, 140) 运动特征向量。""" | |
| transl, global_orient, body_pose, seq_betas = load_smpl_data(file_path) | |
| if seq_betas is not None: | |
| betas = seq_betas | |
| N = transl.shape[0] | |
| frame_params = dict( | |
| transl=transl.to(device), | |
| global_orient=global_orient.to(device), | |
| body_pose=body_pose.to(device), | |
| betas=betas.unsqueeze(0).repeat(N, 1).float().to(device), | |
| leye_pose=torch.zeros((N, 3), device=device), | |
| reye_pose=torch.zeros((N, 3), device=device), | |
| left_hand_pose=torch.zeros((N, 45), device=device), | |
| right_hand_pose=torch.zeros((N, 45), device=device), | |
| jaw_pose=torch.zeros((N, 3), device=device), | |
| expression=torch.zeros((N, 100), device=device), | |
| ) | |
| with torch.no_grad(): | |
| output = smplx_model(**frame_params) | |
| position_data = output.joints.detach().cpu()[:, :22] | |
| global_orient_mat = axis_angle_to_matrix(global_orient) | |
| position_val_data = position_data[1:] - position_data[:-1] | |
| root_idx = 0 | |
| y_min = torch.min(position_data[:, :, 1]) | |
| ori = deepcopy(position_data[0, root_idx]) | |
| ori[1] = y_min | |
| position_data = position_data - ori | |
| velocities_root = position_data[1:, root_idx, :] - position_data[:-1, root_idx, :] | |
| position_data_cp = deepcopy(position_data) | |
| position_data[:, :, 0] -= position_data_cp[:, 0:1, 0] | |
| position_data[:, :, 2] -= position_data_cp[:, 0:1, 2] | |
| T, njoint, _ = position_data.shape | |
| final_x = torch.zeros((T, 2 + 6 + njoint * 3 + njoint * 3)) | |
| final_x[1:, 0] = velocities_root[:, 0] | |
| final_x[1:, 1] = velocities_root[:, 2] | |
| final_x[:, 2:8] = matrix_to_rotation_6d(global_orient_mat) | |
| final_x[:, 8:8 + njoint * 3] = position_data.flatten(1, 2) | |
| final_x[1:, 8 + njoint * 3:8 + njoint * 6] = position_val_data.flatten(1, 2) | |
| return final_x | |
| def postprocess_g1(pred_motion, apply_filter=True): | |
| """从 G1 217 维运动向量提取 DOF/root_trans/root_rot。""" | |
| T = pred_motion.shape[0] | |
| rot_mat = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]]).float() | |
| pred_trans = pred_motion[:, 8:8 + 30 * 3].reshape(T, -1, 3)[:, 0] | |
| pred_trans[:, [0, 2]] += torch.cumsum(pred_motion[:, :2], dim=0) | |
| pred_trans = torch.einsum('ij,tj->ti', rot_mat, pred_trans) | |
| pred_rot_mat = rotation_6d_to_matrix(pred_motion[:, 2:8]) | |
| pred_rot_mat = torch.einsum('ij,tjk->tik', rot_mat, pred_rot_mat) | |
| pred_rot_quat = matrix_to_quaternion(pred_rot_mat) | |
| pred_dof = pred_motion[:, -29:] | |
| if apply_filter and T >= 13: | |
| _b, _a = sp_signal.butter(4, 5 / (30.0 / 2), btype='low') | |
| pred_trans = torch.from_numpy( | |
| sp_signal.filtfilt(_b, _a, pred_trans.numpy(), axis=0).copy() | |
| ).to(pred_trans.dtype) | |
| pred_rot_quat_np = sp_signal.filtfilt(_b, _a, pred_rot_quat.numpy(), axis=0).copy() | |
| pred_rot_quat_np /= np.linalg.norm(pred_rot_quat_np, axis=-1, keepdims=True) | |
| pred_rot_quat = torch.from_numpy(pred_rot_quat_np).to(pred_rot_quat.dtype) | |
| pred_dof = torch.from_numpy( | |
| sp_signal.filtfilt(_b, _a, pred_dof.numpy(), axis=0).copy() | |
| ).to(pred_dof.dtype) | |
| return pred_dof, pred_rot_quat, pred_trans | |
| # ---------- 旋转规范化工具 ---------- | |
| def _make_y_rot(theta): | |
| c, s = torch.cos(theta), torch.sin(theta) | |
| return torch.tensor([[c, 0, s], [0, 1, 0], [-s, 0, c]], dtype=torch.float32) | |
| def _extract_yaw(rot_6d): | |
| R = rotation_6d_to_matrix(rot_6d.unsqueeze(0))[0] | |
| forward = R[:, 2] | |
| return torch.atan2(forward[0], forward[2]) | |
| def _rotate_motion_features(motion, R, n_joints, rotate_6d=True): | |
| result = motion.clone() | |
| vx, vz = motion[:, 0], motion[:, 1] | |
| result[:, 0] = R[0, 0] * vx + R[0, 2] * vz | |
| result[:, 1] = R[2, 0] * vx + R[2, 2] * vz | |
| if rotate_6d: | |
| rot_mat = rotation_6d_to_matrix(motion[:, 2:8]) | |
| rot_mat = torch.einsum('ij,tjk->tik', R, rot_mat) | |
| result[:, 2:8] = matrix_to_rotation_6d(rot_mat) | |
| pos_start, pos_end = 8, 8 + n_joints * 3 | |
| pos = motion[:, pos_start:pos_end].reshape(-1, n_joints, 3) | |
| pos = torch.einsum('ij,tnj->tni', R, pos) | |
| result[:, pos_start:pos_end] = pos.reshape(-1, n_joints * 3) | |
| vel_start, vel_end = pos_end, pos_end + n_joints * 3 | |
| if vel_end <= motion.shape[1]: | |
| vel = motion[:, vel_start:vel_end].reshape(-1, n_joints, 3) | |
| vel = torch.einsum('ij,tnj->tni', R, vel) | |
| result[:, vel_start:vel_end] = vel.reshape(-1, n_joints * 3) | |
| return result | |
| # ---------- 推理核心 ---------- | |
| def _infer_chunk(smplx_motion, model, smplx_mean, smplx_std, g1_mean, g1_std, device): | |
| yaw = _extract_yaw(smplx_motion[0, 2:8]) | |
| R_canon = _make_y_rot(-yaw) | |
| R_restore = _make_y_rot(yaw) | |
| smplx_motion = _rotate_motion_features(smplx_motion, R_canon, n_joints=22) | |
| smplx_motion = (smplx_motion - smplx_mean) / smplx_std | |
| smplx_input = smplx_motion.unsqueeze(0).float().to(device) | |
| motion_length = torch.tensor([smplx_motion.shape[0]]).to(device) | |
| with torch.no_grad(): | |
| pred_motions, _ = model(smplx_motion=smplx_input, motion_length=motion_length, mode='predict') | |
| pred_motion = pred_motions[0].cpu() | |
| pred_motion = pred_motion * g1_std + g1_mean | |
| pred_motion = _rotate_motion_features(pred_motion, R_restore, n_joints=30) | |
| return pred_motion | |
| def infer_single(file_path, model, smplx_model, betas, | |
| smplx_mean, smplx_std, g1_mean, g1_std, | |
| device, apply_filter=True): | |
| t0 = time.time() | |
| smplx_motion = preprocess_smpl(file_path, smplx_model, betas, device) | |
| T_orig = smplx_motion.shape[0] | |
| if T_orig < 4: | |
| return None, None | |
| T_pad = ((T_orig + 3) // 4) * 4 | |
| if T_pad > T_orig: | |
| pad = smplx_motion[-1:].repeat(T_pad - T_orig, 1) | |
| smplx_motion = torch.cat([smplx_motion, pad], dim=0) | |
| T = T_pad | |
| t_preprocess = time.time() - t0 | |
| t1 = time.time() | |
| if T <= CHUNK_FRAMES: | |
| pred_motion = _infer_chunk(smplx_motion, model, smplx_mean, smplx_std, g1_mean, g1_std, device) | |
| else: | |
| chunks = [] | |
| starts = [] | |
| for start in range(0, T, STRIDE_FRAMES): | |
| end = min(start + CHUNK_FRAMES, T) | |
| seg_len = (end - start) // 4 * 4 | |
| if seg_len < 4: | |
| break | |
| chunk = smplx_motion[start:start + seg_len] | |
| chunks.append(_infer_chunk(chunk, model, smplx_mean, smplx_std, g1_mean, g1_std, device)) | |
| starts.append(start) | |
| pred_motion = chunks[0] | |
| for i in range(1, len(chunks)): | |
| overlap = starts[i - 1] + len(chunks[i - 1]) - starts[i] | |
| if overlap > 0: | |
| w = torch.linspace(0, 1, overlap).unsqueeze(1) | |
| prev_tail = pred_motion[-overlap:] | |
| curr_head = chunks[i][:overlap] | |
| blended = prev_tail * (1 - w) + curr_head * w | |
| pred_motion = torch.cat([pred_motion[:-overlap], blended, chunks[i][overlap:]], dim=0) | |
| else: | |
| pred_motion = torch.cat([pred_motion, chunks[i]], dim=0) | |
| t_infer = time.time() - t1 | |
| pred_motion = pred_motion[:T_orig] | |
| t2 = time.time() | |
| pred_dof, pred_rot_quat, pred_trans = postprocess_g1(pred_motion, apply_filter=apply_filter) | |
| t_postprocess = time.time() - t2 | |
| t_total = time.time() - t0 | |
| timing = dict(preprocess=t_preprocess, infer=t_infer, postprocess=t_postprocess, total=t_total) | |
| return dict( | |
| dof=pred_dof.numpy(), | |
| root_trans=pred_trans.numpy(), | |
| root_rot_quat=pred_rot_quat.numpy(), | |
| source_path=file_path, | |
| ), timing | |
| # ---------- 模型加载 ---------- | |
| def load_all(weights_dir=None, assets_dir=None, device=None): | |
| """加载模型、SMPLX 体模型和标准化参数。""" | |
| if weights_dir is None: | |
| weights_dir = os.path.join(BASE_DIR, 'weights') | |
| if assets_dir is None: | |
| assets_dir = os.path.join(BASE_DIR, 'assets') | |
| if device is None: | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # PyTorch 2.6 兼容 | |
| _torch_load = torch.load | |
| torch.load = lambda *args, **kwargs: _torch_load(*args, weights_only=kwargs.pop('weights_only', False), **kwargs) | |
| # 构建并加载模型 | |
| model = MODELS.build(MODEL_CFG) | |
| ckpt = torch.load(os.path.join(weights_dir, 'epoch_30.pth'), map_location='cpu') | |
| model.load_state_dict(ckpt['state_dict']) | |
| model.eval().to(device) | |
| # 加载 SMPLX 体模型 | |
| smplx_model = SMPLX( | |
| model_path=os.path.join(assets_dir, 'SMPLX_NEUTRAL.npz'), | |
| use_pca=False, num_expression_coeffs=100, num_betas=10, ext='npz' | |
| ).to(device).eval() | |
| betas = torch.from_numpy(np.load(os.path.join(assets_dir, 'betas.npy'))).float() | |
| # 标准化参数 | |
| smplx_mean = torch.from_numpy(np.load(os.path.join(weights_dir, 'smplx_mean.npy'))).float() | |
| smplx_std = torch.from_numpy(np.load(os.path.join(weights_dir, 'smplx_std.npy'))).float() | |
| g1_mean = torch.from_numpy(np.load(os.path.join(weights_dir, 'gmr_mean.npy'))).float() | |
| g1_std = torch.from_numpy(np.load(os.path.join(weights_dir, 'gmr_std.npy'))).float() | |
| return model, smplx_model, betas, smplx_mean, smplx_std, g1_mean, g1_std, device | |