NMR / inference.py
RayZhao's picture
initial push
4cc0d6c
"""
推理逻辑:SMPL-X AMASS NPZ → G1 DOF/root_trans/root_rot
从 tools/inference.py 提取,路径适配 HF demo 目录结构。
"""
import os
import time
import torch
import numpy as np
from copy import deepcopy
from scipy import signal as sp_signal
from smplx import SMPLX
from mmengine.registry import MODELS
from src.utils.rotation_conversions import (
axis_angle_to_matrix, matrix_to_axis_angle, matrix_to_rotation_6d,
rotation_6d_to_matrix, matrix_to_quaternion,
)
# 确保 mmengine 注册表加载
import src # noqa
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# ---------- 模型配置(内联,不依赖外部 .py 文件) ----------
MODEL_CFG = dict(
init_cfg=None,
n_embd=512,
smplx_vqvae_cfg=dict(
decoder_cfg=dict(
activation='relu', depth=3, dilation_growth_rate=3, down_t=2,
input_emb_width=140, norm=None, output_emb_width=512,
type='DecoderAttn', width=512),
encoder_cfg=dict(
activation='relu', depth=3, dilation_growth_rate=3, down_t=2,
input_emb_width=140, norm=None, output_emb_width=512,
stride_t=2, type='EncoderAttn', width=512),
quantizer_cfg=dict(dim=512, levels=[8, 8, 6, 5], type='FSQ'),
type='VQVAE'),
transformer_cfg=dict(
block_size=1024, n_embd=512, n_head=8, n_layer=8,
type='LLaMAHF_Fwd', vocab_size=512),
type='RetargetTransformerPredMotion_no_smplvq')
# ---------- 推理常量 ----------
FPS = 30
CHUNK_SECONDS = 4
CHUNK_FRAMES = (FPS * CHUNK_SECONDS // 4) * 4 # 120
OVERLAP_FRAMES = 32
STRIDE_FRAMES = CHUNK_FRAMES - OVERLAP_FRAMES
# ---------- 数据加载 ----------
def load_smpl_data(file_path):
"""从 AMASS NPZ 加载 SMPL-X 运动参数。
支持字段:
- transl/global_orient/body_pose (标准 SMPL-X)
- trans/root_orient/pose_body (AMASS 格式,Z-up)
"""
data = np.load(file_path, allow_pickle=True)
if 'transl' in data:
transl = torch.from_numpy(data['transl']).float()
global_orient = torch.from_numpy(data['global_orient']).float()
body_pose = torch.from_numpy(data['body_pose']).float()
if 'mocap_frame_rate' in data:
src_fps = float(data['mocap_frame_rate'])
if src_fps > 30:
step = int(src_fps / 30)
transl = transl[::step]
global_orient = global_orient[::step]
body_pose = body_pose[::step]
return transl, global_orient, body_pose, None
else:
# AMASS 格式:Z-up → Y-up
transl = torch.from_numpy(data['trans']).float()
global_orient = torch.from_numpy(data['root_orient']).float()
body_pose = torch.from_numpy(data['pose_body']).float()
rot_zup_to_yup = torch.tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]]).float()
transl = torch.einsum('ij,tj->ti', rot_zup_to_yup, transl)
go_mat = axis_angle_to_matrix(global_orient)
go_mat = torch.einsum('ij,tjk->tik', rot_zup_to_yup, go_mat)
global_orient = matrix_to_axis_angle(go_mat)
if 'mocap_frame_rate' in data:
src_fps = float(data['mocap_frame_rate'])
if src_fps > 30:
step = int(src_fps / 30)
transl = transl[::step]
global_orient = global_orient[::step]
body_pose = body_pose[::step]
return transl, global_orient, body_pose, None
def preprocess_smpl(file_path, smplx_model, betas, device):
"""将 SMPL-X NPZ 转换为 (T, 140) 运动特征向量。"""
transl, global_orient, body_pose, seq_betas = load_smpl_data(file_path)
if seq_betas is not None:
betas = seq_betas
N = transl.shape[0]
frame_params = dict(
transl=transl.to(device),
global_orient=global_orient.to(device),
body_pose=body_pose.to(device),
betas=betas.unsqueeze(0).repeat(N, 1).float().to(device),
leye_pose=torch.zeros((N, 3), device=device),
reye_pose=torch.zeros((N, 3), device=device),
left_hand_pose=torch.zeros((N, 45), device=device),
right_hand_pose=torch.zeros((N, 45), device=device),
jaw_pose=torch.zeros((N, 3), device=device),
expression=torch.zeros((N, 100), device=device),
)
with torch.no_grad():
output = smplx_model(**frame_params)
position_data = output.joints.detach().cpu()[:, :22]
global_orient_mat = axis_angle_to_matrix(global_orient)
position_val_data = position_data[1:] - position_data[:-1]
root_idx = 0
y_min = torch.min(position_data[:, :, 1])
ori = deepcopy(position_data[0, root_idx])
ori[1] = y_min
position_data = position_data - ori
velocities_root = position_data[1:, root_idx, :] - position_data[:-1, root_idx, :]
position_data_cp = deepcopy(position_data)
position_data[:, :, 0] -= position_data_cp[:, 0:1, 0]
position_data[:, :, 2] -= position_data_cp[:, 0:1, 2]
T, njoint, _ = position_data.shape
final_x = torch.zeros((T, 2 + 6 + njoint * 3 + njoint * 3))
final_x[1:, 0] = velocities_root[:, 0]
final_x[1:, 1] = velocities_root[:, 2]
final_x[:, 2:8] = matrix_to_rotation_6d(global_orient_mat)
final_x[:, 8:8 + njoint * 3] = position_data.flatten(1, 2)
final_x[1:, 8 + njoint * 3:8 + njoint * 6] = position_val_data.flatten(1, 2)
return final_x
def postprocess_g1(pred_motion, apply_filter=True):
"""从 G1 217 维运动向量提取 DOF/root_trans/root_rot。"""
T = pred_motion.shape[0]
rot_mat = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]]).float()
pred_trans = pred_motion[:, 8:8 + 30 * 3].reshape(T, -1, 3)[:, 0]
pred_trans[:, [0, 2]] += torch.cumsum(pred_motion[:, :2], dim=0)
pred_trans = torch.einsum('ij,tj->ti', rot_mat, pred_trans)
pred_rot_mat = rotation_6d_to_matrix(pred_motion[:, 2:8])
pred_rot_mat = torch.einsum('ij,tjk->tik', rot_mat, pred_rot_mat)
pred_rot_quat = matrix_to_quaternion(pred_rot_mat)
pred_dof = pred_motion[:, -29:]
if apply_filter and T >= 13:
_b, _a = sp_signal.butter(4, 5 / (30.0 / 2), btype='low')
pred_trans = torch.from_numpy(
sp_signal.filtfilt(_b, _a, pred_trans.numpy(), axis=0).copy()
).to(pred_trans.dtype)
pred_rot_quat_np = sp_signal.filtfilt(_b, _a, pred_rot_quat.numpy(), axis=0).copy()
pred_rot_quat_np /= np.linalg.norm(pred_rot_quat_np, axis=-1, keepdims=True)
pred_rot_quat = torch.from_numpy(pred_rot_quat_np).to(pred_rot_quat.dtype)
pred_dof = torch.from_numpy(
sp_signal.filtfilt(_b, _a, pred_dof.numpy(), axis=0).copy()
).to(pred_dof.dtype)
return pred_dof, pred_rot_quat, pred_trans
# ---------- 旋转规范化工具 ----------
def _make_y_rot(theta):
c, s = torch.cos(theta), torch.sin(theta)
return torch.tensor([[c, 0, s], [0, 1, 0], [-s, 0, c]], dtype=torch.float32)
def _extract_yaw(rot_6d):
R = rotation_6d_to_matrix(rot_6d.unsqueeze(0))[0]
forward = R[:, 2]
return torch.atan2(forward[0], forward[2])
def _rotate_motion_features(motion, R, n_joints, rotate_6d=True):
result = motion.clone()
vx, vz = motion[:, 0], motion[:, 1]
result[:, 0] = R[0, 0] * vx + R[0, 2] * vz
result[:, 1] = R[2, 0] * vx + R[2, 2] * vz
if rotate_6d:
rot_mat = rotation_6d_to_matrix(motion[:, 2:8])
rot_mat = torch.einsum('ij,tjk->tik', R, rot_mat)
result[:, 2:8] = matrix_to_rotation_6d(rot_mat)
pos_start, pos_end = 8, 8 + n_joints * 3
pos = motion[:, pos_start:pos_end].reshape(-1, n_joints, 3)
pos = torch.einsum('ij,tnj->tni', R, pos)
result[:, pos_start:pos_end] = pos.reshape(-1, n_joints * 3)
vel_start, vel_end = pos_end, pos_end + n_joints * 3
if vel_end <= motion.shape[1]:
vel = motion[:, vel_start:vel_end].reshape(-1, n_joints, 3)
vel = torch.einsum('ij,tnj->tni', R, vel)
result[:, vel_start:vel_end] = vel.reshape(-1, n_joints * 3)
return result
# ---------- 推理核心 ----------
def _infer_chunk(smplx_motion, model, smplx_mean, smplx_std, g1_mean, g1_std, device):
yaw = _extract_yaw(smplx_motion[0, 2:8])
R_canon = _make_y_rot(-yaw)
R_restore = _make_y_rot(yaw)
smplx_motion = _rotate_motion_features(smplx_motion, R_canon, n_joints=22)
smplx_motion = (smplx_motion - smplx_mean) / smplx_std
smplx_input = smplx_motion.unsqueeze(0).float().to(device)
motion_length = torch.tensor([smplx_motion.shape[0]]).to(device)
with torch.no_grad():
pred_motions, _ = model(smplx_motion=smplx_input, motion_length=motion_length, mode='predict')
pred_motion = pred_motions[0].cpu()
pred_motion = pred_motion * g1_std + g1_mean
pred_motion = _rotate_motion_features(pred_motion, R_restore, n_joints=30)
return pred_motion
def infer_single(file_path, model, smplx_model, betas,
smplx_mean, smplx_std, g1_mean, g1_std,
device, apply_filter=True):
t0 = time.time()
smplx_motion = preprocess_smpl(file_path, smplx_model, betas, device)
T_orig = smplx_motion.shape[0]
if T_orig < 4:
return None, None
T_pad = ((T_orig + 3) // 4) * 4
if T_pad > T_orig:
pad = smplx_motion[-1:].repeat(T_pad - T_orig, 1)
smplx_motion = torch.cat([smplx_motion, pad], dim=0)
T = T_pad
t_preprocess = time.time() - t0
t1 = time.time()
if T <= CHUNK_FRAMES:
pred_motion = _infer_chunk(smplx_motion, model, smplx_mean, smplx_std, g1_mean, g1_std, device)
else:
chunks = []
starts = []
for start in range(0, T, STRIDE_FRAMES):
end = min(start + CHUNK_FRAMES, T)
seg_len = (end - start) // 4 * 4
if seg_len < 4:
break
chunk = smplx_motion[start:start + seg_len]
chunks.append(_infer_chunk(chunk, model, smplx_mean, smplx_std, g1_mean, g1_std, device))
starts.append(start)
pred_motion = chunks[0]
for i in range(1, len(chunks)):
overlap = starts[i - 1] + len(chunks[i - 1]) - starts[i]
if overlap > 0:
w = torch.linspace(0, 1, overlap).unsqueeze(1)
prev_tail = pred_motion[-overlap:]
curr_head = chunks[i][:overlap]
blended = prev_tail * (1 - w) + curr_head * w
pred_motion = torch.cat([pred_motion[:-overlap], blended, chunks[i][overlap:]], dim=0)
else:
pred_motion = torch.cat([pred_motion, chunks[i]], dim=0)
t_infer = time.time() - t1
pred_motion = pred_motion[:T_orig]
t2 = time.time()
pred_dof, pred_rot_quat, pred_trans = postprocess_g1(pred_motion, apply_filter=apply_filter)
t_postprocess = time.time() - t2
t_total = time.time() - t0
timing = dict(preprocess=t_preprocess, infer=t_infer, postprocess=t_postprocess, total=t_total)
return dict(
dof=pred_dof.numpy(),
root_trans=pred_trans.numpy(),
root_rot_quat=pred_rot_quat.numpy(),
source_path=file_path,
), timing
# ---------- 模型加载 ----------
def load_all(weights_dir=None, assets_dir=None, device=None):
"""加载模型、SMPLX 体模型和标准化参数。"""
if weights_dir is None:
weights_dir = os.path.join(BASE_DIR, 'weights')
if assets_dir is None:
assets_dir = os.path.join(BASE_DIR, 'assets')
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# PyTorch 2.6 兼容
_torch_load = torch.load
torch.load = lambda *args, **kwargs: _torch_load(*args, weights_only=kwargs.pop('weights_only', False), **kwargs)
# 构建并加载模型
model = MODELS.build(MODEL_CFG)
ckpt = torch.load(os.path.join(weights_dir, 'epoch_30.pth'), map_location='cpu')
model.load_state_dict(ckpt['state_dict'])
model.eval().to(device)
# 加载 SMPLX 体模型
smplx_model = SMPLX(
model_path=os.path.join(assets_dir, 'SMPLX_NEUTRAL.npz'),
use_pca=False, num_expression_coeffs=100, num_betas=10, ext='npz'
).to(device).eval()
betas = torch.from_numpy(np.load(os.path.join(assets_dir, 'betas.npy'))).float()
# 标准化参数
smplx_mean = torch.from_numpy(np.load(os.path.join(weights_dir, 'smplx_mean.npy'))).float()
smplx_std = torch.from_numpy(np.load(os.path.join(weights_dir, 'smplx_std.npy'))).float()
g1_mean = torch.from_numpy(np.load(os.path.join(weights_dir, 'gmr_mean.npy'))).float()
g1_std = torch.from_numpy(np.load(os.path.join(weights_dir, 'gmr_std.npy'))).float()
return model, smplx_model, betas, smplx_mean, smplx_std, g1_mean, g1_std, device