H-Liu1997's picture
Upload visualization/tools/smplh.py with huggingface_hub
8539c03 verified
"""SMPL-H model loading and GPU forward pass.
Provides a self-contained SMPL-H LBS implementation using PyTorch,
adapted from HumanML3D/custom_138/recover_and_render.py.
"""
from pathlib import Path
import numpy as np
import torch
from utils.paths import PATHS
# ─── Constants ──────────────────────────────────────────────────
SMPLH_MODEL_DIR = PATHS["deps"] / "smplh"
# Hand mean poses (axis-angle, 15 joints Γ— 3)
LEFT_HAND_MEAN_AA = np.array([
0.1117, 0.0429, -0.4164, 0.1088, -0.0660, -0.7562, -0.0964, -0.0909,
-0.1885, -0.1181, 0.0509, -0.5296, -0.1437, 0.0552, -0.7049, -0.0192,
-0.0923, -0.3379, -0.4570, -0.1963, -0.6255, -0.2147, -0.0660, -0.5069,
-0.3697, -0.0603, -0.0795, -0.1419, -0.0859, -0.6355, -0.3033, -0.0579,
-0.6314, -0.1761, -0.1321, -0.3734, 0.8510, 0.2769, -0.0915, -0.4998,
0.0266, 0.0529, 0.5356, 0.0460, -0.2774,
], dtype=np.float32)
RIGHT_HAND_MEAN_AA = np.array([
0.1117, -0.0429, 0.4164, 0.1088, 0.0660, 0.7562, -0.0964, 0.0909,
0.1885, -0.1181, -0.0509, 0.5296, -0.1437, -0.0552, 0.7049, -0.0192,
0.0923, 0.3379, -0.4570, 0.1963, 0.6255, -0.2147, 0.0660, 0.5069,
-0.3697, 0.0603, 0.0795, -0.1419, 0.0859, 0.6355, -0.3033, 0.0579,
0.6314, -0.1761, 0.1321, 0.3734, 0.8510, -0.2769, 0.0915, -0.4998,
-0.0266, -0.0529, 0.5356, -0.0460, 0.2774,
], dtype=np.float32)
# ─── Model loading (cached) ────────────────────────────────────
_model_cache = {}
_gpu_cache = {}
_GPU_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_smplh_model(gender="neutral"):
"""Load SMPL-H model data for given gender (cached).
Returns dict with keys: v_template, f, shapedirs, posedirs,
J_regressor, kintree_table, weights.
"""
if gender not in _model_cache:
model_path = SMPLH_MODEL_DIR / gender / "model.npz"
data = np.load(str(model_path), allow_pickle=True)
J_reg = data["J_regressor"]
if hasattr(J_reg, "toarray"):
J_reg = J_reg.toarray()
elif hasattr(J_reg, "A"):
J_reg = np.array(J_reg.A)
_model_cache[gender] = {
"v_template": data["v_template"].astype(np.float32),
"f": data["f"].astype(np.int32),
"shapedirs": data["shapedirs"].astype(np.float32),
"posedirs": data["posedirs"].astype(np.float32),
"J_regressor": np.asarray(J_reg, dtype=np.float32),
"kintree_table": data["kintree_table"].astype(np.int64),
"weights": data["weights"].astype(np.float32),
}
return _model_cache[gender]
def get_J0(model, betas):
"""Compute pelvis rest position J[0] from model and shape params.
Used to convert pelvis absolute position to SMPLX translation convention:
trans_smplx = pelvis_abs - J0
"""
J_reg = model["J_regressor"]
v_shaped = model["v_template"] + np.einsum(
"vci,i->vc", model["shapedirs"], betas.astype(np.float32)
)
return (J_reg @ v_shaped)[0]
# ─── GPU helpers ────────────────────────────────────────────────
def _get_model_gpu(model, gender, device=None):
"""Get or create GPU-resident tensors for a model (cached by gender)."""
if gender not in _gpu_cache:
dev = device or _GPU_DEVICE
_gpu_cache[gender] = {
"v_template": torch.tensor(model["v_template"], dtype=torch.float32, device=dev),
"shapedirs": torch.tensor(model["shapedirs"], dtype=torch.float32, device=dev),
"posedirs": torch.tensor(model["posedirs"], dtype=torch.float32, device=dev),
"J_regressor": torch.tensor(model["J_regressor"], dtype=torch.float32, device=dev),
"weights": torch.tensor(model["weights"], dtype=torch.float32, device=dev),
"parents": model["kintree_table"][0],
}
return _gpu_cache[gender]
def _batch_rodrigues(aa):
"""Axis-angle (N, 3) -> rotation matrices (N, 3, 3) in PyTorch."""
angle = torch.norm(aa + 1e-8, dim=1, keepdim=True)
axis = aa / angle
c = torch.cos(angle).unsqueeze(1)
s = torch.sin(angle).unsqueeze(1)
rx, ry, rz = axis[:, 0:1], axis[:, 1:2], axis[:, 2:3]
z = torch.zeros_like(rx)
K = torch.cat([z, -rz, ry, rz, z, -rx, -ry, rx, z], 1).view(-1, 3, 3)
I = torch.eye(3, device=aa.device, dtype=aa.dtype).unsqueeze(0)
return I + s * K + (1 - c) * torch.bmm(K, K)
# ─── Forward pass ───────────────────────────────────────────────
@torch.no_grad()
def smplh_forward(model, gender, betas, poses_aa, transl):
"""SMPL-H forward pass on GPU.
Args:
model: dict from load_smplh_model (numpy arrays).
gender: str, one of "male", "female", "neutral".
betas: (16,) numpy shape parameters.
poses_aa: (T, 52, 3) numpy axis-angle poses.
transl: (T, 3) numpy root translation (SMPLX convention).
Returns:
verts: (T, 6890, 3) numpy float32 mesh vertices.
"""
dev = _GPU_DEVICE
m = _get_model_gpu(model, gender, dev)
T = poses_aa.shape[0]
parents = m["parents"]
betas_t = torch.tensor(betas, dtype=torch.float32, device=dev)
poses_t = torch.tensor(poses_aa, dtype=torch.float32, device=dev)
transl_t = torch.tensor(transl, dtype=torch.float32, device=dev)
# 1. Shape blend shapes
v_shaped = m["v_template"] + torch.einsum("vci,i->vc", m["shapedirs"], betas_t)
# 2. Joint regression
J = m["J_regressor"] @ v_shaped # (52, 3)
# 3. Axis-angle -> rotation matrices
rot_mats = _batch_rodrigues(poses_t.reshape(-1, 3)).reshape(T, 52, 3, 3)
# 4. Pose blend shapes
I3 = torch.eye(3, device=dev, dtype=torch.float32)
pose_feature = (rot_mats[:, 1:] - I3).reshape(T, -1) # (T, 459)
v_posed = v_shaped.unsqueeze(0) + torch.einsum("vcp,tp->tvc", m["posedirs"], pose_feature)
# 5. FK chain
rel_J = J.clone()
for i in range(1, 52):
p = parents[i]
if 0 <= p < 52:
rel_J[i] = J[i] - J[p]
local_tf = torch.zeros(T, 52, 4, 4, device=dev, dtype=torch.float32)
local_tf[:, :, :3, :3] = rot_mats
local_tf[:, :, :3, 3] = rel_J.unsqueeze(0)
local_tf[:, :, 3, 3] = 1.0
global_tf = torch.zeros_like(local_tf)
global_tf[:, 0] = local_tf[:, 0]
for i in range(1, 52):
p = parents[i]
if 0 <= p < 52:
global_tf[:, i] = torch.bmm(global_tf[:, p], local_tf[:, i])
else:
global_tf[:, i] = local_tf[:, i]
# 6. Relative transforms
J_homo = torch.zeros(52, 4, device=dev, dtype=torch.float32)
J_homo[:, :3] = J
t_rest = torch.einsum("tjcd,jd->tjc", global_tf[:, :, :3, :], J_homo)
rel_tf = global_tf.clone()
rel_tf[:, :, :3, 3] -= t_rest[:, :, :3]
# 7. LBS
T_blend = torch.einsum("vj,tjcd->tvcd", m["weights"], rel_tf)
v_homo = torch.ones(T, v_posed.shape[1], 4, device=dev, dtype=torch.float32)
v_homo[:, :, :3] = v_posed
verts = torch.einsum("tvcd,tvd->tvc", T_blend[:, :, :3, :], v_homo)
verts += transl_t.unsqueeze(1)
result = verts.cpu().numpy()
del verts, v_homo, T_blend, rel_tf, global_tf, local_tf, v_posed, rot_mats, poses_t, transl_t
torch.cuda.empty_cache()
return result